Coverage for src / local_deep_research / metrics / pricing / cost_calculator.py: 95%

79 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2026-01-11 00:51 +0000

1""" 

2Cost Calculator 

3 

4Calculates LLM usage costs based on token usage and pricing data. 

5Integrates with pricing fetcher and cache systems. 

6""" 

7 

8from typing import Any, Dict, List, Optional 

9 

10from loguru import logger 

11 

12from .pricing_cache import PricingCache 

13from .pricing_fetcher import PricingFetcher 

14 

15 

16class CostCalculator: 

17 """Calculates LLM usage costs.""" 

18 

19 def __init__(self, cache_dir: Optional[str] = None): 

20 self.cache = PricingCache(cache_dir) 

21 self.pricing_fetcher = None 

22 

23 async def __aenter__(self): 

24 self.pricing_fetcher = PricingFetcher() 

25 await self.pricing_fetcher.__aenter__() 

26 return self 

27 

28 async def __aexit__(self, exc_type, exc_val, exc_tb): 

29 if self.pricing_fetcher: 29 ↛ exitline 29 didn't return from function '__aexit__' because the condition on line 29 was always true

30 await self.pricing_fetcher.__aexit__(exc_type, exc_val, exc_tb) 

31 

32 async def get_model_pricing( 

33 self, model_name: str, provider: str = None 

34 ) -> Dict[str, float]: 

35 """Get pricing for a model and provider (cached or fetched).""" 

36 # Create cache key that includes provider 

37 cache_key = f"{provider}:{model_name}" if provider else model_name 

38 

39 # Try cache first 

40 cached_pricing = self.cache.get(f"model:{cache_key}") 

41 if cached_pricing: 

42 return cached_pricing 

43 

44 # Fetch from API 

45 if self.pricing_fetcher: 45 ↛ 54line 45 didn't jump to line 54 because the condition on line 45 was always true

46 pricing = await self.pricing_fetcher.get_model_pricing( 

47 model_name, provider 

48 ) 

49 if pricing: 

50 self.cache.set(f"model:{cache_key}", pricing) 

51 return pricing 

52 

53 # No pricing found 

54 logger.warning( 

55 f"No pricing found for {model_name} (provider: {provider})" 

56 ) 

57 return None 

58 

59 async def calculate_cost( 

60 self, 

61 model_name: str, 

62 prompt_tokens: int, 

63 completion_tokens: int, 

64 provider: str = None, 

65 ) -> Dict[str, float]: 

66 """ 

67 Calculate cost for a single LLM call. 

68 

69 Returns: 

70 Dict with prompt_cost, completion_cost, total_cost 

71 """ 

72 pricing = await self.get_model_pricing(model_name, provider) 

73 

74 # If no pricing found, return zero cost 

75 if pricing is None: 

76 return { 

77 "prompt_cost": 0.0, 

78 "completion_cost": 0.0, 

79 "total_cost": 0.0, 

80 "pricing_used": None, 

81 "error": "No pricing data available for this model", 

82 } 

83 

84 # Convert tokens to thousands for pricing calculation 

85 prompt_cost = (prompt_tokens / 1000) * pricing["prompt"] 

86 completion_cost = (completion_tokens / 1000) * pricing["completion"] 

87 total_cost = prompt_cost + completion_cost 

88 

89 return { 

90 "prompt_cost": round(prompt_cost, 6), 

91 "completion_cost": round(completion_cost, 6), 

92 "total_cost": round(total_cost, 6), 

93 "pricing_used": pricing, 

94 } 

95 

96 async def calculate_batch_costs( 

97 self, usage_records: List[Dict[str, Any]] 

98 ) -> List[Dict[str, Any]]: 

99 """ 

100 Calculate costs for multiple usage records. 

101 

102 Expected record format: 

103 { 

104 "model_name": str, 

105 "provider": str (optional), 

106 "prompt_tokens": int, 

107 "completion_tokens": int, 

108 "research_id": int (optional), 

109 "timestamp": datetime (optional) 

110 } 

111 """ 

112 results = [] 

113 

114 for record in usage_records: 

115 try: 

116 cost_data = await self.calculate_cost( 

117 record["model_name"], 

118 record["prompt_tokens"], 

119 record["completion_tokens"], 

120 record.get("provider"), 

121 ) 

122 

123 result = {**record, **cost_data} 

124 results.append(result) 

125 

126 except Exception as e: 

127 logger.exception( 

128 f"Failed to calculate cost for record {record}: {e}" 

129 ) 

130 # Add record with zero cost on error 

131 results.append( 

132 { 

133 **record, 

134 "prompt_cost": 0.0, 

135 "completion_cost": 0.0, 

136 "total_cost": 0.0, 

137 "error": "Cost calculation failed", 

138 } 

139 ) 

140 

141 return results 

142 

143 def calculate_cost_sync( 

144 self, model_name: str, prompt_tokens: int, completion_tokens: int 

145 ) -> Dict[str, float]: 

146 """ 

147 Synchronous cost calculation using cached pricing only. 

148 Fallback for when async is not available. 

149 """ 

150 # Use cached pricing only 

151 pricing = self.cache.get_model_pricing(model_name) 

152 if not pricing: 

153 # Use static fallback with exact matching only 

154 fetcher = PricingFetcher() 

155 # Try exact match 

156 pricing = fetcher.static_pricing.get(model_name) 

157 if not pricing: 

158 # Try exact match without provider prefix 

159 if "/" in model_name: 

160 model_only = model_name.split("/")[-1] 

161 pricing = fetcher.static_pricing.get(model_only) 

162 

163 # If no pricing found, return zero cost 

164 if not pricing: 

165 return { 

166 "prompt_cost": 0.0, 

167 "completion_cost": 0.0, 

168 "total_cost": 0.0, 

169 "pricing_used": None, 

170 "error": "No pricing data available for this model", 

171 } 

172 

173 prompt_cost = (prompt_tokens / 1000) * pricing["prompt"] 

174 completion_cost = (completion_tokens / 1000) * pricing["completion"] 

175 total_cost = prompt_cost + completion_cost 

176 

177 return { 

178 "prompt_cost": round(prompt_cost, 6), 

179 "completion_cost": round(completion_cost, 6), 

180 "total_cost": round(total_cost, 6), 

181 "pricing_used": pricing, 

182 } 

183 

184 async def get_research_cost_summary( 

185 self, usage_records: List[Dict[str, Any]] 

186 ) -> Dict[str, Any]: 

187 """ 

188 Get cost summary for research session(s). 

189 """ 

190 costs = await self.calculate_batch_costs(usage_records) 

191 

192 total_cost = sum(c["total_cost"] for c in costs) 

193 total_prompt_cost = sum(c["prompt_cost"] for c in costs) 

194 total_completion_cost = sum(c["completion_cost"] for c in costs) 

195 

196 total_prompt_tokens = sum(r["prompt_tokens"] for r in usage_records) 

197 total_completion_tokens = sum( 

198 r["completion_tokens"] for r in usage_records 

199 ) 

200 total_tokens = total_prompt_tokens + total_completion_tokens 

201 

202 # Model breakdown 

203 model_costs = {} 

204 for cost in costs: 

205 model = cost["model_name"] 

206 if model not in model_costs: 

207 model_costs[model] = { 

208 "total_cost": 0.0, 

209 "prompt_tokens": 0, 

210 "completion_tokens": 0, 

211 "calls": 0, 

212 } 

213 

214 model_costs[model]["total_cost"] += cost["total_cost"] 

215 model_costs[model]["prompt_tokens"] += cost["prompt_tokens"] 

216 model_costs[model]["completion_tokens"] += cost["completion_tokens"] 

217 model_costs[model]["calls"] += 1 

218 

219 return { 

220 "total_cost": round(total_cost, 6), 

221 "prompt_cost": round(total_prompt_cost, 6), 

222 "completion_cost": round(total_completion_cost, 6), 

223 "total_tokens": total_tokens, 

224 "prompt_tokens": total_prompt_tokens, 

225 "completion_tokens": total_completion_tokens, 

226 "total_calls": len(usage_records), 

227 "model_breakdown": model_costs, 

228 "avg_cost_per_call": ( 

229 round(total_cost / len(usage_records), 6) 

230 if usage_records 

231 else 0.0 

232 ), 

233 "cost_per_token": ( 

234 round(total_cost / total_tokens, 8) if total_tokens > 0 else 0.0 

235 ), 

236 }