Coverage for src / local_deep_research / database / thread_local_session.py: 97%
139 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:55 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:55 +0000
1"""
2Thread-local database session management.
3Each thread gets its own database session that persists for the thread's lifetime.
4"""
6import functools
7import threading
8from contextlib import ContextDecorator
9from typing import Optional, Dict, Tuple
10from sqlalchemy import text
11from sqlalchemy.exc import PendingRollbackError
12from sqlalchemy.orm import Session
13from loguru import logger
15from .encrypted_db import db_manager
18class ThreadLocalSessionManager:
19 """
20 Manages database sessions per thread.
21 Each thread gets its own session that is reused throughout the thread's lifetime.
22 """
24 def __init__(self):
25 # Thread-local storage for sessions
26 self._local = threading.local()
27 # Track credentials per thread ID (for cleanup)
28 self._thread_credentials: Dict[int, Tuple[str, str]] = {}
29 self._lock = threading.Lock()
31 def get_session(self, username: str, password: str) -> Optional[Session]:
32 """
33 Get or create a database session for the current thread.
35 The session is created once per thread and reused for all subsequent calls.
36 This avoids the expensive SQLCipher decryption on every database access.
37 """
38 thread_id = threading.get_ident()
40 # Check if we already have a session for this thread
41 if hasattr(self._local, "session") and self._local.session:
42 # SECURITY: ensure cached session belongs to the requesting user
43 if getattr(self._local, "username", None) != username:
44 logger.warning(
45 f"Thread {thread_id}: Session username mismatch "
46 f"(cached={self._local.username!r}, requested={username!r}), "
47 "clearing stale cross-user session"
48 )
49 self._cleanup_thread_session()
50 else:
51 # Verify it's still valid
52 try:
53 self._local.session.execute(text("SELECT 1"))
54 return self._local.session
55 except PendingRollbackError:
56 # Session has a pending rollback (e.g. from a previous database lock error).
57 # Attempt rollback to recover without destroying the session.
58 logger.debug(
59 f"Thread {thread_id}: PendingRollbackError, attempting rollback recovery"
60 )
61 try:
62 self._local.session.rollback()
63 self._local.session.execute(text("SELECT 1"))
64 return self._local.session
65 except Exception:
66 logger.warning(
67 f"Thread {thread_id}: Rollback recovery failed, creating new session"
68 )
69 self._cleanup_thread_session()
70 except Exception:
71 # Session is invalid, will create a new one
72 logger.debug(
73 f"Thread {thread_id}: Existing session invalid, creating new one"
74 )
75 self._cleanup_thread_session()
77 # Create new session for this thread
78 logger.debug(
79 f"Thread {thread_id}: Creating new database session for user {username}"
80 )
82 # Ensure database is open
83 engine = db_manager.open_user_database(username, password)
84 if not engine:
85 logger.error(
86 f"Thread {thread_id}: Failed to open database for user {username}"
87 )
88 return None
90 # Create session for this thread
91 session = db_manager.create_thread_safe_session_for_metrics(
92 username, password
93 )
94 if not session: 94 ↛ 95line 94 didn't jump to line 95 because the condition on line 94 was never true
95 logger.error(
96 f"Thread {thread_id}: Failed to create session for user {username}"
97 )
98 return None
100 # Store in thread-local storage
101 self._local.session = session
102 self._local.username = username
104 # Track credentials for cleanup
105 with self._lock:
106 self._thread_credentials[thread_id] = (username, password)
108 return session
110 def get_current_session(self) -> Optional[Session]:
111 """Get the current thread's session if it exists."""
112 if hasattr(self._local, "session"):
113 return self._local.session
114 return None
116 def _cleanup_thread_session(self):
117 """Clean up the current thread's session.
119 Sessions are bound to the shared per-user QueuePool engine, so
120 closing the session returns its connection to the pool — there
121 are no per-thread engines to dispose.
122 """
123 thread_id = threading.get_ident()
125 if hasattr(self._local, "session") and self._local.session:
126 try:
127 self._local.session.rollback()
128 except Exception:
129 logger.warning(
130 f"Thread {thread_id}: Error rolling back session during cleanup"
131 )
132 try:
133 self._local.session.close()
134 logger.debug(f"Thread {thread_id}: Closed database session")
135 except Exception:
136 logger.warning(f"Thread {thread_id}: Error closing session")
137 finally:
138 self._local.session = None
140 if hasattr(self._local, "username"):
141 self._local.username = None
143 # Remove from tracking
144 with self._lock:
145 self._thread_credentials.pop(thread_id, None)
147 def cleanup_thread(self, thread_id: Optional[int] = None):
148 """
149 Clean up session for a specific thread or current thread.
150 Called when a thread is finishing.
151 """
152 if thread_id is None:
153 thread_id = threading.get_ident()
155 # If it's the current thread, we can clean up directly
156 if thread_id == threading.get_ident():
157 self._cleanup_thread_session()
158 else:
159 # For other threads, just remove from tracking
160 # The thread-local storage will be cleaned up when the thread ends
161 with self._lock:
162 self._thread_credentials.pop(thread_id, None)
164 def cleanup_dead_threads(self):
165 """Remove credential entries for threads that are no longer alive.
167 Handles the abnormal case of threads that died without triggering
168 their cleanup handler. Uses threading.enumerate() to identify
169 alive threads and removes credential entries for dead ones.
170 Sessions on dead threads are garbage-collected normally and
171 their connections return to the shared per-user QueuePool.
173 Called from:
174 - processor_v2.py: every ~60s in the queue loop
175 - app_factory.py: in teardown_appcontext
176 - connection_cleanup.py: in cleanup_idle_connections (every ~300s)
177 """
178 alive_ids = {t.ident for t in threading.enumerate()}
179 with self._lock:
180 dead_ids = [
181 tid for tid in self._thread_credentials if tid not in alive_ids
182 ]
183 for tid in dead_ids:
184 del self._thread_credentials[tid]
185 if dead_ids:
186 logger.debug(f"Swept {len(dead_ids)} dead thread credential(s)")
188 def cleanup_all(self):
189 """Clean up all tracked sessions (for shutdown)."""
190 with self._lock:
191 thread_ids = list(self._thread_credentials.keys())
193 for thread_id in thread_ids:
194 self.cleanup_thread(thread_id)
197# Global instance
198thread_session_manager = ThreadLocalSessionManager()
201def get_metrics_session(username: str, password: str) -> Optional[Session]:
202 """
203 Get a database session for metrics operations in the current thread.
204 The session is created once and reused for the thread's lifetime.
206 Note: This specifically uses create_thread_safe_session_for_metrics internally
207 and should only be used for metrics-related database operations.
208 """
209 return thread_session_manager.get_session(username, password)
212def get_current_thread_session() -> Optional[Session]:
213 """Get the current thread's session if it exists."""
214 return thread_session_manager.get_current_session()
217def cleanup_current_thread():
218 """Clean up the current thread's database session."""
219 thread_session_manager.cleanup_thread()
222def cleanup_dead_threads():
223 """Sweep dead-thread credential entries from the session manager."""
224 try:
225 thread_session_manager.cleanup_dead_threads()
226 except Exception:
227 logger.warning("Dead-thread session sweep failed")
230class _ThreadCleanup(ContextDecorator):
231 """Context manager / decorator for thread-local resource cleanup."""
233 def __enter__(self):
234 return self
236 def __exit__(self, exc_type, exc_val, exc_tb):
237 try:
238 cleanup_current_thread()
239 except Exception:
240 logger.debug(
241 "thread_cleanup: error during DB session cleanup",
242 exc_info=True,
243 )
244 try:
245 from ..config.thread_settings import clear_settings_context
247 clear_settings_context()
248 except Exception:
249 logger.debug(
250 "thread_cleanup: error clearing settings context",
251 exc_info=True,
252 )
253 try:
254 from ..utilities.thread_context import clear_search_context
256 clear_search_context()
257 except Exception:
258 logger.debug(
259 "thread_cleanup: error clearing search context",
260 exc_info=True,
261 )
262 return False
265def thread_cleanup(func=None):
266 """Ensure all thread-local resources are cleaned up when a function or block exits.
268 Works as a bare decorator, a decorator factory, or a context manager::
270 @thread_cleanup
271 def worker(): ...
273 @thread_cleanup()
274 def worker(): ...
276 with thread_cleanup():
277 ...
279 executor.submit(thread_cleanup(func), arg)
280 """
281 if func is not None:
283 @functools.wraps(func)
284 def wrapper(*args, **kwargs):
285 with _ThreadCleanup():
286 return func(*args, **kwargs)
288 return wrapper
289 return _ThreadCleanup()
292# Context manager for automatic cleanup
293class ThreadSessionContext:
294 """
295 Context manager that ensures thread session is cleaned up.
296 Usage:
297 with ThreadSessionContext(username, password) as session:
298 # Use session
299 """
301 def __init__(self, username: str, password: str):
302 self.username = username
303 self.password = password
304 self.session = None
306 def __enter__(self) -> Optional[Session]:
307 self.session = get_metrics_session(self.username, self.password)
308 return self.session
310 def __exit__(self, exc_type, exc_val, exc_tb):
311 # Don't cleanup here - let the thread keep its session
312 # Only cleanup when thread ends
313 pass