Coverage for src/local_deep_research/benchmarks/graders.py: 89%

200 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-03 23:15 +0000

1""" 

2Evaluation and grading functionality. 

3 

4This module provides tools for evaluating model outputs against reference answers. 

5""" 

6 

7import json 

8from loguru import logger 

9from pathlib import Path 

10import re 

11from typing import Any, Callable, Dict, List, Optional 

12 

13from langchain_core.messages.human import HumanMessage 

14 

15from ..config.llm_config import get_llm 

16from ..llm.providers.base import normalize_provider 

17from .templates import BROWSECOMP_GRADER_TEMPLATE, SIMPLEQA_GRADER_TEMPLATE 

18 

19 

20# Default evaluation configuration using Claude 3.7 Sonnet via OpenRouter 

21DEFAULT_EVALUATION_CONFIG = { 

22 "model_name": "anthropic/claude-3.7-sonnet", # Correct model ID for OpenRouter 

23 "provider": "openai_endpoint", # Use OpenRouter 

24 "openai_endpoint_url": "https://openrouter.ai/api/v1", # OpenRouter URL 

25 "temperature": 0, # Zero temp for consistent evaluation 

26 # Note: max_tokens removed as it's not supported by LDR's get_llm() 

27} 

28 

29 

30def get_evaluation_llm( 

31 custom_config: Optional[Dict[str, Any]] = None, 

32 settings_snapshot: Optional[Dict[str, Any]] = None, 

33): 

34 """ 

35 Get an LLM for evaluation purposes using Claude 3.7 Sonnet via OpenRouter 

36 by default, which can be overridden with custom settings. 

37 

38 Args: 

39 custom_config: Optional custom configuration that overrides defaults 

40 settings_snapshot: Optional settings snapshot for thread-safe access 

41 

42 Returns: 

43 An LLM instance for evaluation 

44 """ 

45 # Start with default config (Claude 3.7 Sonnet via OpenRouter) 

46 config = DEFAULT_EVALUATION_CONFIG.copy() 

47 

48 # Override with any custom settings 

49 if custom_config: 

50 config.update(custom_config) 

51 

52 logger.info( 

53 f"Getting evaluation LLM with provider={config['provider']}, model={config['model_name']}" 

54 ) 

55 

56 # Remove any parameters that LDR's get_llm doesn't support 

57 # This ensures compatibility with LDR's implementation 

58 ldr_supported_params = { 

59 "model_name", 

60 "temperature", 

61 "provider", 

62 "openai_endpoint_url", 

63 "api_key", 

64 } 

65 

66 filtered_config = { 

67 k: v for k, v in config.items() if k in ldr_supported_params 

68 } 

69 

70 # Check if we're using openai_endpoint but don't have an API key configured 

71 if normalize_provider(filtered_config.get("provider")) == "openai_endpoint": 

72 # Try to get API key from settings snapshot or environment 

73 api_key = None 

74 

75 if settings_snapshot: 

76 # Get from settings snapshot for thread safety 

77 api_key_setting = settings_snapshot.get( 

78 "llm.openai_endpoint.api_key" 

79 ) 

80 if api_key_setting: 

81 api_key = ( 

82 api_key_setting.get("value") 

83 if isinstance(api_key_setting, dict) 

84 else api_key_setting 

85 ) 

86 else: 

87 # No settings snapshot available 

88 logger.warning( 

89 "No settings snapshot provided for benchmark grader. " 

90 "API key must be provided via settings_snapshot for thread safety." 

91 ) 

92 

93 if not api_key: 

94 logger.warning( 

95 "Using openai_endpoint provider but no API key found. " 

96 "Set the llm.openai_endpoint.api_key setting in the database or " 

97 "LDR_LLM_OPENAI_ENDPOINT_API_KEY environment variable." 

98 ) 

99 # Try to fall back to LDR's config if API key not explicitly provided 

100 # The get_llm function will handle this case 

101 

102 # Get the LLM using LDR's existing function 

103 return get_llm(**filtered_config) 

104 

105 

106def extract_answer_from_response( 

107 response: str, dataset_type: str = "simpleqa" 

108) -> Dict[str, str]: 

109 """ 

110 Extract structured information from LDR's response. 

111 

112 Args: 

113 response: Response from LDR 

114 dataset_type: Type of dataset 

115 

116 Returns: 

117 Dictionary with extracted answer and confidence 

118 """ 

119 # Clean up citations — strip both ASCII "[N]" and lenticular "【N】" 

120 # so a lenticular citation (some LLMs emit them) doesn't survive into 

121 # the graded answer text and skew the match. 

122 response = re.sub(r"[\[【]\d+[\]】]", "", response) 

123 

124 # Extract differently based on dataset type 

125 if dataset_type.lower() == "browsecomp": 

126 # Extract the final answer from structured response 

127 answer_match = re.search(r"Exact Answer:\s*(.*?)(?:\n|$)", response) 

128 exact_answer = answer_match.group(1).strip() if answer_match else "None" 

129 

130 # Extract confidence 

131 confidence_match = re.search(r"Confidence:\s*(\d+)%", response) 

132 confidence = confidence_match.group(1) if confidence_match else "100" 

133 

134 return {"extracted_answer": exact_answer, "confidence": confidence} 

135 

136 # For SimpleQA, return the whole response as the answer 

137 return { 

138 "extracted_answer": response, 

139 "confidence": "100", # SimpleQA doesn't have confidence scores 

140 } 

141 

142 

143def grade_single_result( 

144 result_data: Dict[str, Any], 

145 dataset_type: str = "simpleqa", 

146 evaluation_config: Optional[Dict[str, Any]] = None, 

147 settings_snapshot: Optional[Dict[str, Any]] = None, 

148) -> Dict[str, Any]: 

149 """ 

150 Grade a single benchmark result using LLM. 

151 

152 Args: 

153 result_data: Dictionary containing result data with keys: id, problem, correct_answer, response, extracted_answer 

154 dataset_type: Type of dataset 

155 evaluation_config: Optional custom config for evaluation LLM 

156 settings_snapshot: Optional settings snapshot for thread-safe access 

157 

158 Returns: 

159 Dictionary with grading results 

160 """ 

161 # Get evaluation LLM 

162 evaluation_llm = get_evaluation_llm(evaluation_config, settings_snapshot) 

163 

164 try: 

165 # Select appropriate template 

166 template = ( 

167 BROWSECOMP_GRADER_TEMPLATE 

168 if dataset_type.lower() == "browsecomp" 

169 else SIMPLEQA_GRADER_TEMPLATE 

170 ) 

171 

172 question = result_data.get("problem", "") 

173 correct_answer = result_data.get("correct_answer", "") 

174 response = result_data.get("response", "") 

175 

176 logger.info(f"Grading single result: {question[:50]}...") 

177 

178 # Format grading prompt 

179 grading_prompt = template.format( 

180 question=question, correct_answer=correct_answer, response=response 

181 ) 

182 

183 import time 

184 

185 eval_llm_start = time.time() 

186 logger.info( 

187 f"Starting grading LLM call (prompt length: {len(grading_prompt)} chars)..." 

188 ) 

189 

190 # Grade using LLM 

191 if hasattr(evaluation_llm, "invoke") and callable( 

192 evaluation_llm.invoke 

193 ): 

194 if hasattr(evaluation_llm, "chat_messages"): 

195 # Handle ChatOpenAI and similar models that use messages 

196 grading_response = evaluation_llm.invoke( 

197 [HumanMessage(content=grading_prompt)] 

198 ).content 

199 else: 

200 # Handle other LLM types 

201 grading_response = evaluation_llm.invoke(grading_prompt) 

202 if hasattr(grading_response, "content"): 

203 grading_response = grading_response.content 

204 else: 

205 # Fallback for other LLM interfaces 

206 grading_response = str(evaluation_llm(grading_prompt)) 

207 

208 eval_llm_elapsed = time.time() - eval_llm_start 

209 logger.info(f"Grading LLM call completed in {eval_llm_elapsed:.2f}s") 

210 

211 # Extract grading information using regex 

212 if dataset_type.lower() == "browsecomp": 

213 # BrowseComp-specific extraction 

214 extracted_answer_match = re.search( 

215 r"extracted_final_answer:\s*(.*?)(?:\n|$)", grading_response 

216 ) 

217 extracted_answer = ( 

218 extracted_answer_match.group(1).strip() 

219 if extracted_answer_match 

220 else "None" 

221 ) 

222 

223 reasoning_match = re.search( 

224 r"reasoning:\s*(.*?)(?:\n\n|\ncorrect:|\Z)", 

225 grading_response, 

226 re.DOTALL, 

227 ) 

228 reasoning = ( 

229 reasoning_match.group(1).strip() if reasoning_match else "" 

230 ) 

231 

232 correct_match = re.search( 

233 r"correct:\s*(yes|no)", grading_response, re.IGNORECASE 

234 ) 

235 is_correct = ( 

236 (correct_match.group(1).lower() == "yes") 

237 if correct_match 

238 else False 

239 ) 

240 

241 confidence_match = re.search( 

242 r"confidence:\s*(\d+)", grading_response 

243 ) 

244 confidence = ( 

245 confidence_match.group(1) if confidence_match else "100" 

246 ) 

247 else: 

248 # SimpleQA extraction 

249 extracted_answer_match = re.search( 

250 r"Extracted Answer:\s*(.*?)(?:\n|$)", grading_response 

251 ) 

252 extracted_answer = ( 

253 extracted_answer_match.group(1).strip() 

254 if extracted_answer_match 

255 else "None" 

256 ) 

257 

258 reasoning_match = re.search( 

259 r"Reasoning:\s*(.*?)(?:\nCorrect:|\Z)", 

260 grading_response, 

261 re.DOTALL, 

262 ) 

263 reasoning = ( 

264 reasoning_match.group(1).strip() if reasoning_match else "" 

265 ) 

266 

267 correct_match = re.search( 

268 r"Correct:\s*(yes|no)", grading_response, re.IGNORECASE 

269 ) 

270 is_correct = ( 

271 (correct_match.group(1).lower() == "yes") 

272 if correct_match 

273 else False 

274 ) 

275 

276 confidence = "100" # SimpleQA doesn't have confidence 

277 

278 # Format graded result 

279 return { 

280 "extracted_by_grader": extracted_answer, 

281 "reasoning": reasoning, 

282 "is_correct": is_correct, 

283 "graded_confidence": confidence, 

284 "grader_response": grading_response, 

285 } 

286 

287 except Exception as e: 

288 logger.exception("Error grading single result") 

289 return { 

290 "grading_error": str(e), 

291 "is_correct": False, 

292 "graded_confidence": "0", 

293 "grader_response": f"Grading failed: {e!s}", 

294 } 

295 finally: 

296 from ..utilities.resource_utils import safe_close 

297 

298 safe_close(evaluation_llm, "grader LLM") 

299 

300 

301def grade_results( 

302 results_file: str, 

303 output_file: str, 

304 dataset_type: str = "simpleqa", 

305 evaluation_config: Optional[Dict[str, Any]] = None, 

306 progress_callback: Optional[Callable[[int, int, Dict], None]] = None, 

307) -> List[Dict[str, Any]]: 

308 """ 

309 Grade benchmark results using LLM. 

310 

311 Args: 

312 results_file: Path to results file 

313 output_file: Path to save graded results 

314 dataset_type: Type of dataset 

315 evaluation_config: Optional custom config for evaluation LLM 

316 progress_callback: Optional callback for progress updates 

317 

318 Returns: 

319 List of graded results 

320 """ 

321 # Get evaluation LLM 

322 evaluation_llm = get_evaluation_llm(evaluation_config) 

323 

324 try: 

325 return _grade_results_inner( 

326 evaluation_llm, 

327 results_file, 

328 output_file, 

329 dataset_type, 

330 progress_callback, 

331 ) 

332 finally: 

333 from ..utilities.resource_utils import safe_close 

334 

335 safe_close(evaluation_llm, "grader LLM") 

336 

337 

338def _grade_results_inner( 

339 evaluation_llm, 

340 results_file: str, 

341 output_file: str, 

342 dataset_type: str, 

343 progress_callback: Optional[Callable[[int, int, Dict], None]], 

344) -> List[Dict[str, Any]]: 

345 """Inner implementation of grade_results, separated for cleanup.""" 

346 # Select appropriate template 

347 template = ( 

348 BROWSECOMP_GRADER_TEMPLATE 

349 if dataset_type.lower() == "browsecomp" 

350 else SIMPLEQA_GRADER_TEMPLATE 

351 ) 

352 

353 # Load results 

354 results = [] 

355 with open(results_file, "r", encoding="utf-8") as f: 

356 for line in f: 

357 if line.strip(): 357 ↛ 356line 357 didn't jump to line 356 because the condition on line 357 was always true

358 results.append(json.loads(line)) 

359 

360 # Remove output file if it exists 

361 output_path = Path(output_file) 

362 if output_path.exists(): 

363 output_path.unlink() 

364 

365 graded_results = [] 

366 correct_count = 0 

367 

368 # Process each result 

369 for idx, result in enumerate(results): 

370 question = result.get("problem", "") 

371 correct_answer = result.get("correct_answer", "") 

372 response = result.get("response", "") 

373 

374 # Call progress callback if provided 

375 if progress_callback: 

376 progress_callback( 

377 idx, 

378 len(results), 

379 {"status": "grading", "index": idx, "total": len(results)}, 

380 ) 

381 

382 logger.info(f"Grading {idx + 1}/{len(results)}: {question[:50]}...") 

383 

384 # Format grading prompt 

385 grading_prompt = template.format( 

386 question=question, correct_answer=correct_answer, response=response 

387 ) 

388 

389 try: 

390 # Grade using LLM 

391 if hasattr(evaluation_llm, "invoke") and callable( 

392 evaluation_llm.invoke 

393 ): 

394 if hasattr(evaluation_llm, "chat_messages"): 

395 # Handle ChatOpenAI and similar models that use messages 

396 grading_response = evaluation_llm.invoke( 

397 [HumanMessage(content=grading_prompt)] 

398 ).content 

399 else: 

400 # Handle other LLM types 

401 grading_response = evaluation_llm.invoke(grading_prompt) 

402 if hasattr(grading_response, "content"): 

403 grading_response = grading_response.content 

404 else: 

405 # Fallback for other LLM interfaces 

406 grading_response = str(evaluation_llm(grading_prompt)) 

407 

408 # Extract grading information using regex 

409 if dataset_type.lower() == "browsecomp": 

410 # BrowseComp-specific extraction 

411 extracted_answer_match = re.search( 

412 r"extracted_final_answer:\s*(.*?)(?:\n|$)", grading_response 

413 ) 

414 extracted_answer = ( 

415 extracted_answer_match.group(1).strip() 

416 if extracted_answer_match 

417 else "None" 

418 ) 

419 

420 reasoning_match = re.search( 

421 r"reasoning:\s*(.*?)(?:\n\n|\ncorrect:|\Z)", 

422 grading_response, 

423 re.DOTALL, 

424 ) 

425 reasoning = ( 

426 reasoning_match.group(1).strip() if reasoning_match else "" 

427 ) 

428 

429 correct_match = re.search( 

430 r"correct:\s*(yes|no)", grading_response, re.IGNORECASE 

431 ) 

432 is_correct = ( 

433 (correct_match.group(1).lower() == "yes") 

434 if correct_match 

435 else False 

436 ) 

437 

438 confidence_match = re.search( 

439 r"confidence:\s*(\d+)", grading_response 

440 ) 

441 confidence = ( 

442 confidence_match.group(1) if confidence_match else "100" 

443 ) 

444 else: 

445 # SimpleQA extraction 

446 extracted_answer_match = re.search( 

447 r"Extracted Answer:\s*(.*?)(?:\n|$)", grading_response 

448 ) 

449 extracted_answer = ( 

450 extracted_answer_match.group(1).strip() 

451 if extracted_answer_match 

452 else "None" 

453 ) 

454 

455 reasoning_match = re.search( 

456 r"Reasoning:\s*(.*?)(?:\nCorrect:|\Z)", 

457 grading_response, 

458 re.DOTALL, 

459 ) 

460 reasoning = ( 

461 reasoning_match.group(1).strip() if reasoning_match else "" 

462 ) 

463 

464 correct_match = re.search( 

465 r"Correct:\s*(yes|no)", grading_response, re.IGNORECASE 

466 ) 

467 is_correct = ( 

468 (correct_match.group(1).lower() == "yes") 

469 if correct_match 

470 else False 

471 ) 

472 

473 confidence = "100" # SimpleQA doesn't have confidence 

474 

475 if is_correct: 

476 correct_count += 1 

477 

478 # Format graded result 

479 graded_result = result.copy() 

480 graded_result.update( 

481 { 

482 "extracted_by_grader": extracted_answer, 

483 "reasoning": reasoning, 

484 "is_correct": is_correct, 

485 "graded_confidence": confidence, 

486 "grader_response": grading_response, 

487 } 

488 ) 

489 

490 graded_results.append(graded_result) 

491 

492 # Write to output file 

493 with open(output_file, "a", encoding="utf-8") as f: 

494 f.write(json.dumps(graded_result) + "\n") 

495 

496 # Call progress callback if provided 

497 if progress_callback: 

498 progress_callback( 

499 idx, 

500 len(results), 

501 { 

502 "status": "graded", 

503 "is_correct": is_correct, 

504 "result": graded_result, 

505 }, 

506 ) 

507 

508 except Exception as e: 

509 logger.exception(f"Error grading result {idx + 1}") 

510 

511 # Handle error 

512 error_result = result.copy() 

513 error_result["grading_error"] = str(e) 

514 

515 with open(output_file, "a", encoding="utf-8") as f: 

516 f.write(json.dumps(error_result) + "\n") 

517 

518 graded_results.append(error_result) 

519 

520 # Call progress callback if provided 

521 if progress_callback: 

522 progress_callback( 

523 idx, 

524 len(results), 

525 { 

526 "status": "error", 

527 "error": str(e), 

528 "result": error_result, 

529 }, 

530 ) 

531 

532 accuracy = correct_count / len(results) if results else 0 

533 logger.info(f"Grading complete. Accuracy: {accuracy:.3f}") 

534 logger.info(f"Correct: {correct_count}/{len(results)}") 

535 

536 return graded_results 

537 

538 

539def human_evaluation( 

540 results_file: str, output_file: str, interactive: bool = True 

541) -> List[Dict[str, Any]]: 

542 """ 

543 Allow for human evaluation of results. 

544 

545 Args: 

546 results_file: Path to results file 

547 output_file: Path to save human-graded results 

548 interactive: Whether to run in interactive console mode 

549 

550 Returns: 

551 List of human-graded results 

552 """ 

553 # Load results 

554 results = [] 

555 with open(results_file, "r", encoding="utf-8") as f: 

556 for line in f: 

557 if line.strip(): 557 ↛ 556line 557 didn't jump to line 556 because the condition on line 557 was always true

558 results.append(json.loads(line)) 

559 

560 # Remove output file if it exists 

561 output_path = Path(output_file) 

562 if output_path.exists(): 562 ↛ 563line 562 didn't jump to line 563 because the condition on line 562 was never true

563 output_path.unlink() 

564 

565 human_graded_results = [] 

566 correct_count = 0 

567 

568 if interactive: 568 ↛ 569line 568 didn't jump to line 569 because the condition on line 568 was never true

569 logger.info(f"Human evaluation: {len(results)} examples to grade") 

570 print(f"Human evaluation: {len(results)} examples to grade") 

571 print( 

572 "For each example, you'll see the question, correct answer, and model's response." 

573 ) 

574 print("You'll be asked to judge if the model's answer is correct.") 

575 

576 for idx, result in enumerate(results): 

577 question = result.get("problem", "") 

578 correct_answer = result.get("correct_answer", "") 

579 response = result.get("response", "") 

580 extracted_answer = result.get("extracted_answer", "") 

581 

582 if interactive: 582 ↛ 583line 582 didn't jump to line 583 because the condition on line 582 was never true

583 print(f"\n\n===== Example {idx + 1}/{len(results)} =====") 

584 print(f"Question: {question}") 

585 print(f"\nCorrect Answer: {correct_answer}") 

586 print(f"\nModel Response: {response}") 

587 print(f"\nExtracted Answer: {extracted_answer}") 

588 

589 # Get human judgment 

590 while True: 

591 judgment = ( 

592 input("\nIs the model's answer correct? (y/n): ") 

593 .strip() 

594 .lower() 

595 ) 

596 if judgment in ["y", "n"]: 

597 break 

598 print("Please enter 'y' or 'n'") 

599 

600 is_correct = judgment == "y" 

601 

602 # Get reasoning 

603 reasoning = input( 

604 "Please provide reasoning for your judgment: " 

605 ).strip() 

606 else: 

607 # Non-interactive mode - placeholder for API/UI implementation 

608 # In a real implementation, this would be filled by UI actions 

609 is_correct = False 

610 reasoning = "Non-interactive evaluation" 

611 

612 if is_correct: 612 ↛ 613line 612 didn't jump to line 613 because the condition on line 612 was never true

613 correct_count += 1 

614 

615 # Update result with human judgment 

616 human_result = result.copy() 

617 human_result.update( 

618 { 

619 "is_correct": is_correct, 

620 "reasoning": reasoning, 

621 "human_evaluation": True, 

622 } 

623 ) 

624 

625 human_graded_results.append(human_result) 

626 

627 # Write to output file 

628 with open(output_file, "a", encoding="utf-8") as f: 

629 f.write(json.dumps(human_result) + "\n") 

630 

631 accuracy = correct_count / len(results) if results else 0 

632 logger.info(f"Human evaluation complete. Accuracy: {accuracy:.3f}") 

633 if interactive: 633 ↛ 634line 633 didn't jump to line 634 because the condition on line 633 was never true

634 print(f"\nHuman evaluation complete. Accuracy: {accuracy:.3f}") 

635 print(f"Correct: {correct_count}/{len(results)}") 

636 

637 return human_graded_results