Coverage for src / local_deep_research / news / subscription_manager / storage.py: 14%
122 statements
« prev ^ index » next coverage.py v7.12.0, created at 2026-01-11 00:51 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2026-01-11 00:51 +0000
1"""
2SQLAlchemy storage implementation for subscriptions.
3"""
5from typing import List, Optional, Dict, Any
6from datetime import datetime, timedelta, timezone
7from sqlalchemy.orm import Session
8from loguru import logger
10from ..core.storage import SubscriptionStorage
11from ...database.models.news import (
12 NewsSubscription,
13 SubscriptionType,
14 SubscriptionStatus,
15)
18class SQLSubscriptionStorage(SubscriptionStorage):
19 """SQLAlchemy implementation of subscription storage"""
21 def __init__(self, session: Session):
22 """Initialize with a database session from the user's encrypted database"""
23 if not session:
24 raise ValueError("Session is required for SQLSubscriptionStorage")
25 self._session = session
27 @property
28 def session(self):
29 """Get database session"""
30 return self._session
32 def create(self, data: Dict[str, Any]) -> str:
33 """Create a new subscription"""
34 subscription_id = data.get("id") or self.generate_id()
36 with self.session as session:
37 subscription = NewsSubscription(
38 id=subscription_id,
39 user_id=data["user_id"],
40 name=data.get("name"),
41 subscription_type=data["subscription_type"],
42 query_or_topic=data["query_or_topic"],
43 refresh_interval_minutes=data["refresh_interval_minutes"],
44 frequency=data.get("frequency", "daily"),
45 source_type=data.get("source_type"),
46 source_id=data.get("source_id"),
47 created_from=data.get("created_from"),
48 folder=data.get("folder"),
49 folder_id=data.get("folder_id"),
50 notes=data.get("notes"),
51 status=data.get("status", "active"),
52 is_active=data.get("is_active", True),
53 model_provider=data.get("model_provider"),
54 model=data.get("model"),
55 search_strategy=data.get("search_strategy"),
56 custom_endpoint=data.get("custom_endpoint"),
57 search_engine=data.get("search_engine"),
58 search_iterations=data.get("search_iterations", 3),
59 questions_per_iteration=data.get("questions_per_iteration", 5),
60 next_refresh=datetime.now(timezone.utc)
61 + timedelta(minutes=data["refresh_interval_minutes"]),
62 )
64 session.add(subscription)
65 session.commit()
67 logger.info(
68 f"Created subscription {subscription_id} for user {data['user_id']}"
69 )
70 return subscription_id
72 def get(self, id: str) -> Optional[Dict[str, Any]]:
73 """Get a subscription by ID"""
74 with self.session as session:
75 subscription = (
76 session.query(NewsSubscription).filter_by(id=id).first()
77 )
78 if not subscription:
79 return None
81 # Convert to dict manually
82 return {
83 "id": subscription.id,
84 "user_id": subscription.user_id,
85 "name": subscription.name,
86 "subscription_type": subscription.subscription_type,
87 "query_or_topic": subscription.query_or_topic,
88 "refresh_interval_minutes": subscription.refresh_interval_minutes,
89 "created_at": subscription.created_at,
90 "updated_at": subscription.updated_at,
91 "last_refresh": subscription.last_refresh,
92 "next_refresh": subscription.next_refresh,
93 "expires_at": subscription.expires_at,
94 "source_type": subscription.source_type,
95 "source_id": subscription.source_id,
96 "created_from": subscription.created_from,
97 "folder": subscription.folder,
98 "folder_id": subscription.folder_id,
99 "notes": subscription.notes,
100 "status": subscription.status,
101 "is_active": getattr(subscription, "is_active", True),
102 "refresh_count": subscription.refresh_count,
103 "results_count": subscription.results_count,
104 "last_error": subscription.last_error,
105 "error_count": subscription.error_count,
106 "model_provider": getattr(subscription, "model_provider", None),
107 "model": getattr(subscription, "model", None),
108 "search_strategy": getattr(
109 subscription, "search_strategy", None
110 ),
111 "custom_endpoint": getattr(
112 subscription, "custom_endpoint", None
113 ),
114 "search_engine": getattr(subscription, "search_engine", None),
115 "search_iterations": getattr(
116 subscription, "search_iterations", 3
117 ),
118 "questions_per_iteration": getattr(
119 subscription, "questions_per_iteration", 5
120 ),
121 }
123 def update(self, id: str, data: Dict[str, Any]) -> bool:
124 """Update a subscription"""
125 with self.session as session:
126 subscription = (
127 session.query(NewsSubscription).filter_by(id=id).first()
128 )
129 if not subscription:
130 return False
132 # Update allowed fields
133 updateable_fields = [
134 "name",
135 "refresh_interval_minutes",
136 "status",
137 "is_active",
138 "expires_at",
139 "folder_id",
140 "model_provider",
141 "model",
142 "search_strategy",
143 "custom_endpoint",
144 "search_engine",
145 "search_iterations",
146 "questions_per_iteration",
147 ]
148 for field in updateable_fields:
149 if field in data:
150 setattr(subscription, field, data[field])
152 # Recalculate next refresh if interval changed
153 if "refresh_interval_minutes" in data:
154 subscription.next_refresh = datetime.now(
155 timezone.utc
156 ) + timedelta(minutes=data["refresh_interval_minutes"])
158 session.commit()
159 return True
161 def delete(self, id: str) -> bool:
162 """Delete a subscription"""
163 with self.session as session:
164 subscription = (
165 session.query(NewsSubscription).filter_by(id=id).first()
166 )
167 if not subscription:
168 return False
170 session.delete(subscription)
171 session.commit()
172 return True
174 def list(
175 self,
176 filters: Optional[Dict[str, Any]] = None,
177 limit: int = 100,
178 offset: int = 0,
179 ) -> List[Dict[str, Any]]:
180 """List subscriptions with optional filtering"""
181 with self.session as session:
182 query = session.query(NewsSubscription)
184 if filters:
185 if "user_id" in filters:
186 query = query.filter_by(user_id=filters["user_id"])
187 if "status" in filters:
188 query = query.filter_by(
189 status=SubscriptionStatus(filters["status"])
190 )
191 if "subscription_type" in filters:
192 query = query.filter_by(
193 subscription_type=SubscriptionType(
194 filters["subscription_type"]
195 )
196 )
198 subscriptions = query.limit(limit).offset(offset).all()
199 # Detach from session and convert to dicts
200 result = []
201 for sub in subscriptions:
202 session.expunge(sub)
203 result.append(
204 {
205 "id": sub.id,
206 "user_id": sub.user_id,
207 "name": sub.name,
208 "subscription_type": sub.subscription_type,
209 "query_or_topic": sub.query_or_topic,
210 "refresh_interval_minutes": sub.refresh_interval_minutes,
211 "created_at": sub.created_at,
212 "updated_at": sub.updated_at,
213 "last_refresh": sub.last_refresh,
214 "next_refresh": sub.next_refresh,
215 "status": sub.status,
216 "folder": sub.folder,
217 "notes": sub.notes,
218 }
219 )
220 return result
222 def get_active_subscriptions(
223 self, user_id: Optional[str] = None
224 ) -> List[Dict[str, Any]]:
225 """Get all active subscriptions"""
226 with self.session as session:
227 query = session.query(NewsSubscription).filter_by(
228 status=SubscriptionStatus.ACTIVE
229 )
231 if user_id:
232 query = query.filter_by(user_id=user_id)
234 subscriptions = query.all()
235 return [sub.to_dict() for sub in subscriptions]
237 def get_due_subscriptions(self, limit: int = 100) -> List[Dict[str, Any]]:
238 """Get subscriptions that are due for refresh"""
239 with self.session as session:
240 now = datetime.now(timezone.utc)
242 subscriptions = (
243 session.query(NewsSubscription)
244 .filter(
245 NewsSubscription.status == SubscriptionStatus.ACTIVE,
246 NewsSubscription.next_refresh <= now,
247 )
248 .limit(limit)
249 .all()
250 )
252 return [sub.to_dict() for sub in subscriptions]
254 def update_refresh_time(
255 self,
256 subscription_id: str,
257 last_refresh: datetime,
258 next_refresh: datetime,
259 ) -> bool:
260 """Update refresh timestamps after processing"""
261 with self.session as session:
262 subscription = (
263 session.query(NewsSubscription)
264 .filter_by(id=subscription_id)
265 .first()
266 )
267 if not subscription:
268 return False
270 subscription.last_refresh = last_refresh
271 subscription.next_refresh = next_refresh
272 session.commit()
273 return True
275 def increment_stats(self, subscription_id: str, results_count: int) -> bool:
276 """Increment refresh count and update results count"""
277 with self.session as session:
278 subscription = (
279 session.query(NewsSubscription)
280 .filter_by(id=subscription_id)
281 .first()
282 )
283 if not subscription:
284 return False
286 subscription.refresh_count += 1
287 subscription.total_runs = subscription.refresh_count # Keep in sync
288 subscription.results_count = results_count
289 session.commit()
290 return True
292 def pause_subscription(self, subscription_id: str) -> bool:
293 """Pause a subscription"""
294 with self.session as session:
295 subscription = (
296 session.query(NewsSubscription)
297 .filter_by(id=subscription_id)
298 .first()
299 )
300 if not subscription:
301 return False
303 subscription.status = SubscriptionStatus.PAUSED
304 session.commit()
305 return True
307 def resume_subscription(self, subscription_id: str) -> bool:
308 """Resume a paused subscription"""
309 with self.session as session:
310 subscription = (
311 session.query(NewsSubscription)
312 .filter_by(id=subscription_id)
313 .first()
314 )
315 if (
316 not subscription
317 or subscription.status != SubscriptionStatus.PAUSED
318 ):
319 return False
321 subscription.status = SubscriptionStatus.ACTIVE
322 # Reset next refresh time
323 subscription.next_refresh = datetime.now(timezone.utc) + timedelta(
324 minutes=subscription.refresh_interval_minutes
325 )
326 session.commit()
327 return True
329 def expire_subscription(self, subscription_id: str) -> bool:
330 """Mark a subscription as expired"""
331 with self.session as session:
332 subscription = (
333 session.query(NewsSubscription)
334 .filter_by(id=subscription_id)
335 .first()
336 )
337 if not subscription:
338 return False
340 subscription.status = SubscriptionStatus.EXPIRED
341 subscription.expires_at = datetime.now(timezone.utc)
342 session.commit()
343 return True