Coverage for src / local_deep_research / database / queue_service.py: 95%
78 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-25 01:07 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-25 01:07 +0000
1"""
2Queue service for managing tasks using encrypted user databases.
3Replaces the service_db approach with direct access to user databases.
4"""
6from datetime import datetime, timedelta, UTC
7from typing import Any, Dict, List, Optional
9from loguru import logger
10from sqlalchemy.orm import Session
12from .models import QueueStatus, TaskMetadata
15class UserQueueService:
16 """Manages queue operations within a user's encrypted database."""
18 def __init__(self, session: Session):
19 """
20 Initialize with a database session.
22 Args:
23 session: SQLAlchemy session for the user's encrypted database
24 """
25 self.session = session
27 def _safe_commit(self) -> None:
28 """Commit the current transaction, rolling back on failure."""
29 try:
30 self.session.commit()
31 except Exception:
32 self.session.rollback()
33 logger.exception("Database commit failed, transaction rolled back")
34 raise
36 def update_queue_status(
37 self,
38 active_tasks: int,
39 queued_tasks: int,
40 last_task_id: Optional[str] = None,
41 ) -> None:
42 """Update queue status for the user."""
43 status = self.session.query(QueueStatus).first()
45 if status:
46 status.active_tasks = active_tasks
47 status.queued_tasks = queued_tasks
48 status.last_checked = datetime.now(UTC)
49 if last_task_id: 49 ↛ 59line 49 didn't jump to line 59 because the condition on line 49 was always true
50 status.last_task_id = last_task_id
51 else:
52 status = QueueStatus(
53 active_tasks=active_tasks,
54 queued_tasks=queued_tasks,
55 last_task_id=last_task_id,
56 )
57 self.session.add(status)
59 self._safe_commit()
61 def get_queue_status(self) -> Optional[Dict[str, Any]]:
62 """Get queue status for the user."""
63 status = self.session.query(QueueStatus).first()
65 if status:
66 return {
67 "active_tasks": status.active_tasks,
68 "queued_tasks": status.queued_tasks,
69 "last_checked": status.last_checked,
70 "last_task_id": status.last_task_id,
71 }
72 return None
74 def add_task_metadata(
75 self,
76 task_id: str,
77 task_type: str,
78 priority: int = 0,
79 ) -> None:
80 """Add metadata for a new task."""
81 task = TaskMetadata(
82 task_id=task_id,
83 status="queued",
84 task_type=task_type,
85 priority=priority,
86 )
87 self.session.add(task)
89 # Update queue counts
90 self._increment_queue_count()
92 self._safe_commit()
94 def update_task_status(
95 self, task_id: str, status: str, error_message: Optional[str] = None
96 ) -> None:
97 """Update task status."""
98 task = (
99 self.session.query(TaskMetadata).filter_by(task_id=task_id).first()
100 )
102 if task:
103 old_status = task.status
104 task.status = status
105 task.error_message = error_message
107 if status == "processing" and old_status == "queued":
108 task.started_at = datetime.now(UTC)
109 self._update_queue_counts(-1, 1) # -1 queued, +1 active
111 elif status in ["completed", "failed"]:
112 task.completed_at = datetime.now(UTC)
113 self._update_queue_counts(0, -1) # 0 queued, -1 active
115 self._safe_commit()
117 def get_pending_tasks(self, limit: int = 50) -> List[Dict[str, Any]]:
118 """Get pending tasks for the user."""
119 tasks = (
120 self.session.query(TaskMetadata)
121 .filter_by(status="queued")
122 .order_by(TaskMetadata.priority.desc(), TaskMetadata.created_at)
123 .limit(limit)
124 .all()
125 )
127 return [
128 {
129 "task_id": t.task_id,
130 "task_type": t.task_type,
131 "created_at": t.created_at,
132 "priority": t.priority,
133 }
134 for t in tasks
135 ]
137 def cleanup_old_tasks(self, days: int = 7) -> int:
138 """Clean up old completed/failed tasks."""
139 cutoff_date = datetime.now(UTC) - timedelta(days=days)
141 deleted = (
142 self.session.query(TaskMetadata)
143 .filter(
144 TaskMetadata.status.in_(["completed", "failed"]),
145 TaskMetadata.completed_at < cutoff_date,
146 )
147 .delete()
148 )
150 self._safe_commit()
151 return deleted
153 def get_active_task_count(self) -> int:
154 """Get count of active tasks."""
155 status = self.session.query(QueueStatus).first()
156 return status.active_tasks if status else 0
158 def get_queued_task_count(self) -> int:
159 """Get count of queued tasks."""
160 status = self.session.query(QueueStatus).first()
161 return status.queued_tasks if status else 0
163 def _get_or_create_status(self) -> QueueStatus:
164 """Get existing queue status or create a new one with zero counts."""
165 status = self.session.query(QueueStatus).first()
166 if status is None:
167 status = QueueStatus(queued_tasks=0, active_tasks=0)
168 self.session.add(status)
169 return status
171 def _increment_queue_count(self):
172 """Increment the queued task count."""
173 status = self._get_or_create_status()
174 status.queued_tasks += 1
175 status.last_checked = datetime.now(UTC)
177 def _update_queue_counts(self, queued_delta: int, active_delta: int):
178 """Update queue counts by deltas."""
179 status = self._get_or_create_status()
180 status.queued_tasks = max(0, status.queued_tasks + queued_delta)
181 status.active_tasks = max(0, status.active_tasks + active_delta)
182 status.last_checked = datetime.now(UTC)