Coverage for src / local_deep_research / news / subscription_manager / storage.py: 97%

120 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 23:55 +0000

1""" 

2SQLAlchemy storage implementation for subscriptions. 

3""" 

4 

5from typing import List, Optional, Dict, Any 

6from datetime import datetime, timedelta, timezone 

7from sqlalchemy.orm import Session 

8from loguru import logger 

9 

10from ..core.storage import SubscriptionStorage 

11from ...database.models.news import ( 

12 NewsSubscription, 

13 SubscriptionType, 

14 SubscriptionStatus, 

15) 

16 

17 

18class SQLSubscriptionStorage(SubscriptionStorage): 

19 """SQLAlchemy implementation of subscription storage""" 

20 

21 def __init__(self, session: Optional[Session] = None): 

22 """Initialize with a database session from the user's encrypted database""" 

23 self._session = session 

24 

25 @property 

26 def session(self): 

27 """Get database session""" 

28 return self._session 

29 

30 def create(self, data: Dict[str, Any]) -> str: 

31 """Create a new subscription""" 

32 subscription_id = data.get("id") or self.generate_id() 

33 

34 with self.session as session: 

35 subscription = NewsSubscription( 

36 id=subscription_id, 

37 user_id=data["user_id"], 

38 name=data.get("name"), 

39 subscription_type=data["subscription_type"], 

40 query_or_topic=data["query_or_topic"], 

41 refresh_interval_minutes=data["refresh_interval_minutes"], 

42 frequency=data.get("frequency", "daily"), 

43 source_type=data.get("source_type"), 

44 source_id=data.get("source_id"), 

45 created_from=data.get("created_from"), 

46 folder=data.get("folder"), 

47 folder_id=data.get("folder_id"), 

48 notes=data.get("notes"), 

49 status=data.get("status", "active"), 

50 is_active=data.get("is_active", True), 

51 model_provider=data.get("model_provider"), 

52 model=data.get("model"), 

53 search_strategy=data.get("search_strategy"), 

54 custom_endpoint=data.get("custom_endpoint"), 

55 search_engine=data.get("search_engine"), 

56 search_iterations=data.get("search_iterations", 3), 

57 questions_per_iteration=data.get("questions_per_iteration", 5), 

58 next_refresh=datetime.now(timezone.utc) 

59 + timedelta(minutes=data["refresh_interval_minutes"]), 

60 ) 

61 

62 session.add(subscription) 

63 session.commit() 

64 

65 logger.info( 

66 f"Created subscription {subscription_id} for user {data['user_id']}" 

67 ) 

68 return subscription_id 

69 

70 def get(self, id: str) -> Optional[Dict[str, Any]]: 

71 """Get a subscription by ID""" 

72 with self.session as session: 

73 subscription = ( 

74 session.query(NewsSubscription).filter_by(id=id).first() 

75 ) 

76 if not subscription: 

77 return None 

78 

79 # Convert to dict manually 

80 return { 

81 "id": subscription.id, 

82 "user_id": subscription.user_id, 

83 "name": subscription.name, 

84 "subscription_type": subscription.subscription_type, 

85 "query_or_topic": subscription.query_or_topic, 

86 "refresh_interval_minutes": subscription.refresh_interval_minutes, 

87 "created_at": subscription.created_at, 

88 "updated_at": subscription.updated_at, 

89 "last_refresh": subscription.last_refresh, 

90 "next_refresh": subscription.next_refresh, 

91 "expires_at": subscription.expires_at, 

92 "source_type": subscription.source_type, 

93 "source_id": subscription.source_id, 

94 "created_from": subscription.created_from, 

95 "folder": subscription.folder, 

96 "folder_id": subscription.folder_id, 

97 "notes": subscription.notes, 

98 "status": subscription.status, 

99 "is_active": getattr(subscription, "is_active", True), 

100 "refresh_count": subscription.refresh_count, 

101 "results_count": subscription.results_count, 

102 "last_error": subscription.last_error, 

103 "error_count": subscription.error_count, 

104 "model_provider": getattr(subscription, "model_provider", None), 

105 "model": getattr(subscription, "model", None), 

106 "search_strategy": getattr( 

107 subscription, "search_strategy", None 

108 ), 

109 "custom_endpoint": getattr( 

110 subscription, "custom_endpoint", None 

111 ), 

112 "search_engine": getattr(subscription, "search_engine", None), 

113 "search_iterations": getattr( 

114 subscription, "search_iterations", 3 

115 ), 

116 "questions_per_iteration": getattr( 

117 subscription, "questions_per_iteration", 5 

118 ), 

119 } 

120 

121 def update(self, id: str, data: Dict[str, Any]) -> bool: 

122 """Update a subscription""" 

123 with self.session as session: 

124 subscription = ( 

125 session.query(NewsSubscription).filter_by(id=id).first() 

126 ) 

127 if not subscription: 

128 return False 

129 

130 # Update allowed fields 

131 updateable_fields = [ 

132 "name", 

133 "refresh_interval_minutes", 

134 "status", 

135 "is_active", 

136 "expires_at", 

137 "folder_id", 

138 "model_provider", 

139 "model", 

140 "search_strategy", 

141 "custom_endpoint", 

142 "search_engine", 

143 "search_iterations", 

144 "questions_per_iteration", 

145 ] 

146 for field in updateable_fields: 

147 if field in data: 

148 # bearer:disable python_lang_code_injection 

149 setattr(subscription, field, data[field]) 

150 

151 # Recalculate next refresh if interval changed 

152 if "refresh_interval_minutes" in data: 

153 subscription.next_refresh = datetime.now( 

154 timezone.utc 

155 ) + timedelta(minutes=data["refresh_interval_minutes"]) 

156 

157 session.commit() 

158 return True 

159 

160 def delete(self, id: str) -> bool: 

161 """Delete a subscription""" 

162 with self.session as session: 

163 subscription = ( 

164 session.query(NewsSubscription).filter_by(id=id).first() 

165 ) 

166 if not subscription: 

167 return False 

168 

169 session.delete(subscription) 

170 session.commit() 

171 return True 

172 

173 def list( 

174 self, 

175 filters: Optional[Dict[str, Any]] = None, 

176 limit: int = 100, 

177 offset: int = 0, 

178 ) -> List[Dict[str, Any]]: 

179 """List subscriptions with optional filtering""" 

180 with self.session as session: 

181 query = session.query(NewsSubscription) 

182 

183 if filters: 

184 if "user_id" in filters: 184 ↛ 186line 184 didn't jump to line 186 because the condition on line 184 was always true

185 query = query.filter_by(user_id=filters["user_id"]) 

186 if "status" in filters: 186 ↛ 187line 186 didn't jump to line 187 because the condition on line 186 was never true

187 query = query.filter_by( 

188 status=SubscriptionStatus(filters["status"]) 

189 ) 

190 if "subscription_type" in filters: 190 ↛ 191line 190 didn't jump to line 191 because the condition on line 190 was never true

191 query = query.filter_by( 

192 subscription_type=SubscriptionType( 

193 filters["subscription_type"] 

194 ) 

195 ) 

196 

197 subscriptions = query.limit(limit).offset(offset).all() 

198 # Detach from session and convert to dicts 

199 result = [] 

200 for sub in subscriptions: 

201 session.expunge(sub) 

202 result.append( 

203 { 

204 "id": sub.id, 

205 "user_id": sub.user_id, 

206 "name": sub.name, 

207 "subscription_type": sub.subscription_type, 

208 "query_or_topic": sub.query_or_topic, 

209 "refresh_interval_minutes": sub.refresh_interval_minutes, 

210 "created_at": sub.created_at, 

211 "updated_at": sub.updated_at, 

212 "last_refresh": sub.last_refresh, 

213 "next_refresh": sub.next_refresh, 

214 "status": sub.status, 

215 "folder": sub.folder, 

216 "notes": sub.notes, 

217 } 

218 ) 

219 return result 

220 

221 def get_active_subscriptions( 

222 self, user_id: Optional[str] = None 

223 ) -> List[Dict[str, Any]]: 

224 """Get all active subscriptions""" 

225 with self.session as session: 

226 query = session.query(NewsSubscription).filter_by( 

227 status=SubscriptionStatus.ACTIVE 

228 ) 

229 

230 if user_id: 

231 query = query.filter_by(user_id=user_id) 

232 

233 subscriptions = query.all() 

234 return [sub.to_dict() for sub in subscriptions] 

235 

236 def get_due_subscriptions(self, limit: int = 100) -> List[Dict[str, Any]]: 

237 """Get subscriptions that are due for refresh""" 

238 with self.session as session: 

239 now = datetime.now(timezone.utc) 

240 

241 subscriptions = ( 

242 session.query(NewsSubscription) 

243 .filter( 

244 NewsSubscription.status == SubscriptionStatus.ACTIVE, 

245 NewsSubscription.next_refresh <= now, 

246 ) 

247 .limit(limit) 

248 .all() 

249 ) 

250 

251 return [sub.to_dict() for sub in subscriptions] 

252 

253 def update_refresh_time( 

254 self, 

255 subscription_id: str, 

256 last_refresh: datetime, 

257 next_refresh: datetime, 

258 ) -> bool: 

259 """Update refresh timestamps after processing""" 

260 with self.session as session: 

261 subscription = ( 

262 session.query(NewsSubscription) 

263 .filter_by(id=subscription_id) 

264 .first() 

265 ) 

266 if not subscription: 

267 return False 

268 

269 subscription.last_refresh = last_refresh 

270 subscription.next_refresh = next_refresh 

271 session.commit() 

272 return True 

273 

274 def increment_stats(self, subscription_id: str, results_count: int) -> bool: 

275 """Increment refresh count and update results count""" 

276 with self.session as session: 

277 subscription = ( 

278 session.query(NewsSubscription) 

279 .filter_by(id=subscription_id) 

280 .first() 

281 ) 

282 if not subscription: 

283 return False 

284 

285 subscription.refresh_count += 1 

286 subscription.total_runs = subscription.refresh_count # Keep in sync 

287 subscription.results_count = results_count 

288 session.commit() 

289 return True 

290 

291 def pause_subscription(self, subscription_id: str) -> bool: 

292 """Pause a subscription""" 

293 with self.session as session: 

294 subscription = ( 

295 session.query(NewsSubscription) 

296 .filter_by(id=subscription_id) 

297 .first() 

298 ) 

299 if not subscription: 

300 return False 

301 

302 subscription.status = SubscriptionStatus.PAUSED 

303 session.commit() 

304 return True 

305 

306 def resume_subscription(self, subscription_id: str) -> bool: 

307 """Resume a paused subscription""" 

308 with self.session as session: 

309 subscription = ( 

310 session.query(NewsSubscription) 

311 .filter_by(id=subscription_id) 

312 .first() 

313 ) 

314 if ( 

315 not subscription 

316 or subscription.status != SubscriptionStatus.PAUSED 

317 ): 

318 return False 

319 

320 subscription.status = SubscriptionStatus.ACTIVE 

321 # Reset next refresh time 

322 subscription.next_refresh = datetime.now(timezone.utc) + timedelta( 

323 minutes=subscription.refresh_interval_minutes 

324 ) 

325 session.commit() 

326 return True 

327 

328 def expire_subscription(self, subscription_id: str) -> bool: 

329 """Mark a subscription as expired""" 

330 with self.session as session: 

331 subscription = ( 

332 session.query(NewsSubscription) 

333 .filter_by(id=subscription_id) 

334 .first() 

335 ) 

336 if not subscription: 

337 return False 

338 

339 subscription.status = SubscriptionStatus.EXPIRED 

340 subscription.expires_at = datetime.now(timezone.utc) 

341 session.commit() 

342 return True