Coverage for src / local_deep_research / benchmarks / optimization / optuna_optimizer.py: 78%
357 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:55 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:55 +0000
1"""
2Optuna-based parameter optimizer for Local Deep Research.
4This module provides the core optimization functionality using Optuna
5to find optimal parameters for the research system, balancing quality
6and performance metrics.
7"""
9import os
10from pathlib import Path
11import time
12from datetime import datetime, UTC
13from functools import partial
14from typing import Any, Callable, Dict, List, Optional, Tuple
16import joblib
17import numpy as np
18import optuna
19from optuna.visualization import (
20 plot_contour,
21 plot_optimization_history,
22 plot_param_importances,
23 plot_slice,
24)
26from local_deep_research.benchmarks.efficiency.speed_profiler import (
27 SpeedProfiler,
28)
29from local_deep_research.security import sanitize_data
30from loguru import logger
32from local_deep_research.benchmarks.evaluators import (
33 CompositeBenchmarkEvaluator,
34)
36# Import benchmark evaluator components
38# Try to import visualization libraries, but don't fail if not available
39try:
40 import matplotlib.pyplot as plt
41 from matplotlib.lines import Line2D
43 # We'll use matplotlib for plotting visualization results
45 PLOTTING_AVAILABLE = True
46except ImportError:
47 PLOTTING_AVAILABLE = False
48 logger.warning("Matplotlib not available, visualization will be limited")
51class OptunaOptimizer:
52 """
53 Optimize parameters for Local Deep Research using Optuna.
55 This class provides functionality to:
56 1. Define search spaces for parameter optimization
57 2. Evaluate parameter combinations using objective functions
58 3. Find optimal parameters via Optuna
59 4. Visualize and analyze optimization results
60 """
62 def __init__(
63 self,
64 base_query: str,
65 output_dir: str = "optimization_results",
66 model_name: Optional[str] = None,
67 provider: Optional[str] = None,
68 search_tool: Optional[str] = None,
69 temperature: float = 0.7,
70 n_trials: int = 30,
71 timeout: Optional[int] = None,
72 n_jobs: int = 1,
73 study_name: Optional[str] = None,
74 optimization_metrics: Optional[List[str]] = None,
75 metric_weights: Optional[Dict[str, float]] = None,
76 progress_callback: Optional[Callable[[int, int, Dict], None]] = None,
77 benchmark_weights: Optional[Dict[str, float]] = None,
78 ):
79 """
80 Initialize the optimizer.
82 Args:
83 base_query: The research query to use for all experiments
84 output_dir: Directory to save optimization results
85 model_name: Name of the LLM model to use
86 provider: LLM provider
87 search_tool: Search engine to use
88 temperature: LLM temperature
89 n_trials: Number of parameter combinations to try
90 timeout: Maximum seconds to run optimization (None for no limit)
91 n_jobs: Number of parallel jobs for optimization
92 study_name: Name of the Optuna study
93 optimization_metrics: List of metrics to optimize (default: ["quality", "speed"])
94 metric_weights: Dictionary of weights for each metric (e.g., {"quality": 0.6, "speed": 0.4})
95 progress_callback: Optional callback for progress updates
96 benchmark_weights: Dictionary mapping benchmark types to weights
97 (e.g., {"simpleqa": 0.6, "browsecomp": 0.4})
98 If None, only SimpleQA is used with weight 1.0
99 """
100 self.base_query = base_query
101 self.output_dir = output_dir
102 self.model_name = model_name
103 self.provider = provider
104 self.search_tool = search_tool
105 self.temperature = temperature
106 self.n_trials = n_trials
107 self.timeout = timeout
108 self.n_jobs = n_jobs
109 self.optimization_metrics = optimization_metrics or ["quality", "speed"]
110 self.metric_weights = metric_weights or {"quality": 0.6, "speed": 0.4}
111 self.progress_callback = progress_callback
113 # Initialize benchmark evaluator with weights
114 self.benchmark_weights = benchmark_weights or {"simpleqa": 1.0}
115 self.benchmark_evaluator = CompositeBenchmarkEvaluator(
116 self.benchmark_weights
117 )
119 # Normalize weights to sum to 1.0
120 total_weight = sum(self.metric_weights.values())
121 if total_weight > 0:
122 self.metric_weights = {
123 k: v / total_weight for k, v in self.metric_weights.items()
124 }
126 # Generate a unique study name if not provided
127 timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S")
128 self.study_name = study_name or f"ldr_opt_{timestamp}"
130 # Create output directory
131 os.makedirs(output_dir, exist_ok=True)
133 # Store the trial history for analysis
134 self.trials_history: List[Dict[str, Any]] = []
136 # Storage for the best parameters and study
137 self.best_params: Optional[Dict[str, Any]] = None
138 self.study: Optional[optuna.Study] = None
140 def optimize(
141 self, param_space: Optional[Dict[str, Any]] = None
142 ) -> Tuple[Dict[str, Any], float]:
143 """
144 Run the optimization process using Optuna.
146 Args:
147 param_space: Dictionary defining parameter search spaces
148 (if None, use default spaces)
150 Returns:
151 Tuple containing (best_parameters, best_score)
152 """
153 param_space = param_space or self._get_default_param_space()
155 # Create a study object
156 storage_name = f"sqlite:///{self.output_dir}/{self.study_name}.db"
157 self.study = optuna.create_study(
158 study_name=self.study_name,
159 storage=storage_name,
160 load_if_exists=True,
161 direction="maximize",
162 sampler=optuna.samplers.TPESampler(seed=42),
163 )
165 # Create partial function with param_space
166 objective = partial(self._objective, param_space=param_space)
168 # Log optimization start
169 logger.info(
170 f"Starting optimization with {self.n_trials} trials, {self.n_jobs} parallel jobs"
171 )
172 logger.info(f"Parameter space: {param_space}")
173 logger.info(f"Metric weights: {self.metric_weights}")
174 logger.info(f"Benchmark weights: {self.benchmark_weights}")
176 # Initialize progress tracking
177 if self.progress_callback:
178 self.progress_callback(
179 0,
180 self.n_trials,
181 {
182 "status": "starting",
183 "stage": "initialization",
184 "trials_completed": 0,
185 "total_trials": self.n_trials,
186 },
187 )
189 try:
190 # Run optimization
191 self.study.optimize(
192 objective,
193 n_trials=self.n_trials,
194 timeout=self.timeout,
195 n_jobs=self.n_jobs,
196 callbacks=[self._optimization_callback],
197 show_progress_bar=True,
198 )
200 # Store best parameters
201 if self.study is None: 201 ↛ 202line 201 didn't jump to line 202 because the condition on line 201 was never true
202 raise RuntimeError("Study was not created")
203 _completed_study = self.study
204 self.best_params = _completed_study.best_params
206 # Save the results
207 self._save_results()
209 # Create visualizations
210 self._create_visualizations()
212 if self.best_params is None: 212 ↛ 213line 212 didn't jump to line 213 because the condition on line 212 was never true
213 raise RuntimeError("No best parameters found")
214 logger.info(
215 f"Optimization complete. Best parameters: {self.best_params}"
216 )
217 logger.info(f"Best value: {_completed_study.best_value}")
219 # Report completion
220 if self.progress_callback:
221 self.progress_callback(
222 self.n_trials,
223 self.n_trials,
224 {
225 "status": "completed",
226 "stage": "finished",
227 "trials_completed": len(_completed_study.trials),
228 "total_trials": self.n_trials,
229 "best_params": self.best_params,
230 "best_value": _completed_study.best_value,
231 },
232 )
234 return self.best_params, _completed_study.best_value
236 except KeyboardInterrupt:
237 logger.info("Optimization interrupted by user")
238 # Still save what we have
239 self._save_results()
240 self._create_visualizations()
242 if self.study is None: 242 ↛ 243line 242 didn't jump to line 243 because the condition on line 242 was never true
243 raise RuntimeError("Study was not created")
244 _interrupted_study = self.study
245 # Report interruption
246 if self.progress_callback:
247 self.progress_callback(
248 len(_interrupted_study.trials),
249 self.n_trials,
250 {
251 "status": "interrupted",
252 "stage": "interrupted",
253 "trials_completed": len(_interrupted_study.trials),
254 "total_trials": self.n_trials,
255 "best_params": _interrupted_study.best_params,
256 "best_value": _interrupted_study.best_value,
257 },
258 )
260 return _interrupted_study.best_params, _interrupted_study.best_value
262 def _get_default_param_space(self) -> Dict[str, Any]:
263 """
264 Get default parameter search space.
266 Returns:
267 Dictionary defining the default parameter search spaces
268 """
269 return {
270 "iterations": {
271 "type": "int",
272 "low": 1,
273 "high": 5,
274 "step": 1,
275 },
276 "questions_per_iteration": {
277 "type": "int",
278 "low": 1,
279 "high": 5,
280 "step": 1,
281 },
282 "search_strategy": {
283 "type": "categorical",
284 "choices": [
285 "iterdrag",
286 "standard",
287 "rapid",
288 "parallel",
289 "source_based",
290 ],
291 },
292 "max_results": {
293 "type": "int",
294 "low": 10,
295 "high": 100,
296 "step": 10,
297 },
298 }
300 def _objective(
301 self, trial: optuna.Trial, param_space: Dict[str, Any]
302 ) -> float:
303 """
304 Objective function for Optuna optimization.
306 Args:
307 trial: Optuna trial object
308 param_space: Dictionary defining parameter search spaces
310 Returns:
311 Score to maximize
312 """
313 # Generate parameters for this trial
314 params: Dict[str, Any] = {}
315 for param_name, param_config in param_space.items():
316 param_type = param_config["type"]
318 if param_type == "int":
319 params[param_name] = trial.suggest_int(
320 param_name,
321 param_config["low"],
322 param_config["high"],
323 step=param_config.get("step", 1),
324 )
325 elif param_type == "float":
326 params[param_name] = trial.suggest_float(
327 param_name,
328 param_config["low"],
329 param_config["high"],
330 step=param_config.get("step"),
331 log=param_config.get("log", False),
332 )
333 elif param_type == "categorical":
334 params[param_name] = trial.suggest_categorical(
335 param_name, param_config["choices"]
336 )
338 # Log the trial parameters
339 logger.info(f"Trial {trial.number}: {params}")
341 # Update progress callback if available
342 if self.progress_callback:
343 self.progress_callback(
344 trial.number,
345 self.n_trials,
346 {
347 "status": "running",
348 "stage": "trial_started",
349 "trial_number": trial.number,
350 "params": params,
351 "trials_completed": trial.number,
352 "total_trials": self.n_trials,
353 },
354 )
356 # Run an experiment with these parameters
357 try:
358 start_time = time.time()
359 result = self._run_experiment(params)
360 duration = time.time() - start_time
362 # Store details about the trial
363 trial_info = {
364 "trial_number": trial.number,
365 "params": params,
366 "result": result,
367 "score": result.get("score", 0),
368 "duration": duration,
369 "timestamp": datetime.now(UTC).isoformat(),
370 }
371 self.trials_history.append(trial_info)
373 # Update callback with results
374 if self.progress_callback:
375 self.progress_callback(
376 trial.number,
377 self.n_trials,
378 {
379 "status": "completed",
380 "stage": "trial_completed",
381 "trial_number": trial.number,
382 "params": params,
383 "score": result.get("score", 0),
384 "trials_completed": trial.number + 1,
385 "total_trials": self.n_trials,
386 },
387 )
389 logger.info(
390 f"Trial {trial.number} completed: {params}, score: {result['score']:.4f}"
391 )
393 return float(result["score"])
394 except Exception as e:
395 logger.exception(f"Error in trial {trial.number}")
397 # Update callback with error
398 if self.progress_callback:
399 self.progress_callback(
400 trial.number,
401 self.n_trials,
402 {
403 "status": "error",
404 "stage": "trial_error",
405 "trial_number": trial.number,
406 "params": params,
407 "error": str(e),
408 "trials_completed": trial.number,
409 "total_trials": self.n_trials,
410 },
411 )
413 return float("-inf") # Return a very low score for failed trials
415 def _run_experiment(self, params: Dict[str, Any]) -> Dict[str, Any]:
416 """
417 Run a single experiment with the given parameters.
419 Args:
420 params: Dictionary of parameters to test
422 Returns:
423 Results dictionary with metrics and score
424 """
425 # Extract parameters
426 iterations = params.get("iterations", 2)
427 questions_per_iteration = params.get("questions_per_iteration", 2)
428 search_strategy = params.get("search_strategy", "iterdrag")
429 max_results = params.get("max_results", 50)
431 # Initialize profiling tools
432 speed_profiler = SpeedProfiler()
434 # Start profiling
435 speed_profiler.start()
437 try:
438 # Create system configuration
439 system_config = {
440 "iterations": iterations,
441 "questions_per_iteration": questions_per_iteration,
442 "search_strategy": search_strategy,
443 "search_tool": self.search_tool,
444 "max_results": max_results,
445 "model_name": self.model_name,
446 "provider": self.provider,
447 }
449 # Evaluate quality using composite benchmark evaluator
450 # Use a small number of examples for efficiency
451 benchmark_dir = str(Path(self.output_dir) / "benchmark_temp")
452 quality_results = self.benchmark_evaluator.evaluate(
453 system_config=system_config,
454 num_examples=5, # Small number for optimization efficiency
455 output_dir=benchmark_dir,
456 )
458 # Stop timing
459 speed_profiler.stop()
460 timing_results = speed_profiler.get_summary()
462 # Extract key metrics
463 quality_score = quality_results.get("quality_score", 0.0)
464 benchmark_results = quality_results.get("benchmark_results", {})
466 # Speed score: convert duration to a 0-1 score where faster is better
467 # Using a reasonable threshold (e.g., 180 seconds for 5 examples)
468 # Below this threshold: high score, above it: declining score
469 total_duration = timing_results.get("total_duration", 180)
470 speed_score = max(0.0, min(1.0, 1.0 - (total_duration - 60) / 180))
472 # Calculate combined score based on weights
473 combined_score = (
474 self.metric_weights.get("quality", 0.6) * quality_score
475 + self.metric_weights.get("speed", 0.4) * speed_score
476 )
478 # Return streamlined results
479 return {
480 "quality_score": quality_score,
481 "benchmark_results": benchmark_results,
482 "speed_score": speed_score,
483 "total_duration": total_duration,
484 "score": combined_score,
485 "success": True,
486 }
488 except Exception as e:
489 # Stop profiling on error
490 speed_profiler.stop()
492 # Log error
493 logger.exception("Error in experiment")
495 # Return error information
496 return {"error": str(e), "score": 0.0, "success": False}
498 def _optimization_callback(self, study: optuna.Study, trial: optuna.Trial):
499 """
500 Callback for the Optuna optimization process.
502 Args:
503 study: Optuna study object
504 trial: Current trial
505 """
506 # Save intermediate results periodically
507 if trial.number % 10 == 0 and trial.number > 0:
508 self._save_results()
509 self._create_quick_visualizations()
511 def _save_results(self):
512 """Save the optimization results to disk."""
513 # Create a timestamp for filenames
514 timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S")
516 # Save trial history
517 from ...security.file_write_verifier import write_json_verified
519 history_file = str(
520 Path(self.output_dir) / f"{self.study_name}_history.json"
521 )
523 # Convert numpy values to native Python types for JSON serialization
524 clean_history: List[Dict[str, Any]] = []
525 for trial in self.trials_history:
526 clean_trial: Dict[str, Any] = {}
527 for k, v in trial.items():
528 if isinstance(v, dict):
529 clean_trial[k] = {
530 dk: (float(dv) if isinstance(dv, np.number) else dv)
531 for dk, dv in v.items()
532 }
533 elif isinstance(v, np.number):
534 clean_trial[k] = float(v)
535 else:
536 clean_trial[k] = v
537 clean_history.append(clean_trial)
539 # Sanitize sensitive data before writing to disk
540 sanitized_history = sanitize_data(clean_history)
542 write_json_verified(
543 history_file,
544 sanitized_history,
545 "benchmark.allow_file_output",
546 context="optimization history",
547 )
549 # Save current best parameters
550 if (
551 self.study
552 and hasattr(self.study, "best_params")
553 and self.study.best_params
554 ):
555 best_params_file = str(
556 Path(self.output_dir) / f"{self.study_name}_best_params.json"
557 )
559 best_params_data = {
560 "best_params": self.study.best_params,
561 "best_value": float(self.study.best_value),
562 "n_trials": len(self.study.trials),
563 "timestamp": timestamp,
564 "base_query": self.base_query,
565 "model_name": self.model_name,
566 "provider": self.provider,
567 "search_tool": self.search_tool,
568 "metric_weights": self.metric_weights,
569 "benchmark_weights": self.benchmark_weights,
570 }
572 # Sanitize sensitive data before writing to disk
573 sanitized_best_params = sanitize_data(best_params_data)
575 write_json_verified(
576 best_params_file,
577 sanitized_best_params,
578 "benchmark.allow_file_output",
579 context="optimization best params",
580 )
582 # Save the Optuna study
583 if self.study:
584 study_file = str(
585 Path(self.output_dir) / f"{self.study_name}_study.pkl"
586 )
587 joblib.dump(self.study, study_file)
589 logger.info(f"Results saved to {self.output_dir}")
591 def _create_visualizations(self):
592 """Create and save comprehensive visualizations of the optimization results."""
593 if not PLOTTING_AVAILABLE:
594 logger.warning(
595 "Matplotlib not available, skipping visualization creation"
596 )
597 return
599 if not self.study or len(self.study.trials) < 2:
600 logger.warning("Not enough trials to create visualizations")
601 return
603 # Create directory for visualizations
604 _viz_dir_path = Path(self.output_dir) / "visualizations"
605 _viz_dir_path.mkdir(parents=True, exist_ok=True)
606 viz_dir = str(_viz_dir_path)
608 # Create Optuna visualizations
609 self._create_optuna_visualizations(viz_dir)
611 # Create custom visualizations
612 self._create_custom_visualizations(viz_dir)
614 logger.info(f"Visualizations saved to {viz_dir}")
616 def _create_quick_visualizations(self):
617 """Create a smaller set of visualizations for intermediate progress."""
618 if (
619 not PLOTTING_AVAILABLE
620 or not self.study
621 or len(self.study.trials) < 2
622 ):
623 return
625 # Create directory for visualizations
626 _quick_viz_dir_path = Path(self.output_dir) / "visualizations"
627 _quick_viz_dir_path.mkdir(parents=True, exist_ok=True)
628 viz_dir = str(_quick_viz_dir_path)
630 # Create optimization history only (faster than full visualization)
631 try:
632 fig = plot_optimization_history(self.study)
633 fig.write_image(
634 str(
635 Path(viz_dir)
636 / f"{self.study_name}_optimization_history_current.png"
637 )
638 )
639 except Exception:
640 logger.exception("Error creating optimization history plot")
642 def _create_optuna_visualizations(self, viz_dir: str):
643 """
644 Create and save Optuna's built-in visualizations.
646 Args:
647 viz_dir: Directory to save visualizations
648 """
649 if not self.study: 649 ↛ 650line 649 didn't jump to line 650 because the condition on line 649 was never true
650 return
651 study = self.study
652 timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S")
654 # 1. Optimization history
655 try:
656 fig = plot_optimization_history(study)
657 fig.write_image(
658 str(
659 Path(viz_dir)
660 / f"{self.study_name}_optimization_history_{timestamp}.png"
661 )
662 )
663 except Exception:
664 logger.exception("Error creating optimization history plot")
666 # 2. Parameter importances
667 try:
668 fig = plot_param_importances(study)
669 fig.write_image(
670 str(
671 Path(viz_dir)
672 / f"{self.study_name}_param_importances_{timestamp}.png"
673 )
674 )
675 except Exception:
676 logger.exception("Error creating parameter importances plot")
678 # 3. Slice plot for each parameter
679 try:
680 for param_name in study.best_params.keys():
681 fig = plot_slice(study, [param_name])
682 fig.write_image(
683 str(
684 Path(viz_dir)
685 / f"{self.study_name}_slice_{param_name}_{timestamp}.png"
686 )
687 )
688 except Exception:
689 logger.exception("Error creating slice plots")
691 # 4. Contour plots for important parameter pairs
692 try:
693 # Get all parameter names
694 param_names = list(study.best_params.keys())
696 # Create contour plots for each pair
697 for i in range(len(param_names)):
698 for j in range(i + 1, len(param_names)): 698 ↛ 699line 698 didn't jump to line 699 because the loop on line 698 never started
699 try:
700 fig = plot_contour(
701 study, params=[param_names[i], param_names[j]]
702 )
703 fig.write_image(
704 str(
705 Path(viz_dir)
706 / f"{self.study_name}_contour_{param_names[i]}_{param_names[j]}_{timestamp}.png"
707 )
708 )
709 except Exception:
710 logger.warning(
711 f"Error creating contour plot for {param_names[i]} vs {param_names[j]}"
712 )
713 except Exception:
714 logger.exception("Error creating contour plots")
716 def _create_custom_visualizations(self, viz_dir: str):
717 """
718 Create custom visualizations based on trial history.
720 Args:
721 viz_dir: Directory to save visualizations
722 """
723 if not self.trials_history:
724 return
726 timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S")
728 # Create quality vs speed plot
729 self._create_quality_vs_speed_plot(viz_dir, timestamp)
731 # Create parameter evolution plots
732 self._create_parameter_evolution_plots(viz_dir, timestamp)
734 # Create trial duration vs score plot
735 self._create_duration_vs_score_plot(viz_dir, timestamp)
737 def _create_quality_vs_speed_plot(self, viz_dir: str, timestamp: str):
738 """Create a plot showing quality vs. speed trade-off."""
739 if not self.trials_history: 739 ↛ 740line 739 didn't jump to line 740 because the condition on line 739 was never true
740 return
742 # Extract data from successful trials
743 successful_trials = [
744 t
745 for t in self.trials_history
746 if t.get("result", {}).get("success", False)
747 ]
749 if not successful_trials:
750 logger.warning("No successful trials for visualization")
751 return
753 try:
754 plt.figure(figsize=(10, 8))
756 # Extract metrics
757 quality_scores = []
758 speed_scores = []
759 labels = []
760 iterations_values = []
761 questions_values = []
763 for trial in successful_trials: 763 ↛ 777line 763 didn't jump to line 777 because the loop on line 763 didn't complete
764 result = trial["result"]
765 quality = result.get("quality_score", 0)
766 speed = result.get("speed_score", 0)
767 iterations = trial["params"].get("iterations", 0)
768 questions = trial["params"].get("questions_per_iteration", 0)
770 quality_scores.append(quality)
771 speed_scores.append(speed)
772 labels.append(f"Trial {trial['trial_number']}")
773 iterations_values.append(iterations)
774 questions_values.append(questions)
776 # Create scatter plot with size based on iterations*questions
777 sizes = [
778 i * q * 5
779 for i, q in zip(
780 iterations_values, questions_values, strict=False
781 )
782 ]
783 scatter = plt.scatter(
784 quality_scores,
785 speed_scores,
786 s=sizes,
787 alpha=0.7,
788 c=range(len(quality_scores)),
789 cmap="viridis",
790 )
792 # Highlight best trial
793 best_trial = max(
794 successful_trials,
795 key=lambda x: x.get("result", {}).get("score", 0),
796 )
797 best_quality = best_trial["result"].get("quality_score", 0)
798 best_speed = best_trial["result"].get("speed_score", 0)
799 best_iter = best_trial["params"].get("iterations", 0)
800 best_questions = best_trial["params"].get(
801 "questions_per_iteration", 0
802 )
804 plt.scatter(
805 [best_quality],
806 [best_speed],
807 s=200,
808 facecolors="none",
809 edgecolors="red",
810 linewidth=2,
811 label=f"Best: {best_iter}×{best_questions}",
812 )
814 # Add annotations for key points
815 for i, (q, s, label) in enumerate(
816 zip(quality_scores, speed_scores, labels, strict=False)
817 ):
818 if i % max(1, len(quality_scores) // 5) == 0: # Label ~5 points
819 plt.annotate(
820 f"{iterations_values[i]}×{questions_values[i]}",
821 (q, s),
822 xytext=(5, 5),
823 textcoords="offset points",
824 )
826 # Add colorbar and labels
827 cbar = plt.colorbar(scatter)
828 cbar.set_label("Trial Progression")
830 # Add benchmark weight information
831 weights_str = ", ".join(
832 [f"{k}:{v:.1f}" for k, v in self.benchmark_weights.items()]
833 )
834 plt.title(
835 f"Quality vs. Speed Trade-off\nBenchmark Weights: {weights_str}"
836 )
837 plt.xlabel("Quality Score (Benchmark Accuracy)")
838 plt.ylabel("Speed Score")
839 plt.grid(True, linestyle="--", alpha=0.7)
841 # Add legend explaining size
842 legend_elements = [
843 Line2D(
844 [0],
845 [0],
846 marker="o",
847 color="w",
848 markerfacecolor="gray",
849 markersize=np.sqrt(n * 5 / np.pi),
850 label=f"{n} Total Questions",
851 )
852 for n in [5, 10, 15, 20, 25]
853 ]
854 plt.legend(handles=legend_elements, title="Workload")
856 # Save the figure
857 plt.tight_layout()
858 plt.savefig(
859 str(
860 Path(viz_dir)
861 / f"{self.study_name}_quality_vs_speed_{timestamp}.png"
862 )
863 )
864 plt.close()
865 except Exception:
866 logger.exception("Error creating quality vs speed plot")
868 def _create_parameter_evolution_plots(self, viz_dir: str, timestamp: str):
869 """Create plots showing how parameter values evolve over trials."""
870 try:
871 successful_trials = [
872 t
873 for t in self.trials_history
874 if t.get("result", {}).get("success", False)
875 ]
877 if not successful_trials or len(successful_trials) < 5: 877 ↛ 881line 877 didn't jump to line 881 because the condition on line 877 was always true
878 return
880 # Get key parameters
881 main_params = list(successful_trials[0]["params"].keys())
883 # For each parameter, plot its values over trials
884 for param_name in main_params:
885 plt.figure(figsize=(12, 6))
887 trial_numbers = []
888 param_values = []
889 scores = []
891 for trial in self.trials_history:
892 if "params" in trial and param_name in trial["params"]:
893 trial_numbers.append(trial["trial_number"])
894 param_values.append(trial["params"][param_name])
895 scores.append(trial.get("score", 0))
897 # Create evolution plot
898 scatter = plt.scatter(
899 trial_numbers,
900 param_values,
901 c=scores,
902 cmap="plasma",
903 alpha=0.8,
904 s=80,
905 )
907 # Add best trial marker
908 best_trial_idx = scores.index(max(scores))
909 plt.scatter(
910 [trial_numbers[best_trial_idx]],
911 [param_values[best_trial_idx]],
912 s=150,
913 facecolors="none",
914 edgecolors="red",
915 linewidth=2,
916 label=f"Best Value: {param_values[best_trial_idx]}",
917 )
919 # Add colorbar
920 cbar = plt.colorbar(scatter)
921 cbar.set_label("Score")
923 # Set chart properties
924 plt.title(f"Evolution of {param_name} Values")
925 plt.xlabel("Trial Number")
926 plt.ylabel(param_name)
927 plt.grid(True, linestyle="--", alpha=0.7)
928 plt.legend()
930 # For categorical parameters, adjust y-axis
931 if isinstance(param_values[0], str):
932 unique_values = sorted(set(param_values))
933 plt.yticks(range(len(unique_values)), unique_values)
935 # Save the figure
936 plt.tight_layout()
937 plt.savefig(
938 str(
939 Path(viz_dir)
940 / f"{self.study_name}_param_evolution_{param_name}_{timestamp}.png"
941 )
942 )
943 plt.close()
944 except Exception:
945 logger.exception("Error creating parameter evolution plots")
947 def _create_duration_vs_score_plot(self, viz_dir: str, timestamp: str):
948 """Create a plot showing trial duration vs score."""
949 try:
950 plt.figure(figsize=(10, 6))
952 successful_trials = [
953 t
954 for t in self.trials_history
955 if t.get("result", {}).get("success", False)
956 ]
958 if not successful_trials: 958 ↛ 959line 958 didn't jump to line 959 because the condition on line 958 was never true
959 return
961 trial_durations = []
962 trial_scores = []
963 trial_iterations = []
964 trial_questions = []
966 for trial in successful_trials:
967 duration = trial.get("duration", 0)
968 score = trial.get("score", 0)
969 iterations = trial.get("params", {}).get("iterations", 1)
970 questions = trial.get("params", {}).get(
971 "questions_per_iteration", 1
972 )
974 trial_durations.append(duration)
975 trial_scores.append(score)
976 trial_iterations.append(iterations)
977 trial_questions.append(questions)
979 # Total questions per trial
980 total_questions = [
981 i * q
982 for i, q in zip(trial_iterations, trial_questions, strict=False)
983 ]
985 # Create scatter plot with size based on total questions
986 plt.scatter(
987 trial_durations,
988 trial_scores,
989 s=[
990 q * 5 for q in total_questions
991 ], # Size based on total questions
992 alpha=0.7,
993 c=range(len(trial_durations)),
994 cmap="viridis",
995 )
997 # Add labels
998 plt.xlabel("Trial Duration (seconds)")
999 plt.ylabel("Score")
1000 plt.title("Trial Duration vs. Score")
1001 plt.grid(True, linestyle="--", alpha=0.7)
1003 # Add trial number annotations for selected points
1004 for i, (d, s) in enumerate(
1005 zip(trial_durations, trial_scores, strict=False)
1006 ):
1007 if ( 1007 ↛ 1004line 1007 didn't jump to line 1004 because the condition on line 1007 was always true
1008 i % max(1, len(trial_durations) // 5) == 0
1009 ): # Annotate ~5 points
1010 plt.annotate(
1011 f"{trial_iterations[i]}×{trial_questions[i]}",
1012 (d, s),
1013 xytext=(5, 5),
1014 textcoords="offset points",
1015 )
1017 # Save the figure
1018 plt.tight_layout()
1019 plt.savefig(
1020 str(
1021 Path(viz_dir)
1022 / f"{self.study_name}_duration_vs_score_{timestamp}.png"
1023 )
1024 )
1025 plt.close()
1026 except Exception:
1027 logger.exception("Error creating duration vs score plot")
1030def optimize_parameters(
1031 query: str,
1032 param_space: Optional[Dict[str, Any]] = None,
1033 output_dir: str = str(Path("data") / "optimization_results"),
1034 model_name: Optional[str] = None,
1035 provider: Optional[str] = None,
1036 search_tool: Optional[str] = None,
1037 temperature: float = 0.7,
1038 n_trials: int = 30,
1039 timeout: Optional[int] = None,
1040 n_jobs: int = 1,
1041 study_name: Optional[str] = None,
1042 optimization_metrics: Optional[List[str]] = None,
1043 metric_weights: Optional[Dict[str, float]] = None,
1044 progress_callback: Optional[Callable[[int, int, Dict], None]] = None,
1045 benchmark_weights: Optional[Dict[str, float]] = None,
1046) -> Tuple[Dict[str, Any], float]:
1047 """
1048 Optimize parameters for Local Deep Research.
1050 Args:
1051 query: The research query to use for all experiments
1052 param_space: Dictionary defining parameter search spaces (optional)
1053 output_dir: Directory to save optimization results
1054 model_name: Name of the LLM model to use
1055 provider: LLM provider
1056 search_tool: Search engine to use
1057 temperature: LLM temperature
1058 n_trials: Number of parameter combinations to try
1059 timeout: Maximum seconds to run optimization (None for no limit)
1060 n_jobs: Number of parallel jobs for optimization
1061 study_name: Name of the Optuna study
1062 optimization_metrics: List of metrics to optimize (default: ["quality", "speed"])
1063 metric_weights: Dictionary of weights for each metric (e.g., {"quality": 0.6, "speed": 0.4})
1064 progress_callback: Optional callback for progress updates
1065 benchmark_weights: Dictionary mapping benchmark types to weights
1066 (e.g., {"simpleqa": 0.6, "browsecomp": 0.4})
1067 If None, only SimpleQA is used with weight 1.0
1069 Returns:
1070 Tuple of (best_parameters, best_score)
1071 """
1072 # Create optimizer
1073 optimizer = OptunaOptimizer(
1074 base_query=query,
1075 output_dir=output_dir,
1076 model_name=model_name,
1077 provider=provider,
1078 search_tool=search_tool,
1079 temperature=temperature,
1080 n_trials=n_trials,
1081 timeout=timeout,
1082 n_jobs=n_jobs,
1083 study_name=study_name,
1084 optimization_metrics=optimization_metrics,
1085 metric_weights=metric_weights,
1086 progress_callback=progress_callback,
1087 benchmark_weights=benchmark_weights,
1088 )
1090 # Run optimization
1091 return optimizer.optimize(param_space)
1094def optimize_for_speed(
1095 query: str,
1096 n_trials: int = 20,
1097 output_dir: str = str(Path("data") / "optimization_results"),
1098 model_name: Optional[str] = None,
1099 provider: Optional[str] = None,
1100 search_tool: Optional[str] = None,
1101 progress_callback: Optional[Callable[[int, int, Dict], None]] = None,
1102 benchmark_weights: Optional[Dict[str, float]] = None,
1103) -> Tuple[Dict[str, Any], float]:
1104 """
1105 Optimize parameters with a focus on speed performance.
1107 Args:
1108 query: The research query to use for all experiments
1109 n_trials: Number of parameter combinations to try
1110 output_dir: Directory to save optimization results
1111 model_name: Name of the LLM model to use
1112 provider: LLM provider
1113 search_tool: Search engine to use
1114 progress_callback: Optional callback for progress updates
1115 benchmark_weights: Dictionary mapping benchmark types to weights
1116 (e.g., {"simpleqa": 0.6, "browsecomp": 0.4})
1117 If None, only SimpleQA is used with weight 1.0
1119 Returns:
1120 Tuple of (best_parameters, best_score)
1121 """
1122 # Focus on speed with reduced parameter space
1123 param_space = {
1124 "iterations": {
1125 "type": "int",
1126 "low": 1,
1127 "high": 3,
1128 "step": 1,
1129 },
1130 "questions_per_iteration": {
1131 "type": "int",
1132 "low": 1,
1133 "high": 3,
1134 "step": 1,
1135 },
1136 "search_strategy": {
1137 "type": "categorical",
1138 "choices": ["rapid", "parallel", "source_based"],
1139 },
1140 }
1142 # Speed-focused weights
1143 metric_weights = {"speed": 0.8, "quality": 0.2}
1145 return optimize_parameters(
1146 query=query,
1147 param_space=param_space,
1148 output_dir=output_dir,
1149 model_name=model_name,
1150 provider=provider,
1151 search_tool=search_tool,
1152 n_trials=n_trials,
1153 metric_weights=metric_weights,
1154 optimization_metrics=["speed", "quality"],
1155 progress_callback=progress_callback,
1156 benchmark_weights=benchmark_weights,
1157 )
1160def optimize_for_quality(
1161 query: str,
1162 n_trials: int = 30,
1163 output_dir: str = str(Path("data") / "optimization_results"),
1164 model_name: Optional[str] = None,
1165 provider: Optional[str] = None,
1166 search_tool: Optional[str] = None,
1167 progress_callback: Optional[Callable[[int, int, Dict], None]] = None,
1168 benchmark_weights: Optional[Dict[str, float]] = None,
1169) -> Tuple[Dict[str, Any], float]:
1170 """
1171 Optimize parameters with a focus on result quality.
1173 Args:
1174 query: The research query to use for all experiments
1175 n_trials: Number of parameter combinations to try
1176 output_dir: Directory to save optimization results
1177 model_name: Name of the LLM model to use
1178 provider: LLM provider
1179 search_tool: Search engine to use
1180 progress_callback: Optional callback for progress updates
1181 benchmark_weights: Dictionary mapping benchmark types to weights
1182 (e.g., {"simpleqa": 0.6, "browsecomp": 0.4})
1183 If None, only SimpleQA is used with weight 1.0
1185 Returns:
1186 Tuple of (best_parameters, best_score)
1187 """
1188 # Quality-focused weights
1189 metric_weights = {"quality": 0.9, "speed": 0.1}
1191 return optimize_parameters(
1192 query=query,
1193 output_dir=output_dir,
1194 model_name=model_name,
1195 provider=provider,
1196 search_tool=search_tool,
1197 n_trials=n_trials,
1198 metric_weights=metric_weights,
1199 optimization_metrics=["quality", "speed"],
1200 progress_callback=progress_callback,
1201 benchmark_weights=benchmark_weights,
1202 )
1205def optimize_for_efficiency(
1206 query: str,
1207 n_trials: int = 25,
1208 output_dir: str = str(Path("data") / "optimization_results"),
1209 model_name: Optional[str] = None,
1210 provider: Optional[str] = None,
1211 search_tool: Optional[str] = None,
1212 progress_callback: Optional[Callable[[int, int, Dict], None]] = None,
1213 benchmark_weights: Optional[Dict[str, float]] = None,
1214) -> Tuple[Dict[str, Any], float]:
1215 """
1216 Optimize parameters with a focus on resource efficiency.
1218 Args:
1219 query: The research query to use for all experiments
1220 n_trials: Number of parameter combinations to try
1221 output_dir: Directory to save optimization results
1222 model_name: Name of the LLM model to use
1223 provider: LLM provider
1224 search_tool: Search engine to use
1225 progress_callback: Optional callback for progress updates
1226 benchmark_weights: Dictionary mapping benchmark types to weights
1227 (e.g., {"simpleqa": 0.6, "browsecomp": 0.4})
1228 If None, only SimpleQA is used with weight 1.0
1230 Returns:
1231 Tuple of (best_parameters, best_score)
1232 """
1233 # Balance of quality, speed and resource usage
1234 metric_weights = {"quality": 0.4, "speed": 0.3, "resource": 0.3}
1236 return optimize_parameters(
1237 query=query,
1238 output_dir=output_dir,
1239 model_name=model_name,
1240 provider=provider,
1241 search_tool=search_tool,
1242 n_trials=n_trials,
1243 metric_weights=metric_weights,
1244 optimization_metrics=["quality", "speed", "resource"],
1245 progress_callback=progress_callback,
1246 benchmark_weights=benchmark_weights,
1247 )