Coverage for src / local_deep_research / database / thread_local_session.py: 61%
77 statements
« prev ^ index » next coverage.py v7.12.0, created at 2026-01-11 00:51 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2026-01-11 00:51 +0000
1"""
2Thread-local database session management.
3Each thread gets its own database session that persists for the thread's lifetime.
4"""
6import threading
7from typing import Optional, Dict, Tuple
8from sqlalchemy.orm import Session
9from loguru import logger
11from .encrypted_db import db_manager
14class ThreadLocalSessionManager:
15 """
16 Manages database sessions per thread.
17 Each thread gets its own session that is reused throughout the thread's lifetime.
18 """
20 def __init__(self):
21 # Thread-local storage for sessions
22 self._local = threading.local()
23 # Track credentials per thread ID (for cleanup)
24 self._thread_credentials: Dict[int, Tuple[str, str]] = {}
25 self._lock = threading.Lock()
27 def get_session(self, username: str, password: str) -> Optional[Session]:
28 """
29 Get or create a database session for the current thread.
31 The session is created once per thread and reused for all subsequent calls.
32 This avoids the expensive SQLCipher decryption on every database access.
33 """
34 thread_id = threading.get_ident()
36 # Check if we already have a session for this thread
37 if hasattr(self._local, "session") and self._local.session:
38 # Verify it's still valid
39 try:
40 # Simple query to test connection
41 self._local.session.execute("SELECT 1")
42 return self._local.session
43 except Exception:
44 # Session is invalid, will create a new one
45 logger.debug(
46 f"Thread {thread_id}: Existing session invalid, creating new one"
47 )
48 self._cleanup_thread_session()
50 # Create new session for this thread
51 logger.debug(
52 f"Thread {thread_id}: Creating new database session for user {username}"
53 )
55 # Ensure database is open
56 engine = db_manager.open_user_database(username, password)
57 if not engine:
58 logger.error(
59 f"Thread {thread_id}: Failed to open database for user {username}"
60 )
61 return None
63 # Create session for this thread
64 session = db_manager.create_thread_safe_session_for_metrics(
65 username, password
66 )
67 if not session: 67 ↛ 68line 67 didn't jump to line 68 because the condition on line 67 was never true
68 logger.error(
69 f"Thread {thread_id}: Failed to create session for user {username}"
70 )
71 return None
73 # Store in thread-local storage
74 self._local.session = session
75 self._local.username = username
77 # Track credentials for cleanup
78 with self._lock:
79 self._thread_credentials[thread_id] = (username, password)
81 return session
83 def get_current_session(self) -> Optional[Session]:
84 """Get the current thread's session if it exists."""
85 if hasattr(self._local, "session"):
86 return self._local.session
87 return None
89 def _cleanup_thread_session(self):
90 """Clean up the current thread's session."""
91 thread_id = threading.get_ident()
93 if hasattr(self._local, "session") and self._local.session: 93 ↛ 105line 93 didn't jump to line 105 because the condition on line 93 was always true
94 try:
95 self._local.session.close()
96 logger.debug(f"Thread {thread_id}: Closed database session")
97 except Exception as e:
98 logger.exception(
99 f"Thread {thread_id}: Error closing session: {e}"
100 )
101 finally:
102 self._local.session = None
104 # Remove from tracking
105 with self._lock:
106 self._thread_credentials.pop(thread_id, None)
108 def cleanup_thread(self, thread_id: Optional[int] = None):
109 """
110 Clean up session for a specific thread or current thread.
111 Called when a thread is finishing.
112 """
113 if thread_id is None:
114 thread_id = threading.get_ident()
116 # If it's the current thread, we can clean up directly
117 if thread_id == threading.get_ident():
118 self._cleanup_thread_session()
119 else:
120 # For other threads, just remove from tracking
121 # The thread-local storage will be cleaned up when the thread ends
122 with self._lock:
123 self._thread_credentials.pop(thread_id, None)
125 def cleanup_all(self):
126 """Clean up all tracked sessions (for shutdown)."""
127 with self._lock:
128 thread_ids = list(self._thread_credentials.keys())
130 for thread_id in thread_ids:
131 self.cleanup_thread(thread_id)
134# Global instance
135thread_session_manager = ThreadLocalSessionManager()
138def get_metrics_session(username: str, password: str) -> Optional[Session]:
139 """
140 Get a database session for metrics operations in the current thread.
141 The session is created once and reused for the thread's lifetime.
143 Note: This specifically uses create_thread_safe_session_for_metrics internally
144 and should only be used for metrics-related database operations.
145 """
146 return thread_session_manager.get_session(username, password)
149def get_current_thread_session() -> Optional[Session]:
150 """Get the current thread's session if it exists."""
151 return thread_session_manager.get_current_session()
154def cleanup_current_thread():
155 """Clean up the current thread's database session."""
156 thread_session_manager.cleanup_thread()
159# Context manager for automatic cleanup
160class ThreadSessionContext:
161 """
162 Context manager that ensures thread session is cleaned up.
163 Usage:
164 with ThreadSessionContext(username, password) as session:
165 # Use session
166 """
168 def __init__(self, username: str, password: str):
169 self.username = username
170 self.password = password
171 self.session = None
173 def __enter__(self) -> Optional[Session]:
174 self.session = get_metrics_session(self.username, self.password)
175 return self.session
177 def __exit__(self, exc_type, exc_val, exc_tb):
178 # Don't cleanup here - let the thread keep its session
179 # Only cleanup when thread ends
180 pass