Coverage for src / local_deep_research / news / subscription_manager / storage.py: 14%

122 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2026-01-11 00:51 +0000

1""" 

2SQLAlchemy storage implementation for subscriptions. 

3""" 

4 

5from typing import List, Optional, Dict, Any 

6from datetime import datetime, timedelta, timezone 

7from sqlalchemy.orm import Session 

8from loguru import logger 

9 

10from ..core.storage import SubscriptionStorage 

11from ...database.models.news import ( 

12 NewsSubscription, 

13 SubscriptionType, 

14 SubscriptionStatus, 

15) 

16 

17 

18class SQLSubscriptionStorage(SubscriptionStorage): 

19 """SQLAlchemy implementation of subscription storage""" 

20 

21 def __init__(self, session: Session): 

22 """Initialize with a database session from the user's encrypted database""" 

23 if not session: 

24 raise ValueError("Session is required for SQLSubscriptionStorage") 

25 self._session = session 

26 

27 @property 

28 def session(self): 

29 """Get database session""" 

30 return self._session 

31 

32 def create(self, data: Dict[str, Any]) -> str: 

33 """Create a new subscription""" 

34 subscription_id = data.get("id") or self.generate_id() 

35 

36 with self.session as session: 

37 subscription = NewsSubscription( 

38 id=subscription_id, 

39 user_id=data["user_id"], 

40 name=data.get("name"), 

41 subscription_type=data["subscription_type"], 

42 query_or_topic=data["query_or_topic"], 

43 refresh_interval_minutes=data["refresh_interval_minutes"], 

44 frequency=data.get("frequency", "daily"), 

45 source_type=data.get("source_type"), 

46 source_id=data.get("source_id"), 

47 created_from=data.get("created_from"), 

48 folder=data.get("folder"), 

49 folder_id=data.get("folder_id"), 

50 notes=data.get("notes"), 

51 status=data.get("status", "active"), 

52 is_active=data.get("is_active", True), 

53 model_provider=data.get("model_provider"), 

54 model=data.get("model"), 

55 search_strategy=data.get("search_strategy"), 

56 custom_endpoint=data.get("custom_endpoint"), 

57 search_engine=data.get("search_engine"), 

58 search_iterations=data.get("search_iterations", 3), 

59 questions_per_iteration=data.get("questions_per_iteration", 5), 

60 next_refresh=datetime.now(timezone.utc) 

61 + timedelta(minutes=data["refresh_interval_minutes"]), 

62 ) 

63 

64 session.add(subscription) 

65 session.commit() 

66 

67 logger.info( 

68 f"Created subscription {subscription_id} for user {data['user_id']}" 

69 ) 

70 return subscription_id 

71 

72 def get(self, id: str) -> Optional[Dict[str, Any]]: 

73 """Get a subscription by ID""" 

74 with self.session as session: 

75 subscription = ( 

76 session.query(NewsSubscription).filter_by(id=id).first() 

77 ) 

78 if not subscription: 

79 return None 

80 

81 # Convert to dict manually 

82 return { 

83 "id": subscription.id, 

84 "user_id": subscription.user_id, 

85 "name": subscription.name, 

86 "subscription_type": subscription.subscription_type, 

87 "query_or_topic": subscription.query_or_topic, 

88 "refresh_interval_minutes": subscription.refresh_interval_minutes, 

89 "created_at": subscription.created_at, 

90 "updated_at": subscription.updated_at, 

91 "last_refresh": subscription.last_refresh, 

92 "next_refresh": subscription.next_refresh, 

93 "expires_at": subscription.expires_at, 

94 "source_type": subscription.source_type, 

95 "source_id": subscription.source_id, 

96 "created_from": subscription.created_from, 

97 "folder": subscription.folder, 

98 "folder_id": subscription.folder_id, 

99 "notes": subscription.notes, 

100 "status": subscription.status, 

101 "is_active": getattr(subscription, "is_active", True), 

102 "refresh_count": subscription.refresh_count, 

103 "results_count": subscription.results_count, 

104 "last_error": subscription.last_error, 

105 "error_count": subscription.error_count, 

106 "model_provider": getattr(subscription, "model_provider", None), 

107 "model": getattr(subscription, "model", None), 

108 "search_strategy": getattr( 

109 subscription, "search_strategy", None 

110 ), 

111 "custom_endpoint": getattr( 

112 subscription, "custom_endpoint", None 

113 ), 

114 "search_engine": getattr(subscription, "search_engine", None), 

115 "search_iterations": getattr( 

116 subscription, "search_iterations", 3 

117 ), 

118 "questions_per_iteration": getattr( 

119 subscription, "questions_per_iteration", 5 

120 ), 

121 } 

122 

123 def update(self, id: str, data: Dict[str, Any]) -> bool: 

124 """Update a subscription""" 

125 with self.session as session: 

126 subscription = ( 

127 session.query(NewsSubscription).filter_by(id=id).first() 

128 ) 

129 if not subscription: 

130 return False 

131 

132 # Update allowed fields 

133 updateable_fields = [ 

134 "name", 

135 "refresh_interval_minutes", 

136 "status", 

137 "is_active", 

138 "expires_at", 

139 "folder_id", 

140 "model_provider", 

141 "model", 

142 "search_strategy", 

143 "custom_endpoint", 

144 "search_engine", 

145 "search_iterations", 

146 "questions_per_iteration", 

147 ] 

148 for field in updateable_fields: 

149 if field in data: 

150 setattr(subscription, field, data[field]) 

151 

152 # Recalculate next refresh if interval changed 

153 if "refresh_interval_minutes" in data: 

154 subscription.next_refresh = datetime.now( 

155 timezone.utc 

156 ) + timedelta(minutes=data["refresh_interval_minutes"]) 

157 

158 session.commit() 

159 return True 

160 

161 def delete(self, id: str) -> bool: 

162 """Delete a subscription""" 

163 with self.session as session: 

164 subscription = ( 

165 session.query(NewsSubscription).filter_by(id=id).first() 

166 ) 

167 if not subscription: 

168 return False 

169 

170 session.delete(subscription) 

171 session.commit() 

172 return True 

173 

174 def list( 

175 self, 

176 filters: Optional[Dict[str, Any]] = None, 

177 limit: int = 100, 

178 offset: int = 0, 

179 ) -> List[Dict[str, Any]]: 

180 """List subscriptions with optional filtering""" 

181 with self.session as session: 

182 query = session.query(NewsSubscription) 

183 

184 if filters: 

185 if "user_id" in filters: 

186 query = query.filter_by(user_id=filters["user_id"]) 

187 if "status" in filters: 

188 query = query.filter_by( 

189 status=SubscriptionStatus(filters["status"]) 

190 ) 

191 if "subscription_type" in filters: 

192 query = query.filter_by( 

193 subscription_type=SubscriptionType( 

194 filters["subscription_type"] 

195 ) 

196 ) 

197 

198 subscriptions = query.limit(limit).offset(offset).all() 

199 # Detach from session and convert to dicts 

200 result = [] 

201 for sub in subscriptions: 

202 session.expunge(sub) 

203 result.append( 

204 { 

205 "id": sub.id, 

206 "user_id": sub.user_id, 

207 "name": sub.name, 

208 "subscription_type": sub.subscription_type, 

209 "query_or_topic": sub.query_or_topic, 

210 "refresh_interval_minutes": sub.refresh_interval_minutes, 

211 "created_at": sub.created_at, 

212 "updated_at": sub.updated_at, 

213 "last_refresh": sub.last_refresh, 

214 "next_refresh": sub.next_refresh, 

215 "status": sub.status, 

216 "folder": sub.folder, 

217 "notes": sub.notes, 

218 } 

219 ) 

220 return result 

221 

222 def get_active_subscriptions( 

223 self, user_id: Optional[str] = None 

224 ) -> List[Dict[str, Any]]: 

225 """Get all active subscriptions""" 

226 with self.session as session: 

227 query = session.query(NewsSubscription).filter_by( 

228 status=SubscriptionStatus.ACTIVE 

229 ) 

230 

231 if user_id: 

232 query = query.filter_by(user_id=user_id) 

233 

234 subscriptions = query.all() 

235 return [sub.to_dict() for sub in subscriptions] 

236 

237 def get_due_subscriptions(self, limit: int = 100) -> List[Dict[str, Any]]: 

238 """Get subscriptions that are due for refresh""" 

239 with self.session as session: 

240 now = datetime.now(timezone.utc) 

241 

242 subscriptions = ( 

243 session.query(NewsSubscription) 

244 .filter( 

245 NewsSubscription.status == SubscriptionStatus.ACTIVE, 

246 NewsSubscription.next_refresh <= now, 

247 ) 

248 .limit(limit) 

249 .all() 

250 ) 

251 

252 return [sub.to_dict() for sub in subscriptions] 

253 

254 def update_refresh_time( 

255 self, 

256 subscription_id: str, 

257 last_refresh: datetime, 

258 next_refresh: datetime, 

259 ) -> bool: 

260 """Update refresh timestamps after processing""" 

261 with self.session as session: 

262 subscription = ( 

263 session.query(NewsSubscription) 

264 .filter_by(id=subscription_id) 

265 .first() 

266 ) 

267 if not subscription: 

268 return False 

269 

270 subscription.last_refresh = last_refresh 

271 subscription.next_refresh = next_refresh 

272 session.commit() 

273 return True 

274 

275 def increment_stats(self, subscription_id: str, results_count: int) -> bool: 

276 """Increment refresh count and update results count""" 

277 with self.session as session: 

278 subscription = ( 

279 session.query(NewsSubscription) 

280 .filter_by(id=subscription_id) 

281 .first() 

282 ) 

283 if not subscription: 

284 return False 

285 

286 subscription.refresh_count += 1 

287 subscription.total_runs = subscription.refresh_count # Keep in sync 

288 subscription.results_count = results_count 

289 session.commit() 

290 return True 

291 

292 def pause_subscription(self, subscription_id: str) -> bool: 

293 """Pause a subscription""" 

294 with self.session as session: 

295 subscription = ( 

296 session.query(NewsSubscription) 

297 .filter_by(id=subscription_id) 

298 .first() 

299 ) 

300 if not subscription: 

301 return False 

302 

303 subscription.status = SubscriptionStatus.PAUSED 

304 session.commit() 

305 return True 

306 

307 def resume_subscription(self, subscription_id: str) -> bool: 

308 """Resume a paused subscription""" 

309 with self.session as session: 

310 subscription = ( 

311 session.query(NewsSubscription) 

312 .filter_by(id=subscription_id) 

313 .first() 

314 ) 

315 if ( 

316 not subscription 

317 or subscription.status != SubscriptionStatus.PAUSED 

318 ): 

319 return False 

320 

321 subscription.status = SubscriptionStatus.ACTIVE 

322 # Reset next refresh time 

323 subscription.next_refresh = datetime.now(timezone.utc) + timedelta( 

324 minutes=subscription.refresh_interval_minutes 

325 ) 

326 session.commit() 

327 return True 

328 

329 def expire_subscription(self, subscription_id: str) -> bool: 

330 """Mark a subscription as expired""" 

331 with self.session as session: 

332 subscription = ( 

333 session.query(NewsSubscription) 

334 .filter_by(id=subscription_id) 

335 .first() 

336 ) 

337 if not subscription: 

338 return False 

339 

340 subscription.status = SubscriptionStatus.EXPIRED 

341 subscription.expires_at = datetime.now(timezone.utc) 

342 session.commit() 

343 return True