Coverage for src / local_deep_research / database / thread_local_session.py: 95%
82 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-25 01:07 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-25 01:07 +0000
1"""
2Thread-local database session management.
3Each thread gets its own database session that persists for the thread's lifetime.
4"""
6import threading
7from typing import Optional, Dict, Tuple
8from sqlalchemy import text
9from sqlalchemy.orm import Session
10from loguru import logger
12from .encrypted_db import db_manager
15class ThreadLocalSessionManager:
16 """
17 Manages database sessions per thread.
18 Each thread gets its own session that is reused throughout the thread's lifetime.
19 """
21 def __init__(self):
22 # Thread-local storage for sessions
23 self._local = threading.local()
24 # Track credentials per thread ID (for cleanup)
25 self._thread_credentials: Dict[int, Tuple[str, str]] = {}
26 self._lock = threading.Lock()
28 def get_session(self, username: str, password: str) -> Optional[Session]:
29 """
30 Get or create a database session for the current thread.
32 The session is created once per thread and reused for all subsequent calls.
33 This avoids the expensive SQLCipher decryption on every database access.
34 """
35 thread_id = threading.get_ident()
37 # Check if we already have a session for this thread
38 if hasattr(self._local, "session") and self._local.session:
39 # Verify it's still valid
40 try:
41 # Simple query to test connection
42 self._local.session.execute(text("SELECT 1"))
43 return self._local.session
44 except Exception:
45 # Session is invalid, will create a new one
46 logger.debug(
47 f"Thread {thread_id}: Existing session invalid, creating new one"
48 )
49 self._cleanup_thread_session()
51 # Create new session for this thread
52 logger.debug(
53 f"Thread {thread_id}: Creating new database session for user {username}"
54 )
56 # Ensure database is open
57 engine = db_manager.open_user_database(username, password)
58 if not engine:
59 logger.error(
60 f"Thread {thread_id}: Failed to open database for user {username}"
61 )
62 return None
64 # Create session for this thread
65 session = db_manager.create_thread_safe_session_for_metrics(
66 username, password
67 )
68 if not session: 68 ↛ 69line 68 didn't jump to line 69 because the condition on line 68 was never true
69 logger.error(
70 f"Thread {thread_id}: Failed to create session for user {username}"
71 )
72 return None
74 # Store in thread-local storage
75 self._local.session = session
76 self._local.username = username
78 # Track credentials for cleanup
79 with self._lock:
80 self._thread_credentials[thread_id] = (username, password)
82 return session
84 def get_current_session(self) -> Optional[Session]:
85 """Get the current thread's session if it exists."""
86 if hasattr(self._local, "session"):
87 return self._local.session
88 return None
90 def _cleanup_thread_session(self):
91 """Clean up the current thread's session and engine."""
92 thread_id = threading.get_ident()
94 if hasattr(self._local, "session") and self._local.session:
95 try:
96 self._local.session.close()
97 logger.debug(f"Thread {thread_id}: Closed database session")
98 except Exception:
99 logger.exception(f"Thread {thread_id}: Error closing session")
100 finally:
101 self._local.session = None
103 # Clean up the thread engine too
104 if hasattr(self._local, "username") and self._local.username:
105 db_manager.cleanup_thread_engines(
106 username=self._local.username, thread_id=thread_id
107 )
108 self._local.username = None
110 # Remove from tracking
111 with self._lock:
112 self._thread_credentials.pop(thread_id, None)
114 def cleanup_thread(self, thread_id: Optional[int] = None):
115 """
116 Clean up session for a specific thread or current thread.
117 Called when a thread is finishing.
118 """
119 if thread_id is None:
120 thread_id = threading.get_ident()
122 # If it's the current thread, we can clean up directly
123 if thread_id == threading.get_ident():
124 self._cleanup_thread_session()
125 else:
126 # For other threads, just remove from tracking
127 # The thread-local storage will be cleaned up when the thread ends
128 with self._lock:
129 self._thread_credentials.pop(thread_id, None)
131 def cleanup_all(self):
132 """Clean up all tracked sessions (for shutdown)."""
133 with self._lock:
134 thread_ids = list(self._thread_credentials.keys())
136 for thread_id in thread_ids:
137 self.cleanup_thread(thread_id)
139 # Also cleanup any remaining thread engines
140 db_manager.cleanup_all_thread_engines()
143# Global instance
144thread_session_manager = ThreadLocalSessionManager()
147def get_metrics_session(username: str, password: str) -> Optional[Session]:
148 """
149 Get a database session for metrics operations in the current thread.
150 The session is created once and reused for the thread's lifetime.
152 Note: This specifically uses create_thread_safe_session_for_metrics internally
153 and should only be used for metrics-related database operations.
154 """
155 return thread_session_manager.get_session(username, password)
158def get_current_thread_session() -> Optional[Session]:
159 """Get the current thread's session if it exists."""
160 return thread_session_manager.get_current_session()
163def cleanup_current_thread():
164 """Clean up the current thread's database session."""
165 thread_session_manager.cleanup_thread()
168# Context manager for automatic cleanup
169class ThreadSessionContext:
170 """
171 Context manager that ensures thread session is cleaned up.
172 Usage:
173 with ThreadSessionContext(username, password) as session:
174 # Use session
175 """
177 def __init__(self, username: str, password: str):
178 self.username = username
179 self.password = password
180 self.session = None
182 def __enter__(self) -> Optional[Session]:
183 self.session = get_metrics_session(self.username, self.password)
184 return self.session
186 def __exit__(self, exc_type, exc_val, exc_tb):
187 # Don't cleanup here - let the thread keep its session
188 # Only cleanup when thread ends
189 pass