Coverage for src / local_deep_research / database / thread_local_session.py: 97%

139 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 23:55 +0000

1""" 

2Thread-local database session management. 

3Each thread gets its own database session that persists for the thread's lifetime. 

4""" 

5 

6import functools 

7import threading 

8from contextlib import ContextDecorator 

9from typing import Optional, Dict, Tuple 

10from sqlalchemy import text 

11from sqlalchemy.exc import PendingRollbackError 

12from sqlalchemy.orm import Session 

13from loguru import logger 

14 

15from .encrypted_db import db_manager 

16 

17 

18class ThreadLocalSessionManager: 

19 """ 

20 Manages database sessions per thread. 

21 Each thread gets its own session that is reused throughout the thread's lifetime. 

22 """ 

23 

24 def __init__(self): 

25 # Thread-local storage for sessions 

26 self._local = threading.local() 

27 # Track credentials per thread ID (for cleanup) 

28 self._thread_credentials: Dict[int, Tuple[str, str]] = {} 

29 self._lock = threading.Lock() 

30 

31 def get_session(self, username: str, password: str) -> Optional[Session]: 

32 """ 

33 Get or create a database session for the current thread. 

34 

35 The session is created once per thread and reused for all subsequent calls. 

36 This avoids the expensive SQLCipher decryption on every database access. 

37 """ 

38 thread_id = threading.get_ident() 

39 

40 # Check if we already have a session for this thread 

41 if hasattr(self._local, "session") and self._local.session: 

42 # SECURITY: ensure cached session belongs to the requesting user 

43 if getattr(self._local, "username", None) != username: 

44 logger.warning( 

45 f"Thread {thread_id}: Session username mismatch " 

46 f"(cached={self._local.username!r}, requested={username!r}), " 

47 "clearing stale cross-user session" 

48 ) 

49 self._cleanup_thread_session() 

50 else: 

51 # Verify it's still valid 

52 try: 

53 self._local.session.execute(text("SELECT 1")) 

54 return self._local.session 

55 except PendingRollbackError: 

56 # Session has a pending rollback (e.g. from a previous database lock error). 

57 # Attempt rollback to recover without destroying the session. 

58 logger.debug( 

59 f"Thread {thread_id}: PendingRollbackError, attempting rollback recovery" 

60 ) 

61 try: 

62 self._local.session.rollback() 

63 self._local.session.execute(text("SELECT 1")) 

64 return self._local.session 

65 except Exception: 

66 logger.warning( 

67 f"Thread {thread_id}: Rollback recovery failed, creating new session" 

68 ) 

69 self._cleanup_thread_session() 

70 except Exception: 

71 # Session is invalid, will create a new one 

72 logger.debug( 

73 f"Thread {thread_id}: Existing session invalid, creating new one" 

74 ) 

75 self._cleanup_thread_session() 

76 

77 # Create new session for this thread 

78 logger.debug( 

79 f"Thread {thread_id}: Creating new database session for user {username}" 

80 ) 

81 

82 # Ensure database is open 

83 engine = db_manager.open_user_database(username, password) 

84 if not engine: 

85 logger.error( 

86 f"Thread {thread_id}: Failed to open database for user {username}" 

87 ) 

88 return None 

89 

90 # Create session for this thread 

91 session = db_manager.create_thread_safe_session_for_metrics( 

92 username, password 

93 ) 

94 if not session: 94 ↛ 95line 94 didn't jump to line 95 because the condition on line 94 was never true

95 logger.error( 

96 f"Thread {thread_id}: Failed to create session for user {username}" 

97 ) 

98 return None 

99 

100 # Store in thread-local storage 

101 self._local.session = session 

102 self._local.username = username 

103 

104 # Track credentials for cleanup 

105 with self._lock: 

106 self._thread_credentials[thread_id] = (username, password) 

107 

108 return session 

109 

110 def get_current_session(self) -> Optional[Session]: 

111 """Get the current thread's session if it exists.""" 

112 if hasattr(self._local, "session"): 

113 return self._local.session 

114 return None 

115 

116 def _cleanup_thread_session(self): 

117 """Clean up the current thread's session. 

118 

119 Sessions are bound to the shared per-user QueuePool engine, so 

120 closing the session returns its connection to the pool — there 

121 are no per-thread engines to dispose. 

122 """ 

123 thread_id = threading.get_ident() 

124 

125 if hasattr(self._local, "session") and self._local.session: 

126 try: 

127 self._local.session.rollback() 

128 except Exception: 

129 logger.warning( 

130 f"Thread {thread_id}: Error rolling back session during cleanup" 

131 ) 

132 try: 

133 self._local.session.close() 

134 logger.debug(f"Thread {thread_id}: Closed database session") 

135 except Exception: 

136 logger.warning(f"Thread {thread_id}: Error closing session") 

137 finally: 

138 self._local.session = None 

139 

140 if hasattr(self._local, "username"): 

141 self._local.username = None 

142 

143 # Remove from tracking 

144 with self._lock: 

145 self._thread_credentials.pop(thread_id, None) 

146 

147 def cleanup_thread(self, thread_id: Optional[int] = None): 

148 """ 

149 Clean up session for a specific thread or current thread. 

150 Called when a thread is finishing. 

151 """ 

152 if thread_id is None: 

153 thread_id = threading.get_ident() 

154 

155 # If it's the current thread, we can clean up directly 

156 if thread_id == threading.get_ident(): 

157 self._cleanup_thread_session() 

158 else: 

159 # For other threads, just remove from tracking 

160 # The thread-local storage will be cleaned up when the thread ends 

161 with self._lock: 

162 self._thread_credentials.pop(thread_id, None) 

163 

164 def cleanup_dead_threads(self): 

165 """Remove credential entries for threads that are no longer alive. 

166 

167 Handles the abnormal case of threads that died without triggering 

168 their cleanup handler. Uses threading.enumerate() to identify 

169 alive threads and removes credential entries for dead ones. 

170 Sessions on dead threads are garbage-collected normally and 

171 their connections return to the shared per-user QueuePool. 

172 

173 Called from: 

174 - processor_v2.py: every ~60s in the queue loop 

175 - app_factory.py: in teardown_appcontext 

176 - connection_cleanup.py: in cleanup_idle_connections (every ~300s) 

177 """ 

178 alive_ids = {t.ident for t in threading.enumerate()} 

179 with self._lock: 

180 dead_ids = [ 

181 tid for tid in self._thread_credentials if tid not in alive_ids 

182 ] 

183 for tid in dead_ids: 

184 del self._thread_credentials[tid] 

185 if dead_ids: 

186 logger.debug(f"Swept {len(dead_ids)} dead thread credential(s)") 

187 

188 def cleanup_all(self): 

189 """Clean up all tracked sessions (for shutdown).""" 

190 with self._lock: 

191 thread_ids = list(self._thread_credentials.keys()) 

192 

193 for thread_id in thread_ids: 

194 self.cleanup_thread(thread_id) 

195 

196 

197# Global instance 

198thread_session_manager = ThreadLocalSessionManager() 

199 

200 

201def get_metrics_session(username: str, password: str) -> Optional[Session]: 

202 """ 

203 Get a database session for metrics operations in the current thread. 

204 The session is created once and reused for the thread's lifetime. 

205 

206 Note: This specifically uses create_thread_safe_session_for_metrics internally 

207 and should only be used for metrics-related database operations. 

208 """ 

209 return thread_session_manager.get_session(username, password) 

210 

211 

212def get_current_thread_session() -> Optional[Session]: 

213 """Get the current thread's session if it exists.""" 

214 return thread_session_manager.get_current_session() 

215 

216 

217def cleanup_current_thread(): 

218 """Clean up the current thread's database session.""" 

219 thread_session_manager.cleanup_thread() 

220 

221 

222def cleanup_dead_threads(): 

223 """Sweep dead-thread credential entries from the session manager.""" 

224 try: 

225 thread_session_manager.cleanup_dead_threads() 

226 except Exception: 

227 logger.warning("Dead-thread session sweep failed") 

228 

229 

230class _ThreadCleanup(ContextDecorator): 

231 """Context manager / decorator for thread-local resource cleanup.""" 

232 

233 def __enter__(self): 

234 return self 

235 

236 def __exit__(self, exc_type, exc_val, exc_tb): 

237 try: 

238 cleanup_current_thread() 

239 except Exception: 

240 logger.debug( 

241 "thread_cleanup: error during DB session cleanup", 

242 exc_info=True, 

243 ) 

244 try: 

245 from ..config.thread_settings import clear_settings_context 

246 

247 clear_settings_context() 

248 except Exception: 

249 logger.debug( 

250 "thread_cleanup: error clearing settings context", 

251 exc_info=True, 

252 ) 

253 try: 

254 from ..utilities.thread_context import clear_search_context 

255 

256 clear_search_context() 

257 except Exception: 

258 logger.debug( 

259 "thread_cleanup: error clearing search context", 

260 exc_info=True, 

261 ) 

262 return False 

263 

264 

265def thread_cleanup(func=None): 

266 """Ensure all thread-local resources are cleaned up when a function or block exits. 

267 

268 Works as a bare decorator, a decorator factory, or a context manager:: 

269 

270 @thread_cleanup 

271 def worker(): ... 

272 

273 @thread_cleanup() 

274 def worker(): ... 

275 

276 with thread_cleanup(): 

277 ... 

278 

279 executor.submit(thread_cleanup(func), arg) 

280 """ 

281 if func is not None: 

282 

283 @functools.wraps(func) 

284 def wrapper(*args, **kwargs): 

285 with _ThreadCleanup(): 

286 return func(*args, **kwargs) 

287 

288 return wrapper 

289 return _ThreadCleanup() 

290 

291 

292# Context manager for automatic cleanup 

293class ThreadSessionContext: 

294 """ 

295 Context manager that ensures thread session is cleaned up. 

296 Usage: 

297 with ThreadSessionContext(username, password) as session: 

298 # Use session 

299 """ 

300 

301 def __init__(self, username: str, password: str): 

302 self.username = username 

303 self.password = password 

304 self.session = None 

305 

306 def __enter__(self) -> Optional[Session]: 

307 self.session = get_metrics_session(self.username, self.password) 

308 return self.session 

309 

310 def __exit__(self, exc_type, exc_val, exc_tb): 

311 # Don't cleanup here - let the thread keep its session 

312 # Only cleanup when thread ends 

313 pass