Coverage for src / local_deep_research / benchmarks / graders.py: 89%

200 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 23:55 +0000

1""" 

2Evaluation and grading functionality. 

3 

4This module provides tools for evaluating model outputs against reference answers. 

5""" 

6 

7import json 

8from loguru import logger 

9from pathlib import Path 

10import re 

11from typing import Any, Callable, Dict, List, Optional 

12 

13from langchain_core.messages.human import HumanMessage 

14 

15from ..config.llm_config import get_llm 

16from ..llm.providers.base import normalize_provider 

17from .templates import BROWSECOMP_GRADER_TEMPLATE, SIMPLEQA_GRADER_TEMPLATE 

18 

19 

20# Default evaluation configuration using Claude 3.7 Sonnet via OpenRouter 

21DEFAULT_EVALUATION_CONFIG = { 

22 "model_name": "anthropic/claude-3.7-sonnet", # Correct model ID for OpenRouter 

23 "provider": "openai_endpoint", # Use OpenRouter 

24 "openai_endpoint_url": "https://openrouter.ai/api/v1", # OpenRouter URL 

25 "temperature": 0, # Zero temp for consistent evaluation 

26 # Note: max_tokens removed as it's not supported by LDR's get_llm() 

27} 

28 

29 

30def get_evaluation_llm( 

31 custom_config: Optional[Dict[str, Any]] = None, 

32 settings_snapshot: Optional[Dict[str, Any]] = None, 

33): 

34 """ 

35 Get an LLM for evaluation purposes using Claude 3.7 Sonnet via OpenRouter 

36 by default, which can be overridden with custom settings. 

37 

38 Args: 

39 custom_config: Optional custom configuration that overrides defaults 

40 settings_snapshot: Optional settings snapshot for thread-safe access 

41 

42 Returns: 

43 An LLM instance for evaluation 

44 """ 

45 # Start with default config (Claude 3.7 Sonnet via OpenRouter) 

46 config = DEFAULT_EVALUATION_CONFIG.copy() 

47 

48 # Override with any custom settings 

49 if custom_config: 

50 config.update(custom_config) 

51 

52 logger.info( 

53 f"Getting evaluation LLM with provider={config['provider']}, model={config['model_name']}" 

54 ) 

55 

56 # Remove any parameters that LDR's get_llm doesn't support 

57 # This ensures compatibility with LDR's implementation 

58 ldr_supported_params = { 

59 "model_name", 

60 "temperature", 

61 "provider", 

62 "openai_endpoint_url", 

63 "api_key", 

64 } 

65 

66 filtered_config = { 

67 k: v for k, v in config.items() if k in ldr_supported_params 

68 } 

69 

70 # Check if we're using openai_endpoint but don't have an API key configured 

71 if normalize_provider(filtered_config.get("provider")) == "openai_endpoint": 

72 # Try to get API key from settings snapshot or environment 

73 api_key = None 

74 

75 if settings_snapshot: 

76 # Get from settings snapshot for thread safety 

77 api_key_setting = settings_snapshot.get( 

78 "llm.openai_endpoint.api_key" 

79 ) 

80 if api_key_setting: 

81 api_key = ( 

82 api_key_setting.get("value") 

83 if isinstance(api_key_setting, dict) 

84 else api_key_setting 

85 ) 

86 else: 

87 # No settings snapshot available 

88 logger.warning( 

89 "No settings snapshot provided for benchmark grader. " 

90 "API key must be provided via settings_snapshot for thread safety." 

91 ) 

92 

93 if not api_key: 

94 logger.warning( 

95 "Using openai_endpoint provider but no API key found. " 

96 "Set the llm.openai_endpoint.api_key setting in the database or " 

97 "LDR_LLM_OPENAI_ENDPOINT_API_KEY environment variable." 

98 ) 

99 # Try to fall back to LDR's config if API key not explicitly provided 

100 # The get_llm function will handle this case 

101 

102 # Get the LLM using LDR's existing function 

103 return get_llm(**filtered_config) 

104 

105 

106def extract_answer_from_response( 

107 response: str, dataset_type: str = "simpleqa" 

108) -> Dict[str, str]: 

109 """ 

110 Extract structured information from LDR's response. 

111 

112 Args: 

113 response: Response from LDR 

114 dataset_type: Type of dataset 

115 

116 Returns: 

117 Dictionary with extracted answer and confidence 

118 """ 

119 # Clean up citations 

120 response = re.sub(r"\[\d+\]", "", response) 

121 

122 # Extract differently based on dataset type 

123 if dataset_type.lower() == "browsecomp": 

124 # Extract the final answer from structured response 

125 answer_match = re.search(r"Exact Answer:\s*(.*?)(?:\n|$)", response) 

126 exact_answer = answer_match.group(1).strip() if answer_match else "None" 

127 

128 # Extract confidence 

129 confidence_match = re.search(r"Confidence:\s*(\d+)%", response) 

130 confidence = confidence_match.group(1) if confidence_match else "100" 

131 

132 return {"extracted_answer": exact_answer, "confidence": confidence} 

133 

134 # For SimpleQA, return the whole response as the answer 

135 return { 

136 "extracted_answer": response, 

137 "confidence": "100", # SimpleQA doesn't have confidence scores 

138 } 

139 

140 

141def grade_single_result( 

142 result_data: Dict[str, Any], 

143 dataset_type: str = "simpleqa", 

144 evaluation_config: Optional[Dict[str, Any]] = None, 

145 settings_snapshot: Optional[Dict[str, Any]] = None, 

146) -> Dict[str, Any]: 

147 """ 

148 Grade a single benchmark result using LLM. 

149 

150 Args: 

151 result_data: Dictionary containing result data with keys: id, problem, correct_answer, response, extracted_answer 

152 dataset_type: Type of dataset 

153 evaluation_config: Optional custom config for evaluation LLM 

154 settings_snapshot: Optional settings snapshot for thread-safe access 

155 

156 Returns: 

157 Dictionary with grading results 

158 """ 

159 # Get evaluation LLM 

160 evaluation_llm = get_evaluation_llm(evaluation_config, settings_snapshot) 

161 

162 try: 

163 # Select appropriate template 

164 template = ( 

165 BROWSECOMP_GRADER_TEMPLATE 

166 if dataset_type.lower() == "browsecomp" 

167 else SIMPLEQA_GRADER_TEMPLATE 

168 ) 

169 

170 question = result_data.get("problem", "") 

171 correct_answer = result_data.get("correct_answer", "") 

172 response = result_data.get("response", "") 

173 

174 logger.info(f"Grading single result: {question[:50]}...") 

175 

176 # Format grading prompt 

177 grading_prompt = template.format( 

178 question=question, correct_answer=correct_answer, response=response 

179 ) 

180 

181 import time 

182 

183 eval_llm_start = time.time() 

184 logger.info( 

185 f"Starting grading LLM call (prompt length: {len(grading_prompt)} chars)..." 

186 ) 

187 

188 # Grade using LLM 

189 if hasattr(evaluation_llm, "invoke") and callable( 

190 evaluation_llm.invoke 

191 ): 

192 if hasattr(evaluation_llm, "chat_messages"): 

193 # Handle ChatOpenAI and similar models that use messages 

194 grading_response = evaluation_llm.invoke( 

195 [HumanMessage(content=grading_prompt)] 

196 ).content 

197 else: 

198 # Handle other LLM types 

199 grading_response = evaluation_llm.invoke(grading_prompt) 

200 if hasattr(grading_response, "content"): 

201 grading_response = grading_response.content 

202 else: 

203 # Fallback for other LLM interfaces 

204 grading_response = str(evaluation_llm(grading_prompt)) 

205 

206 eval_llm_elapsed = time.time() - eval_llm_start 

207 logger.info(f"Grading LLM call completed in {eval_llm_elapsed:.2f}s") 

208 

209 # Extract grading information using regex 

210 if dataset_type.lower() == "browsecomp": 

211 # BrowseComp-specific extraction 

212 extracted_answer_match = re.search( 

213 r"extracted_final_answer:\s*(.*?)(?:\n|$)", grading_response 

214 ) 

215 extracted_answer = ( 

216 extracted_answer_match.group(1).strip() 

217 if extracted_answer_match 

218 else "None" 

219 ) 

220 

221 reasoning_match = re.search( 

222 r"reasoning:\s*(.*?)(?:\n\n|\ncorrect:|\Z)", 

223 grading_response, 

224 re.DOTALL, 

225 ) 

226 reasoning = ( 

227 reasoning_match.group(1).strip() if reasoning_match else "" 

228 ) 

229 

230 correct_match = re.search( 

231 r"correct:\s*(yes|no)", grading_response, re.IGNORECASE 

232 ) 

233 is_correct = ( 

234 (correct_match.group(1).lower() == "yes") 

235 if correct_match 

236 else False 

237 ) 

238 

239 confidence_match = re.search( 

240 r"confidence:\s*(\d+)", grading_response 

241 ) 

242 confidence = ( 

243 confidence_match.group(1) if confidence_match else "100" 

244 ) 

245 else: 

246 # SimpleQA extraction 

247 extracted_answer_match = re.search( 

248 r"Extracted Answer:\s*(.*?)(?:\n|$)", grading_response 

249 ) 

250 extracted_answer = ( 

251 extracted_answer_match.group(1).strip() 

252 if extracted_answer_match 

253 else "None" 

254 ) 

255 

256 reasoning_match = re.search( 

257 r"Reasoning:\s*(.*?)(?:\nCorrect:|\Z)", 

258 grading_response, 

259 re.DOTALL, 

260 ) 

261 reasoning = ( 

262 reasoning_match.group(1).strip() if reasoning_match else "" 

263 ) 

264 

265 correct_match = re.search( 

266 r"Correct:\s*(yes|no)", grading_response, re.IGNORECASE 

267 ) 

268 is_correct = ( 

269 (correct_match.group(1).lower() == "yes") 

270 if correct_match 

271 else False 

272 ) 

273 

274 confidence = "100" # SimpleQA doesn't have confidence 

275 

276 # Format graded result 

277 return { 

278 "extracted_by_grader": extracted_answer, 

279 "reasoning": reasoning, 

280 "is_correct": is_correct, 

281 "graded_confidence": confidence, 

282 "grader_response": grading_response, 

283 } 

284 

285 except Exception as e: 

286 logger.exception("Error grading single result") 

287 return { 

288 "grading_error": str(e), 

289 "is_correct": False, 

290 "graded_confidence": "0", 

291 "grader_response": f"Grading failed: {e!s}", 

292 } 

293 finally: 

294 from ..utilities.resource_utils import safe_close 

295 

296 safe_close(evaluation_llm, "grader LLM") 

297 

298 

299def grade_results( 

300 results_file: str, 

301 output_file: str, 

302 dataset_type: str = "simpleqa", 

303 evaluation_config: Optional[Dict[str, Any]] = None, 

304 progress_callback: Optional[Callable[[int, int, Dict], None]] = None, 

305) -> List[Dict[str, Any]]: 

306 """ 

307 Grade benchmark results using LLM. 

308 

309 Args: 

310 results_file: Path to results file 

311 output_file: Path to save graded results 

312 dataset_type: Type of dataset 

313 evaluation_config: Optional custom config for evaluation LLM 

314 progress_callback: Optional callback for progress updates 

315 

316 Returns: 

317 List of graded results 

318 """ 

319 # Get evaluation LLM 

320 evaluation_llm = get_evaluation_llm(evaluation_config) 

321 

322 try: 

323 return _grade_results_inner( 

324 evaluation_llm, 

325 results_file, 

326 output_file, 

327 dataset_type, 

328 progress_callback, 

329 ) 

330 finally: 

331 from ..utilities.resource_utils import safe_close 

332 

333 safe_close(evaluation_llm, "grader LLM") 

334 

335 

336def _grade_results_inner( 

337 evaluation_llm, 

338 results_file: str, 

339 output_file: str, 

340 dataset_type: str, 

341 progress_callback: Optional[Callable[[int, int, Dict], None]], 

342) -> List[Dict[str, Any]]: 

343 """Inner implementation of grade_results, separated for cleanup.""" 

344 # Select appropriate template 

345 template = ( 

346 BROWSECOMP_GRADER_TEMPLATE 

347 if dataset_type.lower() == "browsecomp" 

348 else SIMPLEQA_GRADER_TEMPLATE 

349 ) 

350 

351 # Load results 

352 results = [] 

353 with open(results_file, "r") as f: 

354 for line in f: 

355 if line.strip(): 355 ↛ 354line 355 didn't jump to line 354 because the condition on line 355 was always true

356 results.append(json.loads(line)) 

357 

358 # Remove output file if it exists 

359 output_path = Path(output_file) 

360 if output_path.exists(): 

361 output_path.unlink() 

362 

363 graded_results = [] 

364 correct_count = 0 

365 

366 # Process each result 

367 for idx, result in enumerate(results): 

368 question = result.get("problem", "") 

369 correct_answer = result.get("correct_answer", "") 

370 response = result.get("response", "") 

371 

372 # Call progress callback if provided 

373 if progress_callback: 

374 progress_callback( 

375 idx, 

376 len(results), 

377 {"status": "grading", "index": idx, "total": len(results)}, 

378 ) 

379 

380 logger.info(f"Grading {idx + 1}/{len(results)}: {question[:50]}...") 

381 

382 # Format grading prompt 

383 grading_prompt = template.format( 

384 question=question, correct_answer=correct_answer, response=response 

385 ) 

386 

387 try: 

388 # Grade using LLM 

389 if hasattr(evaluation_llm, "invoke") and callable( 

390 evaluation_llm.invoke 

391 ): 

392 if hasattr(evaluation_llm, "chat_messages"): 

393 # Handle ChatOpenAI and similar models that use messages 

394 grading_response = evaluation_llm.invoke( 

395 [HumanMessage(content=grading_prompt)] 

396 ).content 

397 else: 

398 # Handle other LLM types 

399 grading_response = evaluation_llm.invoke(grading_prompt) 

400 if hasattr(grading_response, "content"): 

401 grading_response = grading_response.content 

402 else: 

403 # Fallback for other LLM interfaces 

404 grading_response = str(evaluation_llm(grading_prompt)) 

405 

406 # Extract grading information using regex 

407 if dataset_type.lower() == "browsecomp": 

408 # BrowseComp-specific extraction 

409 extracted_answer_match = re.search( 

410 r"extracted_final_answer:\s*(.*?)(?:\n|$)", grading_response 

411 ) 

412 extracted_answer = ( 

413 extracted_answer_match.group(1).strip() 

414 if extracted_answer_match 

415 else "None" 

416 ) 

417 

418 reasoning_match = re.search( 

419 r"reasoning:\s*(.*?)(?:\n\n|\ncorrect:|\Z)", 

420 grading_response, 

421 re.DOTALL, 

422 ) 

423 reasoning = ( 

424 reasoning_match.group(1).strip() if reasoning_match else "" 

425 ) 

426 

427 correct_match = re.search( 

428 r"correct:\s*(yes|no)", grading_response, re.IGNORECASE 

429 ) 

430 is_correct = ( 

431 (correct_match.group(1).lower() == "yes") 

432 if correct_match 

433 else False 

434 ) 

435 

436 confidence_match = re.search( 

437 r"confidence:\s*(\d+)", grading_response 

438 ) 

439 confidence = ( 

440 confidence_match.group(1) if confidence_match else "100" 

441 ) 

442 else: 

443 # SimpleQA extraction 

444 extracted_answer_match = re.search( 

445 r"Extracted Answer:\s*(.*?)(?:\n|$)", grading_response 

446 ) 

447 extracted_answer = ( 

448 extracted_answer_match.group(1).strip() 

449 if extracted_answer_match 

450 else "None" 

451 ) 

452 

453 reasoning_match = re.search( 

454 r"Reasoning:\s*(.*?)(?:\nCorrect:|\Z)", 

455 grading_response, 

456 re.DOTALL, 

457 ) 

458 reasoning = ( 

459 reasoning_match.group(1).strip() if reasoning_match else "" 

460 ) 

461 

462 correct_match = re.search( 

463 r"Correct:\s*(yes|no)", grading_response, re.IGNORECASE 

464 ) 

465 is_correct = ( 

466 (correct_match.group(1).lower() == "yes") 

467 if correct_match 

468 else False 

469 ) 

470 

471 confidence = "100" # SimpleQA doesn't have confidence 

472 

473 if is_correct: 

474 correct_count += 1 

475 

476 # Format graded result 

477 graded_result = result.copy() 

478 graded_result.update( 

479 { 

480 "extracted_by_grader": extracted_answer, 

481 "reasoning": reasoning, 

482 "is_correct": is_correct, 

483 "graded_confidence": confidence, 

484 "grader_response": grading_response, 

485 } 

486 ) 

487 

488 graded_results.append(graded_result) 

489 

490 # Write to output file 

491 with open(output_file, "a") as f: 

492 f.write(json.dumps(graded_result) + "\n") 

493 

494 # Call progress callback if provided 

495 if progress_callback: 

496 progress_callback( 

497 idx, 

498 len(results), 

499 { 

500 "status": "graded", 

501 "is_correct": is_correct, 

502 "result": graded_result, 

503 }, 

504 ) 

505 

506 except Exception as e: 

507 logger.exception(f"Error grading result {idx + 1}") 

508 

509 # Handle error 

510 error_result = result.copy() 

511 error_result["grading_error"] = str(e) 

512 

513 with open(output_file, "a") as f: 

514 f.write(json.dumps(error_result) + "\n") 

515 

516 graded_results.append(error_result) 

517 

518 # Call progress callback if provided 

519 if progress_callback: 

520 progress_callback( 

521 idx, 

522 len(results), 

523 { 

524 "status": "error", 

525 "error": str(e), 

526 "result": error_result, 

527 }, 

528 ) 

529 

530 accuracy = correct_count / len(results) if results else 0 

531 logger.info(f"Grading complete. Accuracy: {accuracy:.3f}") 

532 logger.info(f"Correct: {correct_count}/{len(results)}") 

533 

534 return graded_results 

535 

536 

537def human_evaluation( 

538 results_file: str, output_file: str, interactive: bool = True 

539) -> List[Dict[str, Any]]: 

540 """ 

541 Allow for human evaluation of results. 

542 

543 Args: 

544 results_file: Path to results file 

545 output_file: Path to save human-graded results 

546 interactive: Whether to run in interactive console mode 

547 

548 Returns: 

549 List of human-graded results 

550 """ 

551 # Load results 

552 results = [] 

553 with open(results_file, "r") as f: 

554 for line in f: 

555 if line.strip(): 555 ↛ 554line 555 didn't jump to line 554 because the condition on line 555 was always true

556 results.append(json.loads(line)) 

557 

558 # Remove output file if it exists 

559 output_path = Path(output_file) 

560 if output_path.exists(): 560 ↛ 561line 560 didn't jump to line 561 because the condition on line 560 was never true

561 output_path.unlink() 

562 

563 human_graded_results = [] 

564 correct_count = 0 

565 

566 if interactive: 566 ↛ 567line 566 didn't jump to line 567 because the condition on line 566 was never true

567 logger.info(f"Human evaluation: {len(results)} examples to grade") 

568 print(f"Human evaluation: {len(results)} examples to grade") 

569 print( 

570 "For each example, you'll see the question, correct answer, and model's response." 

571 ) 

572 print("You'll be asked to judge if the model's answer is correct.") 

573 

574 for idx, result in enumerate(results): 

575 question = result.get("problem", "") 

576 correct_answer = result.get("correct_answer", "") 

577 response = result.get("response", "") 

578 extracted_answer = result.get("extracted_answer", "") 

579 

580 if interactive: 580 ↛ 581line 580 didn't jump to line 581 because the condition on line 580 was never true

581 print(f"\n\n===== Example {idx + 1}/{len(results)} =====") 

582 print(f"Question: {question}") 

583 print(f"\nCorrect Answer: {correct_answer}") 

584 print(f"\nModel Response: {response}") 

585 print(f"\nExtracted Answer: {extracted_answer}") 

586 

587 # Get human judgment 

588 while True: 

589 judgment = ( 

590 input("\nIs the model's answer correct? (y/n): ") 

591 .strip() 

592 .lower() 

593 ) 

594 if judgment in ["y", "n"]: 

595 break 

596 print("Please enter 'y' or 'n'") 

597 

598 is_correct = judgment == "y" 

599 

600 # Get reasoning 

601 reasoning = input( 

602 "Please provide reasoning for your judgment: " 

603 ).strip() 

604 else: 

605 # Non-interactive mode - placeholder for API/UI implementation 

606 # In a real implementation, this would be filled by UI actions 

607 is_correct = False 

608 reasoning = "Non-interactive evaluation" 

609 

610 if is_correct: 610 ↛ 611line 610 didn't jump to line 611 because the condition on line 610 was never true

611 correct_count += 1 

612 

613 # Update result with human judgment 

614 human_result = result.copy() 

615 human_result.update( 

616 { 

617 "is_correct": is_correct, 

618 "reasoning": reasoning, 

619 "human_evaluation": True, 

620 } 

621 ) 

622 

623 human_graded_results.append(human_result) 

624 

625 # Write to output file 

626 with open(output_file, "a") as f: 

627 f.write(json.dumps(human_result) + "\n") 

628 

629 accuracy = correct_count / len(results) if results else 0 

630 logger.info(f"Human evaluation complete. Accuracy: {accuracy:.3f}") 

631 if interactive: 631 ↛ 632line 631 didn't jump to line 632 because the condition on line 631 was never true

632 print(f"\nHuman evaluation complete. Accuracy: {accuracy:.3f}") 

633 print(f"Correct: {correct_count}/{len(results)}") 

634 

635 return human_graded_results