Coverage for src/local_deep_research/metrics/token_counter.py: 86%

500 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-03 23:15 +0000

1"""Token counting functionality for LLM usage tracking.""" 

2 

3import inspect 

4import json 

5import time 

6from datetime import datetime, timedelta, UTC 

7from pathlib import Path 

8from typing import Any, Dict, List, Optional 

9 

10from langchain_core.callbacks import BaseCallbackHandler 

11from langchain_core.outputs import LLMResult 

12from loguru import logger 

13from sqlalchemy import func, text 

14 

15from ..database.models import ModelUsage, TokenUsage 

16from .query_utils import get_research_mode_condition, get_time_filter_condition 

17 

18 

19class TokenCountingCallback(BaseCallbackHandler): 

20 """Callback handler for counting tokens across different models.""" 

21 

22 def __init__( 

23 self, 

24 research_id: Optional[str] = None, 

25 research_context: Optional[Dict[str, Any]] = None, 

26 ): 

27 """Initialize the token counting callback. 

28 

29 Args: 

30 research_id: The ID of the research to track tokens for 

31 research_context: Additional research context for enhanced tracking 

32 """ 

33 super().__init__() 

34 self.research_id = research_id 

35 self.research_context = research_context or {} 

36 self.current_model = None 

37 self.current_provider = None 

38 self.preset_model = None # Model name set during callback creation 

39 self.preset_provider = None # Provider set during callback creation 

40 

41 # Phase 1 Enhancement: Track timing and context 

42 self.start_time = None 

43 self.response_time_ms = None 

44 self.success_status = "success" 

45 self.error_type = None 

46 

47 # Call stack tracking 

48 self.calling_file = None 

49 self.calling_function = None 

50 self.call_stack = None 

51 

52 # Context overflow tracking 

53 self.context_limit = None 

54 self.context_truncated = False 

55 self.tokens_truncated = 0 

56 self.truncation_ratio = 0.0 

57 self.original_prompt_estimate = 0 

58 

59 # Raw Ollama response metrics 

60 self.ollama_metrics = {} 

61 

62 # Track token counts in memory 

63 self.counts = { 

64 "total_tokens": 0, 

65 "total_prompt_tokens": 0, 

66 "total_completion_tokens": 0, 

67 "by_model": {}, 

68 } 

69 

70 def on_llm_start( 

71 self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any 

72 ) -> None: 

73 """Called when LLM starts running.""" 

74 # Phase 1 Enhancement: Start timing 

75 self.start_time = time.time() 

76 

77 # Reset per-call truncation state. The callback instance is shared 

78 # across every LLM call in a research session (see llm_config.py 

79 # wrap_llm), so without this reset the post-loop estimation block's 

80 # `if not self.context_truncated` guard would silently disable 

81 # [estimated] / [estimated-total-context] detection on every call 

82 # after the first one that truncates. 

83 self.context_truncated = False 

84 self.tokens_truncated = 0 

85 self.truncation_ratio = 0.0 

86 

87 # Estimate original prompt size (rough estimate: ~4 chars per token) 

88 if prompts: 

89 total_chars = sum(len(prompt) for prompt in prompts) 

90 self.original_prompt_estimate = total_chars // 4 

91 logger.debug( 

92 f"Estimated prompt tokens: {self.original_prompt_estimate} (from {total_chars} chars)" 

93 ) 

94 

95 # Get context limit from research context (will be set from settings) 

96 self.context_limit = self.research_context.get("context_limit") 

97 

98 # Phase 1 Enhancement: Capture call stack information 

99 try: 

100 stack = inspect.stack() 

101 

102 # Skip the first few frames (this method, langchain internals) 

103 # Look for the first frame that's in our project directory 

104 for frame_info in stack[1:]: 

105 file_path = frame_info.filename 

106 # Look for any frame containing local_deep_research project 

107 if ( 

108 "local_deep_research" in file_path 

109 and "site-packages" not in file_path 

110 and "venv" not in file_path 

111 ): 

112 # Extract relative path from local_deep_research 

113 if "src/local_deep_research" in file_path: 

114 relative_path = file_path.split( 

115 "src/local_deep_research" 

116 )[-1].lstrip("/") 

117 elif "local_deep_research/src" in file_path: 117 ↛ 121line 117 didn't jump to line 121 because the condition on line 117 was always true

118 relative_path = file_path.split( 

119 "local_deep_research/src" 

120 )[-1].lstrip("/") 

121 elif "local_deep_research" in file_path: 

122 # Get everything after local_deep_research 

123 relative_path = file_path.split("local_deep_research")[ 

124 -1 

125 ].lstrip("/") 

126 else: 

127 relative_path = Path(file_path).name 

128 

129 self.calling_file = relative_path 

130 self.calling_function = frame_info.function 

131 

132 # Capture a simplified call stack (just the relevant frames) 

133 call_stack_frames = [] 

134 for frame in stack[1:6]: # Limit to 5 frames 

135 if ( 

136 "local_deep_research" in frame.filename 

137 and "site-packages" not in frame.filename 

138 and "venv" not in frame.filename 

139 ): 

140 frame_name = f"{Path(frame.filename).name}:{frame.function}:{frame.lineno}" 

141 call_stack_frames.append(frame_name) 

142 

143 self.call_stack = ( 

144 " -> ".join(call_stack_frames) 

145 if call_stack_frames 

146 else None 

147 ) 

148 break 

149 except Exception: 

150 logger.warning("Error capturing call stack") 

151 # Continue without call stack info if there's an error 

152 

153 # First, use preset values if available 

154 if self.preset_model: 

155 self.current_model = self.preset_model 

156 else: 

157 # Try multiple locations for model name 

158 model_name = None 

159 

160 # First check invocation_params 

161 invocation_params = kwargs.get("invocation_params", {}) 

162 model_name = invocation_params.get( 

163 "model" 

164 ) or invocation_params.get("model_name") 

165 

166 # Check kwargs directly 

167 if not model_name: 

168 model_name = kwargs.get("model") or kwargs.get("model_name") 

169 

170 # Check serialized data 

171 if not model_name and "kwargs" in serialized: 

172 model_name = serialized["kwargs"].get("model") or serialized[ 

173 "kwargs" 

174 ].get("model_name") 

175 

176 # Check for name in serialized data 

177 if not model_name and "name" in serialized: 

178 model_name = serialized["name"] 

179 

180 # If still not found and we have Ollama, try to extract from the instance 

181 if ( 

182 not model_name 

183 and "_type" in serialized 

184 and "ChatOllama" in serialized["_type"] 

185 ): 

186 # For Ollama, the model name might be in the serialized kwargs 

187 if "kwargs" in serialized and "model" in serialized["kwargs"]: 187 ↛ 188line 187 didn't jump to line 188 because the condition on line 187 was never true

188 model_name = serialized["kwargs"]["model"] 

189 else: 

190 # Default to the type if we can't find the actual model 

191 model_name = "ollama" 

192 

193 # Final fallback 

194 if not model_name: 

195 if "_type" in serialized: 

196 model_name = serialized["_type"] 

197 else: 

198 model_name = "unknown" 

199 

200 self.current_model = model_name 

201 

202 # Use preset provider if available 

203 if self.preset_provider: 

204 self.current_provider = self.preset_provider 

205 else: 

206 # Extract provider from serialized type or kwargs 

207 if "_type" in serialized: 

208 type_str = serialized["_type"] 

209 if "ChatOllama" in type_str: 

210 self.current_provider = "ollama" 

211 elif "ChatOpenAI" in type_str: 

212 self.current_provider = "openai" 

213 elif "ChatAnthropic" in type_str: 

214 self.current_provider = "anthropic" 

215 else: 

216 self.current_provider = kwargs.get("provider", "unknown") 

217 else: 

218 self.current_provider = kwargs.get("provider", "unknown") 

219 

220 # Initialize model tracking if needed 

221 if self.current_model not in self.counts["by_model"]: 

222 self.counts["by_model"][self.current_model] = { 

223 "prompt_tokens": 0, 

224 "completion_tokens": 0, 

225 "total_tokens": 0, 

226 "calls": 0, 

227 "provider": self.current_provider, 

228 } 

229 

230 # Increment call count 

231 self.counts["by_model"][self.current_model]["calls"] += 1 

232 

233 def _check_context_overflow( 

234 self, 

235 prompt_eval_count: int, 

236 completion_tokens: int = 0, 

237 source: str = "", 

238 ) -> None: 

239 """Check for context overflow based on prompt and total token usage. 

240 

241 Args: 

242 prompt_eval_count: Number of tokens the model actually processed. 

243 completion_tokens: Number of tokens generated (for total-context check). 

244 source: Which branch provided the data (for logging). 

245 """ 

246 logger.debug( 

247 f"Context overflow check [{source}]: " 

248 f"prompt_eval_count={prompt_eval_count}, " 

249 f"completion_tokens={completion_tokens}, " 

250 f"context_limit={self.context_limit}" 

251 ) 

252 

253 if not self.context_limit or prompt_eval_count <= 0: 

254 return 

255 

256 # Input-only overflow: prompt at >= 80% of context limit. Matches the 

257 # chart-warning threshold and PR #3840's deliberate choice (PR #3792 

258 # lowered from 95% → 80%). The total-context branch below uses 95% 

259 # because it's a stricter condition (input+output combined). 

260 if prompt_eval_count >= self.context_limit * 0.80: 

261 self.context_truncated = True 

262 

263 if self.original_prompt_estimate > prompt_eval_count: 

264 self.tokens_truncated = max( 

265 0, 

266 self.original_prompt_estimate - prompt_eval_count, 

267 ) 

268 if ( 268 ↛ 282line 268 didn't jump to line 282 because the condition on line 268 was always true

269 self.tokens_truncated > 0 

270 and self.original_prompt_estimate > 0 

271 ): 

272 self.truncation_ratio = ( 

273 self.tokens_truncated / self.original_prompt_estimate 

274 ) 

275 elif prompt_eval_count > self.context_limit: 275 ↛ 276line 275 didn't jump to line 276 because the condition on line 275 was never true

276 self.tokens_truncated = prompt_eval_count - self.context_limit 

277 if self.tokens_truncated > 0 and prompt_eval_count > 0: 

278 self.truncation_ratio = ( 

279 self.tokens_truncated / prompt_eval_count 

280 ) 

281 

282 logger.warning( 

283 f"Context overflow detected [provider-confirmed] " 

284 f"research_id={self.research_id} " 

285 f"model={self.current_model} " 

286 f"provider={self.current_provider} " 

287 f"source={source} " 

288 f"prompt_tokens={prompt_eval_count} " 

289 f"context_limit={self.context_limit} " 

290 f"tokens_truncated={self.tokens_truncated} " 

291 f"truncation_ratio={self.truncation_ratio:.1%}" 

292 ) 

293 

294 # Total-context overflow: input + output exceeds 95% of context limit 

295 elif ( 

296 completion_tokens > 0 

297 and prompt_eval_count + completion_tokens 

298 >= self.context_limit * 0.95 

299 ): 

300 total = prompt_eval_count + completion_tokens 

301 self.context_truncated = True 

302 self.tokens_truncated = max(0, total - self.context_limit) 

303 self.truncation_ratio = ( 

304 self.tokens_truncated / total if total > 0 else 0 

305 ) 

306 logger.warning( 

307 f"Context overflow detected [total-context] " 

308 f"research_id={self.research_id} " 

309 f"model={self.current_model} " 

310 f"provider={self.current_provider} " 

311 f"source={source} " 

312 f"prompt_tokens={prompt_eval_count} " 

313 f"completion_tokens={completion_tokens} " 

314 f"total_tokens={total} " 

315 f"context_limit={self.context_limit} " 

316 f"tokens_truncated={self.tokens_truncated} " 

317 f"truncation_ratio={self.truncation_ratio:.1%}" 

318 ) 

319 else: 

320 logger.debug( 

321 f"Context OK [{source}]: " 

322 f"{prompt_eval_count}/{self.context_limit} " 

323 f"({prompt_eval_count / self.context_limit:.1%})" 

324 ) 

325 

326 def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: 

327 """Called when LLM ends running.""" 

328 # Phase 1 Enhancement: Calculate response time 

329 if self.start_time: 

330 self.response_time_ms = int((time.time() - self.start_time) * 1000) 

331 

332 # Extract token usage from response 

333 token_usage = None 

334 

335 # Check multiple locations for token usage 

336 if hasattr(response, "llm_output") and response.llm_output: 

337 token_usage = response.llm_output.get( 

338 "token_usage" 

339 ) or response.llm_output.get("usage", {}) 

340 

341 # Check for usage metadata in generations (Ollama specific) 

342 if not token_usage and hasattr(response, "generations"): 

343 for generation_list in response.generations: 

344 for generation in generation_list: 

345 if hasattr(generation, "message") and hasattr( 

346 generation.message, "usage_metadata" 

347 ): 

348 usage_meta = generation.message.usage_metadata 

349 if usage_meta: # Check if usage_metadata is not None 

350 token_usage = { 

351 "prompt_tokens": usage_meta.get( 

352 "input_tokens", 0 

353 ), 

354 "completion_tokens": usage_meta.get( 

355 "output_tokens", 0 

356 ), 

357 "total_tokens": usage_meta.get( 

358 "total_tokens", 0 

359 ), 

360 } 

361 # Check context overflow before breaking 

362 # (input_tokens == prompt_eval_count for Ollama) 

363 self._check_context_overflow( 

364 usage_meta.get("input_tokens", 0), 

365 completion_tokens=usage_meta.get( 

366 "output_tokens", 0 

367 ), 

368 source="usage_metadata", 

369 ) 

370 break 

371 # Also check response_metadata 

372 if hasattr(generation, "message") and hasattr( 

373 generation.message, "response_metadata" 

374 ): 

375 resp_meta = generation.message.response_metadata 

376 if resp_meta.get("prompt_eval_count") or resp_meta.get( 

377 "eval_count" 

378 ): 

379 # Capture raw Ollama metrics 

380 self.ollama_metrics = { 

381 "prompt_eval_count": resp_meta.get( 

382 "prompt_eval_count" 

383 ), 

384 "eval_count": resp_meta.get("eval_count"), 

385 "total_duration": resp_meta.get( 

386 "total_duration" 

387 ), 

388 "load_duration": resp_meta.get("load_duration"), 

389 "prompt_eval_duration": resp_meta.get( 

390 "prompt_eval_duration" 

391 ), 

392 "eval_duration": resp_meta.get("eval_duration"), 

393 } 

394 

395 # Check for context overflow (input only) 

396 prompt_eval_count = resp_meta.get( 

397 "prompt_eval_count", 0 

398 ) 

399 self._check_context_overflow( 

400 prompt_eval_count, 

401 completion_tokens=resp_meta.get( 

402 "eval_count", 0 

403 ), 

404 source="response_metadata", 

405 ) 

406 

407 token_usage = { 

408 "prompt_tokens": resp_meta.get( 

409 "prompt_eval_count", 0 

410 ), 

411 "completion_tokens": resp_meta.get( 

412 "eval_count", 0 

413 ), 

414 "total_tokens": resp_meta.get( 

415 "prompt_eval_count", 0 

416 ) 

417 + resp_meta.get("eval_count", 0), 

418 } 

419 break 

420 if token_usage: 

421 break 

422 

423 # Estimation-based overflow detection for providers that don't echo 

424 # prompt_eval_count (OpenAI, Anthropic, OpenRouter, etc.). The Ollama 

425 # path above sets context_truncated=True if it detected provider- 

426 # confirmed truncation; we only fire here if it didn't. For hosted 

427 # providers the API typically rejects oversize prompts rather than 

428 # silently truncating, so this signal flags "would-overflow per our 

429 # estimate" — not the same as actual truncation, hence the [estimated] 

430 # tag in the log message. 

431 if not self.context_truncated and self.context_limit: 

432 # Input-only overflow: prompt estimate exceeds context limit 

433 if self.original_prompt_estimate > self.context_limit: 

434 self.context_truncated = True 

435 self.tokens_truncated = max( 

436 0, self.original_prompt_estimate - self.context_limit 

437 ) 

438 self.truncation_ratio = ( 

439 self.tokens_truncated / self.original_prompt_estimate 

440 if self.original_prompt_estimate > 0 

441 else 0 

442 ) 

443 logger.warning( 

444 "Context overflow detected [estimated] " 

445 f"research_id={self.research_id} " 

446 f"model={self.current_model} " 

447 f"provider={self.current_provider} " 

448 f"estimated_prompt_tokens={self.original_prompt_estimate} " 

449 f"context_limit={self.context_limit} " 

450 f"tokens_truncated={self.tokens_truncated} " 

451 f"truncation_ratio={self.truncation_ratio:.1%}" 

452 ) 

453 # Total-context overflow: input + output exceeds context limit. 

454 # The "[estimated-total-context]" tag below refers to the 

455 # *detection method* (post-loop fallback path that runs when 

456 # _check_context_overflow couldn't fire), not the data source — 

457 # the prompt_tokens / completion_tokens here come from the 

458 # provider via response.llm_output.token_usage and are actual 

459 # counts, not character-based estimates. 

460 elif token_usage and isinstance(token_usage, dict): 460 ↛ 483line 460 didn't jump to line 483 because the condition on line 460 was always true

461 prompt_tokens = token_usage.get("prompt_tokens", 0) 

462 completion_tokens = token_usage.get("completion_tokens", 0) 

463 total = prompt_tokens + completion_tokens 

464 if total >= self.context_limit * 0.95: 

465 self.context_truncated = True 

466 self.tokens_truncated = max(0, total - self.context_limit) 

467 self.truncation_ratio = ( 

468 self.tokens_truncated / total if total > 0 else 0 

469 ) 

470 logger.warning( 

471 "Context overflow detected [estimated-total-context] " 

472 f"research_id={self.research_id} " 

473 f"model={self.current_model} " 

474 f"provider={self.current_provider} " 

475 f"prompt_tokens={prompt_tokens} " 

476 f"completion_tokens={completion_tokens} " 

477 f"total_tokens={total} " 

478 f"context_limit={self.context_limit} " 

479 f"tokens_truncated={self.tokens_truncated} " 

480 f"truncation_ratio={self.truncation_ratio:.1%}" 

481 ) 

482 

483 if token_usage and isinstance(token_usage, dict): 

484 prompt_tokens = token_usage.get("prompt_tokens", 0) 

485 completion_tokens = token_usage.get("completion_tokens", 0) 

486 total_tokens = token_usage.get( 

487 "total_tokens", prompt_tokens + completion_tokens 

488 ) 

489 

490 # Update in-memory counts 

491 self.counts["total_prompt_tokens"] += prompt_tokens 

492 self.counts["total_completion_tokens"] += completion_tokens 

493 self.counts["total_tokens"] += total_tokens 

494 

495 if self.current_model: 

496 self.counts["by_model"][self.current_model][ 

497 "prompt_tokens" 

498 ] += prompt_tokens 

499 self.counts["by_model"][self.current_model][ 

500 "completion_tokens" 

501 ] += completion_tokens 

502 self.counts["by_model"][self.current_model]["total_tokens"] += ( 

503 total_tokens 

504 ) 

505 

506 # Save to database if we have a research_id 

507 if self.research_id: 

508 self._save_to_db(prompt_tokens, completion_tokens) 

509 

510 def on_llm_error(self, error, **kwargs: Any) -> None: 

511 """Called when LLM encounters an error.""" 

512 # Phase 1 Enhancement: Track errors 

513 if self.start_time: 

514 self.response_time_ms = int((time.time() - self.start_time) * 1000) 

515 

516 self.success_status = "error" 

517 self.error_type = str(type(error).__name__) 

518 

519 # Still save to database to track failed calls 

520 if self.research_id: 

521 self._save_to_db(0, 0) 

522 

523 def _get_context_overflow_fields(self) -> Dict[str, Any]: 

524 """Get context overflow detection fields for database saving.""" 

525 return { 

526 "context_limit": self.context_limit, 

527 "context_truncated": self.context_truncated, # Now Boolean 

528 "tokens_truncated": self.tokens_truncated 

529 if self.context_truncated 

530 else None, 

531 "truncation_ratio": self.truncation_ratio 

532 if self.context_truncated 

533 else None, 

534 # Raw Ollama metrics 

535 "ollama_prompt_eval_count": self.ollama_metrics.get( 

536 "prompt_eval_count" 

537 ), 

538 "ollama_eval_count": self.ollama_metrics.get("eval_count"), 

539 "ollama_total_duration": self.ollama_metrics.get("total_duration"), 

540 "ollama_load_duration": self.ollama_metrics.get("load_duration"), 

541 "ollama_prompt_eval_duration": self.ollama_metrics.get( 

542 "prompt_eval_duration" 

543 ), 

544 "ollama_eval_duration": self.ollama_metrics.get("eval_duration"), 

545 } 

546 

547 def _save_to_db(self, prompt_tokens: int, completion_tokens: int): 

548 """Save token usage to the database.""" 

549 # Check if we're in a thread - if so, queue the save for later 

550 import threading 

551 

552 if threading.current_thread().name != "MainThread": 

553 # Use thread-safe metrics database for background threads 

554 username = ( 

555 self.research_context.get("username") 

556 if self.research_context 

557 else None 

558 ) 

559 

560 if not username: 

561 logger.warning( 

562 f"Cannot save token metrics - no username in research context. " 

563 f"Token usage: prompt={prompt_tokens}, completion={completion_tokens}, " 

564 f"Research context: {self.research_context}" 

565 ) 

566 return 

567 

568 # Import the thread-safe metrics database 

569 

570 # Prepare token data 

571 token_data = { 

572 "model_name": self.current_model, 

573 "provider": self.current_provider, 

574 "prompt_tokens": prompt_tokens, 

575 "completion_tokens": completion_tokens, 

576 "research_query": self.research_context.get("research_query"), 

577 "research_mode": self.research_context.get("research_mode"), 

578 "research_phase": self.research_context.get("research_phase"), 

579 "search_iteration": self.research_context.get( 

580 "search_iteration" 

581 ), 

582 "response_time_ms": self.response_time_ms, 

583 "success_status": self.success_status, 

584 "error_type": self.error_type, 

585 "search_engines_planned": self.research_context.get( 

586 "search_engines_planned" 

587 ), 

588 "search_engine_selected": self.research_context.get( 

589 "search_engine_selected" 

590 ), 

591 "calling_file": self.calling_file, 

592 "calling_function": self.calling_function, 

593 "call_stack": self.call_stack, 

594 # Add context overflow fields using helper method 

595 **self._get_context_overflow_fields(), 

596 } 

597 

598 # Convert list to JSON string if needed 

599 if isinstance(token_data.get("search_engines_planned"), list): 

600 token_data["search_engines_planned"] = json.dumps( 

601 token_data["search_engines_planned"] 

602 ) 

603 

604 # Get password from research context 

605 password = self.research_context.get("user_password") 

606 if not password: 

607 logger.warning( 

608 f"Cannot save token metrics - no password in research context. " 

609 f"Username: {username}, Token usage: prompt={prompt_tokens}, completion={completion_tokens}" 

610 ) 

611 return 

612 

613 # Write metrics directly using thread-safe database 

614 try: 

615 from ..database.thread_metrics import metrics_writer 

616 

617 # Set password for this thread 

618 metrics_writer.set_user_password(username, password) 

619 

620 # Write metrics to encrypted database 

621 metrics_writer.write_token_metrics( 

622 username, self.research_id, token_data 

623 ) 

624 except Exception: 

625 logger.exception("Failed to write metrics from thread") 

626 return 

627 

628 # In MainThread, save directly 

629 try: 

630 from flask import session as flask_session 

631 from ..database.session_context import get_user_db_session 

632 

633 username = flask_session.get("username") 

634 if not username: 

635 logger.debug("No user session, skipping token metrics save") 

636 return 

637 

638 with get_user_db_session(username) as session: 

639 # Phase 1 Enhancement: Prepare additional context 

640 research_query = self.research_context.get("research_query") 

641 research_mode = self.research_context.get("research_mode") 

642 research_phase = self.research_context.get("research_phase") 

643 search_iteration = self.research_context.get("search_iteration") 

644 search_engines_planned = self.research_context.get( 

645 "search_engines_planned" 

646 ) 

647 search_engine_selected = self.research_context.get( 

648 "search_engine_selected" 

649 ) 

650 

651 # Debug logging for search engine context 

652 if search_engines_planned or search_engine_selected: 

653 logger.info( 

654 f"Token tracking - Search context: planned={search_engines_planned}, selected={search_engine_selected}, phase={research_phase}" 

655 ) 

656 else: 

657 logger.debug( 

658 f"Token tracking - No search engine context yet, phase={research_phase}" 

659 ) 

660 

661 # Convert list to JSON string if needed 

662 if isinstance(search_engines_planned, list): 

663 search_engines_planned = json.dumps(search_engines_planned) 

664 

665 # Log context overflow detection values before saving 

666 logger.debug( 

667 f"Saving TokenUsage - context_limit: {self.context_limit}, " 

668 f"context_truncated: {self.context_truncated}, " 

669 f"tokens_truncated: {self.tokens_truncated}, " 

670 f"ollama_prompt_eval_count: {self.ollama_metrics.get('prompt_eval_count')}, " 

671 f"prompt_tokens: {prompt_tokens}, " 

672 f"completion_tokens: {completion_tokens}" 

673 ) 

674 

675 # Add token usage record with enhanced fields 

676 token_usage = TokenUsage( 

677 research_id=self.research_id, 

678 model_name=self.current_model, 

679 model_provider=self.current_provider, # Added provider 

680 # for accurate cost tracking 

681 prompt_tokens=prompt_tokens, 

682 completion_tokens=completion_tokens, 

683 total_tokens=prompt_tokens + completion_tokens, 

684 # Phase 1 Enhancement: Research context 

685 research_query=research_query, 

686 research_mode=research_mode, 

687 research_phase=research_phase, 

688 search_iteration=search_iteration, 

689 # Phase 1 Enhancement: Performance metrics 

690 response_time_ms=self.response_time_ms, 

691 success_status=self.success_status, 

692 error_type=self.error_type, 

693 # Phase 1 Enhancement: Search engine context 

694 search_engines_planned=search_engines_planned, 

695 search_engine_selected=search_engine_selected, 

696 # Phase 1 Enhancement: Call stack tracking 

697 calling_file=self.calling_file, 

698 calling_function=self.calling_function, 

699 call_stack=self.call_stack, 

700 # Add context overflow fields using helper method 

701 **self._get_context_overflow_fields(), 

702 ) 

703 session.add(token_usage) 

704 

705 # Update or create model usage statistics 

706 model_usage = ( 

707 session.query(ModelUsage) 

708 .filter_by( 

709 model_name=self.current_model, 

710 ) 

711 .first() 

712 ) 

713 

714 if model_usage: 

715 model_usage.total_tokens += ( 

716 prompt_tokens + completion_tokens 

717 ) 

718 model_usage.total_calls += 1 

719 else: 

720 model_usage = ModelUsage( 

721 model_name=self.current_model, 

722 model_provider=self.current_provider, 

723 total_tokens=prompt_tokens + completion_tokens, 

724 total_calls=1, 

725 ) 

726 session.add(model_usage) 

727 

728 # Commit the transaction 

729 session.commit() 

730 

731 except Exception: 

732 logger.exception("Error saving token usage to database") 

733 

734 def get_counts(self) -> Dict[str, Any]: 

735 """Get the current token counts.""" 

736 return self.counts 

737 

738 

739class TokenCounter: 

740 """Manager class for token counting across the application.""" 

741 

742 def __init__(self): 

743 """Initialize the token counter.""" 

744 

745 def create_callback( 

746 self, 

747 research_id: Optional[str] = None, 

748 research_context: Optional[Dict[str, Any]] = None, 

749 ) -> TokenCountingCallback: 

750 """Create a new token counting callback. 

751 

752 Args: 

753 research_id: The ID of the research to track tokens for 

754 research_context: Additional research context for enhanced tracking 

755 

756 Returns: 

757 A new TokenCountingCallback instance 

758 """ 

759 return TokenCountingCallback( 

760 research_id=research_id, research_context=research_context 

761 ) 

762 

763 def get_research_metrics(self, research_id: str) -> Dict[str, Any]: 

764 """Get token metrics for a specific research. 

765 

766 Args: 

767 research_id: The ID of the research 

768 

769 Returns: 

770 Dictionary containing token usage metrics 

771 """ 

772 from flask import session as flask_session 

773 

774 from ..database.session_context import get_user_db_session 

775 

776 username = flask_session.get("username") 

777 if not username: 

778 return { 

779 "research_id": research_id, 

780 "total_tokens": 0, 

781 "total_calls": 0, 

782 "model_usage": [], 

783 } 

784 

785 with get_user_db_session(username) as session: 

786 # Get token usage for this research from TokenUsage table 

787 from sqlalchemy import func 

788 

789 token_usages = ( 

790 session.query( 

791 TokenUsage.model_name, 

792 TokenUsage.model_provider, 

793 func.sum(TokenUsage.prompt_tokens).label("prompt_tokens"), 

794 func.sum(TokenUsage.completion_tokens).label( 

795 "completion_tokens" 

796 ), 

797 func.sum(TokenUsage.total_tokens).label("total_tokens"), 

798 func.count().label("calls"), 

799 ) 

800 .filter_by(research_id=research_id) 

801 .group_by(TokenUsage.model_name, TokenUsage.model_provider) 

802 .order_by(func.sum(TokenUsage.total_tokens).desc()) 

803 .all() 

804 ) 

805 

806 model_usage = [] 

807 total_tokens = 0 

808 total_calls = 0 

809 

810 for usage in token_usages: 

811 model_usage.append( 

812 { 

813 "model": usage.model_name, 

814 "provider": usage.model_provider, 

815 "tokens": usage.total_tokens or 0, 

816 "calls": usage.calls or 0, 

817 "prompt_tokens": usage.prompt_tokens or 0, 

818 "completion_tokens": usage.completion_tokens or 0, 

819 } 

820 ) 

821 total_tokens += usage.total_tokens or 0 

822 total_calls += usage.calls or 0 

823 

824 return { 

825 "research_id": research_id, 

826 "total_tokens": total_tokens, 

827 "total_calls": total_calls, 

828 "model_usage": model_usage, 

829 } 

830 

831 def get_overall_metrics( 

832 self, period: str = "30d", research_mode: str = "all" 

833 ) -> Dict[str, Any]: 

834 """Get overall token metrics across all researches. 

835 

836 Args: 

837 period: Time period to filter by ('7d', '30d', '3m', '1y', 'all') 

838 research_mode: Research mode to filter by ('quick', 'detailed', 'all') 

839 

840 Returns: 

841 Dictionary containing overall metrics 

842 """ 

843 return self._get_metrics_from_encrypted_db(period, research_mode) 

844 

845 def _get_metrics_from_encrypted_db( 

846 self, period: str, research_mode: str 

847 ) -> Dict[str, Any]: 

848 """Get metrics from user's encrypted database.""" 

849 from flask import session as flask_session 

850 

851 from ..database.session_context import get_user_db_session 

852 

853 username = flask_session.get("username") 

854 if not username: 

855 return self._get_empty_metrics() 

856 

857 try: 

858 with get_user_db_session(username) as session: 

859 # Build base query with filters 

860 query = session.query(TokenUsage) 

861 

862 # Apply time filter 

863 time_condition = get_time_filter_condition( 

864 period, TokenUsage.timestamp 

865 ) 

866 if time_condition is not None: 

867 query = query.filter(time_condition) 

868 

869 # Apply research mode filter 

870 mode_condition = get_research_mode_condition( 

871 research_mode, TokenUsage.research_mode 

872 ) 

873 if mode_condition is not None: 

874 query = query.filter(mode_condition) 

875 

876 # Total tokens from TokenUsage 

877 total_tokens = ( 

878 query.with_entities( 

879 func.sum(TokenUsage.total_tokens) 

880 ).scalar() 

881 or 0 

882 ) 

883 

884 # Import ResearchHistory model 

885 from ..database.models.research import ResearchHistory 

886 

887 # Count researches from ResearchHistory table 

888 research_query = session.query(func.count(ResearchHistory.id)) 

889 

890 # Debug: Check if any research history records exist at all 

891 all_research_count = ( 

892 session.query(func.count(ResearchHistory.id)).scalar() or 0 

893 ) 

894 logger.debug( 

895 f"Total ResearchHistory records in database: {all_research_count}" 

896 ) 

897 

898 # Debug: List first few research IDs and their timestamps 

899 sample_researches = ( 

900 session.query( 

901 ResearchHistory.id, 

902 ResearchHistory.created_at, 

903 ResearchHistory.mode, 

904 ) 

905 .limit(5) 

906 .all() 

907 ) 

908 if sample_researches: 908 ↛ 909line 908 didn't jump to line 909 because the condition on line 908 was never true

909 logger.debug("Sample ResearchHistory records:") 

910 for r_id, r_created, r_mode in sample_researches: 

911 logger.debug( 

912 f" - ID: {r_id}, Created: {r_created}, Mode: {r_mode}" 

913 ) 

914 else: 

915 logger.debug("No ResearchHistory records found in database") 

916 

917 # Get time filter conditions for ResearchHistory query 

918 start_time, end_time = None, None 

919 if period != "all": 

920 if period == "today": 920 ↛ 921line 920 didn't jump to line 921 because the condition on line 920 was never true

921 start_time = datetime.now(UTC).replace( 

922 hour=0, minute=0, second=0, microsecond=0 

923 ) 

924 elif period == "week": 924 ↛ 925line 924 didn't jump to line 925 because the condition on line 924 was never true

925 start_time = datetime.now(UTC) - timedelta(days=7) 

926 elif period == "month": 926 ↛ 927line 926 didn't jump to line 927 because the condition on line 926 was never true

927 start_time = datetime.now(UTC) - timedelta(days=30) 

928 

929 if start_time: 929 ↛ 930line 929 didn't jump to line 930 because the condition on line 929 was never true

930 end_time = datetime.now(UTC) 

931 

932 # Apply time filter if specified 

933 if start_time and end_time: 933 ↛ 934line 933 didn't jump to line 934 because the condition on line 933 was never true

934 research_query = research_query.filter( 

935 ResearchHistory.created_at >= start_time.isoformat(), 

936 ResearchHistory.created_at <= end_time.isoformat(), 

937 ) 

938 

939 # Apply mode filter if specified 

940 mode_filter = research_mode if research_mode != "all" else None 

941 if mode_filter: 

942 logger.debug(f"Applying mode filter: {mode_filter}") 

943 research_query = research_query.filter( 

944 ResearchHistory.mode == mode_filter 

945 ) 

946 

947 total_researches = research_query.scalar() or 0 

948 logger.debug( 

949 f"Final filtered research count: {total_researches}" 

950 ) 

951 

952 # Also check distinct research_ids in TokenUsage for comparison 

953 token_research_count = ( 

954 session.query( 

955 func.count(func.distinct(TokenUsage.research_id)) 

956 ).scalar() 

957 or 0 

958 ) 

959 logger.debug( 

960 f"Distinct research_ids in TokenUsage: {token_research_count}" 

961 ) 

962 

963 # Model statistics using ORM aggregation 

964 model_stats_query = session.query( 

965 TokenUsage.model_name, 

966 func.sum(TokenUsage.total_tokens).label("tokens"), 

967 func.count().label("calls"), 

968 func.sum(TokenUsage.prompt_tokens).label("prompt_tokens"), 

969 func.sum(TokenUsage.completion_tokens).label( 

970 "completion_tokens" 

971 ), 

972 ).filter(TokenUsage.model_name.isnot(None)) 

973 

974 # Apply same filters to model stats 

975 if time_condition is not None: 

976 model_stats_query = model_stats_query.filter(time_condition) 

977 if mode_condition is not None: 

978 model_stats_query = model_stats_query.filter(mode_condition) 

979 

980 model_stats = ( 

981 model_stats_query.group_by(TokenUsage.model_name) 

982 .order_by(func.sum(TokenUsage.total_tokens).desc()) 

983 .all() 

984 ) 

985 

986 # Batch load provider info from ModelUsage table (fix N+1) 

987 model_names = [stat.model_name for stat in model_stats] 

988 provider_map = {} 

989 if model_names: 989 ↛ 990line 989 didn't jump to line 990 because the condition on line 989 was never true

990 provider_results = ( 

991 session.query( 

992 ModelUsage.model_name, ModelUsage.model_provider 

993 ) 

994 .filter(ModelUsage.model_name.in_(model_names)) 

995 .order_by(ModelUsage.id) 

996 .all() 

997 ) 

998 for model_name, model_provider in provider_results: 

999 provider_map.setdefault(model_name, model_provider) 

1000 

1001 by_model = [] 

1002 for stat in model_stats: 1002 ↛ 1003line 1002 didn't jump to line 1003 because the loop on line 1002 never started

1003 provider = provider_map.get(stat.model_name, "unknown") 

1004 

1005 by_model.append( 

1006 { 

1007 "model": stat.model_name, 

1008 "provider": provider, 

1009 "tokens": stat.tokens, 

1010 "calls": stat.calls, 

1011 "prompt_tokens": stat.prompt_tokens, 

1012 "completion_tokens": stat.completion_tokens, 

1013 } 

1014 ) 

1015 

1016 # Get recent researches with token usage 

1017 # Note: This requires research_history table - for now we'll use available data 

1018 recent_research_query = session.query( 

1019 TokenUsage.research_id, 

1020 func.sum(TokenUsage.total_tokens).label("token_count"), 

1021 func.max(TokenUsage.timestamp).label("latest_timestamp"), 

1022 ).filter(TokenUsage.research_id.isnot(None)) 

1023 

1024 if time_condition is not None: 

1025 recent_research_query = recent_research_query.filter( 

1026 time_condition 

1027 ) 

1028 if mode_condition is not None: 

1029 recent_research_query = recent_research_query.filter( 

1030 mode_condition 

1031 ) 

1032 

1033 recent_research_data = ( 

1034 recent_research_query.group_by(TokenUsage.research_id) 

1035 .order_by(func.max(TokenUsage.timestamp).desc()) 

1036 .limit(10) 

1037 .all() 

1038 ) 

1039 

1040 # Batch load research queries for recent researches (fix N+1) 

1041 recent_research_ids = [ 

1042 r.research_id for r in recent_research_data 

1043 ] 

1044 research_query_map = {} 

1045 if recent_research_ids: 1045 ↛ 1047line 1045 didn't jump to line 1047 because the condition on line 1045 was never true

1046 # Get first non-null research_query for each research_id 

1047 query_results = ( 

1048 session.query( 

1049 TokenUsage.research_id, TokenUsage.research_query 

1050 ) 

1051 .filter( 

1052 TokenUsage.research_id.in_(recent_research_ids), 

1053 TokenUsage.research_query.isnot(None), 

1054 ) 

1055 .order_by(TokenUsage.id) 

1056 .all() 

1057 ) 

1058 for research_id, research_query in query_results: 

1059 if research_id not in research_query_map: 

1060 research_query_map[research_id] = research_query 

1061 

1062 recent_researches = [] 

1063 for research_data in recent_research_data: 1063 ↛ 1064line 1063 didn't jump to line 1064 because the loop on line 1063 never started

1064 query_text = research_query_map.get( 

1065 research_data.research_id, 

1066 f"Research {research_data.research_id}", 

1067 ) 

1068 

1069 recent_researches.append( 

1070 { 

1071 "id": research_data.research_id, 

1072 "query": query_text, 

1073 "tokens": research_data.token_count or 0, 

1074 "created_at": research_data.latest_timestamp, 

1075 } 

1076 ) 

1077 

1078 # Token breakdown statistics 

1079 breakdown_query = query.with_entities( 

1080 func.sum(TokenUsage.prompt_tokens).label( 

1081 "total_input_tokens" 

1082 ), 

1083 func.sum(TokenUsage.completion_tokens).label( 

1084 "total_output_tokens" 

1085 ), 

1086 func.avg(TokenUsage.prompt_tokens).label( 

1087 "avg_input_tokens" 

1088 ), 

1089 func.avg(TokenUsage.completion_tokens).label( 

1090 "avg_output_tokens" 

1091 ), 

1092 func.avg(TokenUsage.total_tokens).label("avg_total_tokens"), 

1093 ) 

1094 token_breakdown = breakdown_query.first() 

1095 

1096 # Get rate limiting metrics 

1097 from ..database.models import ( 

1098 RateLimitAttempt, 

1099 RateLimitEstimate, 

1100 ) 

1101 

1102 # Get rate limit attempts 

1103 rate_limit_query = session.query(RateLimitAttempt) 

1104 

1105 # Apply time filter 

1106 if time_condition is not None: 

1107 # RateLimitAttempt uses timestamp as float, not datetime 

1108 if period == "7d": 

1109 cutoff_time = time.time() - (7 * 24 * 3600) 

1110 elif period == "30d": 1110 ↛ 1112line 1110 didn't jump to line 1112 because the condition on line 1110 was always true

1111 cutoff_time = time.time() - (30 * 24 * 3600) 

1112 elif period == "3m": 

1113 cutoff_time = time.time() - (90 * 24 * 3600) 

1114 elif period == "1y": 

1115 cutoff_time = time.time() - (365 * 24 * 3600) 

1116 else: # all 

1117 cutoff_time = 0 

1118 

1119 if cutoff_time > 0: 1119 ↛ 1125line 1119 didn't jump to line 1125 because the condition on line 1119 was always true

1120 rate_limit_query = rate_limit_query.filter( 

1121 RateLimitAttempt.timestamp >= cutoff_time 

1122 ) 

1123 

1124 # Get rate limit statistics 

1125 total_attempts = rate_limit_query.count() 

1126 successful_attempts = rate_limit_query.filter( 

1127 RateLimitAttempt.success 

1128 ).count() 

1129 failed_attempts = total_attempts - successful_attempts 

1130 

1131 # Count rate limiting events (failures with RateLimitError) 

1132 rate_limit_events = rate_limit_query.filter( 

1133 ~RateLimitAttempt.success, 

1134 RateLimitAttempt.error_type == "RateLimitError", 

1135 ).count() 

1136 

1137 logger.debug( 

1138 f"Rate limit attempts in database: total={total_attempts}, successful={successful_attempts}" 

1139 ) 

1140 

1141 # Get all attempts for detailed calculations 

1142 attempts = rate_limit_query.all() 

1143 

1144 # Calculate average wait times 

1145 if attempts: 1145 ↛ 1146line 1145 didn't jump to line 1146 because the condition on line 1145 was never true

1146 avg_wait_time = sum(a.wait_time for a in attempts) / len( 

1147 attempts 

1148 ) 

1149 successful_wait_times = [ 

1150 a.wait_time for a in attempts if a.success 

1151 ] 

1152 avg_successful_wait = ( 

1153 sum(successful_wait_times) / len(successful_wait_times) 

1154 if successful_wait_times 

1155 else 0 

1156 ) 

1157 else: 

1158 avg_wait_time = 0 

1159 avg_successful_wait = 0 

1160 

1161 # Get tracked engines - count distinct engine types from attempts 

1162 tracked_engines_query = session.query( 

1163 func.count(func.distinct(RateLimitAttempt.engine_type)) 

1164 ) 

1165 if cutoff_time > 0: 1165 ↛ 1169line 1165 didn't jump to line 1169 because the condition on line 1165 was always true

1166 tracked_engines_query = tracked_engines_query.filter( 

1167 RateLimitAttempt.timestamp >= cutoff_time 

1168 ) 

1169 tracked_engines = tracked_engines_query.scalar() or 0 

1170 

1171 # Get engine-specific stats from attempts 

1172 engine_stats = [] 

1173 

1174 # Get distinct engine types from attempts 

1175 engine_types_query = session.query( 

1176 RateLimitAttempt.engine_type 

1177 ).distinct() 

1178 if cutoff_time > 0: 1178 ↛ 1182line 1178 didn't jump to line 1182 because the condition on line 1178 was always true

1179 engine_types_query = engine_types_query.filter( 

1180 RateLimitAttempt.timestamp >= cutoff_time 

1181 ) 

1182 engine_types = [ 

1183 row.engine_type for row in engine_types_query.all() 

1184 ] 

1185 

1186 # Batch-load all estimates (fix N+1 query) 

1187 estimates_map = {} 

1188 if engine_types: 1188 ↛ 1189line 1188 didn't jump to line 1189 because the condition on line 1188 was never true

1189 all_estimates = ( 

1190 session.query(RateLimitEstimate) 

1191 .filter(RateLimitEstimate.engine_type.in_(engine_types)) 

1192 .all() 

1193 ) 

1194 estimates_map = {e.engine_type: e for e in all_estimates} 

1195 

1196 for engine_type in engine_types: 1196 ↛ 1197line 1196 didn't jump to line 1197 because the loop on line 1196 never started

1197 engine_attempts_list = [ 

1198 a for a in attempts if a.engine_type == engine_type 

1199 ] 

1200 engine_attempts = len(engine_attempts_list) 

1201 engine_success = len( 

1202 [a for a in engine_attempts_list if a.success] 

1203 ) 

1204 

1205 # Get estimate if exists 

1206 estimate = estimates_map.get(engine_type) 

1207 

1208 # Calculate recent success rate 

1209 recent_success_rate = ( 

1210 (engine_success / engine_attempts * 100) 

1211 if engine_attempts > 0 

1212 else 0 

1213 ) 

1214 

1215 # Determine status based on success rate 

1216 if estimate: 

1217 status = ( 

1218 "healthy" 

1219 if estimate.success_rate > 0.8 

1220 else "degraded" 

1221 if estimate.success_rate > 0.5 

1222 else "poor" 

1223 ) 

1224 else: 

1225 status = ( 

1226 "healthy" 

1227 if recent_success_rate > 80 

1228 else "degraded" 

1229 if recent_success_rate > 50 

1230 else "poor" 

1231 ) 

1232 

1233 engine_stat = { 

1234 "engine": engine_type, 

1235 "base_wait": estimate.base_wait_seconds 

1236 if estimate 

1237 else 0.0, 

1238 "base_wait_seconds": round( 

1239 estimate.base_wait_seconds if estimate else 0.0, 2 

1240 ), 

1241 "min_wait_seconds": round( 

1242 estimate.min_wait_seconds if estimate else 0.0, 2 

1243 ), 

1244 "max_wait_seconds": round( 

1245 estimate.max_wait_seconds if estimate else 0.0, 2 

1246 ), 

1247 "success_rate": round(estimate.success_rate * 100, 1) 

1248 if estimate 

1249 else recent_success_rate, 

1250 "total_attempts": estimate.total_attempts 

1251 if estimate 

1252 else engine_attempts, 

1253 "recent_attempts": engine_attempts, 

1254 "recent_success_rate": round(recent_success_rate, 1), 

1255 "attempts": engine_attempts, 

1256 "status": status, 

1257 } 

1258 

1259 if estimate: 

1260 engine_stat["last_updated"] = datetime.fromtimestamp( 

1261 estimate.last_updated 

1262 ).strftime("%Y-%m-%d %H:%M:%S") 

1263 else: 

1264 engine_stat["last_updated"] = "Never" 

1265 

1266 engine_stats.append(engine_stat) 

1267 

1268 logger.debug( 

1269 f"Tracked engines: {tracked_engines}, engine_stats: {engine_stats}" 

1270 ) 

1271 

1272 result = { 

1273 "total_tokens": total_tokens, 

1274 "total_researches": total_researches, 

1275 "by_model": by_model, 

1276 "recent_researches": recent_researches, 

1277 "token_breakdown": { 

1278 "total_input_tokens": int( 

1279 token_breakdown.total_input_tokens or 0 

1280 ), 

1281 "total_output_tokens": int( 

1282 token_breakdown.total_output_tokens or 0 

1283 ), 

1284 "avg_input_tokens": int( 

1285 token_breakdown.avg_input_tokens or 0 

1286 ), 

1287 "avg_output_tokens": int( 

1288 token_breakdown.avg_output_tokens or 0 

1289 ), 

1290 "avg_total_tokens": int( 

1291 token_breakdown.avg_total_tokens or 0 

1292 ), 

1293 }, 

1294 "rate_limiting": { 

1295 "total_attempts": total_attempts, 

1296 "successful_attempts": successful_attempts, 

1297 "failed_attempts": failed_attempts, 

1298 "success_rate": ( 

1299 successful_attempts / total_attempts * 100 

1300 ) 

1301 if total_attempts > 0 

1302 else 0, 

1303 "rate_limit_events": rate_limit_events, 

1304 "avg_wait_time": round(float(avg_wait_time), 2), 

1305 "avg_successful_wait": round( 

1306 float(avg_successful_wait), 2 

1307 ), 

1308 "tracked_engines": tracked_engines, 

1309 "engine_stats": engine_stats, 

1310 "total_engines_tracked": tracked_engines, 

1311 "healthy_engines": len( 

1312 [ 

1313 s 

1314 for s in engine_stats 

1315 if s["status"] == "healthy" 

1316 ] 

1317 ), 

1318 "degraded_engines": len( 

1319 [ 

1320 s 

1321 for s in engine_stats 

1322 if s["status"] == "degraded" 

1323 ] 

1324 ), 

1325 "poor_engines": len( 

1326 [s for s in engine_stats if s["status"] == "poor"] 

1327 ), 

1328 }, 

1329 } 

1330 

1331 logger.debug( 

1332 f"Returning from _get_metrics_from_encrypted_db - total_researches: {result['total_researches']}" 

1333 ) 

1334 return result 

1335 except Exception: 

1336 logger.exception( 

1337 "CRITICAL ERROR accessing encrypted database for metrics" 

1338 ) 

1339 return self._get_empty_metrics() 

1340 

1341 def _get_empty_metrics(self) -> Dict[str, Any]: 

1342 """Return empty metrics structure when no data is available.""" 

1343 return { 

1344 "total_tokens": 0, 

1345 "total_researches": 0, 

1346 "by_model": [], 

1347 "recent_researches": [], 

1348 "token_breakdown": { 

1349 "prompt_tokens": 0, 

1350 "completion_tokens": 0, 

1351 "avg_prompt_tokens": 0, 

1352 "avg_completion_tokens": 0, 

1353 "avg_total_tokens": 0, 

1354 }, 

1355 } 

1356 

1357 def get_enhanced_metrics( 

1358 self, period: str = "30d", research_mode: str = "all" 

1359 ) -> Dict[str, Any]: 

1360 """Get enhanced Phase 1 tracking metrics. 

1361 

1362 Args: 

1363 period: Time period to filter by ('7d', '30d', '3m', '1y', 'all') 

1364 research_mode: Research mode to filter by ('quick', 'detailed', 'all') 

1365 

1366 Returns: 

1367 Dictionary containing enhanced metrics data including time series 

1368 """ 

1369 from flask import session as flask_session 

1370 

1371 from ..database.session_context import get_user_db_session 

1372 

1373 username = flask_session.get("username") 

1374 if not username: 

1375 # Return empty metrics structure when no user session 

1376 return { 

1377 "recent_enhanced_data": [], 

1378 "performance_stats": { 

1379 "avg_response_time": 0, 

1380 "min_response_time": 0, 

1381 "max_response_time": 0, 

1382 "success_rate": 0, 

1383 "error_rate": 0, 

1384 "total_enhanced_calls": 0, 

1385 }, 

1386 "mode_breakdown": [], 

1387 "search_engine_stats": [], 

1388 "phase_breakdown": [], 

1389 "time_series_data": [], 

1390 "call_stack_analysis": { 

1391 "by_file": [], 

1392 "by_function": [], 

1393 }, 

1394 } 

1395 

1396 try: 

1397 with get_user_db_session(username) as session: 

1398 # Build base query with filters 

1399 query = session.query(TokenUsage) 

1400 

1401 # Apply time filter 

1402 time_condition = get_time_filter_condition( 

1403 period, TokenUsage.timestamp 

1404 ) 

1405 if time_condition is not None: 1405 ↛ 1409line 1405 didn't jump to line 1409 because the condition on line 1405 was always true

1406 query = query.filter(time_condition) 

1407 

1408 # Apply research mode filter 

1409 mode_condition = get_research_mode_condition( 

1410 research_mode, TokenUsage.research_mode 

1411 ) 

1412 if mode_condition is not None: 1412 ↛ 1413line 1412 didn't jump to line 1413 because the condition on line 1412 was never true

1413 query = query.filter(mode_condition) 

1414 

1415 # Get time series data for the chart - most important for "Token Consumption Over Time" 

1416 time_series_query = query.filter( 

1417 TokenUsage.timestamp.isnot(None), 

1418 TokenUsage.total_tokens > 0, 

1419 ).order_by(TokenUsage.timestamp.asc()) 

1420 

1421 # Limit to recent data for performance 

1422 if period != "all": 1422 ↛ 1425line 1422 didn't jump to line 1425 because the condition on line 1422 was always true

1423 time_series_query = time_series_query.limit(200) 

1424 

1425 time_series_data = time_series_query.all() 

1426 

1427 # Format time series data with cumulative calculations 

1428 time_series = [] 

1429 cumulative_tokens = 0 

1430 cumulative_prompt_tokens = 0 

1431 cumulative_completion_tokens = 0 

1432 

1433 for usage in time_series_data: 1433 ↛ 1434line 1433 didn't jump to line 1434 because the loop on line 1433 never started

1434 cumulative_tokens += usage.total_tokens or 0 

1435 cumulative_prompt_tokens += usage.prompt_tokens or 0 

1436 cumulative_completion_tokens += usage.completion_tokens or 0 

1437 

1438 time_series.append( 

1439 { 

1440 "timestamp": str(usage.timestamp) 

1441 if usage.timestamp 

1442 else None, 

1443 "tokens": usage.total_tokens or 0, 

1444 "prompt_tokens": usage.prompt_tokens or 0, 

1445 "completion_tokens": usage.completion_tokens or 0, 

1446 "cumulative_tokens": cumulative_tokens, 

1447 "cumulative_prompt_tokens": cumulative_prompt_tokens, 

1448 "cumulative_completion_tokens": cumulative_completion_tokens, 

1449 "research_id": usage.research_id, 

1450 } 

1451 ) 

1452 

1453 # Basic performance stats using ORM 

1454 performance_query = query.filter( 

1455 TokenUsage.response_time_ms.isnot(None) 

1456 ) 

1457 total_calls = performance_query.count() 

1458 

1459 if total_calls > 0: 

1460 avg_response_time = ( 

1461 performance_query.with_entities( 

1462 func.avg(TokenUsage.response_time_ms) 

1463 ).scalar() 

1464 or 0 

1465 ) 

1466 min_response_time = ( 

1467 performance_query.with_entities( 

1468 func.min(TokenUsage.response_time_ms) 

1469 ).scalar() 

1470 or 0 

1471 ) 

1472 max_response_time = ( 

1473 performance_query.with_entities( 

1474 func.max(TokenUsage.response_time_ms) 

1475 ).scalar() 

1476 or 0 

1477 ) 

1478 success_count = performance_query.filter( 

1479 TokenUsage.success_status == "success" 

1480 ).count() 

1481 error_count = performance_query.filter( 

1482 TokenUsage.success_status == "error" 

1483 ).count() 

1484 

1485 perf_stats = { 

1486 "avg_response_time": round(avg_response_time), 

1487 "min_response_time": min_response_time, 

1488 "max_response_time": max_response_time, 

1489 "success_rate": ( 

1490 round((success_count / total_calls * 100), 1) 

1491 if total_calls > 0 

1492 else 0 

1493 ), 

1494 "error_rate": ( 

1495 round((error_count / total_calls * 100), 1) 

1496 if total_calls > 0 

1497 else 0 

1498 ), 

1499 "total_enhanced_calls": total_calls, 

1500 } 

1501 else: 

1502 perf_stats = { 

1503 "avg_response_time": 0, 

1504 "min_response_time": 0, 

1505 "max_response_time": 0, 

1506 "success_rate": 0, 

1507 "error_rate": 0, 

1508 "total_enhanced_calls": 0, 

1509 } 

1510 

1511 # Research mode breakdown using ORM 

1512 mode_stats = ( 

1513 query.filter(TokenUsage.research_mode.isnot(None)) 

1514 .with_entities( 

1515 TokenUsage.research_mode, 

1516 func.count().label("count"), 

1517 func.avg(TokenUsage.total_tokens).label("avg_tokens"), 

1518 func.avg(TokenUsage.response_time_ms).label( 

1519 "avg_response_time" 

1520 ), 

1521 ) 

1522 .group_by(TokenUsage.research_mode) 

1523 .all() 

1524 ) 

1525 

1526 modes = [ 

1527 { 

1528 "mode": stat.research_mode, 

1529 "count": stat.count, 

1530 "avg_tokens": round(stat.avg_tokens or 0), 

1531 "avg_response_time": round(stat.avg_response_time or 0), 

1532 } 

1533 for stat in mode_stats 

1534 ] 

1535 

1536 # Recent enhanced data (simplified) 

1537 recent_enhanced_query = ( 

1538 query.filter(TokenUsage.research_query.isnot(None)) 

1539 .order_by(TokenUsage.timestamp.desc()) 

1540 .limit(50) 

1541 ) 

1542 

1543 recent_enhanced_data = recent_enhanced_query.all() 

1544 recent_enhanced = [ 

1545 { 

1546 "research_query": usage.research_query, 

1547 "research_mode": usage.research_mode, 

1548 "research_phase": usage.research_phase, 

1549 "search_iteration": usage.search_iteration, 

1550 "response_time_ms": usage.response_time_ms, 

1551 "success_status": usage.success_status, 

1552 "error_type": usage.error_type, 

1553 "search_engines_planned": usage.search_engines_planned, 

1554 "search_engine_selected": usage.search_engine_selected, 

1555 "total_tokens": usage.total_tokens, 

1556 "prompt_tokens": usage.prompt_tokens, 

1557 "completion_tokens": usage.completion_tokens, 

1558 "timestamp": str(usage.timestamp) 

1559 if usage.timestamp 

1560 else None, 

1561 "research_id": usage.research_id, 

1562 "calling_file": usage.calling_file, 

1563 "calling_function": usage.calling_function, 

1564 "call_stack": usage.call_stack, 

1565 } 

1566 for usage in recent_enhanced_data 

1567 ] 

1568 

1569 # Search engine breakdown using ORM 

1570 search_engine_stats = ( 

1571 query.filter(TokenUsage.search_engine_selected.isnot(None)) 

1572 .with_entities( 

1573 TokenUsage.search_engine_selected, 

1574 func.count().label("count"), 

1575 func.avg(TokenUsage.total_tokens).label("avg_tokens"), 

1576 func.avg(TokenUsage.response_time_ms).label( 

1577 "avg_response_time" 

1578 ), 

1579 ) 

1580 .group_by(TokenUsage.search_engine_selected) 

1581 .all() 

1582 ) 

1583 

1584 search_engines = [ 

1585 { 

1586 "search_engine": stat.search_engine_selected, 

1587 "count": stat.count, 

1588 "avg_tokens": round(stat.avg_tokens or 0), 

1589 "avg_response_time": round(stat.avg_response_time or 0), 

1590 } 

1591 for stat in search_engine_stats 

1592 ] 

1593 

1594 # Research phase breakdown using ORM 

1595 phase_stats = ( 

1596 query.filter(TokenUsage.research_phase.isnot(None)) 

1597 .with_entities( 

1598 TokenUsage.research_phase, 

1599 func.count().label("count"), 

1600 func.avg(TokenUsage.total_tokens).label("avg_tokens"), 

1601 func.avg(TokenUsage.response_time_ms).label( 

1602 "avg_response_time" 

1603 ), 

1604 ) 

1605 .group_by(TokenUsage.research_phase) 

1606 .all() 

1607 ) 

1608 

1609 phases = [ 

1610 { 

1611 "phase": stat.research_phase, 

1612 "count": stat.count, 

1613 "avg_tokens": round(stat.avg_tokens or 0), 

1614 "avg_response_time": round(stat.avg_response_time or 0), 

1615 } 

1616 for stat in phase_stats 

1617 ] 

1618 

1619 # Call stack analysis using ORM 

1620 file_stats = ( 

1621 query.filter(TokenUsage.calling_file.isnot(None)) 

1622 .with_entities( 

1623 TokenUsage.calling_file, 

1624 func.count().label("count"), 

1625 func.avg(TokenUsage.total_tokens).label("avg_tokens"), 

1626 ) 

1627 .group_by(TokenUsage.calling_file) 

1628 .order_by(func.count().desc()) 

1629 .limit(10) 

1630 .all() 

1631 ) 

1632 

1633 files = [ 

1634 { 

1635 "file": stat.calling_file, 

1636 "count": stat.count, 

1637 "avg_tokens": round(stat.avg_tokens or 0), 

1638 } 

1639 for stat in file_stats 

1640 ] 

1641 

1642 function_stats = ( 

1643 query.filter(TokenUsage.calling_function.isnot(None)) 

1644 .with_entities( 

1645 TokenUsage.calling_function, 

1646 func.count().label("count"), 

1647 func.avg(TokenUsage.total_tokens).label("avg_tokens"), 

1648 ) 

1649 .group_by(TokenUsage.calling_function) 

1650 .order_by(func.count().desc()) 

1651 .limit(10) 

1652 .all() 

1653 ) 

1654 

1655 functions = [ 

1656 { 

1657 "function": stat.calling_function, 

1658 "count": stat.count, 

1659 "avg_tokens": round(stat.avg_tokens or 0), 

1660 } 

1661 for stat in function_stats 

1662 ] 

1663 

1664 return { 

1665 "recent_enhanced_data": recent_enhanced, 

1666 "performance_stats": perf_stats, 

1667 "mode_breakdown": modes, 

1668 "search_engine_stats": search_engines, 

1669 "phase_breakdown": phases, 

1670 "time_series_data": time_series, 

1671 "call_stack_analysis": { 

1672 "by_file": files, 

1673 "by_function": functions, 

1674 }, 

1675 } 

1676 except Exception: 

1677 logger.exception("Error in get_enhanced_metrics") 

1678 # Return simplified response without non-existent columns 

1679 return { 

1680 "recent_enhanced_data": [], 

1681 "performance_stats": { 

1682 "avg_response_time": 0, 

1683 "min_response_time": 0, 

1684 "max_response_time": 0, 

1685 "success_rate": 0, 

1686 "error_rate": 0, 

1687 "total_enhanced_calls": 0, 

1688 }, 

1689 "mode_breakdown": [], 

1690 "search_engine_stats": [], 

1691 "phase_breakdown": [], 

1692 "time_series_data": [], 

1693 "call_stack_analysis": { 

1694 "by_file": [], 

1695 "by_function": [], 

1696 }, 

1697 } 

1698 

1699 def get_research_timeline_metrics(self, research_id: str) -> Dict[str, Any]: 

1700 """Get timeline metrics for a specific research. 

1701 

1702 Args: 

1703 research_id: The ID of the research 

1704 

1705 Returns: 

1706 Dictionary containing timeline metrics for the research 

1707 """ 

1708 from flask import session as flask_session 

1709 

1710 from ..database.session_context import get_user_db_session 

1711 

1712 username = flask_session.get("username") 

1713 if not username: 

1714 return { 

1715 "research_id": research_id, 

1716 "research_details": {}, 

1717 "timeline": [], 

1718 "summary": { 

1719 "total_calls": 0, 

1720 "total_tokens": 0, 

1721 "total_prompt_tokens": 0, 

1722 "total_completion_tokens": 0, 

1723 "avg_response_time": 0, 

1724 "success_rate": 0, 

1725 }, 

1726 "phase_stats": {}, 

1727 } 

1728 

1729 with get_user_db_session(username) as session: 

1730 # Get all token usage for this research ordered by time including call stack 

1731 timeline_data = session.execute( 

1732 text( 

1733 """ 

1734 SELECT 

1735 timestamp, 

1736 total_tokens, 

1737 prompt_tokens, 

1738 completion_tokens, 

1739 response_time_ms, 

1740 success_status, 

1741 error_type, 

1742 research_phase, 

1743 search_iteration, 

1744 search_engine_selected, 

1745 model_name, 

1746 calling_file, 

1747 calling_function, 

1748 call_stack 

1749 FROM token_usage 

1750 WHERE research_id = :research_id 

1751 ORDER BY timestamp ASC 

1752 """ 

1753 ), 

1754 {"research_id": research_id}, 

1755 ).fetchall() 

1756 

1757 # Format timeline data with cumulative tokens 

1758 timeline = [] 

1759 cumulative_tokens = 0 

1760 cumulative_prompt_tokens = 0 

1761 cumulative_completion_tokens = 0 

1762 

1763 for row in timeline_data: 

1764 cumulative_tokens += row[1] or 0 

1765 cumulative_prompt_tokens += row[2] or 0 

1766 cumulative_completion_tokens += row[3] or 0 

1767 

1768 timeline.append( 

1769 { 

1770 "timestamp": str(row[0]) if row[0] else None, 

1771 "tokens": row[1] or 0, 

1772 "prompt_tokens": row[2] or 0, 

1773 "completion_tokens": row[3] or 0, 

1774 "cumulative_tokens": cumulative_tokens, 

1775 "cumulative_prompt_tokens": cumulative_prompt_tokens, 

1776 "cumulative_completion_tokens": cumulative_completion_tokens, 

1777 "response_time_ms": row[4], 

1778 "success_status": row[5], 

1779 "error_type": row[6], 

1780 "research_phase": row[7], 

1781 "search_iteration": row[8], 

1782 "search_engine_selected": row[9], 

1783 "model_name": row[10], 

1784 "calling_file": row[11], 

1785 "calling_function": row[12], 

1786 "call_stack": row[13], 

1787 } 

1788 ) 

1789 

1790 # Get research basic info 

1791 research_info = session.execute( 

1792 text( 

1793 """ 

1794 SELECT query, mode, status, created_at, completed_at 

1795 FROM research_history 

1796 WHERE id = :research_id 

1797 """ 

1798 ), 

1799 {"research_id": research_id}, 

1800 ).fetchone() 

1801 

1802 research_details = {} 

1803 if research_info: 

1804 research_details = { 

1805 "query": research_info[0], 

1806 "mode": research_info[1], 

1807 "status": research_info[2], 

1808 "created_at": str(research_info[3]) 

1809 if research_info[3] 

1810 else None, 

1811 "completed_at": str(research_info[4]) 

1812 if research_info[4] 

1813 else None, 

1814 } 

1815 

1816 # Calculate summary stats 

1817 total_calls = len(timeline_data) 

1818 total_tokens = cumulative_tokens 

1819 avg_response_time = sum(row[4] or 0 for row in timeline_data) / max( 

1820 total_calls, 1 

1821 ) 

1822 success_rate = ( 

1823 sum(1 for row in timeline_data if row[5] == "success") 

1824 / max(total_calls, 1) 

1825 * 100 

1826 ) 

1827 

1828 # Phase breakdown for this research 

1829 phase_stats = {} 

1830 for row in timeline_data: 

1831 phase = row[7] or "unknown" 

1832 if phase not in phase_stats: 

1833 phase_stats[phase] = { 

1834 "count": 0, 

1835 "tokens": 0, 

1836 "avg_response_time": 0, 

1837 } 

1838 phase_stats[phase]["count"] += 1 

1839 phase_stats[phase]["tokens"] += row[1] or 0 

1840 if row[4]: 1840 ↛ 1830line 1840 didn't jump to line 1830 because the condition on line 1840 was always true

1841 phase_stats[phase]["avg_response_time"] += row[4] 

1842 

1843 # Calculate averages for phases 

1844 for phase in phase_stats: 

1845 if phase_stats[phase]["count"] > 0: 1845 ↛ 1844line 1845 didn't jump to line 1844 because the condition on line 1845 was always true

1846 phase_stats[phase]["avg_response_time"] = round( 

1847 phase_stats[phase]["avg_response_time"] 

1848 / phase_stats[phase]["count"] 

1849 ) 

1850 

1851 return { 

1852 "research_id": research_id, 

1853 "research_details": research_details, 

1854 "timeline": timeline, 

1855 "summary": { 

1856 "total_calls": total_calls, 

1857 "total_tokens": total_tokens, 

1858 "total_prompt_tokens": cumulative_prompt_tokens, 

1859 "total_completion_tokens": cumulative_completion_tokens, 

1860 "avg_response_time": round(avg_response_time), 

1861 "success_rate": round(success_rate, 1), 

1862 }, 

1863 "phase_stats": phase_stats, 

1864 }