Coverage for src/local_deep_research/database/thread_metrics.py: 93%

53 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-03 23:15 +0000

1""" 

2Thread-safe metrics database access. 

3 

4This module provides a way for background threads to write metrics 

5to the user's encrypted database by creating thread-local connections 

6with the provided password. 

7""" 

8 

9import threading 

10from contextlib import contextmanager 

11from typing import Optional 

12 

13from loguru import logger 

14from sqlalchemy.orm import Session 

15 

16from .encrypted_db import db_manager 

17 

18 

19class ThreadSafeMetricsWriter: 

20 """ 

21 Thread-safe writer for metrics to encrypted user databases. 

22 Creates encrypted connections per thread using provided passwords. 

23 """ 

24 

25 def __init__(self): 

26 self._thread_local = threading.local() 

27 

28 def set_user_password(self, username: str, password: str): 

29 """ 

30 Store user password for the current thread. 

31 This allows the thread to create its own encrypted connection. 

32 

33 IMPORTANT: This is safe because: 

34 1. Password is already in memory (user is logged in) 

35 2. It's only stored thread-locally 

36 3. It's explicitly cleared by `clear_passwords()` (invoked from 

37 `cleanup_current_thread()`) so pooled worker threads do not 

38 retain credentials between tasks. 

39 """ 

40 

41 if not hasattr(self._thread_local, "passwords"): 

42 self._thread_local.passwords = {} 

43 self._thread_local.passwords[username] = password 

44 

45 def clear_passwords(self): 

46 """Remove all passwords cached on the current thread's local store. 

47 

48 Called by `cleanup_current_thread()` so that pooled worker threads 

49 don't retain plaintext credentials between tasks. 

50 """ 

51 if hasattr(self._thread_local, "passwords"): 

52 self._thread_local.passwords.clear() 

53 

54 @contextmanager 

55 def get_session(self, username: str = None) -> Session: 

56 """ 

57 Get a database session for metrics in the current thread. 

58 Creates a new encrypted connection if needed. 

59 

60 Args: 

61 username: The username for database access. If not provided, 

62 will attempt to get it from Flask session. 

63 """ 

64 # If username not provided, try to get it from Flask session 

65 if username is None: 

66 try: 

67 from flask import session as flask_session 

68 from werkzeug.exceptions import Unauthorized 

69 

70 username = flask_session.get("username") 

71 if not username: 

72 raise Unauthorized("No username in Flask session") 

73 except (ImportError, RuntimeError) as e: 

74 # Flask context not available or no session 

75 raise ValueError(f"Cannot determine username: {e}") 

76 

77 # Get password for this user in this thread 

78 if not hasattr(self._thread_local, "passwords"): 

79 raise ValueError("No password set for thread metrics access") 

80 

81 password = self._thread_local.passwords.get(username) 

82 

83 if not password: 

84 raise ValueError( 

85 f"No password available for user {username} in this thread" 

86 ) 

87 

88 # Create a thread-safe session for this user 

89 session = None 

90 try: 

91 session = db_manager.create_thread_safe_session_for_metrics( 

92 username, password 

93 ) 

94 if not session: 

95 raise ValueError( # noqa: TRY301 — except does session rollback before re-raise 

96 f"Failed to create session for user {username}" 

97 ) 

98 yield session 

99 session.commit() 

100 except Exception: 

101 logger.exception(f"Session error for {username}") 

102 if session: 

103 session.rollback() 

104 raise 

105 finally: 

106 if session: 106 ↛ exitline 106 didn't return from function 'get_session' because the condition on line 106 was always true

107 from ..utilities.resource_utils import safe_close 

108 

109 safe_close(session, "thread metrics session") 

110 

111 def write_token_metrics( 

112 self, username: str, research_id: Optional[int], token_data: dict 

113 ): 

114 """ 

115 Write token metrics from any thread. 

116 

117 Args: 

118 username: The username (for database access) 

119 research_id: The research ID 

120 token_data: Dictionary with token metrics data 

121 """ 

122 with self.get_session(username) as session: 

123 # Import here to avoid circular imports 

124 from .models import TokenUsage 

125 

126 # Create TokenUsage record 

127 token_usage = TokenUsage( 

128 research_id=research_id, 

129 model_name=token_data.get("model_name"), 

130 model_provider=token_data.get("provider"), 

131 prompt_tokens=token_data.get("prompt_tokens", 0), 

132 completion_tokens=token_data.get("completion_tokens", 0), 

133 total_tokens=token_data.get("prompt_tokens", 0) 

134 + token_data.get("completion_tokens", 0), 

135 # Research context 

136 research_query=token_data.get("research_query"), 

137 research_mode=token_data.get("research_mode"), 

138 research_phase=token_data.get("research_phase"), 

139 search_iteration=token_data.get("search_iteration"), 

140 # Performance metrics 

141 response_time_ms=token_data.get("response_time_ms"), 

142 success_status=token_data.get("success_status", "success"), 

143 error_type=token_data.get("error_type"), 

144 # Search engine context 

145 search_engines_planned=token_data.get("search_engines_planned"), 

146 search_engine_selected=token_data.get("search_engine_selected"), 

147 # Call stack tracking 

148 calling_file=token_data.get("calling_file"), 

149 calling_function=token_data.get("calling_function"), 

150 call_stack=token_data.get("call_stack"), 

151 # Context overflow detection 

152 context_limit=token_data.get("context_limit"), 

153 context_truncated=token_data.get("context_truncated", False), 

154 tokens_truncated=token_data.get("tokens_truncated"), 

155 truncation_ratio=token_data.get("truncation_ratio"), 

156 # Raw Ollama metrics 

157 ollama_prompt_eval_count=token_data.get( 

158 "ollama_prompt_eval_count" 

159 ), 

160 ollama_eval_count=token_data.get("ollama_eval_count"), 

161 ollama_total_duration=token_data.get("ollama_total_duration"), 

162 ollama_load_duration=token_data.get("ollama_load_duration"), 

163 ollama_prompt_eval_duration=token_data.get( 

164 "ollama_prompt_eval_duration" 

165 ), 

166 ollama_eval_duration=token_data.get("ollama_eval_duration"), 

167 ) 

168 session.add(token_usage) 

169 

170 

171# Global instance for thread-safe metrics 

172metrics_writer = ThreadSafeMetricsWriter()