Coverage for src / local_deep_research / metrics / token_counter.py: 72%

506 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-25 01:07 +0000

1"""Token counting functionality for LLM usage tracking.""" 

2 

3import inspect 

4import json 

5import time 

6from datetime import datetime, timedelta, UTC 

7from pathlib import Path 

8from typing import Any, Dict, List, Optional 

9 

10from langchain_core.callbacks import BaseCallbackHandler 

11from langchain_core.outputs import LLMResult 

12from loguru import logger 

13from sqlalchemy import func, text 

14 

15from ..database.models import ModelUsage, TokenUsage 

16from .query_utils import get_research_mode_condition, get_time_filter_condition 

17 

18 

19class TokenCountingCallback(BaseCallbackHandler): 

20 """Callback handler for counting tokens across different models.""" 

21 

22 def __init__( 

23 self, 

24 research_id: Optional[str] = None, 

25 research_context: Optional[Dict[str, Any]] = None, 

26 ): 

27 """Initialize the token counting callback. 

28 

29 Args: 

30 research_id: The ID of the research to track tokens for 

31 research_context: Additional research context for enhanced tracking 

32 """ 

33 super().__init__() 

34 self.research_id = research_id 

35 self.research_context = research_context or {} 

36 self.current_model = None 

37 self.current_provider = None 

38 self.preset_model = None # Model name set during callback creation 

39 self.preset_provider = None # Provider set during callback creation 

40 

41 # Phase 1 Enhancement: Track timing and context 

42 self.start_time = None 

43 self.response_time_ms = None 

44 self.success_status = "success" 

45 self.error_type = None 

46 

47 # Call stack tracking 

48 self.calling_file = None 

49 self.calling_function = None 

50 self.call_stack = None 

51 

52 # Context overflow tracking 

53 self.context_limit = None 

54 self.context_truncated = False 

55 self.tokens_truncated = 0 

56 self.truncation_ratio = 0.0 

57 self.original_prompt_estimate = 0 

58 

59 # Raw Ollama response metrics 

60 self.ollama_metrics = {} 

61 

62 # Track token counts in memory 

63 self.counts = { 

64 "total_tokens": 0, 

65 "total_prompt_tokens": 0, 

66 "total_completion_tokens": 0, 

67 "by_model": {}, 

68 } 

69 

70 def on_llm_start( 

71 self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any 

72 ) -> None: 

73 """Called when LLM starts running.""" 

74 # Phase 1 Enhancement: Start timing 

75 self.start_time = time.time() 

76 

77 # Estimate original prompt size (rough estimate: ~4 chars per token) 

78 if prompts: 78 ↛ 86line 78 didn't jump to line 86 because the condition on line 78 was always true

79 total_chars = sum(len(prompt) for prompt in prompts) 

80 self.original_prompt_estimate = total_chars // 4 

81 logger.debug( 

82 f"Estimated prompt tokens: {self.original_prompt_estimate} (from {total_chars} chars)" 

83 ) 

84 

85 # Get context limit from research context (will be set from settings) 

86 self.context_limit = self.research_context.get("context_limit") 

87 

88 # Phase 1 Enhancement: Capture call stack information 

89 try: 

90 stack = inspect.stack() 

91 

92 # Skip the first few frames (this method, langchain internals) 

93 # Look for the first frame that's in our project directory 

94 for frame_info in stack[1:]: 

95 file_path = frame_info.filename 

96 # Look for any frame containing local_deep_research project 

97 if ( 97 ↛ 103line 97 didn't jump to line 103 because the condition on line 97 was never true

98 "local_deep_research" in file_path 

99 and "site-packages" not in file_path 

100 and "venv" not in file_path 

101 ): 

102 # Extract relative path from local_deep_research 

103 if "src/local_deep_research" in file_path: 

104 relative_path = file_path.split( 

105 "src/local_deep_research" 

106 )[-1].lstrip("/") 

107 elif "local_deep_research/src" in file_path: 

108 relative_path = file_path.split( 

109 "local_deep_research/src" 

110 )[-1].lstrip("/") 

111 elif "local_deep_research" in file_path: 

112 # Get everything after local_deep_research 

113 relative_path = file_path.split("local_deep_research")[ 

114 -1 

115 ].lstrip("/") 

116 else: 

117 relative_path = Path(file_path).name 

118 

119 self.calling_file = relative_path 

120 self.calling_function = frame_info.function 

121 

122 # Capture a simplified call stack (just the relevant frames) 

123 call_stack_frames = [] 

124 for frame in stack[1:6]: # Limit to 5 frames 

125 if ( 

126 "local_deep_research" in frame.filename 

127 and "site-packages" not in frame.filename 

128 and "venv" not in frame.filename 

129 ): 

130 frame_name = f"{Path(frame.filename).name}:{frame.function}:{frame.lineno}" 

131 call_stack_frames.append(frame_name) 

132 

133 self.call_stack = ( 

134 " -> ".join(call_stack_frames) 

135 if call_stack_frames 

136 else None 

137 ) 

138 break 

139 except Exception as e: 

140 logger.debug(f"Error capturing call stack: {e}") 

141 # Continue without call stack info if there's an error 

142 

143 # Debug logging removed to reduce log clutter 

144 # Uncomment below if you need to debug token counting 

145 # logger.debug(f"on_llm_start serialized: {serialized}") 

146 # logger.debug(f"on_llm_start kwargs keys: {list(kwargs.keys()) if kwargs else []}") 

147 

148 # First, use preset values if available 

149 if self.preset_model: 

150 self.current_model = self.preset_model 

151 else: 

152 # Try multiple locations for model name 

153 model_name = None 

154 

155 # First check invocation_params 

156 invocation_params = kwargs.get("invocation_params", {}) 

157 model_name = invocation_params.get( 

158 "model" 

159 ) or invocation_params.get("model_name") 

160 

161 # Check kwargs directly 

162 if not model_name: 

163 model_name = kwargs.get("model") or kwargs.get("model_name") 

164 

165 # Check serialized data 

166 if not model_name and "kwargs" in serialized: 

167 model_name = serialized["kwargs"].get("model") or serialized[ 

168 "kwargs" 

169 ].get("model_name") 

170 

171 # Check for name in serialized data 

172 if not model_name and "name" in serialized: 

173 model_name = serialized["name"] 

174 

175 # If still not found and we have Ollama, try to extract from the instance 

176 if ( 

177 not model_name 

178 and "_type" in serialized 

179 and "ChatOllama" in serialized["_type"] 

180 ): 

181 # For Ollama, the model name might be in the serialized kwargs 

182 if "kwargs" in serialized and "model" in serialized["kwargs"]: 182 ↛ 183line 182 didn't jump to line 183 because the condition on line 182 was never true

183 model_name = serialized["kwargs"]["model"] 

184 else: 

185 # Default to the type if we can't find the actual model 

186 model_name = "ollama" 

187 

188 # Final fallback 

189 if not model_name: 

190 if "_type" in serialized: 

191 model_name = serialized["_type"] 

192 else: 

193 model_name = "unknown" 

194 

195 self.current_model = model_name 

196 

197 # Use preset provider if available 

198 if self.preset_provider: 

199 self.current_provider = self.preset_provider 

200 else: 

201 # Extract provider from serialized type or kwargs 

202 if "_type" in serialized: 

203 type_str = serialized["_type"] 

204 if "ChatOllama" in type_str: 

205 self.current_provider = "ollama" 

206 elif "ChatOpenAI" in type_str: 

207 self.current_provider = "openai" 

208 elif "ChatAnthropic" in type_str: 

209 self.current_provider = "anthropic" 

210 else: 

211 self.current_provider = kwargs.get("provider", "unknown") 

212 else: 

213 self.current_provider = kwargs.get("provider", "unknown") 

214 

215 # Initialize model tracking if needed 

216 if self.current_model not in self.counts["by_model"]: 

217 self.counts["by_model"][self.current_model] = { 

218 "prompt_tokens": 0, 

219 "completion_tokens": 0, 

220 "total_tokens": 0, 

221 "calls": 0, 

222 "provider": self.current_provider, 

223 } 

224 

225 # Increment call count 

226 self.counts["by_model"][self.current_model]["calls"] += 1 

227 

228 def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: 

229 """Called when LLM ends running.""" 

230 # Phase 1 Enhancement: Calculate response time 

231 if self.start_time: 

232 self.response_time_ms = int((time.time() - self.start_time) * 1000) 

233 

234 # Extract token usage from response 

235 token_usage = None 

236 

237 # Check multiple locations for token usage 

238 if hasattr(response, "llm_output") and response.llm_output: 

239 token_usage = response.llm_output.get( 

240 "token_usage" 

241 ) or response.llm_output.get("usage", {}) 

242 

243 # Check for usage metadata in generations (Ollama specific) 

244 if not token_usage and hasattr(response, "generations"): 

245 for generation_list in response.generations: 

246 for generation in generation_list: 246 ↛ 336line 246 didn't jump to line 336 because the loop on line 246 didn't complete

247 if hasattr(generation, "message") and hasattr( 247 ↛ 265line 247 didn't jump to line 265 because the condition on line 247 was always true

248 generation.message, "usage_metadata" 

249 ): 

250 usage_meta = generation.message.usage_metadata 

251 if usage_meta: # Check if usage_metadata is not None 

252 token_usage = { 

253 "prompt_tokens": usage_meta.get( 

254 "input_tokens", 0 

255 ), 

256 "completion_tokens": usage_meta.get( 

257 "output_tokens", 0 

258 ), 

259 "total_tokens": usage_meta.get( 

260 "total_tokens", 0 

261 ), 

262 } 

263 break 

264 # Also check response_metadata 

265 if hasattr(generation, "message") and hasattr( 265 ↛ 246line 265 didn't jump to line 246 because the condition on line 265 was always true

266 generation.message, "response_metadata" 

267 ): 

268 resp_meta = generation.message.response_metadata 

269 if resp_meta.get("prompt_eval_count") or resp_meta.get( 269 ↛ 246line 269 didn't jump to line 246 because the condition on line 269 was always true

270 "eval_count" 

271 ): 

272 # Capture raw Ollama metrics 

273 self.ollama_metrics = { 

274 "prompt_eval_count": resp_meta.get( 

275 "prompt_eval_count" 

276 ), 

277 "eval_count": resp_meta.get("eval_count"), 

278 "total_duration": resp_meta.get( 

279 "total_duration" 

280 ), 

281 "load_duration": resp_meta.get("load_duration"), 

282 "prompt_eval_duration": resp_meta.get( 

283 "prompt_eval_duration" 

284 ), 

285 "eval_duration": resp_meta.get("eval_duration"), 

286 } 

287 

288 # Check for context overflow 

289 prompt_eval_count = resp_meta.get( 

290 "prompt_eval_count", 0 

291 ) 

292 if self.context_limit and prompt_eval_count > 0: 

293 # Check if we're near or at the context limit 

294 if ( 

295 prompt_eval_count 

296 >= self.context_limit * 0.95 

297 ): # 95% threshold 

298 self.context_truncated = True 

299 

300 # Estimate tokens truncated 

301 if ( 

302 self.original_prompt_estimate 

303 > prompt_eval_count 

304 ): 

305 self.tokens_truncated = max( 

306 0, 

307 self.original_prompt_estimate 

308 - prompt_eval_count, 

309 ) 

310 self.truncation_ratio = ( 

311 self.tokens_truncated 

312 / self.original_prompt_estimate 

313 if self.original_prompt_estimate > 0 

314 else 0 

315 ) 

316 logger.warning( 

317 f"Context overflow detected! " 

318 f"Prompt tokens: {prompt_eval_count}/{self.context_limit} " 

319 f"(estimated {self.tokens_truncated} tokens truncated, " 

320 f"{self.truncation_ratio:.1%} of prompt)" 

321 ) 

322 

323 token_usage = { 

324 "prompt_tokens": resp_meta.get( 

325 "prompt_eval_count", 0 

326 ), 

327 "completion_tokens": resp_meta.get( 

328 "eval_count", 0 

329 ), 

330 "total_tokens": resp_meta.get( 

331 "prompt_eval_count", 0 

332 ) 

333 + resp_meta.get("eval_count", 0), 

334 } 

335 break 

336 if token_usage: 336 ↛ 245line 336 didn't jump to line 245 because the condition on line 336 was always true

337 break 

338 

339 if token_usage and isinstance(token_usage, dict): 

340 prompt_tokens = token_usage.get("prompt_tokens", 0) 

341 completion_tokens = token_usage.get("completion_tokens", 0) 

342 total_tokens = token_usage.get( 

343 "total_tokens", prompt_tokens + completion_tokens 

344 ) 

345 

346 # Update in-memory counts 

347 self.counts["total_prompt_tokens"] += prompt_tokens 

348 self.counts["total_completion_tokens"] += completion_tokens 

349 self.counts["total_tokens"] += total_tokens 

350 

351 if self.current_model: 

352 self.counts["by_model"][self.current_model][ 

353 "prompt_tokens" 

354 ] += prompt_tokens 

355 self.counts["by_model"][self.current_model][ 

356 "completion_tokens" 

357 ] += completion_tokens 

358 self.counts["by_model"][self.current_model]["total_tokens"] += ( 

359 total_tokens 

360 ) 

361 

362 # Save to database if we have a research_id 

363 if self.research_id: 

364 self._save_to_db(prompt_tokens, completion_tokens) 

365 

366 def on_llm_error(self, error, **kwargs: Any) -> None: 

367 """Called when LLM encounters an error.""" 

368 # Phase 1 Enhancement: Track errors 

369 if self.start_time: 

370 self.response_time_ms = int((time.time() - self.start_time) * 1000) 

371 

372 self.success_status = "error" 

373 self.error_type = str(type(error).__name__) 

374 

375 # Still save to database to track failed calls 

376 if self.research_id: 376 ↛ 377line 376 didn't jump to line 377 because the condition on line 376 was never true

377 self._save_to_db(0, 0) 

378 

379 def _get_context_overflow_fields(self) -> Dict[str, Any]: 

380 """Get context overflow detection fields for database saving.""" 

381 return { 

382 "context_limit": self.context_limit, 

383 "context_truncated": self.context_truncated, # Now Boolean 

384 "tokens_truncated": self.tokens_truncated 

385 if self.context_truncated 

386 else None, 

387 "truncation_ratio": self.truncation_ratio 

388 if self.context_truncated 

389 else None, 

390 # Raw Ollama metrics 

391 "ollama_prompt_eval_count": self.ollama_metrics.get( 

392 "prompt_eval_count" 

393 ), 

394 "ollama_eval_count": self.ollama_metrics.get("eval_count"), 

395 "ollama_total_duration": self.ollama_metrics.get("total_duration"), 

396 "ollama_load_duration": self.ollama_metrics.get("load_duration"), 

397 "ollama_prompt_eval_duration": self.ollama_metrics.get( 

398 "prompt_eval_duration" 

399 ), 

400 "ollama_eval_duration": self.ollama_metrics.get("eval_duration"), 

401 } 

402 

403 def _save_to_db(self, prompt_tokens: int, completion_tokens: int): 

404 """Save token usage to the database.""" 

405 # Check if we're in a thread - if so, queue the save for later 

406 import threading 

407 

408 if threading.current_thread().name != "MainThread": 

409 # Use thread-safe metrics database for background threads 

410 username = ( 

411 self.research_context.get("username") 

412 if self.research_context 

413 else None 

414 ) 

415 

416 if not username: 

417 logger.warning( 

418 f"Cannot save token metrics - no username in research context. " 

419 f"Token usage: prompt={prompt_tokens}, completion={completion_tokens}, " 

420 f"Research context: {self.research_context}" 

421 ) 

422 return 

423 

424 # Import the thread-safe metrics database 

425 

426 # Prepare token data 

427 token_data = { 

428 "model_name": self.current_model, 

429 "provider": self.current_provider, 

430 "prompt_tokens": prompt_tokens, 

431 "completion_tokens": completion_tokens, 

432 "research_query": self.research_context.get("research_query"), 

433 "research_mode": self.research_context.get("research_mode"), 

434 "research_phase": self.research_context.get("research_phase"), 

435 "search_iteration": self.research_context.get( 

436 "search_iteration" 

437 ), 

438 "response_time_ms": self.response_time_ms, 

439 "success_status": self.success_status, 

440 "error_type": self.error_type, 

441 "search_engines_planned": self.research_context.get( 

442 "search_engines_planned" 

443 ), 

444 "search_engine_selected": self.research_context.get( 

445 "search_engine_selected" 

446 ), 

447 "calling_file": self.calling_file, 

448 "calling_function": self.calling_function, 

449 "call_stack": self.call_stack, 

450 # Add context overflow fields using helper method 

451 **self._get_context_overflow_fields(), 

452 } 

453 

454 # Convert list to JSON string if needed 

455 if isinstance(token_data.get("search_engines_planned"), list): 

456 token_data["search_engines_planned"] = json.dumps( 

457 token_data["search_engines_planned"] 

458 ) 

459 

460 # Get password from research context 

461 password = self.research_context.get("user_password") 

462 if not password: 

463 logger.warning( 

464 f"Cannot save token metrics - no password in research context. " 

465 f"Username: {username}, Token usage: prompt={prompt_tokens}, completion={completion_tokens}" 

466 ) 

467 return 

468 

469 # Write metrics directly using thread-safe database 

470 try: 

471 from ..database.thread_metrics import metrics_writer 

472 

473 # Set password for this thread 

474 metrics_writer.set_user_password(username, password) 

475 

476 # Write metrics to encrypted database 

477 metrics_writer.write_token_metrics( 

478 username, self.research_id, token_data 

479 ) 

480 except Exception: 

481 logger.exception("Failed to write metrics from thread") 

482 return 

483 

484 # In MainThread, save directly 

485 try: 

486 from flask import session as flask_session 

487 from ..database.session_context import get_user_db_session 

488 

489 username = flask_session.get("username") 

490 if not username: 

491 logger.debug("No user session, skipping token metrics save") 

492 return 

493 

494 with get_user_db_session(username) as session: 

495 # Phase 1 Enhancement: Prepare additional context 

496 research_query = self.research_context.get("research_query") 

497 research_mode = self.research_context.get("research_mode") 

498 research_phase = self.research_context.get("research_phase") 

499 search_iteration = self.research_context.get("search_iteration") 

500 search_engines_planned = self.research_context.get( 

501 "search_engines_planned" 

502 ) 

503 search_engine_selected = self.research_context.get( 

504 "search_engine_selected" 

505 ) 

506 

507 # Debug logging for search engine context 

508 if search_engines_planned or search_engine_selected: 508 ↛ 509line 508 didn't jump to line 509 because the condition on line 508 was never true

509 logger.info( 

510 f"Token tracking - Search context: planned={search_engines_planned}, selected={search_engine_selected}, phase={research_phase}" 

511 ) 

512 else: 

513 logger.debug( 

514 f"Token tracking - No search engine context yet, phase={research_phase}" 

515 ) 

516 

517 # Convert list to JSON string if needed 

518 if isinstance(search_engines_planned, list): 518 ↛ 519line 518 didn't jump to line 519 because the condition on line 518 was never true

519 search_engines_planned = json.dumps(search_engines_planned) 

520 

521 # Log context overflow detection values before saving 

522 logger.debug( 

523 f"Saving TokenUsage - context_limit: {self.context_limit}, " 

524 f"context_truncated: {self.context_truncated}, " 

525 f"tokens_truncated: {self.tokens_truncated}, " 

526 f"ollama_prompt_eval_count: {self.ollama_metrics.get('prompt_eval_count')}, " 

527 f"prompt_tokens: {prompt_tokens}, " 

528 f"completion_tokens: {completion_tokens}" 

529 ) 

530 

531 # Add token usage record with enhanced fields 

532 token_usage = TokenUsage( 

533 research_id=self.research_id, 

534 model_name=self.current_model, 

535 model_provider=self.current_provider, # Added provider 

536 # for accurate cost tracking 

537 prompt_tokens=prompt_tokens, 

538 completion_tokens=completion_tokens, 

539 total_tokens=prompt_tokens + completion_tokens, 

540 # Phase 1 Enhancement: Research context 

541 research_query=research_query, 

542 research_mode=research_mode, 

543 research_phase=research_phase, 

544 search_iteration=search_iteration, 

545 # Phase 1 Enhancement: Performance metrics 

546 response_time_ms=self.response_time_ms, 

547 success_status=self.success_status, 

548 error_type=self.error_type, 

549 # Phase 1 Enhancement: Search engine context 

550 search_engines_planned=search_engines_planned, 

551 search_engine_selected=search_engine_selected, 

552 # Phase 1 Enhancement: Call stack tracking 

553 calling_file=self.calling_file, 

554 calling_function=self.calling_function, 

555 call_stack=self.call_stack, 

556 # Add context overflow fields using helper method 

557 **self._get_context_overflow_fields(), 

558 ) 

559 session.add(token_usage) 

560 

561 # Update or create model usage statistics 

562 model_usage = ( 

563 session.query(ModelUsage) 

564 .filter_by( 

565 model_name=self.current_model, 

566 ) 

567 .first() 

568 ) 

569 

570 if model_usage: 

571 model_usage.total_tokens += ( 

572 prompt_tokens + completion_tokens 

573 ) 

574 model_usage.total_calls += 1 

575 else: 

576 model_usage = ModelUsage( 

577 model_name=self.current_model, 

578 model_provider=self.current_provider, 

579 total_tokens=prompt_tokens + completion_tokens, 

580 total_calls=1, 

581 ) 

582 session.add(model_usage) 

583 

584 # Commit the transaction 

585 session.commit() 

586 

587 except Exception: 

588 logger.exception("Error saving token usage to database") 

589 

590 def get_counts(self) -> Dict[str, Any]: 

591 """Get the current token counts.""" 

592 return self.counts 

593 

594 

595class TokenCounter: 

596 """Manager class for token counting across the application.""" 

597 

598 def __init__(self): 

599 """Initialize the token counter.""" 

600 # No longer need to store database reference 

601 self._thread_metrics_db = None 

602 

603 @property 

604 def thread_metrics_db(self): 

605 """Lazy load thread metrics writer.""" 

606 if self._thread_metrics_db is None: 

607 try: 

608 from ..database.thread_metrics import metrics_writer 

609 

610 self._thread_metrics_db = metrics_writer 

611 except ImportError: 

612 logger.warning("Thread metrics writer not available") 

613 return self._thread_metrics_db 

614 

615 def create_callback( 

616 self, 

617 research_id: Optional[str] = None, 

618 research_context: Optional[Dict[str, Any]] = None, 

619 ) -> TokenCountingCallback: 

620 """Create a new token counting callback. 

621 

622 Args: 

623 research_id: The ID of the research to track tokens for 

624 research_context: Additional research context for enhanced tracking 

625 

626 Returns: 

627 A new TokenCountingCallback instance 

628 """ 

629 return TokenCountingCallback( 

630 research_id=research_id, research_context=research_context 

631 ) 

632 

633 def get_research_metrics(self, research_id: str) -> Dict[str, Any]: 

634 """Get token metrics for a specific research. 

635 

636 Args: 

637 research_id: The ID of the research 

638 

639 Returns: 

640 Dictionary containing token usage metrics 

641 """ 

642 from flask import session as flask_session 

643 

644 from ..database.session_context import get_user_db_session 

645 

646 username = flask_session.get("username") 

647 if not username: 647 ↛ 655line 647 didn't jump to line 655 because the condition on line 647 was always true

648 return { 

649 "research_id": research_id, 

650 "total_tokens": 0, 

651 "total_calls": 0, 

652 "model_usage": [], 

653 } 

654 

655 with get_user_db_session(username) as session: 

656 # Get token usage for this research from TokenUsage table 

657 from sqlalchemy import func 

658 

659 token_usages = ( 

660 session.query( 

661 TokenUsage.model_name, 

662 TokenUsage.model_provider, 

663 func.sum(TokenUsage.prompt_tokens).label("prompt_tokens"), 

664 func.sum(TokenUsage.completion_tokens).label( 

665 "completion_tokens" 

666 ), 

667 func.sum(TokenUsage.total_tokens).label("total_tokens"), 

668 func.count().label("calls"), 

669 ) 

670 .filter_by(research_id=research_id) 

671 .group_by(TokenUsage.model_name, TokenUsage.model_provider) 

672 .order_by(func.sum(TokenUsage.total_tokens).desc()) 

673 .all() 

674 ) 

675 

676 model_usage = [] 

677 total_tokens = 0 

678 total_calls = 0 

679 

680 for usage in token_usages: 

681 model_usage.append( 

682 { 

683 "model": usage.model_name, 

684 "provider": usage.model_provider, 

685 "tokens": usage.total_tokens or 0, 

686 "calls": usage.calls or 0, 

687 "prompt_tokens": usage.prompt_tokens or 0, 

688 "completion_tokens": usage.completion_tokens or 0, 

689 } 

690 ) 

691 total_tokens += usage.total_tokens or 0 

692 total_calls += usage.calls or 0 

693 

694 return { 

695 "research_id": research_id, 

696 "total_tokens": total_tokens, 

697 "total_calls": total_calls, 

698 "model_usage": model_usage, 

699 } 

700 

701 def get_overall_metrics( 

702 self, period: str = "30d", research_mode: str = "all" 

703 ) -> Dict[str, Any]: 

704 """Get overall token metrics across all researches. 

705 

706 Args: 

707 period: Time period to filter by ('7d', '30d', '3m', '1y', 'all') 

708 research_mode: Research mode to filter by ('quick', 'detailed', 'all') 

709 

710 Returns: 

711 Dictionary containing overall metrics 

712 """ 

713 # First get metrics from user's encrypted database 

714 encrypted_metrics = self._get_metrics_from_encrypted_db( 

715 period, research_mode 

716 ) 

717 

718 # Then get metrics from thread-safe metrics database 

719 thread_metrics = self._get_metrics_from_thread_db(period, research_mode) 

720 

721 # Merge the results 

722 return self._merge_metrics(encrypted_metrics, thread_metrics) 

723 

724 def _get_metrics_from_encrypted_db( 

725 self, period: str, research_mode: str 

726 ) -> Dict[str, Any]: 

727 """Get metrics from user's encrypted database.""" 

728 from flask import session as flask_session 

729 

730 from ..database.session_context import get_user_db_session 

731 

732 username = flask_session.get("username") 

733 if not username: 

734 return self._get_empty_metrics() 

735 

736 try: 

737 with get_user_db_session(username) as session: 

738 # Build base query with filters 

739 query = session.query(TokenUsage) 

740 

741 # Apply time filter 

742 time_condition = get_time_filter_condition( 

743 period, TokenUsage.timestamp 

744 ) 

745 if time_condition is not None: 

746 query = query.filter(time_condition) 

747 

748 # Apply research mode filter 

749 mode_condition = get_research_mode_condition( 

750 research_mode, TokenUsage.research_mode 

751 ) 

752 if mode_condition is not None: 

753 query = query.filter(mode_condition) 

754 

755 # Total tokens from TokenUsage 

756 total_tokens = ( 

757 query.with_entities( 

758 func.sum(TokenUsage.total_tokens) 

759 ).scalar() 

760 or 0 

761 ) 

762 

763 # Import ResearchHistory model 

764 from ..database.models.research import ResearchHistory 

765 

766 # Count researches from ResearchHistory table 

767 research_query = session.query(func.count(ResearchHistory.id)) 

768 

769 # Debug: Check if any research history records exist at all 

770 all_research_count = ( 

771 session.query(func.count(ResearchHistory.id)).scalar() or 0 

772 ) 

773 logger.debug( 

774 f"Total ResearchHistory records in database: {all_research_count}" 

775 ) 

776 

777 # Debug: List first few research IDs and their timestamps 

778 sample_researches = ( 

779 session.query( 

780 ResearchHistory.id, 

781 ResearchHistory.created_at, 

782 ResearchHistory.mode, 

783 ) 

784 .limit(5) 

785 .all() 

786 ) 

787 if sample_researches: 787 ↛ 788line 787 didn't jump to line 788 because the condition on line 787 was never true

788 logger.debug("Sample ResearchHistory records:") 

789 for r_id, r_created, r_mode in sample_researches: 

790 logger.debug( 

791 f" - ID: {r_id}, Created: {r_created}, Mode: {r_mode}" 

792 ) 

793 else: 

794 logger.debug("No ResearchHistory records found in database") 

795 

796 # Get time filter conditions for ResearchHistory query 

797 start_time, end_time = None, None 

798 if period != "all": 

799 if period == "today": 799 ↛ 800line 799 didn't jump to line 800 because the condition on line 799 was never true

800 start_time = datetime.now(UTC).replace( 

801 hour=0, minute=0, second=0, microsecond=0 

802 ) 

803 elif period == "week": 803 ↛ 804line 803 didn't jump to line 804 because the condition on line 803 was never true

804 start_time = datetime.now(UTC) - timedelta(days=7) 

805 elif period == "month": 805 ↛ 806line 805 didn't jump to line 806 because the condition on line 805 was never true

806 start_time = datetime.now(UTC) - timedelta(days=30) 

807 

808 if start_time: 808 ↛ 809line 808 didn't jump to line 809 because the condition on line 808 was never true

809 end_time = datetime.now(UTC) 

810 

811 # Apply time filter if specified 

812 if start_time and end_time: 812 ↛ 813line 812 didn't jump to line 813 because the condition on line 812 was never true

813 research_query = research_query.filter( 

814 ResearchHistory.created_at >= start_time.isoformat(), 

815 ResearchHistory.created_at <= end_time.isoformat(), 

816 ) 

817 

818 # Apply mode filter if specified 

819 mode_filter = research_mode if research_mode != "all" else None 

820 if mode_filter: 

821 logger.debug(f"Applying mode filter: {mode_filter}") 

822 research_query = research_query.filter( 

823 ResearchHistory.mode == mode_filter 

824 ) 

825 

826 total_researches = research_query.scalar() or 0 

827 logger.debug( 

828 f"Final filtered research count: {total_researches}" 

829 ) 

830 

831 # Also check distinct research_ids in TokenUsage for comparison 

832 token_research_count = ( 

833 session.query( 

834 func.count(func.distinct(TokenUsage.research_id)) 

835 ).scalar() 

836 or 0 

837 ) 

838 logger.debug( 

839 f"Distinct research_ids in TokenUsage: {token_research_count}" 

840 ) 

841 

842 # Model statistics using ORM aggregation 

843 model_stats_query = session.query( 

844 TokenUsage.model_name, 

845 func.sum(TokenUsage.total_tokens).label("tokens"), 

846 func.count().label("calls"), 

847 func.sum(TokenUsage.prompt_tokens).label("prompt_tokens"), 

848 func.sum(TokenUsage.completion_tokens).label( 

849 "completion_tokens" 

850 ), 

851 ).filter(TokenUsage.model_name.isnot(None)) 

852 

853 # Apply same filters to model stats 

854 if time_condition is not None: 

855 model_stats_query = model_stats_query.filter(time_condition) 

856 if mode_condition is not None: 

857 model_stats_query = model_stats_query.filter(mode_condition) 

858 

859 model_stats = ( 

860 model_stats_query.group_by(TokenUsage.model_name) 

861 .order_by(func.sum(TokenUsage.total_tokens).desc()) 

862 .all() 

863 ) 

864 

865 # Get provider info from ModelUsage table 

866 by_model = [] 

867 for stat in model_stats: 867 ↛ 869line 867 didn't jump to line 869 because the loop on line 867 never started

868 # Try to get provider from ModelUsage table 

869 provider_info = ( 

870 session.query(ModelUsage.model_provider) 

871 .filter(ModelUsage.model_name == stat.model_name) 

872 .first() 

873 ) 

874 provider = ( 

875 provider_info.model_provider 

876 if provider_info 

877 else "unknown" 

878 ) 

879 

880 by_model.append( 

881 { 

882 "model": stat.model_name, 

883 "provider": provider, 

884 "tokens": stat.tokens, 

885 "calls": stat.calls, 

886 "prompt_tokens": stat.prompt_tokens, 

887 "completion_tokens": stat.completion_tokens, 

888 } 

889 ) 

890 

891 # Get recent researches with token usage 

892 # Note: This requires research_history table - for now we'll use available data 

893 recent_research_query = session.query( 

894 TokenUsage.research_id, 

895 func.sum(TokenUsage.total_tokens).label("token_count"), 

896 func.max(TokenUsage.timestamp).label("latest_timestamp"), 

897 ).filter(TokenUsage.research_id.isnot(None)) 

898 

899 if time_condition is not None: 

900 recent_research_query = recent_research_query.filter( 

901 time_condition 

902 ) 

903 if mode_condition is not None: 

904 recent_research_query = recent_research_query.filter( 

905 mode_condition 

906 ) 

907 

908 recent_research_data = ( 

909 recent_research_query.group_by(TokenUsage.research_id) 

910 .order_by(func.max(TokenUsage.timestamp).desc()) 

911 .limit(10) 

912 .all() 

913 ) 

914 

915 recent_researches = [] 

916 for research_data in recent_research_data: 916 ↛ 918line 916 didn't jump to line 918 because the loop on line 916 never started

917 # Get research query from token_usage table if available 

918 research_query_data = ( 

919 session.query(TokenUsage.research_query) 

920 .filter( 

921 TokenUsage.research_id == research_data.research_id, 

922 TokenUsage.research_query.isnot(None), 

923 ) 

924 .first() 

925 ) 

926 

927 query_text = ( 

928 research_query_data.research_query 

929 if research_query_data 

930 else f"Research {research_data.research_id}" 

931 ) 

932 

933 recent_researches.append( 

934 { 

935 "id": research_data.research_id, 

936 "query": query_text, 

937 "tokens": research_data.token_count or 0, 

938 "created_at": research_data.latest_timestamp, 

939 } 

940 ) 

941 

942 # Token breakdown statistics 

943 breakdown_query = query.with_entities( 

944 func.sum(TokenUsage.prompt_tokens).label( 

945 "total_input_tokens" 

946 ), 

947 func.sum(TokenUsage.completion_tokens).label( 

948 "total_output_tokens" 

949 ), 

950 func.avg(TokenUsage.prompt_tokens).label( 

951 "avg_input_tokens" 

952 ), 

953 func.avg(TokenUsage.completion_tokens).label( 

954 "avg_output_tokens" 

955 ), 

956 func.avg(TokenUsage.total_tokens).label("avg_total_tokens"), 

957 ) 

958 token_breakdown = breakdown_query.first() 

959 

960 # Get rate limiting metrics 

961 from ..database.models import ( 

962 RateLimitAttempt, 

963 RateLimitEstimate, 

964 ) 

965 

966 # Get rate limit attempts 

967 rate_limit_query = session.query(RateLimitAttempt) 

968 

969 # Apply time filter 

970 if time_condition is not None: 

971 # RateLimitAttempt uses timestamp as float, not datetime 

972 if period == "7d": 

973 cutoff_time = time.time() - (7 * 24 * 3600) 

974 elif period == "30d": 974 ↛ 976line 974 didn't jump to line 976 because the condition on line 974 was always true

975 cutoff_time = time.time() - (30 * 24 * 3600) 

976 elif period == "3m": 

977 cutoff_time = time.time() - (90 * 24 * 3600) 

978 elif period == "1y": 

979 cutoff_time = time.time() - (365 * 24 * 3600) 

980 else: # all 

981 cutoff_time = 0 

982 

983 if cutoff_time > 0: 983 ↛ 989line 983 didn't jump to line 989 because the condition on line 983 was always true

984 rate_limit_query = rate_limit_query.filter( 

985 RateLimitAttempt.timestamp >= cutoff_time 

986 ) 

987 

988 # Get rate limit statistics 

989 total_attempts = rate_limit_query.count() 

990 successful_attempts = rate_limit_query.filter( 

991 RateLimitAttempt.success 

992 ).count() 

993 failed_attempts = total_attempts - successful_attempts 

994 

995 # Count rate limiting events (failures with RateLimitError) 

996 rate_limit_events = rate_limit_query.filter( 

997 ~RateLimitAttempt.success, 

998 RateLimitAttempt.error_type == "RateLimitError", 

999 ).count() 

1000 

1001 logger.debug( 

1002 f"Rate limit attempts in database: total={total_attempts}, successful={successful_attempts}" 

1003 ) 

1004 

1005 # Get all attempts for detailed calculations 

1006 attempts = rate_limit_query.all() 

1007 

1008 # Calculate average wait times 

1009 if attempts: 1009 ↛ 1010line 1009 didn't jump to line 1010 because the condition on line 1009 was never true

1010 avg_wait_time = sum(a.wait_time for a in attempts) / len( 

1011 attempts 

1012 ) 

1013 successful_wait_times = [ 

1014 a.wait_time for a in attempts if a.success 

1015 ] 

1016 avg_successful_wait = ( 

1017 sum(successful_wait_times) / len(successful_wait_times) 

1018 if successful_wait_times 

1019 else 0 

1020 ) 

1021 else: 

1022 avg_wait_time = 0 

1023 avg_successful_wait = 0 

1024 

1025 # Get tracked engines - count distinct engine types from attempts 

1026 tracked_engines_query = session.query( 

1027 func.count(func.distinct(RateLimitAttempt.engine_type)) 

1028 ) 

1029 if cutoff_time > 0: 1029 ↛ 1033line 1029 didn't jump to line 1033 because the condition on line 1029 was always true

1030 tracked_engines_query = tracked_engines_query.filter( 

1031 RateLimitAttempt.timestamp >= cutoff_time 

1032 ) 

1033 tracked_engines = tracked_engines_query.scalar() or 0 

1034 

1035 # Get engine-specific stats from attempts 

1036 engine_stats = [] 

1037 

1038 # Get distinct engine types from attempts 

1039 engine_types_query = session.query( 

1040 RateLimitAttempt.engine_type 

1041 ).distinct() 

1042 if cutoff_time > 0: 1042 ↛ 1046line 1042 didn't jump to line 1046 because the condition on line 1042 was always true

1043 engine_types_query = engine_types_query.filter( 

1044 RateLimitAttempt.timestamp >= cutoff_time 

1045 ) 

1046 engine_types = [ 

1047 row.engine_type for row in engine_types_query.all() 

1048 ] 

1049 

1050 for engine_type in engine_types: 1050 ↛ 1051line 1050 didn't jump to line 1051 because the loop on line 1050 never started

1051 engine_attempts_list = [ 

1052 a for a in attempts if a.engine_type == engine_type 

1053 ] 

1054 engine_attempts = len(engine_attempts_list) 

1055 engine_success = len( 

1056 [a for a in engine_attempts_list if a.success] 

1057 ) 

1058 

1059 # Get estimate if exists 

1060 estimate = ( 

1061 session.query(RateLimitEstimate) 

1062 .filter(RateLimitEstimate.engine_type == engine_type) 

1063 .first() 

1064 ) 

1065 

1066 # Calculate recent success rate 

1067 recent_success_rate = ( 

1068 (engine_success / engine_attempts * 100) 

1069 if engine_attempts > 0 

1070 else 0 

1071 ) 

1072 

1073 # Determine status based on success rate 

1074 if estimate: 

1075 status = ( 

1076 "healthy" 

1077 if estimate.success_rate > 0.8 

1078 else "degraded" 

1079 if estimate.success_rate > 0.5 

1080 else "poor" 

1081 ) 

1082 else: 

1083 status = ( 

1084 "healthy" 

1085 if recent_success_rate > 80 

1086 else "degraded" 

1087 if recent_success_rate > 50 

1088 else "poor" 

1089 ) 

1090 

1091 engine_stat = { 

1092 "engine": engine_type, 

1093 "base_wait": estimate.base_wait_seconds 

1094 if estimate 

1095 else 0.0, 

1096 "base_wait_seconds": round( 

1097 estimate.base_wait_seconds if estimate else 0.0, 2 

1098 ), 

1099 "min_wait_seconds": round( 

1100 estimate.min_wait_seconds if estimate else 0.0, 2 

1101 ), 

1102 "max_wait_seconds": round( 

1103 estimate.max_wait_seconds if estimate else 0.0, 2 

1104 ), 

1105 "success_rate": round(estimate.success_rate * 100, 1) 

1106 if estimate 

1107 else recent_success_rate, 

1108 "total_attempts": estimate.total_attempts 

1109 if estimate 

1110 else engine_attempts, 

1111 "recent_attempts": engine_attempts, 

1112 "recent_success_rate": round(recent_success_rate, 1), 

1113 "attempts": engine_attempts, 

1114 "status": status, 

1115 } 

1116 

1117 if estimate: 

1118 engine_stat["last_updated"] = datetime.fromtimestamp( 

1119 estimate.last_updated 

1120 ).strftime("%Y-%m-%d %H:%M:%S") 

1121 else: 

1122 engine_stat["last_updated"] = "Never" 

1123 

1124 engine_stats.append(engine_stat) 

1125 

1126 logger.debug( 

1127 f"Tracked engines: {tracked_engines}, engine_stats: {engine_stats}" 

1128 ) 

1129 

1130 result = { 

1131 "total_tokens": total_tokens, 

1132 "total_researches": total_researches, 

1133 "by_model": by_model, 

1134 "recent_researches": recent_researches, 

1135 "token_breakdown": { 

1136 "total_input_tokens": int( 

1137 token_breakdown.total_input_tokens or 0 

1138 ), 

1139 "total_output_tokens": int( 

1140 token_breakdown.total_output_tokens or 0 

1141 ), 

1142 "avg_input_tokens": int( 

1143 token_breakdown.avg_input_tokens or 0 

1144 ), 

1145 "avg_output_tokens": int( 

1146 token_breakdown.avg_output_tokens or 0 

1147 ), 

1148 "avg_total_tokens": int( 

1149 token_breakdown.avg_total_tokens or 0 

1150 ), 

1151 }, 

1152 "rate_limiting": { 

1153 "total_attempts": total_attempts, 

1154 "successful_attempts": successful_attempts, 

1155 "failed_attempts": failed_attempts, 

1156 "success_rate": ( 

1157 successful_attempts / total_attempts * 100 

1158 ) 

1159 if total_attempts > 0 

1160 else 0, 

1161 "rate_limit_events": rate_limit_events, 

1162 "avg_wait_time": round(float(avg_wait_time), 2), 

1163 "avg_successful_wait": round( 

1164 float(avg_successful_wait), 2 

1165 ), 

1166 "tracked_engines": tracked_engines, 

1167 "engine_stats": engine_stats, 

1168 "total_engines_tracked": tracked_engines, 

1169 "healthy_engines": len( 

1170 [ 

1171 s 

1172 for s in engine_stats 

1173 if s["status"] == "healthy" 

1174 ] 

1175 ), 

1176 "degraded_engines": len( 

1177 [ 

1178 s 

1179 for s in engine_stats 

1180 if s["status"] == "degraded" 

1181 ] 

1182 ), 

1183 "poor_engines": len( 

1184 [s for s in engine_stats if s["status"] == "poor"] 

1185 ), 

1186 }, 

1187 } 

1188 

1189 logger.debug( 

1190 f"Returning from _get_metrics_from_encrypted_db - total_researches: {result['total_researches']}" 

1191 ) 

1192 return result 

1193 except Exception: 

1194 logger.exception( 

1195 "CRITICAL ERROR accessing encrypted database for metrics" 

1196 ) 

1197 return self._get_empty_metrics() 

1198 

1199 def _get_metrics_from_thread_db( 

1200 self, period: str, research_mode: str 

1201 ) -> Dict[str, Any]: 

1202 """Get metrics from thread-safe metrics database.""" 

1203 if not self.thread_metrics_db: 1203 ↛ 1204line 1203 didn't jump to line 1204 because the condition on line 1203 was never true

1204 return { 

1205 "total_tokens": 0, 

1206 "total_researches": 0, 

1207 "by_model": [], 

1208 "recent_researches": [], 

1209 "token_breakdown": { 

1210 "total_input_tokens": 0, 

1211 "total_output_tokens": 0, 

1212 "avg_input_tokens": 0, 

1213 "avg_output_tokens": 0, 

1214 "avg_total_tokens": 0, 

1215 }, 

1216 } 

1217 

1218 try: 

1219 with self.thread_metrics_db.get_session() as session: 

1220 # Build base query with filters 

1221 query = session.query(TokenUsage) 

1222 

1223 # Apply time filter 

1224 time_condition = get_time_filter_condition( 

1225 period, TokenUsage.timestamp 

1226 ) 

1227 if time_condition is not None: 1227 ↛ 1231line 1227 didn't jump to line 1231 because the condition on line 1227 was always true

1228 query = query.filter(time_condition) 

1229 

1230 # Apply research mode filter 

1231 mode_condition = get_research_mode_condition( 

1232 research_mode, TokenUsage.research_mode 

1233 ) 

1234 if mode_condition is not None: 1234 ↛ 1235line 1234 didn't jump to line 1235 because the condition on line 1234 was never true

1235 query = query.filter(mode_condition) 

1236 

1237 # Get totals 

1238 total_tokens = ( 

1239 query.with_entities( 

1240 func.sum(TokenUsage.total_tokens) 

1241 ).scalar() 

1242 or 0 

1243 ) 

1244 total_researches = ( 

1245 query.with_entities( 

1246 func.count(func.distinct(TokenUsage.research_id)) 

1247 ).scalar() 

1248 or 0 

1249 ) 

1250 

1251 # Get model statistics 

1252 model_stats = ( 

1253 query.with_entities( 

1254 TokenUsage.model_name, 

1255 func.sum(TokenUsage.total_tokens).label("tokens"), 

1256 func.count().label("calls"), 

1257 func.sum(TokenUsage.prompt_tokens).label( 

1258 "prompt_tokens" 

1259 ), 

1260 func.sum(TokenUsage.completion_tokens).label( 

1261 "completion_tokens" 

1262 ), 

1263 ) 

1264 .filter(TokenUsage.model_name.isnot(None)) 

1265 .group_by(TokenUsage.model_name) 

1266 .all() 

1267 ) 

1268 

1269 by_model = [] 

1270 for stat in model_stats: 1270 ↛ 1271line 1270 didn't jump to line 1271 because the loop on line 1270 never started

1271 by_model.append( 

1272 { 

1273 "model": stat.model_name, 

1274 "provider": "unknown", # Provider info might not be in thread DB 

1275 "tokens": stat.tokens, 

1276 "calls": stat.calls, 

1277 "prompt_tokens": stat.prompt_tokens, 

1278 "completion_tokens": stat.completion_tokens, 

1279 } 

1280 ) 

1281 

1282 # Token breakdown 

1283 breakdown = query.with_entities( 

1284 func.sum(TokenUsage.prompt_tokens).label( 

1285 "total_input_tokens" 

1286 ), 

1287 func.sum(TokenUsage.completion_tokens).label( 

1288 "total_output_tokens" 

1289 ), 

1290 ).first() 

1291 

1292 return { 

1293 "total_tokens": total_tokens, 

1294 "total_researches": total_researches, 

1295 "by_model": by_model, 

1296 "recent_researches": [], # Skip for thread DB 

1297 "token_breakdown": { 

1298 "total_input_tokens": int( 

1299 breakdown.total_input_tokens or 0 

1300 ), 

1301 "total_output_tokens": int( 

1302 breakdown.total_output_tokens or 0 

1303 ), 

1304 "avg_input_tokens": 0, 

1305 "avg_output_tokens": 0, 

1306 "avg_total_tokens": 0, 

1307 }, 

1308 } 

1309 except Exception: 

1310 logger.exception("Error reading thread metrics database") 

1311 return { 

1312 "total_tokens": 0, 

1313 "total_researches": 0, 

1314 "by_model": [], 

1315 "recent_researches": [], 

1316 "token_breakdown": { 

1317 "total_input_tokens": 0, 

1318 "total_output_tokens": 0, 

1319 "avg_input_tokens": 0, 

1320 "avg_output_tokens": 0, 

1321 "avg_total_tokens": 0, 

1322 }, 

1323 } 

1324 

1325 def _merge_metrics( 

1326 self, encrypted: Dict[str, Any], thread: Dict[str, Any] 

1327 ) -> Dict[str, Any]: 

1328 """Merge metrics from both databases.""" 

1329 # Combine totals 

1330 total_tokens = encrypted.get("total_tokens", 0) + thread.get( 

1331 "total_tokens", 0 

1332 ) 

1333 total_researches = max( 

1334 encrypted.get("total_researches", 0), 

1335 thread.get("total_researches", 0), 

1336 ) 

1337 logger.debug( 

1338 f"Merged metrics - encrypted researches: {encrypted.get('total_researches', 0)}, thread researches: {thread.get('total_researches', 0)}, final: {total_researches}" 

1339 ) 

1340 

1341 # Merge model usage 

1342 model_map = {} 

1343 for model_data in encrypted.get("by_model", []): 

1344 key = model_data["model"] 

1345 model_map[key] = model_data 

1346 

1347 for model_data in thread.get("by_model", []): 

1348 key = model_data["model"] 

1349 if key in model_map: 

1350 # Merge with existing 

1351 model_map[key]["tokens"] += model_data["tokens"] 

1352 model_map[key]["calls"] += model_data["calls"] 

1353 model_map[key]["prompt_tokens"] += model_data["prompt_tokens"] 

1354 model_map[key]["completion_tokens"] += model_data[ 

1355 "completion_tokens" 

1356 ] 

1357 else: 

1358 model_map[key] = model_data 

1359 

1360 by_model = sorted( 

1361 model_map.values(), key=lambda x: x["tokens"], reverse=True 

1362 ) 

1363 

1364 # Merge token breakdown 

1365 token_breakdown = { 

1366 "total_input_tokens": ( 

1367 encrypted.get("token_breakdown", {}).get( 

1368 "total_input_tokens", 0 

1369 ) 

1370 + thread.get("token_breakdown", {}).get("total_input_tokens", 0) 

1371 ), 

1372 "total_output_tokens": ( 

1373 encrypted.get("token_breakdown", {}).get( 

1374 "total_output_tokens", 0 

1375 ) 

1376 + thread.get("token_breakdown", {}).get( 

1377 "total_output_tokens", 0 

1378 ) 

1379 ), 

1380 "avg_input_tokens": encrypted.get("token_breakdown", {}).get( 

1381 "avg_input_tokens", 0 

1382 ), 

1383 "avg_output_tokens": encrypted.get("token_breakdown", {}).get( 

1384 "avg_output_tokens", 0 

1385 ), 

1386 "avg_total_tokens": encrypted.get("token_breakdown", {}).get( 

1387 "avg_total_tokens", 0 

1388 ), 

1389 } 

1390 

1391 result = { 

1392 "total_tokens": total_tokens, 

1393 "total_researches": total_researches, 

1394 "by_model": by_model, 

1395 "recent_researches": encrypted.get("recent_researches", []), 

1396 "token_breakdown": token_breakdown, 

1397 } 

1398 

1399 logger.debug( 

1400 f"Final get_token_metrics result - total_researches: {result['total_researches']}" 

1401 ) 

1402 return result 

1403 

1404 def _get_empty_metrics(self) -> Dict[str, Any]: 

1405 """Return empty metrics structure when no data is available.""" 

1406 return { 

1407 "total_tokens": 0, 

1408 "total_researches": 0, 

1409 "by_model": [], 

1410 "recent_researches": [], 

1411 "token_breakdown": { 

1412 "prompt_tokens": 0, 

1413 "completion_tokens": 0, 

1414 "avg_prompt_tokens": 0, 

1415 "avg_completion_tokens": 0, 

1416 "avg_total_tokens": 0, 

1417 }, 

1418 } 

1419 

1420 def get_enhanced_metrics( 

1421 self, period: str = "30d", research_mode: str = "all" 

1422 ) -> Dict[str, Any]: 

1423 """Get enhanced Phase 1 tracking metrics. 

1424 

1425 Args: 

1426 period: Time period to filter by ('7d', '30d', '3m', '1y', 'all') 

1427 research_mode: Research mode to filter by ('quick', 'detailed', 'all') 

1428 

1429 Returns: 

1430 Dictionary containing enhanced metrics data including time series 

1431 """ 

1432 from flask import session as flask_session 

1433 

1434 from ..database.session_context import get_user_db_session 

1435 

1436 username = flask_session.get("username") 

1437 if not username: 1437 ↛ 1439line 1437 didn't jump to line 1439 because the condition on line 1437 was never true

1438 # Return empty metrics structure when no user session 

1439 return { 

1440 "recent_enhanced_data": [], 

1441 "performance_stats": { 

1442 "avg_response_time": 0, 

1443 "min_response_time": 0, 

1444 "max_response_time": 0, 

1445 "success_rate": 0, 

1446 "error_rate": 0, 

1447 "total_enhanced_calls": 0, 

1448 }, 

1449 "mode_breakdown": [], 

1450 "search_engine_stats": [], 

1451 "phase_breakdown": [], 

1452 "time_series_data": [], 

1453 "call_stack_analysis": { 

1454 "by_file": [], 

1455 "by_function": [], 

1456 }, 

1457 } 

1458 

1459 try: 

1460 with get_user_db_session(username) as session: 

1461 # Build base query with filters 

1462 query = session.query(TokenUsage) 

1463 

1464 # Apply time filter 

1465 time_condition = get_time_filter_condition( 

1466 period, TokenUsage.timestamp 

1467 ) 

1468 if time_condition is not None: 1468 ↛ 1472line 1468 didn't jump to line 1472 because the condition on line 1468 was always true

1469 query = query.filter(time_condition) 

1470 

1471 # Apply research mode filter 

1472 mode_condition = get_research_mode_condition( 

1473 research_mode, TokenUsage.research_mode 

1474 ) 

1475 if mode_condition is not None: 1475 ↛ 1476line 1475 didn't jump to line 1476 because the condition on line 1475 was never true

1476 query = query.filter(mode_condition) 

1477 

1478 # Get time series data for the chart - most important for "Token Consumption Over Time" 

1479 time_series_query = query.filter( 

1480 TokenUsage.timestamp.isnot(None), 

1481 TokenUsage.total_tokens > 0, 

1482 ).order_by(TokenUsage.timestamp.asc()) 

1483 

1484 # Limit to recent data for performance 

1485 if period != "all": 1485 ↛ 1488line 1485 didn't jump to line 1488 because the condition on line 1485 was always true

1486 time_series_query = time_series_query.limit(200) 

1487 

1488 time_series_data = time_series_query.all() 

1489 

1490 # Format time series data with cumulative calculations 

1491 time_series = [] 

1492 cumulative_tokens = 0 

1493 cumulative_prompt_tokens = 0 

1494 cumulative_completion_tokens = 0 

1495 

1496 for usage in time_series_data: 1496 ↛ 1497line 1496 didn't jump to line 1497 because the loop on line 1496 never started

1497 cumulative_tokens += usage.total_tokens or 0 

1498 cumulative_prompt_tokens += usage.prompt_tokens or 0 

1499 cumulative_completion_tokens += usage.completion_tokens or 0 

1500 

1501 time_series.append( 

1502 { 

1503 "timestamp": str(usage.timestamp) 

1504 if usage.timestamp 

1505 else None, 

1506 "tokens": usage.total_tokens or 0, 

1507 "prompt_tokens": usage.prompt_tokens or 0, 

1508 "completion_tokens": usage.completion_tokens or 0, 

1509 "cumulative_tokens": cumulative_tokens, 

1510 "cumulative_prompt_tokens": cumulative_prompt_tokens, 

1511 "cumulative_completion_tokens": cumulative_completion_tokens, 

1512 "research_id": usage.research_id, 

1513 } 

1514 ) 

1515 

1516 # Basic performance stats using ORM 

1517 performance_query = query.filter( 

1518 TokenUsage.response_time_ms.isnot(None) 

1519 ) 

1520 total_calls = performance_query.count() 

1521 

1522 if total_calls > 0: 1522 ↛ 1523line 1522 didn't jump to line 1523 because the condition on line 1522 was never true

1523 avg_response_time = ( 

1524 performance_query.with_entities( 

1525 func.avg(TokenUsage.response_time_ms) 

1526 ).scalar() 

1527 or 0 

1528 ) 

1529 min_response_time = ( 

1530 performance_query.with_entities( 

1531 func.min(TokenUsage.response_time_ms) 

1532 ).scalar() 

1533 or 0 

1534 ) 

1535 max_response_time = ( 

1536 performance_query.with_entities( 

1537 func.max(TokenUsage.response_time_ms) 

1538 ).scalar() 

1539 or 0 

1540 ) 

1541 success_count = performance_query.filter( 

1542 TokenUsage.success_status == "success" 

1543 ).count() 

1544 error_count = performance_query.filter( 

1545 TokenUsage.success_status == "error" 

1546 ).count() 

1547 

1548 perf_stats = { 

1549 "avg_response_time": round(avg_response_time), 

1550 "min_response_time": min_response_time, 

1551 "max_response_time": max_response_time, 

1552 "success_rate": ( 

1553 round((success_count / total_calls * 100), 1) 

1554 if total_calls > 0 

1555 else 0 

1556 ), 

1557 "error_rate": ( 

1558 round((error_count / total_calls * 100), 1) 

1559 if total_calls > 0 

1560 else 0 

1561 ), 

1562 "total_enhanced_calls": total_calls, 

1563 } 

1564 else: 

1565 perf_stats = { 

1566 "avg_response_time": 0, 

1567 "min_response_time": 0, 

1568 "max_response_time": 0, 

1569 "success_rate": 0, 

1570 "error_rate": 0, 

1571 "total_enhanced_calls": 0, 

1572 } 

1573 

1574 # Research mode breakdown using ORM 

1575 mode_stats = ( 

1576 query.filter(TokenUsage.research_mode.isnot(None)) 

1577 .with_entities( 

1578 TokenUsage.research_mode, 

1579 func.count().label("count"), 

1580 func.avg(TokenUsage.total_tokens).label("avg_tokens"), 

1581 func.avg(TokenUsage.response_time_ms).label( 

1582 "avg_response_time" 

1583 ), 

1584 ) 

1585 .group_by(TokenUsage.research_mode) 

1586 .all() 

1587 ) 

1588 

1589 modes = [ 

1590 { 

1591 "mode": stat.research_mode, 

1592 "count": stat.count, 

1593 "avg_tokens": round(stat.avg_tokens or 0), 

1594 "avg_response_time": round(stat.avg_response_time or 0), 

1595 } 

1596 for stat in mode_stats 

1597 ] 

1598 

1599 # Recent enhanced data (simplified) 

1600 recent_enhanced_query = ( 

1601 query.filter(TokenUsage.research_query.isnot(None)) 

1602 .order_by(TokenUsage.timestamp.desc()) 

1603 .limit(50) 

1604 ) 

1605 

1606 recent_enhanced_data = recent_enhanced_query.all() 

1607 recent_enhanced = [ 

1608 { 

1609 "research_query": usage.research_query, 

1610 "research_mode": usage.research_mode, 

1611 "research_phase": usage.research_phase, 

1612 "search_iteration": usage.search_iteration, 

1613 "response_time_ms": usage.response_time_ms, 

1614 "success_status": usage.success_status, 

1615 "error_type": usage.error_type, 

1616 "search_engines_planned": usage.search_engines_planned, 

1617 "search_engine_selected": usage.search_engine_selected, 

1618 "total_tokens": usage.total_tokens, 

1619 "prompt_tokens": usage.prompt_tokens, 

1620 "completion_tokens": usage.completion_tokens, 

1621 "timestamp": str(usage.timestamp) 

1622 if usage.timestamp 

1623 else None, 

1624 "research_id": usage.research_id, 

1625 "calling_file": usage.calling_file, 

1626 "calling_function": usage.calling_function, 

1627 "call_stack": usage.call_stack, 

1628 } 

1629 for usage in recent_enhanced_data 

1630 ] 

1631 

1632 # Search engine breakdown using ORM 

1633 search_engine_stats = ( 

1634 query.filter(TokenUsage.search_engine_selected.isnot(None)) 

1635 .with_entities( 

1636 TokenUsage.search_engine_selected, 

1637 func.count().label("count"), 

1638 func.avg(TokenUsage.total_tokens).label("avg_tokens"), 

1639 func.avg(TokenUsage.response_time_ms).label( 

1640 "avg_response_time" 

1641 ), 

1642 ) 

1643 .group_by(TokenUsage.search_engine_selected) 

1644 .all() 

1645 ) 

1646 

1647 search_engines = [ 

1648 { 

1649 "search_engine": stat.search_engine_selected, 

1650 "count": stat.count, 

1651 "avg_tokens": round(stat.avg_tokens or 0), 

1652 "avg_response_time": round(stat.avg_response_time or 0), 

1653 } 

1654 for stat in search_engine_stats 

1655 ] 

1656 

1657 # Research phase breakdown using ORM 

1658 phase_stats = ( 

1659 query.filter(TokenUsage.research_phase.isnot(None)) 

1660 .with_entities( 

1661 TokenUsage.research_phase, 

1662 func.count().label("count"), 

1663 func.avg(TokenUsage.total_tokens).label("avg_tokens"), 

1664 func.avg(TokenUsage.response_time_ms).label( 

1665 "avg_response_time" 

1666 ), 

1667 ) 

1668 .group_by(TokenUsage.research_phase) 

1669 .all() 

1670 ) 

1671 

1672 phases = [ 

1673 { 

1674 "phase": stat.research_phase, 

1675 "count": stat.count, 

1676 "avg_tokens": round(stat.avg_tokens or 0), 

1677 "avg_response_time": round(stat.avg_response_time or 0), 

1678 } 

1679 for stat in phase_stats 

1680 ] 

1681 

1682 # Call stack analysis using ORM 

1683 file_stats = ( 

1684 query.filter(TokenUsage.calling_file.isnot(None)) 

1685 .with_entities( 

1686 TokenUsage.calling_file, 

1687 func.count().label("count"), 

1688 func.avg(TokenUsage.total_tokens).label("avg_tokens"), 

1689 ) 

1690 .group_by(TokenUsage.calling_file) 

1691 .order_by(func.count().desc()) 

1692 .limit(10) 

1693 .all() 

1694 ) 

1695 

1696 files = [ 

1697 { 

1698 "file": stat.calling_file, 

1699 "count": stat.count, 

1700 "avg_tokens": round(stat.avg_tokens or 0), 

1701 } 

1702 for stat in file_stats 

1703 ] 

1704 

1705 function_stats = ( 

1706 query.filter(TokenUsage.calling_function.isnot(None)) 

1707 .with_entities( 

1708 TokenUsage.calling_function, 

1709 func.count().label("count"), 

1710 func.avg(TokenUsage.total_tokens).label("avg_tokens"), 

1711 ) 

1712 .group_by(TokenUsage.calling_function) 

1713 .order_by(func.count().desc()) 

1714 .limit(10) 

1715 .all() 

1716 ) 

1717 

1718 functions = [ 

1719 { 

1720 "function": stat.calling_function, 

1721 "count": stat.count, 

1722 "avg_tokens": round(stat.avg_tokens or 0), 

1723 } 

1724 for stat in function_stats 

1725 ] 

1726 

1727 return { 

1728 "recent_enhanced_data": recent_enhanced, 

1729 "performance_stats": perf_stats, 

1730 "mode_breakdown": modes, 

1731 "search_engine_stats": search_engines, 

1732 "phase_breakdown": phases, 

1733 "time_series_data": time_series, 

1734 "call_stack_analysis": { 

1735 "by_file": files, 

1736 "by_function": functions, 

1737 }, 

1738 } 

1739 except Exception: 

1740 logger.exception("Error in get_enhanced_metrics") 

1741 # Return simplified response without non-existent columns 

1742 return { 

1743 "recent_enhanced_data": [], 

1744 "performance_stats": { 

1745 "avg_response_time": 0, 

1746 "min_response_time": 0, 

1747 "max_response_time": 0, 

1748 "success_rate": 0, 

1749 "error_rate": 0, 

1750 "total_enhanced_calls": 0, 

1751 }, 

1752 "mode_breakdown": [], 

1753 "search_engine_stats": [], 

1754 "phase_breakdown": [], 

1755 "time_series_data": [], 

1756 "call_stack_analysis": { 

1757 "by_file": [], 

1758 "by_function": [], 

1759 }, 

1760 } 

1761 

1762 def get_research_timeline_metrics(self, research_id: str) -> Dict[str, Any]: 

1763 """Get timeline metrics for a specific research. 

1764 

1765 Args: 

1766 research_id: The ID of the research 

1767 

1768 Returns: 

1769 Dictionary containing timeline metrics for the research 

1770 """ 

1771 from flask import session as flask_session 

1772 

1773 from ..database.session_context import get_user_db_session 

1774 

1775 username = flask_session.get("username") 

1776 if not username: 

1777 return { 

1778 "research_id": research_id, 

1779 "research_details": {}, 

1780 "timeline": [], 

1781 "summary": { 

1782 "total_calls": 0, 

1783 "total_tokens": 0, 

1784 "total_prompt_tokens": 0, 

1785 "total_completion_tokens": 0, 

1786 "avg_response_time": 0, 

1787 "success_rate": 0, 

1788 }, 

1789 "phase_stats": {}, 

1790 } 

1791 

1792 with get_user_db_session(username) as session: 

1793 # Get all token usage for this research ordered by time including call stack 

1794 timeline_data = session.execute( 

1795 text( 

1796 """ 

1797 SELECT 

1798 timestamp, 

1799 total_tokens, 

1800 prompt_tokens, 

1801 completion_tokens, 

1802 response_time_ms, 

1803 success_status, 

1804 error_type, 

1805 research_phase, 

1806 search_iteration, 

1807 search_engine_selected, 

1808 model_name, 

1809 calling_file, 

1810 calling_function, 

1811 call_stack 

1812 FROM token_usage 

1813 WHERE research_id = :research_id 

1814 ORDER BY timestamp ASC 

1815 """ 

1816 ), 

1817 {"research_id": research_id}, 

1818 ).fetchall() 

1819 

1820 # Format timeline data with cumulative tokens 

1821 timeline = [] 

1822 cumulative_tokens = 0 

1823 cumulative_prompt_tokens = 0 

1824 cumulative_completion_tokens = 0 

1825 

1826 for row in timeline_data: 

1827 cumulative_tokens += row[1] or 0 

1828 cumulative_prompt_tokens += row[2] or 0 

1829 cumulative_completion_tokens += row[3] or 0 

1830 

1831 timeline.append( 

1832 { 

1833 "timestamp": str(row[0]) if row[0] else None, 

1834 "tokens": row[1] or 0, 

1835 "prompt_tokens": row[2] or 0, 

1836 "completion_tokens": row[3] or 0, 

1837 "cumulative_tokens": cumulative_tokens, 

1838 "cumulative_prompt_tokens": cumulative_prompt_tokens, 

1839 "cumulative_completion_tokens": cumulative_completion_tokens, 

1840 "response_time_ms": row[4], 

1841 "success_status": row[5], 

1842 "error_type": row[6], 

1843 "research_phase": row[7], 

1844 "search_iteration": row[8], 

1845 "search_engine_selected": row[9], 

1846 "model_name": row[10], 

1847 "calling_file": row[11], 

1848 "calling_function": row[12], 

1849 "call_stack": row[13], 

1850 } 

1851 ) 

1852 

1853 # Get research basic info 

1854 research_info = session.execute( 

1855 text( 

1856 """ 

1857 SELECT query, mode, status, created_at, completed_at 

1858 FROM research_history 

1859 WHERE id = :research_id 

1860 """ 

1861 ), 

1862 {"research_id": research_id}, 

1863 ).fetchone() 

1864 

1865 research_details = {} 

1866 if research_info: 

1867 research_details = { 

1868 "query": research_info[0], 

1869 "mode": research_info[1], 

1870 "status": research_info[2], 

1871 "created_at": str(research_info[3]) 

1872 if research_info[3] 

1873 else None, 

1874 "completed_at": str(research_info[4]) 

1875 if research_info[4] 

1876 else None, 

1877 } 

1878 

1879 # Calculate summary stats 

1880 total_calls = len(timeline_data) 

1881 total_tokens = cumulative_tokens 

1882 avg_response_time = sum(row[4] or 0 for row in timeline_data) / max( 

1883 total_calls, 1 

1884 ) 

1885 success_rate = ( 

1886 sum(1 for row in timeline_data if row[5] == "success") 

1887 / max(total_calls, 1) 

1888 * 100 

1889 ) 

1890 

1891 # Phase breakdown for this research 

1892 phase_stats = {} 

1893 for row in timeline_data: 

1894 phase = row[7] or "unknown" 

1895 if phase not in phase_stats: 

1896 phase_stats[phase] = { 

1897 "count": 0, 

1898 "tokens": 0, 

1899 "avg_response_time": 0, 

1900 } 

1901 phase_stats[phase]["count"] += 1 

1902 phase_stats[phase]["tokens"] += row[1] or 0 

1903 if row[4]: 

1904 phase_stats[phase]["avg_response_time"] += row[4] 

1905 

1906 # Calculate averages for phases 

1907 for phase in phase_stats: 

1908 if phase_stats[phase]["count"] > 0: 

1909 phase_stats[phase]["avg_response_time"] = round( 

1910 phase_stats[phase]["avg_response_time"] 

1911 / phase_stats[phase]["count"] 

1912 ) 

1913 

1914 return { 

1915 "research_id": research_id, 

1916 "research_details": research_details, 

1917 "timeline": timeline, 

1918 "summary": { 

1919 "total_calls": total_calls, 

1920 "total_tokens": total_tokens, 

1921 "total_prompt_tokens": cumulative_prompt_tokens, 

1922 "total_completion_tokens": cumulative_completion_tokens, 

1923 "avg_response_time": round(avg_response_time), 

1924 "success_rate": round(success_rate, 1), 

1925 }, 

1926 "phase_stats": phase_stats, 

1927 }