Coverage for src/local_deep_research/database/thread_local_session.py: 94%
151 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-03 23:15 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-03 23:15 +0000
1"""
2Thread-local database session management.
3Each thread gets its own database session that persists for the thread's lifetime.
4"""
6import functools
7import threading
8from contextlib import ContextDecorator
9from typing import Optional, Dict, Tuple
10from sqlalchemy import text
11from sqlalchemy.exc import PendingRollbackError
12from sqlalchemy.orm import Session
13from loguru import logger
15from .encrypted_db import db_manager
18class ThreadLocalSessionManager:
19 """
20 Manages database sessions per thread.
21 Each thread gets its own session that is reused throughout the thread's lifetime.
22 """
24 def __init__(self):
25 # Thread-local storage for sessions
26 self._local = threading.local()
27 # Track credentials per thread ID (for cleanup)
28 self._thread_credentials: Dict[int, Tuple[str, str]] = {}
29 self._lock = threading.Lock()
31 def get_session(self, username: str, password: str) -> Optional[Session]:
32 """
33 Get or create a database session for the current thread.
35 The session is created once per thread and reused for all subsequent calls.
36 This avoids the expensive SQLCipher decryption on every database access.
37 """
38 thread_id = threading.get_ident()
40 # Check if we already have a session for this thread
41 if hasattr(self._local, "session") and self._local.session:
42 # SECURITY: ensure cached session belongs to the requesting user
43 if getattr(self._local, "username", None) != username:
44 logger.warning(
45 f"Thread {thread_id}: Session username mismatch "
46 f"(cached={self._local.username!r}, requested={username!r}), "
47 "clearing stale cross-user session"
48 )
49 self._cleanup_thread_session()
50 else:
51 # Verify it's still valid
52 try:
53 self._local.session.execute(text("SELECT 1"))
54 # Under DEFERRED isolation the validation SELECT opens
55 # a transaction that holds a SHARED lock on SQLite
56 # until an explicit commit/rollback. A long-lived
57 # thread-local session reused across requests would
58 # keep that lock held and block the first writer.
59 # Roll it back so subsequent callers start fresh.
60 self._local.session.rollback()
61 return self._local.session
62 except PendingRollbackError:
63 # Session has a pending rollback (e.g. from a previous database lock error).
64 # Attempt rollback to recover without destroying the session.
65 logger.debug(
66 f"Thread {thread_id}: PendingRollbackError, attempting rollback recovery"
67 )
68 try:
69 self._local.session.rollback()
70 self._local.session.execute(text("SELECT 1"))
71 self._local.session.rollback()
72 return self._local.session
73 except Exception:
74 logger.warning(
75 f"Thread {thread_id}: Rollback recovery failed, creating new session"
76 )
77 self._cleanup_thread_session()
78 except Exception:
79 # Session is invalid, will create a new one
80 logger.debug(
81 f"Thread {thread_id}: Existing session invalid, creating new one"
82 )
83 self._cleanup_thread_session()
85 # Create new session for this thread
86 logger.debug(
87 f"Thread {thread_id}: Creating new database session for user {username}"
88 )
90 # Ensure database is open. open_user_database returns None for
91 # credential failures and raises DatabaseInitializationError when
92 # the schema can't be initialised; from a worker-thread caller
93 # both mean "no usable session right now" so collapse them.
94 from .encrypted_db import DatabaseInitializationError
96 try:
97 engine = db_manager.open_user_database(username, password)
98 except DatabaseInitializationError:
99 logger.exception(
100 f"Thread {thread_id}: database init failed for user {username}"
101 )
102 return None
103 if not engine:
104 logger.error(
105 f"Thread {thread_id}: Failed to open database for user {username}"
106 )
107 return None
109 # Create session for this thread
110 session = db_manager.create_thread_safe_session_for_metrics(
111 username, password
112 )
113 if not session: 113 ↛ 114line 113 didn't jump to line 114 because the condition on line 113 was never true
114 logger.error(
115 f"Thread {thread_id}: Failed to create session for user {username}"
116 )
117 return None
119 # Store in thread-local storage
120 self._local.session = session
121 self._local.username = username
123 # Track credentials for cleanup
124 with self._lock:
125 self._thread_credentials[thread_id] = (username, password)
127 return session
129 def get_current_session(self) -> Optional[Session]:
130 """Get the current thread's session if it exists."""
131 if hasattr(self._local, "session"):
132 return self._local.session
133 return None
135 def _cleanup_thread_session(self):
136 """Clean up the current thread's session.
138 Sessions are bound to the shared per-user QueuePool engine, so
139 closing the session returns its connection to the pool — there
140 are no per-thread engines to dispose.
141 """
142 thread_id = threading.get_ident()
144 if hasattr(self._local, "session") and self._local.session:
145 try:
146 self._local.session.rollback()
147 except Exception:
148 logger.warning(
149 f"Thread {thread_id}: Error rolling back session during cleanup"
150 )
151 try:
152 self._local.session.close()
153 logger.debug(f"Thread {thread_id}: Closed database session")
154 except Exception:
155 logger.warning(f"Thread {thread_id}: Error closing session")
156 finally:
157 self._local.session = None
159 if hasattr(self._local, "username"):
160 self._local.username = None
162 # Remove from tracking
163 with self._lock:
164 self._thread_credentials.pop(thread_id, None)
166 def cleanup_thread(self, thread_id: Optional[int] = None):
167 """
168 Clean up session for a specific thread or current thread.
169 Called when a thread is finishing.
170 """
171 if thread_id is None:
172 thread_id = threading.get_ident()
174 # If it's the current thread, we can clean up directly
175 if thread_id == threading.get_ident():
176 self._cleanup_thread_session()
177 else:
178 # For other threads, just remove from tracking
179 # The thread-local storage will be cleaned up when the thread ends
180 with self._lock:
181 self._thread_credentials.pop(thread_id, None)
183 def cleanup_dead_threads(self):
184 """Remove credential entries for threads that are no longer alive.
186 Handles the abnormal case of threads that died without triggering
187 their cleanup handler. Uses threading.enumerate() to identify
188 alive threads and removes credential entries for dead ones.
189 Sessions on dead threads are garbage-collected normally and
190 their connections return to the shared per-user QueuePool.
192 Called from:
193 - processor_v2.py: every ~60s in the queue loop
194 - app_factory.py: in teardown_appcontext
195 - connection_cleanup.py: in cleanup_idle_connections (every ~300s)
196 """
197 alive_ids = {t.ident for t in threading.enumerate()}
198 with self._lock:
199 dead_ids = [
200 tid for tid in self._thread_credentials if tid not in alive_ids
201 ]
202 for tid in dead_ids:
203 del self._thread_credentials[tid]
204 if dead_ids:
205 logger.debug(f"Swept {len(dead_ids)} dead thread credential(s)")
207 def cleanup_all(self):
208 """Clean up all tracked sessions (for shutdown)."""
209 with self._lock:
210 thread_ids = list(self._thread_credentials.keys())
212 for thread_id in thread_ids:
213 self.cleanup_thread(thread_id)
216# Global instance
217thread_session_manager = ThreadLocalSessionManager()
220def get_metrics_session(username: str, password: str) -> Optional[Session]:
221 """
222 Get a database session for metrics operations in the current thread.
223 The session is created once and reused for the thread's lifetime.
225 Note: This specifically uses create_thread_safe_session_for_metrics internally
226 and should only be used for metrics-related database operations.
227 """
228 return thread_session_manager.get_session(username, password)
231def get_current_thread_session() -> Optional[Session]:
232 """Get the current thread's session if it exists."""
233 return thread_session_manager.get_current_session()
236def cleanup_current_thread():
237 """Clean up the current thread's database session and cached credentials.
239 Also clears any passwords cached by ``metrics_writer`` on this thread —
240 pooled worker threads must not retain plaintext credentials across tasks.
241 """
242 thread_session_manager.cleanup_thread()
243 try:
244 from .thread_metrics import metrics_writer
246 metrics_writer.clear_passwords()
247 except Exception:
248 logger.debug(
249 "cleanup_current_thread: error clearing metrics_writer passwords",
250 exc_info=True,
251 )
254def cleanup_dead_threads():
255 """Sweep dead-thread credential entries from the session manager."""
256 try:
257 thread_session_manager.cleanup_dead_threads()
258 except Exception:
259 logger.warning("Dead-thread session sweep failed")
262class _ThreadCleanup(ContextDecorator):
263 """Context manager / decorator for thread-local resource cleanup."""
265 def __enter__(self):
266 return self
268 def __exit__(self, exc_type, exc_val, exc_tb):
269 try:
270 cleanup_current_thread()
271 except Exception:
272 logger.debug(
273 "thread_cleanup: error during DB session cleanup",
274 exc_info=True,
275 )
276 try:
277 from ..config.thread_settings import clear_settings_context
279 clear_settings_context()
280 except Exception:
281 logger.debug(
282 "thread_cleanup: error clearing settings context",
283 exc_info=True,
284 )
285 try:
286 from ..utilities.thread_context import clear_search_context
288 clear_search_context()
289 except Exception:
290 logger.debug(
291 "thread_cleanup: error clearing search context",
292 exc_info=True,
293 )
294 return False
297def thread_cleanup(func=None):
298 """Ensure all thread-local resources are cleaned up when a function or block exits.
300 Works as a bare decorator, a decorator factory, or a context manager::
302 @thread_cleanup
303 def worker(): ...
305 @thread_cleanup()
306 def worker(): ...
308 with thread_cleanup():
309 ...
311 executor.submit(thread_cleanup(func), arg)
312 """
313 if func is not None:
315 @functools.wraps(func)
316 def wrapper(*args, **kwargs):
317 with _ThreadCleanup():
318 return func(*args, **kwargs)
320 return wrapper
321 return _ThreadCleanup()
324# Context manager for automatic cleanup
325class ThreadSessionContext:
326 """
327 Context manager that ensures thread session is cleaned up.
328 Usage:
329 with ThreadSessionContext(username, password) as session:
330 # Use session
331 """
333 def __init__(self, username: str, password: str):
334 self.username = username
335 self.password = password
336 self.session = None
338 def __enter__(self) -> Optional[Session]:
339 self.session = get_metrics_session(self.username, self.password)
340 return self.session
342 def __exit__(self, exc_type, exc_val, exc_tb):
343 # Don't cleanup here - let the thread keep its session
344 # Only cleanup when thread ends
345 pass