Coverage for src / local_deep_research / database / thread_local_session.py: 61%

77 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2026-01-11 00:51 +0000

1""" 

2Thread-local database session management. 

3Each thread gets its own database session that persists for the thread's lifetime. 

4""" 

5 

6import threading 

7from typing import Optional, Dict, Tuple 

8from sqlalchemy.orm import Session 

9from loguru import logger 

10 

11from .encrypted_db import db_manager 

12 

13 

14class ThreadLocalSessionManager: 

15 """ 

16 Manages database sessions per thread. 

17 Each thread gets its own session that is reused throughout the thread's lifetime. 

18 """ 

19 

20 def __init__(self): 

21 # Thread-local storage for sessions 

22 self._local = threading.local() 

23 # Track credentials per thread ID (for cleanup) 

24 self._thread_credentials: Dict[int, Tuple[str, str]] = {} 

25 self._lock = threading.Lock() 

26 

27 def get_session(self, username: str, password: str) -> Optional[Session]: 

28 """ 

29 Get or create a database session for the current thread. 

30 

31 The session is created once per thread and reused for all subsequent calls. 

32 This avoids the expensive SQLCipher decryption on every database access. 

33 """ 

34 thread_id = threading.get_ident() 

35 

36 # Check if we already have a session for this thread 

37 if hasattr(self._local, "session") and self._local.session: 

38 # Verify it's still valid 

39 try: 

40 # Simple query to test connection 

41 self._local.session.execute("SELECT 1") 

42 return self._local.session 

43 except Exception: 

44 # Session is invalid, will create a new one 

45 logger.debug( 

46 f"Thread {thread_id}: Existing session invalid, creating new one" 

47 ) 

48 self._cleanup_thread_session() 

49 

50 # Create new session for this thread 

51 logger.debug( 

52 f"Thread {thread_id}: Creating new database session for user {username}" 

53 ) 

54 

55 # Ensure database is open 

56 engine = db_manager.open_user_database(username, password) 

57 if not engine: 

58 logger.error( 

59 f"Thread {thread_id}: Failed to open database for user {username}" 

60 ) 

61 return None 

62 

63 # Create session for this thread 

64 session = db_manager.create_thread_safe_session_for_metrics( 

65 username, password 

66 ) 

67 if not session: 67 ↛ 68line 67 didn't jump to line 68 because the condition on line 67 was never true

68 logger.error( 

69 f"Thread {thread_id}: Failed to create session for user {username}" 

70 ) 

71 return None 

72 

73 # Store in thread-local storage 

74 self._local.session = session 

75 self._local.username = username 

76 

77 # Track credentials for cleanup 

78 with self._lock: 

79 self._thread_credentials[thread_id] = (username, password) 

80 

81 return session 

82 

83 def get_current_session(self) -> Optional[Session]: 

84 """Get the current thread's session if it exists.""" 

85 if hasattr(self._local, "session"): 

86 return self._local.session 

87 return None 

88 

89 def _cleanup_thread_session(self): 

90 """Clean up the current thread's session.""" 

91 thread_id = threading.get_ident() 

92 

93 if hasattr(self._local, "session") and self._local.session: 93 ↛ 105line 93 didn't jump to line 105 because the condition on line 93 was always true

94 try: 

95 self._local.session.close() 

96 logger.debug(f"Thread {thread_id}: Closed database session") 

97 except Exception as e: 

98 logger.exception( 

99 f"Thread {thread_id}: Error closing session: {e}" 

100 ) 

101 finally: 

102 self._local.session = None 

103 

104 # Remove from tracking 

105 with self._lock: 

106 self._thread_credentials.pop(thread_id, None) 

107 

108 def cleanup_thread(self, thread_id: Optional[int] = None): 

109 """ 

110 Clean up session for a specific thread or current thread. 

111 Called when a thread is finishing. 

112 """ 

113 if thread_id is None: 

114 thread_id = threading.get_ident() 

115 

116 # If it's the current thread, we can clean up directly 

117 if thread_id == threading.get_ident(): 

118 self._cleanup_thread_session() 

119 else: 

120 # For other threads, just remove from tracking 

121 # The thread-local storage will be cleaned up when the thread ends 

122 with self._lock: 

123 self._thread_credentials.pop(thread_id, None) 

124 

125 def cleanup_all(self): 

126 """Clean up all tracked sessions (for shutdown).""" 

127 with self._lock: 

128 thread_ids = list(self._thread_credentials.keys()) 

129 

130 for thread_id in thread_ids: 

131 self.cleanup_thread(thread_id) 

132 

133 

134# Global instance 

135thread_session_manager = ThreadLocalSessionManager() 

136 

137 

138def get_metrics_session(username: str, password: str) -> Optional[Session]: 

139 """ 

140 Get a database session for metrics operations in the current thread. 

141 The session is created once and reused for the thread's lifetime. 

142 

143 Note: This specifically uses create_thread_safe_session_for_metrics internally 

144 and should only be used for metrics-related database operations. 

145 """ 

146 return thread_session_manager.get_session(username, password) 

147 

148 

149def get_current_thread_session() -> Optional[Session]: 

150 """Get the current thread's session if it exists.""" 

151 return thread_session_manager.get_current_session() 

152 

153 

154def cleanup_current_thread(): 

155 """Clean up the current thread's database session.""" 

156 thread_session_manager.cleanup_thread() 

157 

158 

159# Context manager for automatic cleanup 

160class ThreadSessionContext: 

161 """ 

162 Context manager that ensures thread session is cleaned up. 

163 Usage: 

164 with ThreadSessionContext(username, password) as session: 

165 # Use session 

166 """ 

167 

168 def __init__(self, username: str, password: str): 

169 self.username = username 

170 self.password = password 

171 self.session = None 

172 

173 def __enter__(self) -> Optional[Session]: 

174 self.session = get_metrics_session(self.username, self.password) 

175 return self.session 

176 

177 def __exit__(self, exc_type, exc_val, exc_tb): 

178 # Don't cleanup here - let the thread keep its session 

179 # Only cleanup when thread ends 

180 pass