Coverage for src/local_deep_research/database/thread_local_session.py: 94%

151 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-03 23:15 +0000

1""" 

2Thread-local database session management. 

3Each thread gets its own database session that persists for the thread's lifetime. 

4""" 

5 

6import functools 

7import threading 

8from contextlib import ContextDecorator 

9from typing import Optional, Dict, Tuple 

10from sqlalchemy import text 

11from sqlalchemy.exc import PendingRollbackError 

12from sqlalchemy.orm import Session 

13from loguru import logger 

14 

15from .encrypted_db import db_manager 

16 

17 

18class ThreadLocalSessionManager: 

19 """ 

20 Manages database sessions per thread. 

21 Each thread gets its own session that is reused throughout the thread's lifetime. 

22 """ 

23 

24 def __init__(self): 

25 # Thread-local storage for sessions 

26 self._local = threading.local() 

27 # Track credentials per thread ID (for cleanup) 

28 self._thread_credentials: Dict[int, Tuple[str, str]] = {} 

29 self._lock = threading.Lock() 

30 

31 def get_session(self, username: str, password: str) -> Optional[Session]: 

32 """ 

33 Get or create a database session for the current thread. 

34 

35 The session is created once per thread and reused for all subsequent calls. 

36 This avoids the expensive SQLCipher decryption on every database access. 

37 """ 

38 thread_id = threading.get_ident() 

39 

40 # Check if we already have a session for this thread 

41 if hasattr(self._local, "session") and self._local.session: 

42 # SECURITY: ensure cached session belongs to the requesting user 

43 if getattr(self._local, "username", None) != username: 

44 logger.warning( 

45 f"Thread {thread_id}: Session username mismatch " 

46 f"(cached={self._local.username!r}, requested={username!r}), " 

47 "clearing stale cross-user session" 

48 ) 

49 self._cleanup_thread_session() 

50 else: 

51 # Verify it's still valid 

52 try: 

53 self._local.session.execute(text("SELECT 1")) 

54 # Under DEFERRED isolation the validation SELECT opens 

55 # a transaction that holds a SHARED lock on SQLite 

56 # until an explicit commit/rollback. A long-lived 

57 # thread-local session reused across requests would 

58 # keep that lock held and block the first writer. 

59 # Roll it back so subsequent callers start fresh. 

60 self._local.session.rollback() 

61 return self._local.session 

62 except PendingRollbackError: 

63 # Session has a pending rollback (e.g. from a previous database lock error). 

64 # Attempt rollback to recover without destroying the session. 

65 logger.debug( 

66 f"Thread {thread_id}: PendingRollbackError, attempting rollback recovery" 

67 ) 

68 try: 

69 self._local.session.rollback() 

70 self._local.session.execute(text("SELECT 1")) 

71 self._local.session.rollback() 

72 return self._local.session 

73 except Exception: 

74 logger.warning( 

75 f"Thread {thread_id}: Rollback recovery failed, creating new session" 

76 ) 

77 self._cleanup_thread_session() 

78 except Exception: 

79 # Session is invalid, will create a new one 

80 logger.debug( 

81 f"Thread {thread_id}: Existing session invalid, creating new one" 

82 ) 

83 self._cleanup_thread_session() 

84 

85 # Create new session for this thread 

86 logger.debug( 

87 f"Thread {thread_id}: Creating new database session for user {username}" 

88 ) 

89 

90 # Ensure database is open. open_user_database returns None for 

91 # credential failures and raises DatabaseInitializationError when 

92 # the schema can't be initialised; from a worker-thread caller 

93 # both mean "no usable session right now" so collapse them. 

94 from .encrypted_db import DatabaseInitializationError 

95 

96 try: 

97 engine = db_manager.open_user_database(username, password) 

98 except DatabaseInitializationError: 

99 logger.exception( 

100 f"Thread {thread_id}: database init failed for user {username}" 

101 ) 

102 return None 

103 if not engine: 

104 logger.error( 

105 f"Thread {thread_id}: Failed to open database for user {username}" 

106 ) 

107 return None 

108 

109 # Create session for this thread 

110 session = db_manager.create_thread_safe_session_for_metrics( 

111 username, password 

112 ) 

113 if not session: 113 ↛ 114line 113 didn't jump to line 114 because the condition on line 113 was never true

114 logger.error( 

115 f"Thread {thread_id}: Failed to create session for user {username}" 

116 ) 

117 return None 

118 

119 # Store in thread-local storage 

120 self._local.session = session 

121 self._local.username = username 

122 

123 # Track credentials for cleanup 

124 with self._lock: 

125 self._thread_credentials[thread_id] = (username, password) 

126 

127 return session 

128 

129 def get_current_session(self) -> Optional[Session]: 

130 """Get the current thread's session if it exists.""" 

131 if hasattr(self._local, "session"): 

132 return self._local.session 

133 return None 

134 

135 def _cleanup_thread_session(self): 

136 """Clean up the current thread's session. 

137 

138 Sessions are bound to the shared per-user QueuePool engine, so 

139 closing the session returns its connection to the pool — there 

140 are no per-thread engines to dispose. 

141 """ 

142 thread_id = threading.get_ident() 

143 

144 if hasattr(self._local, "session") and self._local.session: 

145 try: 

146 self._local.session.rollback() 

147 except Exception: 

148 logger.warning( 

149 f"Thread {thread_id}: Error rolling back session during cleanup" 

150 ) 

151 try: 

152 self._local.session.close() 

153 logger.debug(f"Thread {thread_id}: Closed database session") 

154 except Exception: 

155 logger.warning(f"Thread {thread_id}: Error closing session") 

156 finally: 

157 self._local.session = None 

158 

159 if hasattr(self._local, "username"): 

160 self._local.username = None 

161 

162 # Remove from tracking 

163 with self._lock: 

164 self._thread_credentials.pop(thread_id, None) 

165 

166 def cleanup_thread(self, thread_id: Optional[int] = None): 

167 """ 

168 Clean up session for a specific thread or current thread. 

169 Called when a thread is finishing. 

170 """ 

171 if thread_id is None: 

172 thread_id = threading.get_ident() 

173 

174 # If it's the current thread, we can clean up directly 

175 if thread_id == threading.get_ident(): 

176 self._cleanup_thread_session() 

177 else: 

178 # For other threads, just remove from tracking 

179 # The thread-local storage will be cleaned up when the thread ends 

180 with self._lock: 

181 self._thread_credentials.pop(thread_id, None) 

182 

183 def cleanup_dead_threads(self): 

184 """Remove credential entries for threads that are no longer alive. 

185 

186 Handles the abnormal case of threads that died without triggering 

187 their cleanup handler. Uses threading.enumerate() to identify 

188 alive threads and removes credential entries for dead ones. 

189 Sessions on dead threads are garbage-collected normally and 

190 their connections return to the shared per-user QueuePool. 

191 

192 Called from: 

193 - processor_v2.py: every ~60s in the queue loop 

194 - app_factory.py: in teardown_appcontext 

195 - connection_cleanup.py: in cleanup_idle_connections (every ~300s) 

196 """ 

197 alive_ids = {t.ident for t in threading.enumerate()} 

198 with self._lock: 

199 dead_ids = [ 

200 tid for tid in self._thread_credentials if tid not in alive_ids 

201 ] 

202 for tid in dead_ids: 

203 del self._thread_credentials[tid] 

204 if dead_ids: 

205 logger.debug(f"Swept {len(dead_ids)} dead thread credential(s)") 

206 

207 def cleanup_all(self): 

208 """Clean up all tracked sessions (for shutdown).""" 

209 with self._lock: 

210 thread_ids = list(self._thread_credentials.keys()) 

211 

212 for thread_id in thread_ids: 

213 self.cleanup_thread(thread_id) 

214 

215 

216# Global instance 

217thread_session_manager = ThreadLocalSessionManager() 

218 

219 

220def get_metrics_session(username: str, password: str) -> Optional[Session]: 

221 """ 

222 Get a database session for metrics operations in the current thread. 

223 The session is created once and reused for the thread's lifetime. 

224 

225 Note: This specifically uses create_thread_safe_session_for_metrics internally 

226 and should only be used for metrics-related database operations. 

227 """ 

228 return thread_session_manager.get_session(username, password) 

229 

230 

231def get_current_thread_session() -> Optional[Session]: 

232 """Get the current thread's session if it exists.""" 

233 return thread_session_manager.get_current_session() 

234 

235 

236def cleanup_current_thread(): 

237 """Clean up the current thread's database session and cached credentials. 

238 

239 Also clears any passwords cached by ``metrics_writer`` on this thread — 

240 pooled worker threads must not retain plaintext credentials across tasks. 

241 """ 

242 thread_session_manager.cleanup_thread() 

243 try: 

244 from .thread_metrics import metrics_writer 

245 

246 metrics_writer.clear_passwords() 

247 except Exception: 

248 logger.debug( 

249 "cleanup_current_thread: error clearing metrics_writer passwords", 

250 exc_info=True, 

251 ) 

252 

253 

254def cleanup_dead_threads(): 

255 """Sweep dead-thread credential entries from the session manager.""" 

256 try: 

257 thread_session_manager.cleanup_dead_threads() 

258 except Exception: 

259 logger.warning("Dead-thread session sweep failed") 

260 

261 

262class _ThreadCleanup(ContextDecorator): 

263 """Context manager / decorator for thread-local resource cleanup.""" 

264 

265 def __enter__(self): 

266 return self 

267 

268 def __exit__(self, exc_type, exc_val, exc_tb): 

269 try: 

270 cleanup_current_thread() 

271 except Exception: 

272 logger.debug( 

273 "thread_cleanup: error during DB session cleanup", 

274 exc_info=True, 

275 ) 

276 try: 

277 from ..config.thread_settings import clear_settings_context 

278 

279 clear_settings_context() 

280 except Exception: 

281 logger.debug( 

282 "thread_cleanup: error clearing settings context", 

283 exc_info=True, 

284 ) 

285 try: 

286 from ..utilities.thread_context import clear_search_context 

287 

288 clear_search_context() 

289 except Exception: 

290 logger.debug( 

291 "thread_cleanup: error clearing search context", 

292 exc_info=True, 

293 ) 

294 return False 

295 

296 

297def thread_cleanup(func=None): 

298 """Ensure all thread-local resources are cleaned up when a function or block exits. 

299 

300 Works as a bare decorator, a decorator factory, or a context manager:: 

301 

302 @thread_cleanup 

303 def worker(): ... 

304 

305 @thread_cleanup() 

306 def worker(): ... 

307 

308 with thread_cleanup(): 

309 ... 

310 

311 executor.submit(thread_cleanup(func), arg) 

312 """ 

313 if func is not None: 

314 

315 @functools.wraps(func) 

316 def wrapper(*args, **kwargs): 

317 with _ThreadCleanup(): 

318 return func(*args, **kwargs) 

319 

320 return wrapper 

321 return _ThreadCleanup() 

322 

323 

324# Context manager for automatic cleanup 

325class ThreadSessionContext: 

326 """ 

327 Context manager that ensures thread session is cleaned up. 

328 Usage: 

329 with ThreadSessionContext(username, password) as session: 

330 # Use session 

331 """ 

332 

333 def __init__(self, username: str, password: str): 

334 self.username = username 

335 self.password = password 

336 self.session = None 

337 

338 def __enter__(self) -> Optional[Session]: 

339 self.session = get_metrics_session(self.username, self.password) 

340 return self.session 

341 

342 def __exit__(self, exc_type, exc_val, exc_tb): 

343 # Don't cleanup here - let the thread keep its session 

344 # Only cleanup when thread ends 

345 pass