Coverage for src/local_deep_research/metrics/token_counter.py: 86%
500 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-03 23:15 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-03 23:15 +0000
1"""Token counting functionality for LLM usage tracking."""
3import inspect
4import json
5import time
6from datetime import datetime, timedelta, UTC
7from pathlib import Path
8from typing import Any, Dict, List, Optional
10from langchain_core.callbacks import BaseCallbackHandler
11from langchain_core.outputs import LLMResult
12from loguru import logger
13from sqlalchemy import func, text
15from ..database.models import ModelUsage, TokenUsage
16from .query_utils import get_research_mode_condition, get_time_filter_condition
19class TokenCountingCallback(BaseCallbackHandler):
20 """Callback handler for counting tokens across different models."""
22 def __init__(
23 self,
24 research_id: Optional[str] = None,
25 research_context: Optional[Dict[str, Any]] = None,
26 ):
27 """Initialize the token counting callback.
29 Args:
30 research_id: The ID of the research to track tokens for
31 research_context: Additional research context for enhanced tracking
32 """
33 super().__init__()
34 self.research_id = research_id
35 self.research_context = research_context or {}
36 self.current_model = None
37 self.current_provider = None
38 self.preset_model = None # Model name set during callback creation
39 self.preset_provider = None # Provider set during callback creation
41 # Phase 1 Enhancement: Track timing and context
42 self.start_time = None
43 self.response_time_ms = None
44 self.success_status = "success"
45 self.error_type = None
47 # Call stack tracking
48 self.calling_file = None
49 self.calling_function = None
50 self.call_stack = None
52 # Context overflow tracking
53 self.context_limit = None
54 self.context_truncated = False
55 self.tokens_truncated = 0
56 self.truncation_ratio = 0.0
57 self.original_prompt_estimate = 0
59 # Raw Ollama response metrics
60 self.ollama_metrics = {}
62 # Track token counts in memory
63 self.counts = {
64 "total_tokens": 0,
65 "total_prompt_tokens": 0,
66 "total_completion_tokens": 0,
67 "by_model": {},
68 }
70 def on_llm_start(
71 self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
72 ) -> None:
73 """Called when LLM starts running."""
74 # Phase 1 Enhancement: Start timing
75 self.start_time = time.time()
77 # Reset per-call truncation state. The callback instance is shared
78 # across every LLM call in a research session (see llm_config.py
79 # wrap_llm), so without this reset the post-loop estimation block's
80 # `if not self.context_truncated` guard would silently disable
81 # [estimated] / [estimated-total-context] detection on every call
82 # after the first one that truncates.
83 self.context_truncated = False
84 self.tokens_truncated = 0
85 self.truncation_ratio = 0.0
87 # Estimate original prompt size (rough estimate: ~4 chars per token)
88 if prompts:
89 total_chars = sum(len(prompt) for prompt in prompts)
90 self.original_prompt_estimate = total_chars // 4
91 logger.debug(
92 f"Estimated prompt tokens: {self.original_prompt_estimate} (from {total_chars} chars)"
93 )
95 # Get context limit from research context (will be set from settings)
96 self.context_limit = self.research_context.get("context_limit")
98 # Phase 1 Enhancement: Capture call stack information
99 try:
100 stack = inspect.stack()
102 # Skip the first few frames (this method, langchain internals)
103 # Look for the first frame that's in our project directory
104 for frame_info in stack[1:]:
105 file_path = frame_info.filename
106 # Look for any frame containing local_deep_research project
107 if (
108 "local_deep_research" in file_path
109 and "site-packages" not in file_path
110 and "venv" not in file_path
111 ):
112 # Extract relative path from local_deep_research
113 if "src/local_deep_research" in file_path:
114 relative_path = file_path.split(
115 "src/local_deep_research"
116 )[-1].lstrip("/")
117 elif "local_deep_research/src" in file_path: 117 ↛ 121line 117 didn't jump to line 121 because the condition on line 117 was always true
118 relative_path = file_path.split(
119 "local_deep_research/src"
120 )[-1].lstrip("/")
121 elif "local_deep_research" in file_path:
122 # Get everything after local_deep_research
123 relative_path = file_path.split("local_deep_research")[
124 -1
125 ].lstrip("/")
126 else:
127 relative_path = Path(file_path).name
129 self.calling_file = relative_path
130 self.calling_function = frame_info.function
132 # Capture a simplified call stack (just the relevant frames)
133 call_stack_frames = []
134 for frame in stack[1:6]: # Limit to 5 frames
135 if (
136 "local_deep_research" in frame.filename
137 and "site-packages" not in frame.filename
138 and "venv" not in frame.filename
139 ):
140 frame_name = f"{Path(frame.filename).name}:{frame.function}:{frame.lineno}"
141 call_stack_frames.append(frame_name)
143 self.call_stack = (
144 " -> ".join(call_stack_frames)
145 if call_stack_frames
146 else None
147 )
148 break
149 except Exception:
150 logger.warning("Error capturing call stack")
151 # Continue without call stack info if there's an error
153 # First, use preset values if available
154 if self.preset_model:
155 self.current_model = self.preset_model
156 else:
157 # Try multiple locations for model name
158 model_name = None
160 # First check invocation_params
161 invocation_params = kwargs.get("invocation_params", {})
162 model_name = invocation_params.get(
163 "model"
164 ) or invocation_params.get("model_name")
166 # Check kwargs directly
167 if not model_name:
168 model_name = kwargs.get("model") or kwargs.get("model_name")
170 # Check serialized data
171 if not model_name and "kwargs" in serialized:
172 model_name = serialized["kwargs"].get("model") or serialized[
173 "kwargs"
174 ].get("model_name")
176 # Check for name in serialized data
177 if not model_name and "name" in serialized:
178 model_name = serialized["name"]
180 # If still not found and we have Ollama, try to extract from the instance
181 if (
182 not model_name
183 and "_type" in serialized
184 and "ChatOllama" in serialized["_type"]
185 ):
186 # For Ollama, the model name might be in the serialized kwargs
187 if "kwargs" in serialized and "model" in serialized["kwargs"]: 187 ↛ 188line 187 didn't jump to line 188 because the condition on line 187 was never true
188 model_name = serialized["kwargs"]["model"]
189 else:
190 # Default to the type if we can't find the actual model
191 model_name = "ollama"
193 # Final fallback
194 if not model_name:
195 if "_type" in serialized:
196 model_name = serialized["_type"]
197 else:
198 model_name = "unknown"
200 self.current_model = model_name
202 # Use preset provider if available
203 if self.preset_provider:
204 self.current_provider = self.preset_provider
205 else:
206 # Extract provider from serialized type or kwargs
207 if "_type" in serialized:
208 type_str = serialized["_type"]
209 if "ChatOllama" in type_str:
210 self.current_provider = "ollama"
211 elif "ChatOpenAI" in type_str:
212 self.current_provider = "openai"
213 elif "ChatAnthropic" in type_str:
214 self.current_provider = "anthropic"
215 else:
216 self.current_provider = kwargs.get("provider", "unknown")
217 else:
218 self.current_provider = kwargs.get("provider", "unknown")
220 # Initialize model tracking if needed
221 if self.current_model not in self.counts["by_model"]:
222 self.counts["by_model"][self.current_model] = {
223 "prompt_tokens": 0,
224 "completion_tokens": 0,
225 "total_tokens": 0,
226 "calls": 0,
227 "provider": self.current_provider,
228 }
230 # Increment call count
231 self.counts["by_model"][self.current_model]["calls"] += 1
233 def _check_context_overflow(
234 self,
235 prompt_eval_count: int,
236 completion_tokens: int = 0,
237 source: str = "",
238 ) -> None:
239 """Check for context overflow based on prompt and total token usage.
241 Args:
242 prompt_eval_count: Number of tokens the model actually processed.
243 completion_tokens: Number of tokens generated (for total-context check).
244 source: Which branch provided the data (for logging).
245 """
246 logger.debug(
247 f"Context overflow check [{source}]: "
248 f"prompt_eval_count={prompt_eval_count}, "
249 f"completion_tokens={completion_tokens}, "
250 f"context_limit={self.context_limit}"
251 )
253 if not self.context_limit or prompt_eval_count <= 0:
254 return
256 # Input-only overflow: prompt at >= 80% of context limit. Matches the
257 # chart-warning threshold and PR #3840's deliberate choice (PR #3792
258 # lowered from 95% → 80%). The total-context branch below uses 95%
259 # because it's a stricter condition (input+output combined).
260 if prompt_eval_count >= self.context_limit * 0.80:
261 self.context_truncated = True
263 if self.original_prompt_estimate > prompt_eval_count:
264 self.tokens_truncated = max(
265 0,
266 self.original_prompt_estimate - prompt_eval_count,
267 )
268 if ( 268 ↛ 282line 268 didn't jump to line 282 because the condition on line 268 was always true
269 self.tokens_truncated > 0
270 and self.original_prompt_estimate > 0
271 ):
272 self.truncation_ratio = (
273 self.tokens_truncated / self.original_prompt_estimate
274 )
275 elif prompt_eval_count > self.context_limit: 275 ↛ 276line 275 didn't jump to line 276 because the condition on line 275 was never true
276 self.tokens_truncated = prompt_eval_count - self.context_limit
277 if self.tokens_truncated > 0 and prompt_eval_count > 0:
278 self.truncation_ratio = (
279 self.tokens_truncated / prompt_eval_count
280 )
282 logger.warning(
283 f"Context overflow detected [provider-confirmed] "
284 f"research_id={self.research_id} "
285 f"model={self.current_model} "
286 f"provider={self.current_provider} "
287 f"source={source} "
288 f"prompt_tokens={prompt_eval_count} "
289 f"context_limit={self.context_limit} "
290 f"tokens_truncated={self.tokens_truncated} "
291 f"truncation_ratio={self.truncation_ratio:.1%}"
292 )
294 # Total-context overflow: input + output exceeds 95% of context limit
295 elif (
296 completion_tokens > 0
297 and prompt_eval_count + completion_tokens
298 >= self.context_limit * 0.95
299 ):
300 total = prompt_eval_count + completion_tokens
301 self.context_truncated = True
302 self.tokens_truncated = max(0, total - self.context_limit)
303 self.truncation_ratio = (
304 self.tokens_truncated / total if total > 0 else 0
305 )
306 logger.warning(
307 f"Context overflow detected [total-context] "
308 f"research_id={self.research_id} "
309 f"model={self.current_model} "
310 f"provider={self.current_provider} "
311 f"source={source} "
312 f"prompt_tokens={prompt_eval_count} "
313 f"completion_tokens={completion_tokens} "
314 f"total_tokens={total} "
315 f"context_limit={self.context_limit} "
316 f"tokens_truncated={self.tokens_truncated} "
317 f"truncation_ratio={self.truncation_ratio:.1%}"
318 )
319 else:
320 logger.debug(
321 f"Context OK [{source}]: "
322 f"{prompt_eval_count}/{self.context_limit} "
323 f"({prompt_eval_count / self.context_limit:.1%})"
324 )
326 def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
327 """Called when LLM ends running."""
328 # Phase 1 Enhancement: Calculate response time
329 if self.start_time:
330 self.response_time_ms = int((time.time() - self.start_time) * 1000)
332 # Extract token usage from response
333 token_usage = None
335 # Check multiple locations for token usage
336 if hasattr(response, "llm_output") and response.llm_output:
337 token_usage = response.llm_output.get(
338 "token_usage"
339 ) or response.llm_output.get("usage", {})
341 # Check for usage metadata in generations (Ollama specific)
342 if not token_usage and hasattr(response, "generations"):
343 for generation_list in response.generations:
344 for generation in generation_list:
345 if hasattr(generation, "message") and hasattr(
346 generation.message, "usage_metadata"
347 ):
348 usage_meta = generation.message.usage_metadata
349 if usage_meta: # Check if usage_metadata is not None
350 token_usage = {
351 "prompt_tokens": usage_meta.get(
352 "input_tokens", 0
353 ),
354 "completion_tokens": usage_meta.get(
355 "output_tokens", 0
356 ),
357 "total_tokens": usage_meta.get(
358 "total_tokens", 0
359 ),
360 }
361 # Check context overflow before breaking
362 # (input_tokens == prompt_eval_count for Ollama)
363 self._check_context_overflow(
364 usage_meta.get("input_tokens", 0),
365 completion_tokens=usage_meta.get(
366 "output_tokens", 0
367 ),
368 source="usage_metadata",
369 )
370 break
371 # Also check response_metadata
372 if hasattr(generation, "message") and hasattr(
373 generation.message, "response_metadata"
374 ):
375 resp_meta = generation.message.response_metadata
376 if resp_meta.get("prompt_eval_count") or resp_meta.get(
377 "eval_count"
378 ):
379 # Capture raw Ollama metrics
380 self.ollama_metrics = {
381 "prompt_eval_count": resp_meta.get(
382 "prompt_eval_count"
383 ),
384 "eval_count": resp_meta.get("eval_count"),
385 "total_duration": resp_meta.get(
386 "total_duration"
387 ),
388 "load_duration": resp_meta.get("load_duration"),
389 "prompt_eval_duration": resp_meta.get(
390 "prompt_eval_duration"
391 ),
392 "eval_duration": resp_meta.get("eval_duration"),
393 }
395 # Check for context overflow (input only)
396 prompt_eval_count = resp_meta.get(
397 "prompt_eval_count", 0
398 )
399 self._check_context_overflow(
400 prompt_eval_count,
401 completion_tokens=resp_meta.get(
402 "eval_count", 0
403 ),
404 source="response_metadata",
405 )
407 token_usage = {
408 "prompt_tokens": resp_meta.get(
409 "prompt_eval_count", 0
410 ),
411 "completion_tokens": resp_meta.get(
412 "eval_count", 0
413 ),
414 "total_tokens": resp_meta.get(
415 "prompt_eval_count", 0
416 )
417 + resp_meta.get("eval_count", 0),
418 }
419 break
420 if token_usage:
421 break
423 # Estimation-based overflow detection for providers that don't echo
424 # prompt_eval_count (OpenAI, Anthropic, OpenRouter, etc.). The Ollama
425 # path above sets context_truncated=True if it detected provider-
426 # confirmed truncation; we only fire here if it didn't. For hosted
427 # providers the API typically rejects oversize prompts rather than
428 # silently truncating, so this signal flags "would-overflow per our
429 # estimate" — not the same as actual truncation, hence the [estimated]
430 # tag in the log message.
431 if not self.context_truncated and self.context_limit:
432 # Input-only overflow: prompt estimate exceeds context limit
433 if self.original_prompt_estimate > self.context_limit:
434 self.context_truncated = True
435 self.tokens_truncated = max(
436 0, self.original_prompt_estimate - self.context_limit
437 )
438 self.truncation_ratio = (
439 self.tokens_truncated / self.original_prompt_estimate
440 if self.original_prompt_estimate > 0
441 else 0
442 )
443 logger.warning(
444 "Context overflow detected [estimated] "
445 f"research_id={self.research_id} "
446 f"model={self.current_model} "
447 f"provider={self.current_provider} "
448 f"estimated_prompt_tokens={self.original_prompt_estimate} "
449 f"context_limit={self.context_limit} "
450 f"tokens_truncated={self.tokens_truncated} "
451 f"truncation_ratio={self.truncation_ratio:.1%}"
452 )
453 # Total-context overflow: input + output exceeds context limit.
454 # The "[estimated-total-context]" tag below refers to the
455 # *detection method* (post-loop fallback path that runs when
456 # _check_context_overflow couldn't fire), not the data source —
457 # the prompt_tokens / completion_tokens here come from the
458 # provider via response.llm_output.token_usage and are actual
459 # counts, not character-based estimates.
460 elif token_usage and isinstance(token_usage, dict): 460 ↛ 483line 460 didn't jump to line 483 because the condition on line 460 was always true
461 prompt_tokens = token_usage.get("prompt_tokens", 0)
462 completion_tokens = token_usage.get("completion_tokens", 0)
463 total = prompt_tokens + completion_tokens
464 if total >= self.context_limit * 0.95:
465 self.context_truncated = True
466 self.tokens_truncated = max(0, total - self.context_limit)
467 self.truncation_ratio = (
468 self.tokens_truncated / total if total > 0 else 0
469 )
470 logger.warning(
471 "Context overflow detected [estimated-total-context] "
472 f"research_id={self.research_id} "
473 f"model={self.current_model} "
474 f"provider={self.current_provider} "
475 f"prompt_tokens={prompt_tokens} "
476 f"completion_tokens={completion_tokens} "
477 f"total_tokens={total} "
478 f"context_limit={self.context_limit} "
479 f"tokens_truncated={self.tokens_truncated} "
480 f"truncation_ratio={self.truncation_ratio:.1%}"
481 )
483 if token_usage and isinstance(token_usage, dict):
484 prompt_tokens = token_usage.get("prompt_tokens", 0)
485 completion_tokens = token_usage.get("completion_tokens", 0)
486 total_tokens = token_usage.get(
487 "total_tokens", prompt_tokens + completion_tokens
488 )
490 # Update in-memory counts
491 self.counts["total_prompt_tokens"] += prompt_tokens
492 self.counts["total_completion_tokens"] += completion_tokens
493 self.counts["total_tokens"] += total_tokens
495 if self.current_model:
496 self.counts["by_model"][self.current_model][
497 "prompt_tokens"
498 ] += prompt_tokens
499 self.counts["by_model"][self.current_model][
500 "completion_tokens"
501 ] += completion_tokens
502 self.counts["by_model"][self.current_model]["total_tokens"] += (
503 total_tokens
504 )
506 # Save to database if we have a research_id
507 if self.research_id:
508 self._save_to_db(prompt_tokens, completion_tokens)
510 def on_llm_error(self, error, **kwargs: Any) -> None:
511 """Called when LLM encounters an error."""
512 # Phase 1 Enhancement: Track errors
513 if self.start_time:
514 self.response_time_ms = int((time.time() - self.start_time) * 1000)
516 self.success_status = "error"
517 self.error_type = str(type(error).__name__)
519 # Still save to database to track failed calls
520 if self.research_id:
521 self._save_to_db(0, 0)
523 def _get_context_overflow_fields(self) -> Dict[str, Any]:
524 """Get context overflow detection fields for database saving."""
525 return {
526 "context_limit": self.context_limit,
527 "context_truncated": self.context_truncated, # Now Boolean
528 "tokens_truncated": self.tokens_truncated
529 if self.context_truncated
530 else None,
531 "truncation_ratio": self.truncation_ratio
532 if self.context_truncated
533 else None,
534 # Raw Ollama metrics
535 "ollama_prompt_eval_count": self.ollama_metrics.get(
536 "prompt_eval_count"
537 ),
538 "ollama_eval_count": self.ollama_metrics.get("eval_count"),
539 "ollama_total_duration": self.ollama_metrics.get("total_duration"),
540 "ollama_load_duration": self.ollama_metrics.get("load_duration"),
541 "ollama_prompt_eval_duration": self.ollama_metrics.get(
542 "prompt_eval_duration"
543 ),
544 "ollama_eval_duration": self.ollama_metrics.get("eval_duration"),
545 }
547 def _save_to_db(self, prompt_tokens: int, completion_tokens: int):
548 """Save token usage to the database."""
549 # Check if we're in a thread - if so, queue the save for later
550 import threading
552 if threading.current_thread().name != "MainThread":
553 # Use thread-safe metrics database for background threads
554 username = (
555 self.research_context.get("username")
556 if self.research_context
557 else None
558 )
560 if not username:
561 logger.warning(
562 f"Cannot save token metrics - no username in research context. "
563 f"Token usage: prompt={prompt_tokens}, completion={completion_tokens}, "
564 f"Research context: {self.research_context}"
565 )
566 return
568 # Import the thread-safe metrics database
570 # Prepare token data
571 token_data = {
572 "model_name": self.current_model,
573 "provider": self.current_provider,
574 "prompt_tokens": prompt_tokens,
575 "completion_tokens": completion_tokens,
576 "research_query": self.research_context.get("research_query"),
577 "research_mode": self.research_context.get("research_mode"),
578 "research_phase": self.research_context.get("research_phase"),
579 "search_iteration": self.research_context.get(
580 "search_iteration"
581 ),
582 "response_time_ms": self.response_time_ms,
583 "success_status": self.success_status,
584 "error_type": self.error_type,
585 "search_engines_planned": self.research_context.get(
586 "search_engines_planned"
587 ),
588 "search_engine_selected": self.research_context.get(
589 "search_engine_selected"
590 ),
591 "calling_file": self.calling_file,
592 "calling_function": self.calling_function,
593 "call_stack": self.call_stack,
594 # Add context overflow fields using helper method
595 **self._get_context_overflow_fields(),
596 }
598 # Convert list to JSON string if needed
599 if isinstance(token_data.get("search_engines_planned"), list):
600 token_data["search_engines_planned"] = json.dumps(
601 token_data["search_engines_planned"]
602 )
604 # Get password from research context
605 password = self.research_context.get("user_password")
606 if not password:
607 logger.warning(
608 f"Cannot save token metrics - no password in research context. "
609 f"Username: {username}, Token usage: prompt={prompt_tokens}, completion={completion_tokens}"
610 )
611 return
613 # Write metrics directly using thread-safe database
614 try:
615 from ..database.thread_metrics import metrics_writer
617 # Set password for this thread
618 metrics_writer.set_user_password(username, password)
620 # Write metrics to encrypted database
621 metrics_writer.write_token_metrics(
622 username, self.research_id, token_data
623 )
624 except Exception:
625 logger.exception("Failed to write metrics from thread")
626 return
628 # In MainThread, save directly
629 try:
630 from flask import session as flask_session
631 from ..database.session_context import get_user_db_session
633 username = flask_session.get("username")
634 if not username:
635 logger.debug("No user session, skipping token metrics save")
636 return
638 with get_user_db_session(username) as session:
639 # Phase 1 Enhancement: Prepare additional context
640 research_query = self.research_context.get("research_query")
641 research_mode = self.research_context.get("research_mode")
642 research_phase = self.research_context.get("research_phase")
643 search_iteration = self.research_context.get("search_iteration")
644 search_engines_planned = self.research_context.get(
645 "search_engines_planned"
646 )
647 search_engine_selected = self.research_context.get(
648 "search_engine_selected"
649 )
651 # Debug logging for search engine context
652 if search_engines_planned or search_engine_selected:
653 logger.info(
654 f"Token tracking - Search context: planned={search_engines_planned}, selected={search_engine_selected}, phase={research_phase}"
655 )
656 else:
657 logger.debug(
658 f"Token tracking - No search engine context yet, phase={research_phase}"
659 )
661 # Convert list to JSON string if needed
662 if isinstance(search_engines_planned, list):
663 search_engines_planned = json.dumps(search_engines_planned)
665 # Log context overflow detection values before saving
666 logger.debug(
667 f"Saving TokenUsage - context_limit: {self.context_limit}, "
668 f"context_truncated: {self.context_truncated}, "
669 f"tokens_truncated: {self.tokens_truncated}, "
670 f"ollama_prompt_eval_count: {self.ollama_metrics.get('prompt_eval_count')}, "
671 f"prompt_tokens: {prompt_tokens}, "
672 f"completion_tokens: {completion_tokens}"
673 )
675 # Add token usage record with enhanced fields
676 token_usage = TokenUsage(
677 research_id=self.research_id,
678 model_name=self.current_model,
679 model_provider=self.current_provider, # Added provider
680 # for accurate cost tracking
681 prompt_tokens=prompt_tokens,
682 completion_tokens=completion_tokens,
683 total_tokens=prompt_tokens + completion_tokens,
684 # Phase 1 Enhancement: Research context
685 research_query=research_query,
686 research_mode=research_mode,
687 research_phase=research_phase,
688 search_iteration=search_iteration,
689 # Phase 1 Enhancement: Performance metrics
690 response_time_ms=self.response_time_ms,
691 success_status=self.success_status,
692 error_type=self.error_type,
693 # Phase 1 Enhancement: Search engine context
694 search_engines_planned=search_engines_planned,
695 search_engine_selected=search_engine_selected,
696 # Phase 1 Enhancement: Call stack tracking
697 calling_file=self.calling_file,
698 calling_function=self.calling_function,
699 call_stack=self.call_stack,
700 # Add context overflow fields using helper method
701 **self._get_context_overflow_fields(),
702 )
703 session.add(token_usage)
705 # Update or create model usage statistics
706 model_usage = (
707 session.query(ModelUsage)
708 .filter_by(
709 model_name=self.current_model,
710 )
711 .first()
712 )
714 if model_usage:
715 model_usage.total_tokens += (
716 prompt_tokens + completion_tokens
717 )
718 model_usage.total_calls += 1
719 else:
720 model_usage = ModelUsage(
721 model_name=self.current_model,
722 model_provider=self.current_provider,
723 total_tokens=prompt_tokens + completion_tokens,
724 total_calls=1,
725 )
726 session.add(model_usage)
728 # Commit the transaction
729 session.commit()
731 except Exception:
732 logger.exception("Error saving token usage to database")
734 def get_counts(self) -> Dict[str, Any]:
735 """Get the current token counts."""
736 return self.counts
739class TokenCounter:
740 """Manager class for token counting across the application."""
742 def __init__(self):
743 """Initialize the token counter."""
745 def create_callback(
746 self,
747 research_id: Optional[str] = None,
748 research_context: Optional[Dict[str, Any]] = None,
749 ) -> TokenCountingCallback:
750 """Create a new token counting callback.
752 Args:
753 research_id: The ID of the research to track tokens for
754 research_context: Additional research context for enhanced tracking
756 Returns:
757 A new TokenCountingCallback instance
758 """
759 return TokenCountingCallback(
760 research_id=research_id, research_context=research_context
761 )
763 def get_research_metrics(self, research_id: str) -> Dict[str, Any]:
764 """Get token metrics for a specific research.
766 Args:
767 research_id: The ID of the research
769 Returns:
770 Dictionary containing token usage metrics
771 """
772 from flask import session as flask_session
774 from ..database.session_context import get_user_db_session
776 username = flask_session.get("username")
777 if not username:
778 return {
779 "research_id": research_id,
780 "total_tokens": 0,
781 "total_calls": 0,
782 "model_usage": [],
783 }
785 with get_user_db_session(username) as session:
786 # Get token usage for this research from TokenUsage table
787 from sqlalchemy import func
789 token_usages = (
790 session.query(
791 TokenUsage.model_name,
792 TokenUsage.model_provider,
793 func.sum(TokenUsage.prompt_tokens).label("prompt_tokens"),
794 func.sum(TokenUsage.completion_tokens).label(
795 "completion_tokens"
796 ),
797 func.sum(TokenUsage.total_tokens).label("total_tokens"),
798 func.count().label("calls"),
799 )
800 .filter_by(research_id=research_id)
801 .group_by(TokenUsage.model_name, TokenUsage.model_provider)
802 .order_by(func.sum(TokenUsage.total_tokens).desc())
803 .all()
804 )
806 model_usage = []
807 total_tokens = 0
808 total_calls = 0
810 for usage in token_usages:
811 model_usage.append(
812 {
813 "model": usage.model_name,
814 "provider": usage.model_provider,
815 "tokens": usage.total_tokens or 0,
816 "calls": usage.calls or 0,
817 "prompt_tokens": usage.prompt_tokens or 0,
818 "completion_tokens": usage.completion_tokens or 0,
819 }
820 )
821 total_tokens += usage.total_tokens or 0
822 total_calls += usage.calls or 0
824 return {
825 "research_id": research_id,
826 "total_tokens": total_tokens,
827 "total_calls": total_calls,
828 "model_usage": model_usage,
829 }
831 def get_overall_metrics(
832 self, period: str = "30d", research_mode: str = "all"
833 ) -> Dict[str, Any]:
834 """Get overall token metrics across all researches.
836 Args:
837 period: Time period to filter by ('7d', '30d', '3m', '1y', 'all')
838 research_mode: Research mode to filter by ('quick', 'detailed', 'all')
840 Returns:
841 Dictionary containing overall metrics
842 """
843 return self._get_metrics_from_encrypted_db(period, research_mode)
845 def _get_metrics_from_encrypted_db(
846 self, period: str, research_mode: str
847 ) -> Dict[str, Any]:
848 """Get metrics from user's encrypted database."""
849 from flask import session as flask_session
851 from ..database.session_context import get_user_db_session
853 username = flask_session.get("username")
854 if not username:
855 return self._get_empty_metrics()
857 try:
858 with get_user_db_session(username) as session:
859 # Build base query with filters
860 query = session.query(TokenUsage)
862 # Apply time filter
863 time_condition = get_time_filter_condition(
864 period, TokenUsage.timestamp
865 )
866 if time_condition is not None:
867 query = query.filter(time_condition)
869 # Apply research mode filter
870 mode_condition = get_research_mode_condition(
871 research_mode, TokenUsage.research_mode
872 )
873 if mode_condition is not None:
874 query = query.filter(mode_condition)
876 # Total tokens from TokenUsage
877 total_tokens = (
878 query.with_entities(
879 func.sum(TokenUsage.total_tokens)
880 ).scalar()
881 or 0
882 )
884 # Import ResearchHistory model
885 from ..database.models.research import ResearchHistory
887 # Count researches from ResearchHistory table
888 research_query = session.query(func.count(ResearchHistory.id))
890 # Debug: Check if any research history records exist at all
891 all_research_count = (
892 session.query(func.count(ResearchHistory.id)).scalar() or 0
893 )
894 logger.debug(
895 f"Total ResearchHistory records in database: {all_research_count}"
896 )
898 # Debug: List first few research IDs and their timestamps
899 sample_researches = (
900 session.query(
901 ResearchHistory.id,
902 ResearchHistory.created_at,
903 ResearchHistory.mode,
904 )
905 .limit(5)
906 .all()
907 )
908 if sample_researches: 908 ↛ 909line 908 didn't jump to line 909 because the condition on line 908 was never true
909 logger.debug("Sample ResearchHistory records:")
910 for r_id, r_created, r_mode in sample_researches:
911 logger.debug(
912 f" - ID: {r_id}, Created: {r_created}, Mode: {r_mode}"
913 )
914 else:
915 logger.debug("No ResearchHistory records found in database")
917 # Get time filter conditions for ResearchHistory query
918 start_time, end_time = None, None
919 if period != "all":
920 if period == "today": 920 ↛ 921line 920 didn't jump to line 921 because the condition on line 920 was never true
921 start_time = datetime.now(UTC).replace(
922 hour=0, minute=0, second=0, microsecond=0
923 )
924 elif period == "week": 924 ↛ 925line 924 didn't jump to line 925 because the condition on line 924 was never true
925 start_time = datetime.now(UTC) - timedelta(days=7)
926 elif period == "month": 926 ↛ 927line 926 didn't jump to line 927 because the condition on line 926 was never true
927 start_time = datetime.now(UTC) - timedelta(days=30)
929 if start_time: 929 ↛ 930line 929 didn't jump to line 930 because the condition on line 929 was never true
930 end_time = datetime.now(UTC)
932 # Apply time filter if specified
933 if start_time and end_time: 933 ↛ 934line 933 didn't jump to line 934 because the condition on line 933 was never true
934 research_query = research_query.filter(
935 ResearchHistory.created_at >= start_time.isoformat(),
936 ResearchHistory.created_at <= end_time.isoformat(),
937 )
939 # Apply mode filter if specified
940 mode_filter = research_mode if research_mode != "all" else None
941 if mode_filter:
942 logger.debug(f"Applying mode filter: {mode_filter}")
943 research_query = research_query.filter(
944 ResearchHistory.mode == mode_filter
945 )
947 total_researches = research_query.scalar() or 0
948 logger.debug(
949 f"Final filtered research count: {total_researches}"
950 )
952 # Also check distinct research_ids in TokenUsage for comparison
953 token_research_count = (
954 session.query(
955 func.count(func.distinct(TokenUsage.research_id))
956 ).scalar()
957 or 0
958 )
959 logger.debug(
960 f"Distinct research_ids in TokenUsage: {token_research_count}"
961 )
963 # Model statistics using ORM aggregation
964 model_stats_query = session.query(
965 TokenUsage.model_name,
966 func.sum(TokenUsage.total_tokens).label("tokens"),
967 func.count().label("calls"),
968 func.sum(TokenUsage.prompt_tokens).label("prompt_tokens"),
969 func.sum(TokenUsage.completion_tokens).label(
970 "completion_tokens"
971 ),
972 ).filter(TokenUsage.model_name.isnot(None))
974 # Apply same filters to model stats
975 if time_condition is not None:
976 model_stats_query = model_stats_query.filter(time_condition)
977 if mode_condition is not None:
978 model_stats_query = model_stats_query.filter(mode_condition)
980 model_stats = (
981 model_stats_query.group_by(TokenUsage.model_name)
982 .order_by(func.sum(TokenUsage.total_tokens).desc())
983 .all()
984 )
986 # Batch load provider info from ModelUsage table (fix N+1)
987 model_names = [stat.model_name for stat in model_stats]
988 provider_map = {}
989 if model_names: 989 ↛ 990line 989 didn't jump to line 990 because the condition on line 989 was never true
990 provider_results = (
991 session.query(
992 ModelUsage.model_name, ModelUsage.model_provider
993 )
994 .filter(ModelUsage.model_name.in_(model_names))
995 .order_by(ModelUsage.id)
996 .all()
997 )
998 for model_name, model_provider in provider_results:
999 provider_map.setdefault(model_name, model_provider)
1001 by_model = []
1002 for stat in model_stats: 1002 ↛ 1003line 1002 didn't jump to line 1003 because the loop on line 1002 never started
1003 provider = provider_map.get(stat.model_name, "unknown")
1005 by_model.append(
1006 {
1007 "model": stat.model_name,
1008 "provider": provider,
1009 "tokens": stat.tokens,
1010 "calls": stat.calls,
1011 "prompt_tokens": stat.prompt_tokens,
1012 "completion_tokens": stat.completion_tokens,
1013 }
1014 )
1016 # Get recent researches with token usage
1017 # Note: This requires research_history table - for now we'll use available data
1018 recent_research_query = session.query(
1019 TokenUsage.research_id,
1020 func.sum(TokenUsage.total_tokens).label("token_count"),
1021 func.max(TokenUsage.timestamp).label("latest_timestamp"),
1022 ).filter(TokenUsage.research_id.isnot(None))
1024 if time_condition is not None:
1025 recent_research_query = recent_research_query.filter(
1026 time_condition
1027 )
1028 if mode_condition is not None:
1029 recent_research_query = recent_research_query.filter(
1030 mode_condition
1031 )
1033 recent_research_data = (
1034 recent_research_query.group_by(TokenUsage.research_id)
1035 .order_by(func.max(TokenUsage.timestamp).desc())
1036 .limit(10)
1037 .all()
1038 )
1040 # Batch load research queries for recent researches (fix N+1)
1041 recent_research_ids = [
1042 r.research_id for r in recent_research_data
1043 ]
1044 research_query_map = {}
1045 if recent_research_ids: 1045 ↛ 1047line 1045 didn't jump to line 1047 because the condition on line 1045 was never true
1046 # Get first non-null research_query for each research_id
1047 query_results = (
1048 session.query(
1049 TokenUsage.research_id, TokenUsage.research_query
1050 )
1051 .filter(
1052 TokenUsage.research_id.in_(recent_research_ids),
1053 TokenUsage.research_query.isnot(None),
1054 )
1055 .order_by(TokenUsage.id)
1056 .all()
1057 )
1058 for research_id, research_query in query_results:
1059 if research_id not in research_query_map:
1060 research_query_map[research_id] = research_query
1062 recent_researches = []
1063 for research_data in recent_research_data: 1063 ↛ 1064line 1063 didn't jump to line 1064 because the loop on line 1063 never started
1064 query_text = research_query_map.get(
1065 research_data.research_id,
1066 f"Research {research_data.research_id}",
1067 )
1069 recent_researches.append(
1070 {
1071 "id": research_data.research_id,
1072 "query": query_text,
1073 "tokens": research_data.token_count or 0,
1074 "created_at": research_data.latest_timestamp,
1075 }
1076 )
1078 # Token breakdown statistics
1079 breakdown_query = query.with_entities(
1080 func.sum(TokenUsage.prompt_tokens).label(
1081 "total_input_tokens"
1082 ),
1083 func.sum(TokenUsage.completion_tokens).label(
1084 "total_output_tokens"
1085 ),
1086 func.avg(TokenUsage.prompt_tokens).label(
1087 "avg_input_tokens"
1088 ),
1089 func.avg(TokenUsage.completion_tokens).label(
1090 "avg_output_tokens"
1091 ),
1092 func.avg(TokenUsage.total_tokens).label("avg_total_tokens"),
1093 )
1094 token_breakdown = breakdown_query.first()
1096 # Get rate limiting metrics
1097 from ..database.models import (
1098 RateLimitAttempt,
1099 RateLimitEstimate,
1100 )
1102 # Get rate limit attempts
1103 rate_limit_query = session.query(RateLimitAttempt)
1105 # Apply time filter
1106 if time_condition is not None:
1107 # RateLimitAttempt uses timestamp as float, not datetime
1108 if period == "7d":
1109 cutoff_time = time.time() - (7 * 24 * 3600)
1110 elif period == "30d": 1110 ↛ 1112line 1110 didn't jump to line 1112 because the condition on line 1110 was always true
1111 cutoff_time = time.time() - (30 * 24 * 3600)
1112 elif period == "3m":
1113 cutoff_time = time.time() - (90 * 24 * 3600)
1114 elif period == "1y":
1115 cutoff_time = time.time() - (365 * 24 * 3600)
1116 else: # all
1117 cutoff_time = 0
1119 if cutoff_time > 0: 1119 ↛ 1125line 1119 didn't jump to line 1125 because the condition on line 1119 was always true
1120 rate_limit_query = rate_limit_query.filter(
1121 RateLimitAttempt.timestamp >= cutoff_time
1122 )
1124 # Get rate limit statistics
1125 total_attempts = rate_limit_query.count()
1126 successful_attempts = rate_limit_query.filter(
1127 RateLimitAttempt.success
1128 ).count()
1129 failed_attempts = total_attempts - successful_attempts
1131 # Count rate limiting events (failures with RateLimitError)
1132 rate_limit_events = rate_limit_query.filter(
1133 ~RateLimitAttempt.success,
1134 RateLimitAttempt.error_type == "RateLimitError",
1135 ).count()
1137 logger.debug(
1138 f"Rate limit attempts in database: total={total_attempts}, successful={successful_attempts}"
1139 )
1141 # Get all attempts for detailed calculations
1142 attempts = rate_limit_query.all()
1144 # Calculate average wait times
1145 if attempts: 1145 ↛ 1146line 1145 didn't jump to line 1146 because the condition on line 1145 was never true
1146 avg_wait_time = sum(a.wait_time for a in attempts) / len(
1147 attempts
1148 )
1149 successful_wait_times = [
1150 a.wait_time for a in attempts if a.success
1151 ]
1152 avg_successful_wait = (
1153 sum(successful_wait_times) / len(successful_wait_times)
1154 if successful_wait_times
1155 else 0
1156 )
1157 else:
1158 avg_wait_time = 0
1159 avg_successful_wait = 0
1161 # Get tracked engines - count distinct engine types from attempts
1162 tracked_engines_query = session.query(
1163 func.count(func.distinct(RateLimitAttempt.engine_type))
1164 )
1165 if cutoff_time > 0: 1165 ↛ 1169line 1165 didn't jump to line 1169 because the condition on line 1165 was always true
1166 tracked_engines_query = tracked_engines_query.filter(
1167 RateLimitAttempt.timestamp >= cutoff_time
1168 )
1169 tracked_engines = tracked_engines_query.scalar() or 0
1171 # Get engine-specific stats from attempts
1172 engine_stats = []
1174 # Get distinct engine types from attempts
1175 engine_types_query = session.query(
1176 RateLimitAttempt.engine_type
1177 ).distinct()
1178 if cutoff_time > 0: 1178 ↛ 1182line 1178 didn't jump to line 1182 because the condition on line 1178 was always true
1179 engine_types_query = engine_types_query.filter(
1180 RateLimitAttempt.timestamp >= cutoff_time
1181 )
1182 engine_types = [
1183 row.engine_type for row in engine_types_query.all()
1184 ]
1186 # Batch-load all estimates (fix N+1 query)
1187 estimates_map = {}
1188 if engine_types: 1188 ↛ 1189line 1188 didn't jump to line 1189 because the condition on line 1188 was never true
1189 all_estimates = (
1190 session.query(RateLimitEstimate)
1191 .filter(RateLimitEstimate.engine_type.in_(engine_types))
1192 .all()
1193 )
1194 estimates_map = {e.engine_type: e for e in all_estimates}
1196 for engine_type in engine_types: 1196 ↛ 1197line 1196 didn't jump to line 1197 because the loop on line 1196 never started
1197 engine_attempts_list = [
1198 a for a in attempts if a.engine_type == engine_type
1199 ]
1200 engine_attempts = len(engine_attempts_list)
1201 engine_success = len(
1202 [a for a in engine_attempts_list if a.success]
1203 )
1205 # Get estimate if exists
1206 estimate = estimates_map.get(engine_type)
1208 # Calculate recent success rate
1209 recent_success_rate = (
1210 (engine_success / engine_attempts * 100)
1211 if engine_attempts > 0
1212 else 0
1213 )
1215 # Determine status based on success rate
1216 if estimate:
1217 status = (
1218 "healthy"
1219 if estimate.success_rate > 0.8
1220 else "degraded"
1221 if estimate.success_rate > 0.5
1222 else "poor"
1223 )
1224 else:
1225 status = (
1226 "healthy"
1227 if recent_success_rate > 80
1228 else "degraded"
1229 if recent_success_rate > 50
1230 else "poor"
1231 )
1233 engine_stat = {
1234 "engine": engine_type,
1235 "base_wait": estimate.base_wait_seconds
1236 if estimate
1237 else 0.0,
1238 "base_wait_seconds": round(
1239 estimate.base_wait_seconds if estimate else 0.0, 2
1240 ),
1241 "min_wait_seconds": round(
1242 estimate.min_wait_seconds if estimate else 0.0, 2
1243 ),
1244 "max_wait_seconds": round(
1245 estimate.max_wait_seconds if estimate else 0.0, 2
1246 ),
1247 "success_rate": round(estimate.success_rate * 100, 1)
1248 if estimate
1249 else recent_success_rate,
1250 "total_attempts": estimate.total_attempts
1251 if estimate
1252 else engine_attempts,
1253 "recent_attempts": engine_attempts,
1254 "recent_success_rate": round(recent_success_rate, 1),
1255 "attempts": engine_attempts,
1256 "status": status,
1257 }
1259 if estimate:
1260 engine_stat["last_updated"] = datetime.fromtimestamp(
1261 estimate.last_updated
1262 ).strftime("%Y-%m-%d %H:%M:%S")
1263 else:
1264 engine_stat["last_updated"] = "Never"
1266 engine_stats.append(engine_stat)
1268 logger.debug(
1269 f"Tracked engines: {tracked_engines}, engine_stats: {engine_stats}"
1270 )
1272 result = {
1273 "total_tokens": total_tokens,
1274 "total_researches": total_researches,
1275 "by_model": by_model,
1276 "recent_researches": recent_researches,
1277 "token_breakdown": {
1278 "total_input_tokens": int(
1279 token_breakdown.total_input_tokens or 0
1280 ),
1281 "total_output_tokens": int(
1282 token_breakdown.total_output_tokens or 0
1283 ),
1284 "avg_input_tokens": int(
1285 token_breakdown.avg_input_tokens or 0
1286 ),
1287 "avg_output_tokens": int(
1288 token_breakdown.avg_output_tokens or 0
1289 ),
1290 "avg_total_tokens": int(
1291 token_breakdown.avg_total_tokens or 0
1292 ),
1293 },
1294 "rate_limiting": {
1295 "total_attempts": total_attempts,
1296 "successful_attempts": successful_attempts,
1297 "failed_attempts": failed_attempts,
1298 "success_rate": (
1299 successful_attempts / total_attempts * 100
1300 )
1301 if total_attempts > 0
1302 else 0,
1303 "rate_limit_events": rate_limit_events,
1304 "avg_wait_time": round(float(avg_wait_time), 2),
1305 "avg_successful_wait": round(
1306 float(avg_successful_wait), 2
1307 ),
1308 "tracked_engines": tracked_engines,
1309 "engine_stats": engine_stats,
1310 "total_engines_tracked": tracked_engines,
1311 "healthy_engines": len(
1312 [
1313 s
1314 for s in engine_stats
1315 if s["status"] == "healthy"
1316 ]
1317 ),
1318 "degraded_engines": len(
1319 [
1320 s
1321 for s in engine_stats
1322 if s["status"] == "degraded"
1323 ]
1324 ),
1325 "poor_engines": len(
1326 [s for s in engine_stats if s["status"] == "poor"]
1327 ),
1328 },
1329 }
1331 logger.debug(
1332 f"Returning from _get_metrics_from_encrypted_db - total_researches: {result['total_researches']}"
1333 )
1334 return result
1335 except Exception:
1336 logger.exception(
1337 "CRITICAL ERROR accessing encrypted database for metrics"
1338 )
1339 return self._get_empty_metrics()
1341 def _get_empty_metrics(self) -> Dict[str, Any]:
1342 """Return empty metrics structure when no data is available."""
1343 return {
1344 "total_tokens": 0,
1345 "total_researches": 0,
1346 "by_model": [],
1347 "recent_researches": [],
1348 "token_breakdown": {
1349 "prompt_tokens": 0,
1350 "completion_tokens": 0,
1351 "avg_prompt_tokens": 0,
1352 "avg_completion_tokens": 0,
1353 "avg_total_tokens": 0,
1354 },
1355 }
1357 def get_enhanced_metrics(
1358 self, period: str = "30d", research_mode: str = "all"
1359 ) -> Dict[str, Any]:
1360 """Get enhanced Phase 1 tracking metrics.
1362 Args:
1363 period: Time period to filter by ('7d', '30d', '3m', '1y', 'all')
1364 research_mode: Research mode to filter by ('quick', 'detailed', 'all')
1366 Returns:
1367 Dictionary containing enhanced metrics data including time series
1368 """
1369 from flask import session as flask_session
1371 from ..database.session_context import get_user_db_session
1373 username = flask_session.get("username")
1374 if not username:
1375 # Return empty metrics structure when no user session
1376 return {
1377 "recent_enhanced_data": [],
1378 "performance_stats": {
1379 "avg_response_time": 0,
1380 "min_response_time": 0,
1381 "max_response_time": 0,
1382 "success_rate": 0,
1383 "error_rate": 0,
1384 "total_enhanced_calls": 0,
1385 },
1386 "mode_breakdown": [],
1387 "search_engine_stats": [],
1388 "phase_breakdown": [],
1389 "time_series_data": [],
1390 "call_stack_analysis": {
1391 "by_file": [],
1392 "by_function": [],
1393 },
1394 }
1396 try:
1397 with get_user_db_session(username) as session:
1398 # Build base query with filters
1399 query = session.query(TokenUsage)
1401 # Apply time filter
1402 time_condition = get_time_filter_condition(
1403 period, TokenUsage.timestamp
1404 )
1405 if time_condition is not None: 1405 ↛ 1409line 1405 didn't jump to line 1409 because the condition on line 1405 was always true
1406 query = query.filter(time_condition)
1408 # Apply research mode filter
1409 mode_condition = get_research_mode_condition(
1410 research_mode, TokenUsage.research_mode
1411 )
1412 if mode_condition is not None: 1412 ↛ 1413line 1412 didn't jump to line 1413 because the condition on line 1412 was never true
1413 query = query.filter(mode_condition)
1415 # Get time series data for the chart - most important for "Token Consumption Over Time"
1416 time_series_query = query.filter(
1417 TokenUsage.timestamp.isnot(None),
1418 TokenUsage.total_tokens > 0,
1419 ).order_by(TokenUsage.timestamp.asc())
1421 # Limit to recent data for performance
1422 if period != "all": 1422 ↛ 1425line 1422 didn't jump to line 1425 because the condition on line 1422 was always true
1423 time_series_query = time_series_query.limit(200)
1425 time_series_data = time_series_query.all()
1427 # Format time series data with cumulative calculations
1428 time_series = []
1429 cumulative_tokens = 0
1430 cumulative_prompt_tokens = 0
1431 cumulative_completion_tokens = 0
1433 for usage in time_series_data: 1433 ↛ 1434line 1433 didn't jump to line 1434 because the loop on line 1433 never started
1434 cumulative_tokens += usage.total_tokens or 0
1435 cumulative_prompt_tokens += usage.prompt_tokens or 0
1436 cumulative_completion_tokens += usage.completion_tokens or 0
1438 time_series.append(
1439 {
1440 "timestamp": str(usage.timestamp)
1441 if usage.timestamp
1442 else None,
1443 "tokens": usage.total_tokens or 0,
1444 "prompt_tokens": usage.prompt_tokens or 0,
1445 "completion_tokens": usage.completion_tokens or 0,
1446 "cumulative_tokens": cumulative_tokens,
1447 "cumulative_prompt_tokens": cumulative_prompt_tokens,
1448 "cumulative_completion_tokens": cumulative_completion_tokens,
1449 "research_id": usage.research_id,
1450 }
1451 )
1453 # Basic performance stats using ORM
1454 performance_query = query.filter(
1455 TokenUsage.response_time_ms.isnot(None)
1456 )
1457 total_calls = performance_query.count()
1459 if total_calls > 0:
1460 avg_response_time = (
1461 performance_query.with_entities(
1462 func.avg(TokenUsage.response_time_ms)
1463 ).scalar()
1464 or 0
1465 )
1466 min_response_time = (
1467 performance_query.with_entities(
1468 func.min(TokenUsage.response_time_ms)
1469 ).scalar()
1470 or 0
1471 )
1472 max_response_time = (
1473 performance_query.with_entities(
1474 func.max(TokenUsage.response_time_ms)
1475 ).scalar()
1476 or 0
1477 )
1478 success_count = performance_query.filter(
1479 TokenUsage.success_status == "success"
1480 ).count()
1481 error_count = performance_query.filter(
1482 TokenUsage.success_status == "error"
1483 ).count()
1485 perf_stats = {
1486 "avg_response_time": round(avg_response_time),
1487 "min_response_time": min_response_time,
1488 "max_response_time": max_response_time,
1489 "success_rate": (
1490 round((success_count / total_calls * 100), 1)
1491 if total_calls > 0
1492 else 0
1493 ),
1494 "error_rate": (
1495 round((error_count / total_calls * 100), 1)
1496 if total_calls > 0
1497 else 0
1498 ),
1499 "total_enhanced_calls": total_calls,
1500 }
1501 else:
1502 perf_stats = {
1503 "avg_response_time": 0,
1504 "min_response_time": 0,
1505 "max_response_time": 0,
1506 "success_rate": 0,
1507 "error_rate": 0,
1508 "total_enhanced_calls": 0,
1509 }
1511 # Research mode breakdown using ORM
1512 mode_stats = (
1513 query.filter(TokenUsage.research_mode.isnot(None))
1514 .with_entities(
1515 TokenUsage.research_mode,
1516 func.count().label("count"),
1517 func.avg(TokenUsage.total_tokens).label("avg_tokens"),
1518 func.avg(TokenUsage.response_time_ms).label(
1519 "avg_response_time"
1520 ),
1521 )
1522 .group_by(TokenUsage.research_mode)
1523 .all()
1524 )
1526 modes = [
1527 {
1528 "mode": stat.research_mode,
1529 "count": stat.count,
1530 "avg_tokens": round(stat.avg_tokens or 0),
1531 "avg_response_time": round(stat.avg_response_time or 0),
1532 }
1533 for stat in mode_stats
1534 ]
1536 # Recent enhanced data (simplified)
1537 recent_enhanced_query = (
1538 query.filter(TokenUsage.research_query.isnot(None))
1539 .order_by(TokenUsage.timestamp.desc())
1540 .limit(50)
1541 )
1543 recent_enhanced_data = recent_enhanced_query.all()
1544 recent_enhanced = [
1545 {
1546 "research_query": usage.research_query,
1547 "research_mode": usage.research_mode,
1548 "research_phase": usage.research_phase,
1549 "search_iteration": usage.search_iteration,
1550 "response_time_ms": usage.response_time_ms,
1551 "success_status": usage.success_status,
1552 "error_type": usage.error_type,
1553 "search_engines_planned": usage.search_engines_planned,
1554 "search_engine_selected": usage.search_engine_selected,
1555 "total_tokens": usage.total_tokens,
1556 "prompt_tokens": usage.prompt_tokens,
1557 "completion_tokens": usage.completion_tokens,
1558 "timestamp": str(usage.timestamp)
1559 if usage.timestamp
1560 else None,
1561 "research_id": usage.research_id,
1562 "calling_file": usage.calling_file,
1563 "calling_function": usage.calling_function,
1564 "call_stack": usage.call_stack,
1565 }
1566 for usage in recent_enhanced_data
1567 ]
1569 # Search engine breakdown using ORM
1570 search_engine_stats = (
1571 query.filter(TokenUsage.search_engine_selected.isnot(None))
1572 .with_entities(
1573 TokenUsage.search_engine_selected,
1574 func.count().label("count"),
1575 func.avg(TokenUsage.total_tokens).label("avg_tokens"),
1576 func.avg(TokenUsage.response_time_ms).label(
1577 "avg_response_time"
1578 ),
1579 )
1580 .group_by(TokenUsage.search_engine_selected)
1581 .all()
1582 )
1584 search_engines = [
1585 {
1586 "search_engine": stat.search_engine_selected,
1587 "count": stat.count,
1588 "avg_tokens": round(stat.avg_tokens or 0),
1589 "avg_response_time": round(stat.avg_response_time or 0),
1590 }
1591 for stat in search_engine_stats
1592 ]
1594 # Research phase breakdown using ORM
1595 phase_stats = (
1596 query.filter(TokenUsage.research_phase.isnot(None))
1597 .with_entities(
1598 TokenUsage.research_phase,
1599 func.count().label("count"),
1600 func.avg(TokenUsage.total_tokens).label("avg_tokens"),
1601 func.avg(TokenUsage.response_time_ms).label(
1602 "avg_response_time"
1603 ),
1604 )
1605 .group_by(TokenUsage.research_phase)
1606 .all()
1607 )
1609 phases = [
1610 {
1611 "phase": stat.research_phase,
1612 "count": stat.count,
1613 "avg_tokens": round(stat.avg_tokens or 0),
1614 "avg_response_time": round(stat.avg_response_time or 0),
1615 }
1616 for stat in phase_stats
1617 ]
1619 # Call stack analysis using ORM
1620 file_stats = (
1621 query.filter(TokenUsage.calling_file.isnot(None))
1622 .with_entities(
1623 TokenUsage.calling_file,
1624 func.count().label("count"),
1625 func.avg(TokenUsage.total_tokens).label("avg_tokens"),
1626 )
1627 .group_by(TokenUsage.calling_file)
1628 .order_by(func.count().desc())
1629 .limit(10)
1630 .all()
1631 )
1633 files = [
1634 {
1635 "file": stat.calling_file,
1636 "count": stat.count,
1637 "avg_tokens": round(stat.avg_tokens or 0),
1638 }
1639 for stat in file_stats
1640 ]
1642 function_stats = (
1643 query.filter(TokenUsage.calling_function.isnot(None))
1644 .with_entities(
1645 TokenUsage.calling_function,
1646 func.count().label("count"),
1647 func.avg(TokenUsage.total_tokens).label("avg_tokens"),
1648 )
1649 .group_by(TokenUsage.calling_function)
1650 .order_by(func.count().desc())
1651 .limit(10)
1652 .all()
1653 )
1655 functions = [
1656 {
1657 "function": stat.calling_function,
1658 "count": stat.count,
1659 "avg_tokens": round(stat.avg_tokens or 0),
1660 }
1661 for stat in function_stats
1662 ]
1664 return {
1665 "recent_enhanced_data": recent_enhanced,
1666 "performance_stats": perf_stats,
1667 "mode_breakdown": modes,
1668 "search_engine_stats": search_engines,
1669 "phase_breakdown": phases,
1670 "time_series_data": time_series,
1671 "call_stack_analysis": {
1672 "by_file": files,
1673 "by_function": functions,
1674 },
1675 }
1676 except Exception:
1677 logger.exception("Error in get_enhanced_metrics")
1678 # Return simplified response without non-existent columns
1679 return {
1680 "recent_enhanced_data": [],
1681 "performance_stats": {
1682 "avg_response_time": 0,
1683 "min_response_time": 0,
1684 "max_response_time": 0,
1685 "success_rate": 0,
1686 "error_rate": 0,
1687 "total_enhanced_calls": 0,
1688 },
1689 "mode_breakdown": [],
1690 "search_engine_stats": [],
1691 "phase_breakdown": [],
1692 "time_series_data": [],
1693 "call_stack_analysis": {
1694 "by_file": [],
1695 "by_function": [],
1696 },
1697 }
1699 def get_research_timeline_metrics(self, research_id: str) -> Dict[str, Any]:
1700 """Get timeline metrics for a specific research.
1702 Args:
1703 research_id: The ID of the research
1705 Returns:
1706 Dictionary containing timeline metrics for the research
1707 """
1708 from flask import session as flask_session
1710 from ..database.session_context import get_user_db_session
1712 username = flask_session.get("username")
1713 if not username:
1714 return {
1715 "research_id": research_id,
1716 "research_details": {},
1717 "timeline": [],
1718 "summary": {
1719 "total_calls": 0,
1720 "total_tokens": 0,
1721 "total_prompt_tokens": 0,
1722 "total_completion_tokens": 0,
1723 "avg_response_time": 0,
1724 "success_rate": 0,
1725 },
1726 "phase_stats": {},
1727 }
1729 with get_user_db_session(username) as session:
1730 # Get all token usage for this research ordered by time including call stack
1731 timeline_data = session.execute(
1732 text(
1733 """
1734 SELECT
1735 timestamp,
1736 total_tokens,
1737 prompt_tokens,
1738 completion_tokens,
1739 response_time_ms,
1740 success_status,
1741 error_type,
1742 research_phase,
1743 search_iteration,
1744 search_engine_selected,
1745 model_name,
1746 calling_file,
1747 calling_function,
1748 call_stack
1749 FROM token_usage
1750 WHERE research_id = :research_id
1751 ORDER BY timestamp ASC
1752 """
1753 ),
1754 {"research_id": research_id},
1755 ).fetchall()
1757 # Format timeline data with cumulative tokens
1758 timeline = []
1759 cumulative_tokens = 0
1760 cumulative_prompt_tokens = 0
1761 cumulative_completion_tokens = 0
1763 for row in timeline_data:
1764 cumulative_tokens += row[1] or 0
1765 cumulative_prompt_tokens += row[2] or 0
1766 cumulative_completion_tokens += row[3] or 0
1768 timeline.append(
1769 {
1770 "timestamp": str(row[0]) if row[0] else None,
1771 "tokens": row[1] or 0,
1772 "prompt_tokens": row[2] or 0,
1773 "completion_tokens": row[3] or 0,
1774 "cumulative_tokens": cumulative_tokens,
1775 "cumulative_prompt_tokens": cumulative_prompt_tokens,
1776 "cumulative_completion_tokens": cumulative_completion_tokens,
1777 "response_time_ms": row[4],
1778 "success_status": row[5],
1779 "error_type": row[6],
1780 "research_phase": row[7],
1781 "search_iteration": row[8],
1782 "search_engine_selected": row[9],
1783 "model_name": row[10],
1784 "calling_file": row[11],
1785 "calling_function": row[12],
1786 "call_stack": row[13],
1787 }
1788 )
1790 # Get research basic info
1791 research_info = session.execute(
1792 text(
1793 """
1794 SELECT query, mode, status, created_at, completed_at
1795 FROM research_history
1796 WHERE id = :research_id
1797 """
1798 ),
1799 {"research_id": research_id},
1800 ).fetchone()
1802 research_details = {}
1803 if research_info:
1804 research_details = {
1805 "query": research_info[0],
1806 "mode": research_info[1],
1807 "status": research_info[2],
1808 "created_at": str(research_info[3])
1809 if research_info[3]
1810 else None,
1811 "completed_at": str(research_info[4])
1812 if research_info[4]
1813 else None,
1814 }
1816 # Calculate summary stats
1817 total_calls = len(timeline_data)
1818 total_tokens = cumulative_tokens
1819 avg_response_time = sum(row[4] or 0 for row in timeline_data) / max(
1820 total_calls, 1
1821 )
1822 success_rate = (
1823 sum(1 for row in timeline_data if row[5] == "success")
1824 / max(total_calls, 1)
1825 * 100
1826 )
1828 # Phase breakdown for this research
1829 phase_stats = {}
1830 for row in timeline_data:
1831 phase = row[7] or "unknown"
1832 if phase not in phase_stats:
1833 phase_stats[phase] = {
1834 "count": 0,
1835 "tokens": 0,
1836 "avg_response_time": 0,
1837 }
1838 phase_stats[phase]["count"] += 1
1839 phase_stats[phase]["tokens"] += row[1] or 0
1840 if row[4]: 1840 ↛ 1830line 1840 didn't jump to line 1830 because the condition on line 1840 was always true
1841 phase_stats[phase]["avg_response_time"] += row[4]
1843 # Calculate averages for phases
1844 for phase in phase_stats:
1845 if phase_stats[phase]["count"] > 0: 1845 ↛ 1844line 1845 didn't jump to line 1844 because the condition on line 1845 was always true
1846 phase_stats[phase]["avg_response_time"] = round(
1847 phase_stats[phase]["avg_response_time"]
1848 / phase_stats[phase]["count"]
1849 )
1851 return {
1852 "research_id": research_id,
1853 "research_details": research_details,
1854 "timeline": timeline,
1855 "summary": {
1856 "total_calls": total_calls,
1857 "total_tokens": total_tokens,
1858 "total_prompt_tokens": cumulative_prompt_tokens,
1859 "total_completion_tokens": cumulative_completion_tokens,
1860 "avg_response_time": round(avg_response_time),
1861 "success_rate": round(success_rate, 1),
1862 },
1863 "phase_stats": phase_stats,
1864 }