Coverage for src / local_deep_research / metrics / pricing / cost_calculator.py: 95%
79 statements
« prev ^ index » next coverage.py v7.12.0, created at 2026-01-11 00:51 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2026-01-11 00:51 +0000
1"""
2Cost Calculator
4Calculates LLM usage costs based on token usage and pricing data.
5Integrates with pricing fetcher and cache systems.
6"""
8from typing import Any, Dict, List, Optional
10from loguru import logger
12from .pricing_cache import PricingCache
13from .pricing_fetcher import PricingFetcher
16class CostCalculator:
17 """Calculates LLM usage costs."""
19 def __init__(self, cache_dir: Optional[str] = None):
20 self.cache = PricingCache(cache_dir)
21 self.pricing_fetcher = None
23 async def __aenter__(self):
24 self.pricing_fetcher = PricingFetcher()
25 await self.pricing_fetcher.__aenter__()
26 return self
28 async def __aexit__(self, exc_type, exc_val, exc_tb):
29 if self.pricing_fetcher: 29 ↛ exitline 29 didn't return from function '__aexit__' because the condition on line 29 was always true
30 await self.pricing_fetcher.__aexit__(exc_type, exc_val, exc_tb)
32 async def get_model_pricing(
33 self, model_name: str, provider: str = None
34 ) -> Dict[str, float]:
35 """Get pricing for a model and provider (cached or fetched)."""
36 # Create cache key that includes provider
37 cache_key = f"{provider}:{model_name}" if provider else model_name
39 # Try cache first
40 cached_pricing = self.cache.get(f"model:{cache_key}")
41 if cached_pricing:
42 return cached_pricing
44 # Fetch from API
45 if self.pricing_fetcher: 45 ↛ 54line 45 didn't jump to line 54 because the condition on line 45 was always true
46 pricing = await self.pricing_fetcher.get_model_pricing(
47 model_name, provider
48 )
49 if pricing:
50 self.cache.set(f"model:{cache_key}", pricing)
51 return pricing
53 # No pricing found
54 logger.warning(
55 f"No pricing found for {model_name} (provider: {provider})"
56 )
57 return None
59 async def calculate_cost(
60 self,
61 model_name: str,
62 prompt_tokens: int,
63 completion_tokens: int,
64 provider: str = None,
65 ) -> Dict[str, float]:
66 """
67 Calculate cost for a single LLM call.
69 Returns:
70 Dict with prompt_cost, completion_cost, total_cost
71 """
72 pricing = await self.get_model_pricing(model_name, provider)
74 # If no pricing found, return zero cost
75 if pricing is None:
76 return {
77 "prompt_cost": 0.0,
78 "completion_cost": 0.0,
79 "total_cost": 0.0,
80 "pricing_used": None,
81 "error": "No pricing data available for this model",
82 }
84 # Convert tokens to thousands for pricing calculation
85 prompt_cost = (prompt_tokens / 1000) * pricing["prompt"]
86 completion_cost = (completion_tokens / 1000) * pricing["completion"]
87 total_cost = prompt_cost + completion_cost
89 return {
90 "prompt_cost": round(prompt_cost, 6),
91 "completion_cost": round(completion_cost, 6),
92 "total_cost": round(total_cost, 6),
93 "pricing_used": pricing,
94 }
96 async def calculate_batch_costs(
97 self, usage_records: List[Dict[str, Any]]
98 ) -> List[Dict[str, Any]]:
99 """
100 Calculate costs for multiple usage records.
102 Expected record format:
103 {
104 "model_name": str,
105 "provider": str (optional),
106 "prompt_tokens": int,
107 "completion_tokens": int,
108 "research_id": int (optional),
109 "timestamp": datetime (optional)
110 }
111 """
112 results = []
114 for record in usage_records:
115 try:
116 cost_data = await self.calculate_cost(
117 record["model_name"],
118 record["prompt_tokens"],
119 record["completion_tokens"],
120 record.get("provider"),
121 )
123 result = {**record, **cost_data}
124 results.append(result)
126 except Exception as e:
127 logger.exception(
128 f"Failed to calculate cost for record {record}: {e}"
129 )
130 # Add record with zero cost on error
131 results.append(
132 {
133 **record,
134 "prompt_cost": 0.0,
135 "completion_cost": 0.0,
136 "total_cost": 0.0,
137 "error": "Cost calculation failed",
138 }
139 )
141 return results
143 def calculate_cost_sync(
144 self, model_name: str, prompt_tokens: int, completion_tokens: int
145 ) -> Dict[str, float]:
146 """
147 Synchronous cost calculation using cached pricing only.
148 Fallback for when async is not available.
149 """
150 # Use cached pricing only
151 pricing = self.cache.get_model_pricing(model_name)
152 if not pricing:
153 # Use static fallback with exact matching only
154 fetcher = PricingFetcher()
155 # Try exact match
156 pricing = fetcher.static_pricing.get(model_name)
157 if not pricing:
158 # Try exact match without provider prefix
159 if "/" in model_name:
160 model_only = model_name.split("/")[-1]
161 pricing = fetcher.static_pricing.get(model_only)
163 # If no pricing found, return zero cost
164 if not pricing:
165 return {
166 "prompt_cost": 0.0,
167 "completion_cost": 0.0,
168 "total_cost": 0.0,
169 "pricing_used": None,
170 "error": "No pricing data available for this model",
171 }
173 prompt_cost = (prompt_tokens / 1000) * pricing["prompt"]
174 completion_cost = (completion_tokens / 1000) * pricing["completion"]
175 total_cost = prompt_cost + completion_cost
177 return {
178 "prompt_cost": round(prompt_cost, 6),
179 "completion_cost": round(completion_cost, 6),
180 "total_cost": round(total_cost, 6),
181 "pricing_used": pricing,
182 }
184 async def get_research_cost_summary(
185 self, usage_records: List[Dict[str, Any]]
186 ) -> Dict[str, Any]:
187 """
188 Get cost summary for research session(s).
189 """
190 costs = await self.calculate_batch_costs(usage_records)
192 total_cost = sum(c["total_cost"] for c in costs)
193 total_prompt_cost = sum(c["prompt_cost"] for c in costs)
194 total_completion_cost = sum(c["completion_cost"] for c in costs)
196 total_prompt_tokens = sum(r["prompt_tokens"] for r in usage_records)
197 total_completion_tokens = sum(
198 r["completion_tokens"] for r in usage_records
199 )
200 total_tokens = total_prompt_tokens + total_completion_tokens
202 # Model breakdown
203 model_costs = {}
204 for cost in costs:
205 model = cost["model_name"]
206 if model not in model_costs:
207 model_costs[model] = {
208 "total_cost": 0.0,
209 "prompt_tokens": 0,
210 "completion_tokens": 0,
211 "calls": 0,
212 }
214 model_costs[model]["total_cost"] += cost["total_cost"]
215 model_costs[model]["prompt_tokens"] += cost["prompt_tokens"]
216 model_costs[model]["completion_tokens"] += cost["completion_tokens"]
217 model_costs[model]["calls"] += 1
219 return {
220 "total_cost": round(total_cost, 6),
221 "prompt_cost": round(total_prompt_cost, 6),
222 "completion_cost": round(total_completion_cost, 6),
223 "total_tokens": total_tokens,
224 "prompt_tokens": total_prompt_tokens,
225 "completion_tokens": total_completion_tokens,
226 "total_calls": len(usage_records),
227 "model_breakdown": model_costs,
228 "avg_cost_per_call": (
229 round(total_cost / len(usage_records), 6)
230 if usage_records
231 else 0.0
232 ),
233 "cost_per_token": (
234 round(total_cost / total_tokens, 8) if total_tokens > 0 else 0.0
235 ),
236 }