Coverage for src / local_deep_research / database / thread_local_session.py: 95%

82 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-25 01:07 +0000

1""" 

2Thread-local database session management. 

3Each thread gets its own database session that persists for the thread's lifetime. 

4""" 

5 

6import threading 

7from typing import Optional, Dict, Tuple 

8from sqlalchemy import text 

9from sqlalchemy.orm import Session 

10from loguru import logger 

11 

12from .encrypted_db import db_manager 

13 

14 

15class ThreadLocalSessionManager: 

16 """ 

17 Manages database sessions per thread. 

18 Each thread gets its own session that is reused throughout the thread's lifetime. 

19 """ 

20 

21 def __init__(self): 

22 # Thread-local storage for sessions 

23 self._local = threading.local() 

24 # Track credentials per thread ID (for cleanup) 

25 self._thread_credentials: Dict[int, Tuple[str, str]] = {} 

26 self._lock = threading.Lock() 

27 

28 def get_session(self, username: str, password: str) -> Optional[Session]: 

29 """ 

30 Get or create a database session for the current thread. 

31 

32 The session is created once per thread and reused for all subsequent calls. 

33 This avoids the expensive SQLCipher decryption on every database access. 

34 """ 

35 thread_id = threading.get_ident() 

36 

37 # Check if we already have a session for this thread 

38 if hasattr(self._local, "session") and self._local.session: 

39 # Verify it's still valid 

40 try: 

41 # Simple query to test connection 

42 self._local.session.execute(text("SELECT 1")) 

43 return self._local.session 

44 except Exception: 

45 # Session is invalid, will create a new one 

46 logger.debug( 

47 f"Thread {thread_id}: Existing session invalid, creating new one" 

48 ) 

49 self._cleanup_thread_session() 

50 

51 # Create new session for this thread 

52 logger.debug( 

53 f"Thread {thread_id}: Creating new database session for user {username}" 

54 ) 

55 

56 # Ensure database is open 

57 engine = db_manager.open_user_database(username, password) 

58 if not engine: 

59 logger.error( 

60 f"Thread {thread_id}: Failed to open database for user {username}" 

61 ) 

62 return None 

63 

64 # Create session for this thread 

65 session = db_manager.create_thread_safe_session_for_metrics( 

66 username, password 

67 ) 

68 if not session: 68 ↛ 69line 68 didn't jump to line 69 because the condition on line 68 was never true

69 logger.error( 

70 f"Thread {thread_id}: Failed to create session for user {username}" 

71 ) 

72 return None 

73 

74 # Store in thread-local storage 

75 self._local.session = session 

76 self._local.username = username 

77 

78 # Track credentials for cleanup 

79 with self._lock: 

80 self._thread_credentials[thread_id] = (username, password) 

81 

82 return session 

83 

84 def get_current_session(self) -> Optional[Session]: 

85 """Get the current thread's session if it exists.""" 

86 if hasattr(self._local, "session"): 

87 return self._local.session 

88 return None 

89 

90 def _cleanup_thread_session(self): 

91 """Clean up the current thread's session and engine.""" 

92 thread_id = threading.get_ident() 

93 

94 if hasattr(self._local, "session") and self._local.session: 

95 try: 

96 self._local.session.close() 

97 logger.debug(f"Thread {thread_id}: Closed database session") 

98 except Exception: 

99 logger.exception(f"Thread {thread_id}: Error closing session") 

100 finally: 

101 self._local.session = None 

102 

103 # Clean up the thread engine too 

104 if hasattr(self._local, "username") and self._local.username: 

105 db_manager.cleanup_thread_engines( 

106 username=self._local.username, thread_id=thread_id 

107 ) 

108 self._local.username = None 

109 

110 # Remove from tracking 

111 with self._lock: 

112 self._thread_credentials.pop(thread_id, None) 

113 

114 def cleanup_thread(self, thread_id: Optional[int] = None): 

115 """ 

116 Clean up session for a specific thread or current thread. 

117 Called when a thread is finishing. 

118 """ 

119 if thread_id is None: 

120 thread_id = threading.get_ident() 

121 

122 # If it's the current thread, we can clean up directly 

123 if thread_id == threading.get_ident(): 

124 self._cleanup_thread_session() 

125 else: 

126 # For other threads, just remove from tracking 

127 # The thread-local storage will be cleaned up when the thread ends 

128 with self._lock: 

129 self._thread_credentials.pop(thread_id, None) 

130 

131 def cleanup_all(self): 

132 """Clean up all tracked sessions (for shutdown).""" 

133 with self._lock: 

134 thread_ids = list(self._thread_credentials.keys()) 

135 

136 for thread_id in thread_ids: 

137 self.cleanup_thread(thread_id) 

138 

139 # Also cleanup any remaining thread engines 

140 db_manager.cleanup_all_thread_engines() 

141 

142 

143# Global instance 

144thread_session_manager = ThreadLocalSessionManager() 

145 

146 

147def get_metrics_session(username: str, password: str) -> Optional[Session]: 

148 """ 

149 Get a database session for metrics operations in the current thread. 

150 The session is created once and reused for the thread's lifetime. 

151 

152 Note: This specifically uses create_thread_safe_session_for_metrics internally 

153 and should only be used for metrics-related database operations. 

154 """ 

155 return thread_session_manager.get_session(username, password) 

156 

157 

158def get_current_thread_session() -> Optional[Session]: 

159 """Get the current thread's session if it exists.""" 

160 return thread_session_manager.get_current_session() 

161 

162 

163def cleanup_current_thread(): 

164 """Clean up the current thread's database session.""" 

165 thread_session_manager.cleanup_thread() 

166 

167 

168# Context manager for automatic cleanup 

169class ThreadSessionContext: 

170 """ 

171 Context manager that ensures thread session is cleaned up. 

172 Usage: 

173 with ThreadSessionContext(username, password) as session: 

174 # Use session 

175 """ 

176 

177 def __init__(self, username: str, password: str): 

178 self.username = username 

179 self.password = password 

180 self.session = None 

181 

182 def __enter__(self) -> Optional[Session]: 

183 self.session = get_metrics_session(self.username, self.password) 

184 return self.session 

185 

186 def __exit__(self, exc_type, exc_val, exc_tb): 

187 # Don't cleanup here - let the thread keep its session 

188 # Only cleanup when thread ends 

189 pass