Coverage for src / local_deep_research / news / subscription_manager / storage.py: 97%
120 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:55 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:55 +0000
1"""
2SQLAlchemy storage implementation for subscriptions.
3"""
5from typing import List, Optional, Dict, Any
6from datetime import datetime, timedelta, timezone
7from sqlalchemy.orm import Session
8from loguru import logger
10from ..core.storage import SubscriptionStorage
11from ...database.models.news import (
12 NewsSubscription,
13 SubscriptionType,
14 SubscriptionStatus,
15)
18class SQLSubscriptionStorage(SubscriptionStorage):
19 """SQLAlchemy implementation of subscription storage"""
21 def __init__(self, session: Optional[Session] = None):
22 """Initialize with a database session from the user's encrypted database"""
23 self._session = session
25 @property
26 def session(self):
27 """Get database session"""
28 return self._session
30 def create(self, data: Dict[str, Any]) -> str:
31 """Create a new subscription"""
32 subscription_id = data.get("id") or self.generate_id()
34 with self.session as session:
35 subscription = NewsSubscription(
36 id=subscription_id,
37 user_id=data["user_id"],
38 name=data.get("name"),
39 subscription_type=data["subscription_type"],
40 query_or_topic=data["query_or_topic"],
41 refresh_interval_minutes=data["refresh_interval_minutes"],
42 frequency=data.get("frequency", "daily"),
43 source_type=data.get("source_type"),
44 source_id=data.get("source_id"),
45 created_from=data.get("created_from"),
46 folder=data.get("folder"),
47 folder_id=data.get("folder_id"),
48 notes=data.get("notes"),
49 status=data.get("status", "active"),
50 is_active=data.get("is_active", True),
51 model_provider=data.get("model_provider"),
52 model=data.get("model"),
53 search_strategy=data.get("search_strategy"),
54 custom_endpoint=data.get("custom_endpoint"),
55 search_engine=data.get("search_engine"),
56 search_iterations=data.get("search_iterations", 3),
57 questions_per_iteration=data.get("questions_per_iteration", 5),
58 next_refresh=datetime.now(timezone.utc)
59 + timedelta(minutes=data["refresh_interval_minutes"]),
60 )
62 session.add(subscription)
63 session.commit()
65 logger.info(
66 f"Created subscription {subscription_id} for user {data['user_id']}"
67 )
68 return subscription_id
70 def get(self, id: str) -> Optional[Dict[str, Any]]:
71 """Get a subscription by ID"""
72 with self.session as session:
73 subscription = (
74 session.query(NewsSubscription).filter_by(id=id).first()
75 )
76 if not subscription:
77 return None
79 # Convert to dict manually
80 return {
81 "id": subscription.id,
82 "user_id": subscription.user_id,
83 "name": subscription.name,
84 "subscription_type": subscription.subscription_type,
85 "query_or_topic": subscription.query_or_topic,
86 "refresh_interval_minutes": subscription.refresh_interval_minutes,
87 "created_at": subscription.created_at,
88 "updated_at": subscription.updated_at,
89 "last_refresh": subscription.last_refresh,
90 "next_refresh": subscription.next_refresh,
91 "expires_at": subscription.expires_at,
92 "source_type": subscription.source_type,
93 "source_id": subscription.source_id,
94 "created_from": subscription.created_from,
95 "folder": subscription.folder,
96 "folder_id": subscription.folder_id,
97 "notes": subscription.notes,
98 "status": subscription.status,
99 "is_active": getattr(subscription, "is_active", True),
100 "refresh_count": subscription.refresh_count,
101 "results_count": subscription.results_count,
102 "last_error": subscription.last_error,
103 "error_count": subscription.error_count,
104 "model_provider": getattr(subscription, "model_provider", None),
105 "model": getattr(subscription, "model", None),
106 "search_strategy": getattr(
107 subscription, "search_strategy", None
108 ),
109 "custom_endpoint": getattr(
110 subscription, "custom_endpoint", None
111 ),
112 "search_engine": getattr(subscription, "search_engine", None),
113 "search_iterations": getattr(
114 subscription, "search_iterations", 3
115 ),
116 "questions_per_iteration": getattr(
117 subscription, "questions_per_iteration", 5
118 ),
119 }
121 def update(self, id: str, data: Dict[str, Any]) -> bool:
122 """Update a subscription"""
123 with self.session as session:
124 subscription = (
125 session.query(NewsSubscription).filter_by(id=id).first()
126 )
127 if not subscription:
128 return False
130 # Update allowed fields
131 updateable_fields = [
132 "name",
133 "refresh_interval_minutes",
134 "status",
135 "is_active",
136 "expires_at",
137 "folder_id",
138 "model_provider",
139 "model",
140 "search_strategy",
141 "custom_endpoint",
142 "search_engine",
143 "search_iterations",
144 "questions_per_iteration",
145 ]
146 for field in updateable_fields:
147 if field in data:
148 # bearer:disable python_lang_code_injection
149 setattr(subscription, field, data[field])
151 # Recalculate next refresh if interval changed
152 if "refresh_interval_minutes" in data:
153 subscription.next_refresh = datetime.now(
154 timezone.utc
155 ) + timedelta(minutes=data["refresh_interval_minutes"])
157 session.commit()
158 return True
160 def delete(self, id: str) -> bool:
161 """Delete a subscription"""
162 with self.session as session:
163 subscription = (
164 session.query(NewsSubscription).filter_by(id=id).first()
165 )
166 if not subscription:
167 return False
169 session.delete(subscription)
170 session.commit()
171 return True
173 def list(
174 self,
175 filters: Optional[Dict[str, Any]] = None,
176 limit: int = 100,
177 offset: int = 0,
178 ) -> List[Dict[str, Any]]:
179 """List subscriptions with optional filtering"""
180 with self.session as session:
181 query = session.query(NewsSubscription)
183 if filters:
184 if "user_id" in filters: 184 ↛ 186line 184 didn't jump to line 186 because the condition on line 184 was always true
185 query = query.filter_by(user_id=filters["user_id"])
186 if "status" in filters: 186 ↛ 187line 186 didn't jump to line 187 because the condition on line 186 was never true
187 query = query.filter_by(
188 status=SubscriptionStatus(filters["status"])
189 )
190 if "subscription_type" in filters: 190 ↛ 191line 190 didn't jump to line 191 because the condition on line 190 was never true
191 query = query.filter_by(
192 subscription_type=SubscriptionType(
193 filters["subscription_type"]
194 )
195 )
197 subscriptions = query.limit(limit).offset(offset).all()
198 # Detach from session and convert to dicts
199 result = []
200 for sub in subscriptions:
201 session.expunge(sub)
202 result.append(
203 {
204 "id": sub.id,
205 "user_id": sub.user_id,
206 "name": sub.name,
207 "subscription_type": sub.subscription_type,
208 "query_or_topic": sub.query_or_topic,
209 "refresh_interval_minutes": sub.refresh_interval_minutes,
210 "created_at": sub.created_at,
211 "updated_at": sub.updated_at,
212 "last_refresh": sub.last_refresh,
213 "next_refresh": sub.next_refresh,
214 "status": sub.status,
215 "folder": sub.folder,
216 "notes": sub.notes,
217 }
218 )
219 return result
221 def get_active_subscriptions(
222 self, user_id: Optional[str] = None
223 ) -> List[Dict[str, Any]]:
224 """Get all active subscriptions"""
225 with self.session as session:
226 query = session.query(NewsSubscription).filter_by(
227 status=SubscriptionStatus.ACTIVE
228 )
230 if user_id:
231 query = query.filter_by(user_id=user_id)
233 subscriptions = query.all()
234 return [sub.to_dict() for sub in subscriptions]
236 def get_due_subscriptions(self, limit: int = 100) -> List[Dict[str, Any]]:
237 """Get subscriptions that are due for refresh"""
238 with self.session as session:
239 now = datetime.now(timezone.utc)
241 subscriptions = (
242 session.query(NewsSubscription)
243 .filter(
244 NewsSubscription.status == SubscriptionStatus.ACTIVE,
245 NewsSubscription.next_refresh <= now,
246 )
247 .limit(limit)
248 .all()
249 )
251 return [sub.to_dict() for sub in subscriptions]
253 def update_refresh_time(
254 self,
255 subscription_id: str,
256 last_refresh: datetime,
257 next_refresh: datetime,
258 ) -> bool:
259 """Update refresh timestamps after processing"""
260 with self.session as session:
261 subscription = (
262 session.query(NewsSubscription)
263 .filter_by(id=subscription_id)
264 .first()
265 )
266 if not subscription:
267 return False
269 subscription.last_refresh = last_refresh
270 subscription.next_refresh = next_refresh
271 session.commit()
272 return True
274 def increment_stats(self, subscription_id: str, results_count: int) -> bool:
275 """Increment refresh count and update results count"""
276 with self.session as session:
277 subscription = (
278 session.query(NewsSubscription)
279 .filter_by(id=subscription_id)
280 .first()
281 )
282 if not subscription:
283 return False
285 subscription.refresh_count += 1
286 subscription.total_runs = subscription.refresh_count # Keep in sync
287 subscription.results_count = results_count
288 session.commit()
289 return True
291 def pause_subscription(self, subscription_id: str) -> bool:
292 """Pause a subscription"""
293 with self.session as session:
294 subscription = (
295 session.query(NewsSubscription)
296 .filter_by(id=subscription_id)
297 .first()
298 )
299 if not subscription:
300 return False
302 subscription.status = SubscriptionStatus.PAUSED
303 session.commit()
304 return True
306 def resume_subscription(self, subscription_id: str) -> bool:
307 """Resume a paused subscription"""
308 with self.session as session:
309 subscription = (
310 session.query(NewsSubscription)
311 .filter_by(id=subscription_id)
312 .first()
313 )
314 if (
315 not subscription
316 or subscription.status != SubscriptionStatus.PAUSED
317 ):
318 return False
320 subscription.status = SubscriptionStatus.ACTIVE
321 # Reset next refresh time
322 subscription.next_refresh = datetime.now(timezone.utc) + timedelta(
323 minutes=subscription.refresh_interval_minutes
324 )
325 session.commit()
326 return True
328 def expire_subscription(self, subscription_id: str) -> bool:
329 """Mark a subscription as expired"""
330 with self.session as session:
331 subscription = (
332 session.query(NewsSubscription)
333 .filter_by(id=subscription_id)
334 .first()
335 )
336 if not subscription:
337 return False
339 subscription.status = SubscriptionStatus.EXPIRED
340 subscription.expires_at = datetime.now(timezone.utc)
341 session.commit()
342 return True