Coverage for src / local_deep_research / benchmarks / graders.py: 35%
193 statements
« prev ^ index » next coverage.py v7.12.0, created at 2026-01-11 00:51 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2026-01-11 00:51 +0000
1"""
2Evaluation and grading functionality.
4This module provides tools for evaluating model outputs against reference answers.
5"""
7import json
8from loguru import logger
9from pathlib import Path
10import re
11from typing import Any, Callable, Dict, List, Optional
13from langchain_core.messages.human import HumanMessage
15from ..config.llm_config import get_llm
16from .templates import BROWSECOMP_GRADER_TEMPLATE, SIMPLEQA_GRADER_TEMPLATE
19# Default evaluation configuration using Claude 3.7 Sonnet via OpenRouter
20DEFAULT_EVALUATION_CONFIG = {
21 "model_name": "anthropic/claude-3.7-sonnet", # Correct model ID for OpenRouter
22 "provider": "openai_endpoint", # Use OpenRouter
23 "openai_endpoint_url": "https://openrouter.ai/api/v1", # OpenRouter URL
24 "temperature": 0, # Zero temp for consistent evaluation
25 # Note: max_tokens removed as it's not supported by LDR's get_llm()
26}
29def get_evaluation_llm(
30 custom_config: Optional[Dict[str, Any]] = None,
31 settings_snapshot: Optional[Dict[str, Any]] = None,
32):
33 """
34 Get an LLM for evaluation purposes using Claude 3.7 Sonnet via OpenRouter
35 by default, which can be overridden with custom settings.
37 Args:
38 custom_config: Optional custom configuration that overrides defaults
39 settings_snapshot: Optional settings snapshot for thread-safe access
41 Returns:
42 An LLM instance for evaluation
43 """
44 # Start with default config (Claude 3.7 Sonnet via OpenRouter)
45 config = DEFAULT_EVALUATION_CONFIG.copy()
47 # Override with any custom settings
48 if custom_config:
49 config.update(custom_config)
51 logger.info(
52 f"Getting evaluation LLM with provider={config['provider']}, model={config['model_name']}"
53 )
55 # Remove any parameters that LDR's get_llm doesn't support
56 # This ensures compatibility with LDR's implementation
57 ldr_supported_params = {
58 "model_name",
59 "temperature",
60 "provider",
61 "openai_endpoint_url",
62 "api_key",
63 }
65 filtered_config = {
66 k: v for k, v in config.items() if k in ldr_supported_params
67 }
69 # Check if we're using openai_endpoint but don't have an API key configured
70 if filtered_config.get("provider") == "openai_endpoint": 70 ↛ 102line 70 didn't jump to line 102 because the condition on line 70 was always true
71 # Try to get API key from settings snapshot or environment
72 api_key = None
74 if settings_snapshot:
75 # Get from settings snapshot for thread safety
76 api_key_setting = settings_snapshot.get(
77 "llm.openai_endpoint.api_key"
78 )
79 if api_key_setting: 79 ↛ 92line 79 didn't jump to line 92 because the condition on line 79 was always true
80 api_key = (
81 api_key_setting.get("value")
82 if isinstance(api_key_setting, dict)
83 else api_key_setting
84 )
85 else:
86 # No settings snapshot available
87 logger.warning(
88 "No settings snapshot provided for benchmark grader. "
89 "API key must be provided via settings_snapshot for thread safety."
90 )
92 if not api_key:
93 logger.warning(
94 "Using openai_endpoint provider but no API key found. "
95 "Set the llm.openai_endpoint.api_key setting in the database or "
96 "LDR_LLM_OPENAI_ENDPOINT_API_KEY environment variable."
97 )
98 # Try to fall back to LDR's config if API key not explicitly provided
99 # The get_llm function will handle this case
101 # Get the LLM using LDR's existing function
102 return get_llm(**filtered_config)
105def extract_answer_from_response(
106 response: str, dataset_type: str = "simpleqa"
107) -> Dict[str, str]:
108 """
109 Extract structured information from LDR's response.
111 Args:
112 response: Response from LDR
113 dataset_type: Type of dataset
115 Returns:
116 Dictionary with extracted answer and confidence
117 """
118 # Clean up citations
119 response = re.sub(r"\[\d+\]", "", response)
121 # Extract differently based on dataset type
122 if dataset_type.lower() == "browsecomp":
123 # Extract the final answer from structured response
124 answer_match = re.search(r"Exact Answer:\s*(.*?)(?:\n|$)", response)
125 exact_answer = answer_match.group(1).strip() if answer_match else "None"
127 # Extract confidence
128 confidence_match = re.search(r"Confidence:\s*(\d+)%", response)
129 confidence = confidence_match.group(1) if confidence_match else "100"
131 return {"extracted_answer": exact_answer, "confidence": confidence}
133 # For SimpleQA, return the whole response as the answer
134 return {
135 "extracted_answer": response,
136 "confidence": "100", # SimpleQA doesn't have confidence scores
137 }
140def grade_single_result(
141 result_data: Dict[str, Any],
142 dataset_type: str = "simpleqa",
143 evaluation_config: Optional[Dict[str, Any]] = None,
144 settings_snapshot: Optional[Dict[str, Any]] = None,
145) -> Dict[str, Any]:
146 """
147 Grade a single benchmark result using LLM.
149 Args:
150 result_data: Dictionary containing result data with keys: id, problem, correct_answer, response, extracted_answer
151 dataset_type: Type of dataset
152 evaluation_config: Optional custom config for evaluation LLM
153 settings_snapshot: Optional settings snapshot for thread-safe access
155 Returns:
156 Dictionary with grading results
157 """
158 # Get evaluation LLM
159 evaluation_llm = get_evaluation_llm(evaluation_config, settings_snapshot)
161 # Select appropriate template
162 template = (
163 BROWSECOMP_GRADER_TEMPLATE
164 if dataset_type.lower() == "browsecomp"
165 else SIMPLEQA_GRADER_TEMPLATE
166 )
168 question = result_data.get("problem", "")
169 correct_answer = result_data.get("correct_answer", "")
170 response = result_data.get("response", "")
172 logger.info(f"Grading single result: {question[:50]}...")
174 # Format grading prompt
175 grading_prompt = template.format(
176 question=question, correct_answer=correct_answer, response=response
177 )
179 try:
180 import time
182 eval_llm_start = time.time()
183 logger.info(
184 f"Starting grading LLM call (prompt length: {len(grading_prompt)} chars)..."
185 )
187 # Grade using LLM
188 if hasattr(evaluation_llm, "invoke") and callable( 188 ↛ 203line 188 didn't jump to line 203 because the condition on line 188 was always true
189 evaluation_llm.invoke
190 ):
191 if hasattr(evaluation_llm, "chat_messages"): 191 ↛ 198line 191 didn't jump to line 198 because the condition on line 191 was always true
192 # Handle ChatOpenAI and similar models that use messages
193 grading_response = evaluation_llm.invoke(
194 [HumanMessage(content=grading_prompt)]
195 ).content
196 else:
197 # Handle other LLM types
198 grading_response = evaluation_llm.invoke(grading_prompt)
199 if hasattr(grading_response, "content"):
200 grading_response = grading_response.content
201 else:
202 # Fallback for other LLM interfaces
203 grading_response = str(evaluation_llm(grading_prompt))
205 eval_llm_elapsed = time.time() - eval_llm_start
206 logger.info(f"Grading LLM call completed in {eval_llm_elapsed:.2f}s")
208 # Extract grading information using regex
209 if dataset_type.lower() == "browsecomp":
210 # BrowseComp-specific extraction
211 extracted_answer_match = re.search(
212 r"extracted_final_answer:\s*(.*?)(?:\n|$)", grading_response
213 )
214 extracted_answer = (
215 extracted_answer_match.group(1).strip()
216 if extracted_answer_match
217 else "None"
218 )
220 reasoning_match = re.search(
221 r"reasoning:\s*(.*?)(?:\n\n|\ncorrect:|\Z)",
222 grading_response,
223 re.DOTALL,
224 )
225 reasoning = (
226 reasoning_match.group(1).strip() if reasoning_match else ""
227 )
229 correct_match = re.search(
230 r"correct:\s*(yes|no)", grading_response, re.IGNORECASE
231 )
232 is_correct = (
233 (correct_match.group(1).lower() == "yes")
234 if correct_match
235 else False
236 )
238 confidence_match = re.search(
239 r"confidence:\s*(\d+)", grading_response
240 )
241 confidence = (
242 confidence_match.group(1) if confidence_match else "100"
243 )
244 else:
245 # SimpleQA extraction
246 extracted_answer_match = re.search(
247 r"Extracted Answer:\s*(.*?)(?:\n|$)", grading_response
248 )
249 extracted_answer = (
250 extracted_answer_match.group(1).strip()
251 if extracted_answer_match
252 else "None"
253 )
255 reasoning_match = re.search(
256 r"Reasoning:\s*(.*?)(?:\nCorrect:|\Z)",
257 grading_response,
258 re.DOTALL,
259 )
260 reasoning = (
261 reasoning_match.group(1).strip() if reasoning_match else ""
262 )
264 correct_match = re.search(
265 r"Correct:\s*(yes|no)", grading_response, re.IGNORECASE
266 )
267 is_correct = (
268 (correct_match.group(1).lower() == "yes")
269 if correct_match
270 else False
271 )
273 confidence = "100" # SimpleQA doesn't have confidence
275 # Format graded result
276 graded_result = {
277 "extracted_by_grader": extracted_answer,
278 "reasoning": reasoning,
279 "is_correct": is_correct,
280 "graded_confidence": confidence,
281 "grader_response": grading_response,
282 }
284 return graded_result
286 except Exception as e:
287 logger.exception(f"Error grading single result: {e!s}")
288 return {
289 "grading_error": str(e),
290 "is_correct": False,
291 "graded_confidence": "0",
292 "grader_response": f"Grading failed: {e!s}",
293 }
296def grade_results(
297 results_file: str,
298 output_file: str,
299 dataset_type: str = "simpleqa",
300 evaluation_config: Optional[Dict[str, Any]] = None,
301 progress_callback: Optional[Callable[[int, int, Dict], None]] = None,
302) -> List[Dict[str, Any]]:
303 """
304 Grade benchmark results using LLM.
306 Args:
307 results_file: Path to results file
308 output_file: Path to save graded results
309 dataset_type: Type of dataset
310 evaluation_config: Optional custom config for evaluation LLM
311 progress_callback: Optional callback for progress updates
313 Returns:
314 List of graded results
315 """
316 # Get evaluation LLM
317 evaluation_llm = get_evaluation_llm(evaluation_config)
319 # Select appropriate template
320 template = (
321 BROWSECOMP_GRADER_TEMPLATE
322 if dataset_type.lower() == "browsecomp"
323 else SIMPLEQA_GRADER_TEMPLATE
324 )
326 # Load results
327 results = []
328 with open(results_file, "r") as f:
329 for line in f:
330 if line.strip():
331 results.append(json.loads(line))
333 # Remove output file if it exists
334 output_path = Path(output_file)
335 if output_path.exists():
336 output_path.unlink()
338 graded_results = []
339 correct_count = 0
341 # Process each result
342 for idx, result in enumerate(results):
343 question = result.get("problem", "")
344 correct_answer = result.get("correct_answer", "")
345 response = result.get("response", "")
347 # Call progress callback if provided
348 if progress_callback:
349 progress_callback(
350 idx,
351 len(results),
352 {"status": "grading", "index": idx, "total": len(results)},
353 )
355 logger.info(f"Grading {idx + 1}/{len(results)}: {question[:50]}...")
357 # Format grading prompt
358 grading_prompt = template.format(
359 question=question, correct_answer=correct_answer, response=response
360 )
362 try:
363 # Grade using LLM
364 if hasattr(evaluation_llm, "invoke") and callable(
365 evaluation_llm.invoke
366 ):
367 if hasattr(evaluation_llm, "chat_messages"):
368 # Handle ChatOpenAI and similar models that use messages
369 grading_response = evaluation_llm.invoke(
370 [HumanMessage(content=grading_prompt)]
371 ).content
372 else:
373 # Handle other LLM types
374 grading_response = evaluation_llm.invoke(grading_prompt)
375 if hasattr(grading_response, "content"):
376 grading_response = grading_response.content
377 else:
378 # Fallback for other LLM interfaces
379 grading_response = str(evaluation_llm(grading_prompt))
381 # Extract grading information using regex
382 if dataset_type.lower() == "browsecomp":
383 # BrowseComp-specific extraction
384 extracted_answer_match = re.search(
385 r"extracted_final_answer:\s*(.*?)(?:\n|$)", grading_response
386 )
387 extracted_answer = (
388 extracted_answer_match.group(1).strip()
389 if extracted_answer_match
390 else "None"
391 )
393 reasoning_match = re.search(
394 r"reasoning:\s*(.*?)(?:\n\n|\ncorrect:|\Z)",
395 grading_response,
396 re.DOTALL,
397 )
398 reasoning = (
399 reasoning_match.group(1).strip() if reasoning_match else ""
400 )
402 correct_match = re.search(
403 r"correct:\s*(yes|no)", grading_response, re.IGNORECASE
404 )
405 is_correct = (
406 (correct_match.group(1).lower() == "yes")
407 if correct_match
408 else False
409 )
411 confidence_match = re.search(
412 r"confidence:\s*(\d+)", grading_response
413 )
414 confidence = (
415 confidence_match.group(1) if confidence_match else "100"
416 )
417 else:
418 # SimpleQA extraction
419 extracted_answer_match = re.search(
420 r"Extracted Answer:\s*(.*?)(?:\n|$)", grading_response
421 )
422 extracted_answer = (
423 extracted_answer_match.group(1).strip()
424 if extracted_answer_match
425 else "None"
426 )
428 reasoning_match = re.search(
429 r"Reasoning:\s*(.*?)(?:\nCorrect:|\Z)",
430 grading_response,
431 re.DOTALL,
432 )
433 reasoning = (
434 reasoning_match.group(1).strip() if reasoning_match else ""
435 )
437 correct_match = re.search(
438 r"Correct:\s*(yes|no)", grading_response, re.IGNORECASE
439 )
440 is_correct = (
441 (correct_match.group(1).lower() == "yes")
442 if correct_match
443 else False
444 )
446 confidence = "100" # SimpleQA doesn't have confidence
448 if is_correct:
449 correct_count += 1
451 # Format graded result
452 graded_result = result.copy()
453 graded_result.update(
454 {
455 "extracted_by_grader": extracted_answer,
456 "reasoning": reasoning,
457 "is_correct": is_correct,
458 "graded_confidence": confidence,
459 "grader_response": grading_response,
460 }
461 )
463 graded_results.append(graded_result)
465 # Write to output file
466 with open(output_file, "a") as f:
467 f.write(json.dumps(graded_result) + "\n")
469 # Call progress callback if provided
470 if progress_callback:
471 progress_callback(
472 idx,
473 len(results),
474 {
475 "status": "graded",
476 "is_correct": is_correct,
477 "result": graded_result,
478 },
479 )
481 except Exception as e:
482 logger.exception(f"Error grading result {idx + 1}: {e!s}")
484 # Handle error
485 error_result = result.copy()
486 error_result["grading_error"] = str(e)
488 with open(output_file, "a") as f:
489 f.write(json.dumps(error_result) + "\n")
491 graded_results.append(error_result)
493 # Call progress callback if provided
494 if progress_callback:
495 progress_callback(
496 idx,
497 len(results),
498 {
499 "status": "error",
500 "error": str(e),
501 "result": error_result,
502 },
503 )
505 accuracy = correct_count / len(results) if results else 0
506 logger.info(f"Grading complete. Accuracy: {accuracy:.3f}")
507 logger.info(f"Correct: {correct_count}/{len(results)}")
509 return graded_results
512def human_evaluation(
513 results_file: str, output_file: str, interactive: bool = True
514) -> List[Dict[str, Any]]:
515 """
516 Allow for human evaluation of results.
518 Args:
519 results_file: Path to results file
520 output_file: Path to save human-graded results
521 interactive: Whether to run in interactive console mode
523 Returns:
524 List of human-graded results
525 """
526 # Load results
527 results = []
528 with open(results_file, "r") as f:
529 for line in f:
530 if line.strip():
531 results.append(json.loads(line))
533 # Remove output file if it exists
534 output_path = Path(output_file)
535 if output_path.exists():
536 output_path.unlink()
538 human_graded_results = []
539 correct_count = 0
541 if interactive:
542 logger.info(f"Human evaluation: {len(results)} examples to grade")
543 print(f"Human evaluation: {len(results)} examples to grade")
544 print(
545 "For each example, you'll see the question, correct answer, and model's response."
546 )
547 print("You'll be asked to judge if the model's answer is correct.")
549 for idx, result in enumerate(results):
550 question = result.get("problem", "")
551 correct_answer = result.get("correct_answer", "")
552 response = result.get("response", "")
553 extracted_answer = result.get("extracted_answer", "")
555 if interactive:
556 print(f"\n\n===== Example {idx + 1}/{len(results)} =====")
557 print(f"Question: {question}")
558 print(f"\nCorrect Answer: {correct_answer}")
559 print(f"\nModel Response: {response}")
560 print(f"\nExtracted Answer: {extracted_answer}")
562 # Get human judgment
563 while True:
564 judgment = (
565 input("\nIs the model's answer correct? (y/n): ")
566 .strip()
567 .lower()
568 )
569 if judgment in ["y", "n"]:
570 break
571 print("Please enter 'y' or 'n'")
573 is_correct = judgment == "y"
575 # Get reasoning
576 reasoning = input(
577 "Please provide reasoning for your judgment: "
578 ).strip()
579 else:
580 # Non-interactive mode - placeholder for API/UI implementation
581 # In a real implementation, this would be filled by UI actions
582 is_correct = False
583 reasoning = "Non-interactive evaluation"
585 if is_correct:
586 correct_count += 1
588 # Update result with human judgment
589 human_result = result.copy()
590 human_result.update(
591 {
592 "is_correct": is_correct,
593 "reasoning": reasoning,
594 "human_evaluation": True,
595 }
596 )
598 human_graded_results.append(human_result)
600 # Write to output file
601 with open(output_file, "a") as f:
602 f.write(json.dumps(human_result) + "\n")
604 accuracy = correct_count / len(results) if results else 0
605 logger.info(f"Human evaluation complete. Accuracy: {accuracy:.3f}")
606 if interactive:
607 print(f"\nHuman evaluation complete. Accuracy: {accuracy:.3f}")
608 print(f"Correct: {correct_count}/{len(results)}")
610 return human_graded_results