Coverage for src / local_deep_research / news / subscription_manager / storage.py: 97%

122 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-25 01:07 +0000

1""" 

2SQLAlchemy storage implementation for subscriptions. 

3""" 

4 

5from typing import List, Optional, Dict, Any 

6from datetime import datetime, timedelta, timezone 

7from sqlalchemy.orm import Session 

8from loguru import logger 

9 

10from ..core.storage import SubscriptionStorage 

11from ...database.models.news import ( 

12 NewsSubscription, 

13 SubscriptionType, 

14 SubscriptionStatus, 

15) 

16 

17 

18class SQLSubscriptionStorage(SubscriptionStorage): 

19 """SQLAlchemy implementation of subscription storage""" 

20 

21 def __init__(self, session: Session): 

22 """Initialize with a database session from the user's encrypted database""" 

23 if not session: 

24 raise ValueError("Session is required for SQLSubscriptionStorage") 

25 self._session = session 

26 

27 @property 

28 def session(self): 

29 """Get database session""" 

30 return self._session 

31 

32 def create(self, data: Dict[str, Any]) -> str: 

33 """Create a new subscription""" 

34 subscription_id = data.get("id") or self.generate_id() 

35 

36 with self.session as session: 

37 subscription = NewsSubscription( 

38 id=subscription_id, 

39 user_id=data["user_id"], 

40 name=data.get("name"), 

41 subscription_type=data["subscription_type"], 

42 query_or_topic=data["query_or_topic"], 

43 refresh_interval_minutes=data["refresh_interval_minutes"], 

44 frequency=data.get("frequency", "daily"), 

45 source_type=data.get("source_type"), 

46 source_id=data.get("source_id"), 

47 created_from=data.get("created_from"), 

48 folder=data.get("folder"), 

49 folder_id=data.get("folder_id"), 

50 notes=data.get("notes"), 

51 status=data.get("status", "active"), 

52 is_active=data.get("is_active", True), 

53 model_provider=data.get("model_provider"), 

54 model=data.get("model"), 

55 search_strategy=data.get("search_strategy"), 

56 custom_endpoint=data.get("custom_endpoint"), 

57 search_engine=data.get("search_engine"), 

58 search_iterations=data.get("search_iterations", 3), 

59 questions_per_iteration=data.get("questions_per_iteration", 5), 

60 next_refresh=datetime.now(timezone.utc) 

61 + timedelta(minutes=data["refresh_interval_minutes"]), 

62 ) 

63 

64 session.add(subscription) 

65 session.commit() 

66 

67 logger.info( 

68 f"Created subscription {subscription_id} for user {data['user_id']}" 

69 ) 

70 return subscription_id 

71 

72 def get(self, id: str) -> Optional[Dict[str, Any]]: 

73 """Get a subscription by ID""" 

74 with self.session as session: 

75 subscription = ( 

76 session.query(NewsSubscription).filter_by(id=id).first() 

77 ) 

78 if not subscription: 

79 return None 

80 

81 # Convert to dict manually 

82 return { 

83 "id": subscription.id, 

84 "user_id": subscription.user_id, 

85 "name": subscription.name, 

86 "subscription_type": subscription.subscription_type, 

87 "query_or_topic": subscription.query_or_topic, 

88 "refresh_interval_minutes": subscription.refresh_interval_minutes, 

89 "created_at": subscription.created_at, 

90 "updated_at": subscription.updated_at, 

91 "last_refresh": subscription.last_refresh, 

92 "next_refresh": subscription.next_refresh, 

93 "expires_at": subscription.expires_at, 

94 "source_type": subscription.source_type, 

95 "source_id": subscription.source_id, 

96 "created_from": subscription.created_from, 

97 "folder": subscription.folder, 

98 "folder_id": subscription.folder_id, 

99 "notes": subscription.notes, 

100 "status": subscription.status, 

101 "is_active": getattr(subscription, "is_active", True), 

102 "refresh_count": subscription.refresh_count, 

103 "results_count": subscription.results_count, 

104 "last_error": subscription.last_error, 

105 "error_count": subscription.error_count, 

106 "model_provider": getattr(subscription, "model_provider", None), 

107 "model": getattr(subscription, "model", None), 

108 "search_strategy": getattr( 

109 subscription, "search_strategy", None 

110 ), 

111 "custom_endpoint": getattr( 

112 subscription, "custom_endpoint", None 

113 ), 

114 "search_engine": getattr(subscription, "search_engine", None), 

115 "search_iterations": getattr( 

116 subscription, "search_iterations", 3 

117 ), 

118 "questions_per_iteration": getattr( 

119 subscription, "questions_per_iteration", 5 

120 ), 

121 } 

122 

123 def update(self, id: str, data: Dict[str, Any]) -> bool: 

124 """Update a subscription""" 

125 with self.session as session: 

126 subscription = ( 

127 session.query(NewsSubscription).filter_by(id=id).first() 

128 ) 

129 if not subscription: 

130 return False 

131 

132 # Update allowed fields 

133 updateable_fields = [ 

134 "name", 

135 "refresh_interval_minutes", 

136 "status", 

137 "is_active", 

138 "expires_at", 

139 "folder_id", 

140 "model_provider", 

141 "model", 

142 "search_strategy", 

143 "custom_endpoint", 

144 "search_engine", 

145 "search_iterations", 

146 "questions_per_iteration", 

147 ] 

148 for field in updateable_fields: 

149 if field in data: 

150 # bearer:disable python_lang_code_injection 

151 setattr(subscription, field, data[field]) 

152 

153 # Recalculate next refresh if interval changed 

154 if "refresh_interval_minutes" in data: 

155 subscription.next_refresh = datetime.now( 

156 timezone.utc 

157 ) + timedelta(minutes=data["refresh_interval_minutes"]) 

158 

159 session.commit() 

160 return True 

161 

162 def delete(self, id: str) -> bool: 

163 """Delete a subscription""" 

164 with self.session as session: 

165 subscription = ( 

166 session.query(NewsSubscription).filter_by(id=id).first() 

167 ) 

168 if not subscription: 

169 return False 

170 

171 session.delete(subscription) 

172 session.commit() 

173 return True 

174 

175 def list( 

176 self, 

177 filters: Optional[Dict[str, Any]] = None, 

178 limit: int = 100, 

179 offset: int = 0, 

180 ) -> List[Dict[str, Any]]: 

181 """List subscriptions with optional filtering""" 

182 with self.session as session: 

183 query = session.query(NewsSubscription) 

184 

185 if filters: 

186 if "user_id" in filters: 186 ↛ 188line 186 didn't jump to line 188 because the condition on line 186 was always true

187 query = query.filter_by(user_id=filters["user_id"]) 

188 if "status" in filters: 188 ↛ 189line 188 didn't jump to line 189 because the condition on line 188 was never true

189 query = query.filter_by( 

190 status=SubscriptionStatus(filters["status"]) 

191 ) 

192 if "subscription_type" in filters: 192 ↛ 193line 192 didn't jump to line 193 because the condition on line 192 was never true

193 query = query.filter_by( 

194 subscription_type=SubscriptionType( 

195 filters["subscription_type"] 

196 ) 

197 ) 

198 

199 subscriptions = query.limit(limit).offset(offset).all() 

200 # Detach from session and convert to dicts 

201 result = [] 

202 for sub in subscriptions: 

203 session.expunge(sub) 

204 result.append( 

205 { 

206 "id": sub.id, 

207 "user_id": sub.user_id, 

208 "name": sub.name, 

209 "subscription_type": sub.subscription_type, 

210 "query_or_topic": sub.query_or_topic, 

211 "refresh_interval_minutes": sub.refresh_interval_minutes, 

212 "created_at": sub.created_at, 

213 "updated_at": sub.updated_at, 

214 "last_refresh": sub.last_refresh, 

215 "next_refresh": sub.next_refresh, 

216 "status": sub.status, 

217 "folder": sub.folder, 

218 "notes": sub.notes, 

219 } 

220 ) 

221 return result 

222 

223 def get_active_subscriptions( 

224 self, user_id: Optional[str] = None 

225 ) -> List[Dict[str, Any]]: 

226 """Get all active subscriptions""" 

227 with self.session as session: 

228 query = session.query(NewsSubscription).filter_by( 

229 status=SubscriptionStatus.ACTIVE 

230 ) 

231 

232 if user_id: 

233 query = query.filter_by(user_id=user_id) 

234 

235 subscriptions = query.all() 

236 return [sub.to_dict() for sub in subscriptions] 

237 

238 def get_due_subscriptions(self, limit: int = 100) -> List[Dict[str, Any]]: 

239 """Get subscriptions that are due for refresh""" 

240 with self.session as session: 

241 now = datetime.now(timezone.utc) 

242 

243 subscriptions = ( 

244 session.query(NewsSubscription) 

245 .filter( 

246 NewsSubscription.status == SubscriptionStatus.ACTIVE, 

247 NewsSubscription.next_refresh <= now, 

248 ) 

249 .limit(limit) 

250 .all() 

251 ) 

252 

253 return [sub.to_dict() for sub in subscriptions] 

254 

255 def update_refresh_time( 

256 self, 

257 subscription_id: str, 

258 last_refresh: datetime, 

259 next_refresh: datetime, 

260 ) -> bool: 

261 """Update refresh timestamps after processing""" 

262 with self.session as session: 

263 subscription = ( 

264 session.query(NewsSubscription) 

265 .filter_by(id=subscription_id) 

266 .first() 

267 ) 

268 if not subscription: 

269 return False 

270 

271 subscription.last_refresh = last_refresh 

272 subscription.next_refresh = next_refresh 

273 session.commit() 

274 return True 

275 

276 def increment_stats(self, subscription_id: str, results_count: int) -> bool: 

277 """Increment refresh count and update results count""" 

278 with self.session as session: 

279 subscription = ( 

280 session.query(NewsSubscription) 

281 .filter_by(id=subscription_id) 

282 .first() 

283 ) 

284 if not subscription: 

285 return False 

286 

287 subscription.refresh_count += 1 

288 subscription.total_runs = subscription.refresh_count # Keep in sync 

289 subscription.results_count = results_count 

290 session.commit() 

291 return True 

292 

293 def pause_subscription(self, subscription_id: str) -> bool: 

294 """Pause a subscription""" 

295 with self.session as session: 

296 subscription = ( 

297 session.query(NewsSubscription) 

298 .filter_by(id=subscription_id) 

299 .first() 

300 ) 

301 if not subscription: 

302 return False 

303 

304 subscription.status = SubscriptionStatus.PAUSED 

305 session.commit() 

306 return True 

307 

308 def resume_subscription(self, subscription_id: str) -> bool: 

309 """Resume a paused subscription""" 

310 with self.session as session: 

311 subscription = ( 

312 session.query(NewsSubscription) 

313 .filter_by(id=subscription_id) 

314 .first() 

315 ) 

316 if ( 

317 not subscription 

318 or subscription.status != SubscriptionStatus.PAUSED 

319 ): 

320 return False 

321 

322 subscription.status = SubscriptionStatus.ACTIVE 

323 # Reset next refresh time 

324 subscription.next_refresh = datetime.now(timezone.utc) + timedelta( 

325 minutes=subscription.refresh_interval_minutes 

326 ) 

327 session.commit() 

328 return True 

329 

330 def expire_subscription(self, subscription_id: str) -> bool: 

331 """Mark a subscription as expired""" 

332 with self.session as session: 

333 subscription = ( 

334 session.query(NewsSubscription) 

335 .filter_by(id=subscription_id) 

336 .first() 

337 ) 

338 if not subscription: 

339 return False 

340 

341 subscription.status = SubscriptionStatus.EXPIRED 

342 subscription.expires_at = datetime.now(timezone.utc) 

343 session.commit() 

344 return True