Coverage for src / local_deep_research / news / subscription_manager / storage.py: 97%
122 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-25 01:07 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-25 01:07 +0000
1"""
2SQLAlchemy storage implementation for subscriptions.
3"""
5from typing import List, Optional, Dict, Any
6from datetime import datetime, timedelta, timezone
7from sqlalchemy.orm import Session
8from loguru import logger
10from ..core.storage import SubscriptionStorage
11from ...database.models.news import (
12 NewsSubscription,
13 SubscriptionType,
14 SubscriptionStatus,
15)
18class SQLSubscriptionStorage(SubscriptionStorage):
19 """SQLAlchemy implementation of subscription storage"""
21 def __init__(self, session: Session):
22 """Initialize with a database session from the user's encrypted database"""
23 if not session:
24 raise ValueError("Session is required for SQLSubscriptionStorage")
25 self._session = session
27 @property
28 def session(self):
29 """Get database session"""
30 return self._session
32 def create(self, data: Dict[str, Any]) -> str:
33 """Create a new subscription"""
34 subscription_id = data.get("id") or self.generate_id()
36 with self.session as session:
37 subscription = NewsSubscription(
38 id=subscription_id,
39 user_id=data["user_id"],
40 name=data.get("name"),
41 subscription_type=data["subscription_type"],
42 query_or_topic=data["query_or_topic"],
43 refresh_interval_minutes=data["refresh_interval_minutes"],
44 frequency=data.get("frequency", "daily"),
45 source_type=data.get("source_type"),
46 source_id=data.get("source_id"),
47 created_from=data.get("created_from"),
48 folder=data.get("folder"),
49 folder_id=data.get("folder_id"),
50 notes=data.get("notes"),
51 status=data.get("status", "active"),
52 is_active=data.get("is_active", True),
53 model_provider=data.get("model_provider"),
54 model=data.get("model"),
55 search_strategy=data.get("search_strategy"),
56 custom_endpoint=data.get("custom_endpoint"),
57 search_engine=data.get("search_engine"),
58 search_iterations=data.get("search_iterations", 3),
59 questions_per_iteration=data.get("questions_per_iteration", 5),
60 next_refresh=datetime.now(timezone.utc)
61 + timedelta(minutes=data["refresh_interval_minutes"]),
62 )
64 session.add(subscription)
65 session.commit()
67 logger.info(
68 f"Created subscription {subscription_id} for user {data['user_id']}"
69 )
70 return subscription_id
72 def get(self, id: str) -> Optional[Dict[str, Any]]:
73 """Get a subscription by ID"""
74 with self.session as session:
75 subscription = (
76 session.query(NewsSubscription).filter_by(id=id).first()
77 )
78 if not subscription:
79 return None
81 # Convert to dict manually
82 return {
83 "id": subscription.id,
84 "user_id": subscription.user_id,
85 "name": subscription.name,
86 "subscription_type": subscription.subscription_type,
87 "query_or_topic": subscription.query_or_topic,
88 "refresh_interval_minutes": subscription.refresh_interval_minutes,
89 "created_at": subscription.created_at,
90 "updated_at": subscription.updated_at,
91 "last_refresh": subscription.last_refresh,
92 "next_refresh": subscription.next_refresh,
93 "expires_at": subscription.expires_at,
94 "source_type": subscription.source_type,
95 "source_id": subscription.source_id,
96 "created_from": subscription.created_from,
97 "folder": subscription.folder,
98 "folder_id": subscription.folder_id,
99 "notes": subscription.notes,
100 "status": subscription.status,
101 "is_active": getattr(subscription, "is_active", True),
102 "refresh_count": subscription.refresh_count,
103 "results_count": subscription.results_count,
104 "last_error": subscription.last_error,
105 "error_count": subscription.error_count,
106 "model_provider": getattr(subscription, "model_provider", None),
107 "model": getattr(subscription, "model", None),
108 "search_strategy": getattr(
109 subscription, "search_strategy", None
110 ),
111 "custom_endpoint": getattr(
112 subscription, "custom_endpoint", None
113 ),
114 "search_engine": getattr(subscription, "search_engine", None),
115 "search_iterations": getattr(
116 subscription, "search_iterations", 3
117 ),
118 "questions_per_iteration": getattr(
119 subscription, "questions_per_iteration", 5
120 ),
121 }
123 def update(self, id: str, data: Dict[str, Any]) -> bool:
124 """Update a subscription"""
125 with self.session as session:
126 subscription = (
127 session.query(NewsSubscription).filter_by(id=id).first()
128 )
129 if not subscription:
130 return False
132 # Update allowed fields
133 updateable_fields = [
134 "name",
135 "refresh_interval_minutes",
136 "status",
137 "is_active",
138 "expires_at",
139 "folder_id",
140 "model_provider",
141 "model",
142 "search_strategy",
143 "custom_endpoint",
144 "search_engine",
145 "search_iterations",
146 "questions_per_iteration",
147 ]
148 for field in updateable_fields:
149 if field in data:
150 # bearer:disable python_lang_code_injection
151 setattr(subscription, field, data[field])
153 # Recalculate next refresh if interval changed
154 if "refresh_interval_minutes" in data:
155 subscription.next_refresh = datetime.now(
156 timezone.utc
157 ) + timedelta(minutes=data["refresh_interval_minutes"])
159 session.commit()
160 return True
162 def delete(self, id: str) -> bool:
163 """Delete a subscription"""
164 with self.session as session:
165 subscription = (
166 session.query(NewsSubscription).filter_by(id=id).first()
167 )
168 if not subscription:
169 return False
171 session.delete(subscription)
172 session.commit()
173 return True
175 def list(
176 self,
177 filters: Optional[Dict[str, Any]] = None,
178 limit: int = 100,
179 offset: int = 0,
180 ) -> List[Dict[str, Any]]:
181 """List subscriptions with optional filtering"""
182 with self.session as session:
183 query = session.query(NewsSubscription)
185 if filters:
186 if "user_id" in filters: 186 ↛ 188line 186 didn't jump to line 188 because the condition on line 186 was always true
187 query = query.filter_by(user_id=filters["user_id"])
188 if "status" in filters: 188 ↛ 189line 188 didn't jump to line 189 because the condition on line 188 was never true
189 query = query.filter_by(
190 status=SubscriptionStatus(filters["status"])
191 )
192 if "subscription_type" in filters: 192 ↛ 193line 192 didn't jump to line 193 because the condition on line 192 was never true
193 query = query.filter_by(
194 subscription_type=SubscriptionType(
195 filters["subscription_type"]
196 )
197 )
199 subscriptions = query.limit(limit).offset(offset).all()
200 # Detach from session and convert to dicts
201 result = []
202 for sub in subscriptions:
203 session.expunge(sub)
204 result.append(
205 {
206 "id": sub.id,
207 "user_id": sub.user_id,
208 "name": sub.name,
209 "subscription_type": sub.subscription_type,
210 "query_or_topic": sub.query_or_topic,
211 "refresh_interval_minutes": sub.refresh_interval_minutes,
212 "created_at": sub.created_at,
213 "updated_at": sub.updated_at,
214 "last_refresh": sub.last_refresh,
215 "next_refresh": sub.next_refresh,
216 "status": sub.status,
217 "folder": sub.folder,
218 "notes": sub.notes,
219 }
220 )
221 return result
223 def get_active_subscriptions(
224 self, user_id: Optional[str] = None
225 ) -> List[Dict[str, Any]]:
226 """Get all active subscriptions"""
227 with self.session as session:
228 query = session.query(NewsSubscription).filter_by(
229 status=SubscriptionStatus.ACTIVE
230 )
232 if user_id:
233 query = query.filter_by(user_id=user_id)
235 subscriptions = query.all()
236 return [sub.to_dict() for sub in subscriptions]
238 def get_due_subscriptions(self, limit: int = 100) -> List[Dict[str, Any]]:
239 """Get subscriptions that are due for refresh"""
240 with self.session as session:
241 now = datetime.now(timezone.utc)
243 subscriptions = (
244 session.query(NewsSubscription)
245 .filter(
246 NewsSubscription.status == SubscriptionStatus.ACTIVE,
247 NewsSubscription.next_refresh <= now,
248 )
249 .limit(limit)
250 .all()
251 )
253 return [sub.to_dict() for sub in subscriptions]
255 def update_refresh_time(
256 self,
257 subscription_id: str,
258 last_refresh: datetime,
259 next_refresh: datetime,
260 ) -> bool:
261 """Update refresh timestamps after processing"""
262 with self.session as session:
263 subscription = (
264 session.query(NewsSubscription)
265 .filter_by(id=subscription_id)
266 .first()
267 )
268 if not subscription:
269 return False
271 subscription.last_refresh = last_refresh
272 subscription.next_refresh = next_refresh
273 session.commit()
274 return True
276 def increment_stats(self, subscription_id: str, results_count: int) -> bool:
277 """Increment refresh count and update results count"""
278 with self.session as session:
279 subscription = (
280 session.query(NewsSubscription)
281 .filter_by(id=subscription_id)
282 .first()
283 )
284 if not subscription:
285 return False
287 subscription.refresh_count += 1
288 subscription.total_runs = subscription.refresh_count # Keep in sync
289 subscription.results_count = results_count
290 session.commit()
291 return True
293 def pause_subscription(self, subscription_id: str) -> bool:
294 """Pause a subscription"""
295 with self.session as session:
296 subscription = (
297 session.query(NewsSubscription)
298 .filter_by(id=subscription_id)
299 .first()
300 )
301 if not subscription:
302 return False
304 subscription.status = SubscriptionStatus.PAUSED
305 session.commit()
306 return True
308 def resume_subscription(self, subscription_id: str) -> bool:
309 """Resume a paused subscription"""
310 with self.session as session:
311 subscription = (
312 session.query(NewsSubscription)
313 .filter_by(id=subscription_id)
314 .first()
315 )
316 if (
317 not subscription
318 or subscription.status != SubscriptionStatus.PAUSED
319 ):
320 return False
322 subscription.status = SubscriptionStatus.ACTIVE
323 # Reset next refresh time
324 subscription.next_refresh = datetime.now(timezone.utc) + timedelta(
325 minutes=subscription.refresh_interval_minutes
326 )
327 session.commit()
328 return True
330 def expire_subscription(self, subscription_id: str) -> bool:
331 """Mark a subscription as expired"""
332 with self.session as session:
333 subscription = (
334 session.query(NewsSubscription)
335 .filter_by(id=subscription_id)
336 .first()
337 )
338 if not subscription:
339 return False
341 subscription.status = SubscriptionStatus.EXPIRED
342 subscription.expires_at = datetime.now(timezone.utc)
343 session.commit()
344 return True