Coverage for src / local_deep_research / benchmarks / datasets / base.py: 33%
97 statements
« prev ^ index » next coverage.py v7.12.0, created at 2026-01-11 00:51 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2026-01-11 00:51 +0000
1"""
2Base classes for benchmark datasets.
4This module provides base classes for loading, processing, and working
5with benchmark datasets in a maintainable, extensible way.
6"""
8from loguru import logger
9import random
10from abc import ABC, abstractmethod
11from typing import Any, Dict, List, Optional
13import pandas as pd
16class BenchmarkDataset(ABC):
17 """Base class for all benchmark datasets.
19 This abstract base class defines the interface that all benchmark
20 datasets must implement, providing a consistent way to load and
21 process benchmark data.
22 """
24 def __init__(
25 self,
26 dataset_path: Optional[str] = None,
27 num_examples: Optional[int] = None,
28 seed: int = 42,
29 ):
30 """Initialize the dataset.
32 Args:
33 dataset_path: Optional path to the dataset file. If None, uses the default.
34 num_examples: Optional number of examples to sample from the dataset.
35 seed: Random seed for reproducible sampling.
36 """
37 self.dataset_path = dataset_path or self.get_default_dataset_path()
38 self.num_examples = num_examples
39 self.seed = seed
40 self.examples = []
41 self._is_loaded = False
43 @classmethod
44 @abstractmethod
45 def get_dataset_info(cls) -> Dict[str, str]:
46 """Get basic information about the dataset.
48 Returns:
49 Dictionary with dataset metadata (id, name, description, etc.).
50 """
51 pass
53 @classmethod
54 @abstractmethod
55 def get_default_dataset_path(cls) -> str:
56 """Get the default path or URL for the dataset.
58 Returns:
59 String path or URL to the default dataset source.
60 """
61 pass
63 @abstractmethod
64 def process_example(self, example: Dict[str, Any]) -> Dict[str, Any]:
65 """Process a single example from the dataset.
67 This method is called for each example during loading. It can be used
68 to decrypt data, transform fields, or any other necessary processing.
70 Args:
71 example: Raw example from the dataset.
73 Returns:
74 Processed example ready for use.
75 """
76 pass
78 def load(self) -> List[Dict[str, Any]]:
79 """Load and process the dataset.
81 This method loads the dataset, processes each example, and optionally
82 samples a subset of examples.
84 Returns:
85 List of processed examples.
86 """
87 if self._is_loaded:
88 return self.examples
90 logger.info(f"Loading dataset from {self.dataset_path}")
92 try:
93 # Load raw data
94 if self.dataset_path.endswith(".csv"):
95 df = pd.read_csv(self.dataset_path)
96 raw_examples = [row.to_dict() for _, row in df.iterrows()]
97 elif self.dataset_path.endswith(
98 ".json"
99 ) or self.dataset_path.endswith(".jsonl"):
100 import json
102 with open(self.dataset_path, "r") as f:
103 if self.dataset_path.endswith(".jsonl"):
104 raw_examples = [
105 json.loads(line) for line in f if line.strip()
106 ]
107 else:
108 raw_examples = json.load(f)
109 else:
110 raise ValueError(
111 f"Unsupported file format: {self.dataset_path}"
112 )
114 # Process each example
115 processed_examples = []
116 for i, example in enumerate(raw_examples):
117 try:
118 processed = self.process_example(example)
119 processed_examples.append(processed)
120 except Exception:
121 logger.exception(f"Error processing example {i}")
123 # Sample if needed
124 if self.num_examples and self.num_examples < len(
125 processed_examples
126 ):
127 random.seed(self.seed)
128 sampled_examples = random.sample(
129 processed_examples, self.num_examples
130 )
131 logger.info(
132 f"Sampled {self.num_examples} examples out of {len(processed_examples)}"
133 )
134 self.examples = sampled_examples
135 else:
136 logger.info(f"Loaded {len(processed_examples)} examples")
137 self.examples = processed_examples
139 self._is_loaded = True
140 return self.examples
142 except Exception:
143 logger.exception("Error loading dataset")
144 raise
146 def get_examples(self) -> List[Dict[str, Any]]:
147 """Get the loaded examples, loading the dataset if needed.
149 Returns:
150 List of processed examples.
151 """
152 if not self._is_loaded:
153 return self.load()
154 return self.examples
156 def get_example(self, index: int) -> Dict[str, Any]:
157 """Get a specific example by index.
159 Args:
160 index: Index of the example to retrieve.
162 Returns:
163 The specified example.
165 Raises:
166 IndexError: If the index is out of range.
167 """
168 examples = self.get_examples()
169 if index < 0 or index >= len(examples):
170 raise IndexError(
171 f"Example index {index} out of range (0-{len(examples) - 1})"
172 )
173 return examples[index]
175 def get_question(self, example: Dict[str, Any]) -> str:
176 """Extract the question from an example.
178 This method may be overridden by subclasses to customize question extraction.
180 Args:
181 example: The example to extract the question from.
183 Returns:
184 The question string.
185 """
186 return example.get("problem", "")
188 def get_answer(self, example: Dict[str, Any]) -> str:
189 """Extract the answer from an example.
191 This method may be overridden by subclasses to customize answer extraction.
193 Args:
194 example: The example to extract the answer from.
196 Returns:
197 The answer string.
198 """
199 # Try the correct_answer field first, then fall back to answer
200 return example.get("correct_answer", example.get("answer", ""))
203class DatasetRegistry:
204 """Registry for all available benchmark datasets.
206 This class serves as a central registry for all datasets, allowing
207 them to be discovered and instantiated by name.
208 """
210 _registry = {}
212 @classmethod
213 def register(cls, dataset_class):
214 """Register a dataset class.
216 Args:
217 dataset_class: A class inheriting from BenchmarkDataset.
219 Returns:
220 The dataset class (to allow use as a decorator).
221 """
222 dataset_info = dataset_class.get_dataset_info()
223 dataset_id = dataset_info.get("id")
224 if not dataset_id: 224 ↛ 225line 224 didn't jump to line 225 because the condition on line 224 was never true
225 raise ValueError("Dataset must have an ID")
227 cls._registry[dataset_id] = dataset_class
228 logger.debug(f"Registered dataset: {dataset_id}")
229 return dataset_class
231 @classmethod
232 def get_dataset_class(cls, dataset_id: str):
233 """Get a dataset class by ID.
235 Args:
236 dataset_id: ID of the dataset to retrieve.
238 Returns:
239 The dataset class.
241 Raises:
242 ValueError: If no dataset with the given ID is registered.
243 """
244 if dataset_id not in cls._registry:
245 raise ValueError(f"Unknown dataset: {dataset_id}")
246 return cls._registry[dataset_id]
248 @classmethod
249 def create_dataset(
250 cls,
251 dataset_id: str,
252 dataset_path: Optional[str] = None,
253 num_examples: Optional[int] = None,
254 seed: int = 42,
255 ) -> BenchmarkDataset:
256 """Create a dataset instance by ID.
258 Args:
259 dataset_id: ID of the dataset to create.
260 dataset_path: Optional path to the dataset file.
261 num_examples: Optional number of examples to sample.
262 seed: Random seed for sampling.
264 Returns:
265 A dataset instance.
266 """
267 dataset_class = cls.get_dataset_class(dataset_id)
268 return dataset_class(
269 dataset_path=dataset_path, num_examples=num_examples, seed=seed
270 )
272 @classmethod
273 def get_available_datasets(cls) -> List[Dict[str, str]]:
274 """Get information about all registered datasets.
276 Returns:
277 List of dictionaries with dataset information.
278 """
279 return [
280 cls.get_dataset_class(dataset_id).get_dataset_info()
281 for dataset_id in cls._registry
282 ]
284 @classmethod
285 def load_dataset(
286 cls,
287 dataset_id: str,
288 dataset_path: Optional[str] = None,
289 num_examples: Optional[int] = None,
290 seed: int = 42,
291 ) -> List[Dict[str, Any]]:
292 """Load a dataset by ID.
294 This is a convenience method that creates a dataset instance
295 and loads its examples in one call.
297 Args:
298 dataset_id: ID of the dataset to load.
299 dataset_path: Optional path to the dataset file.
300 num_examples: Optional number of examples to sample.
301 seed: Random seed for sampling.
303 Returns:
304 List of processed examples.
305 """
306 dataset = cls.create_dataset(
307 dataset_id=dataset_id,
308 dataset_path=dataset_path,
309 num_examples=num_examples,
310 seed=seed,
311 )
312 return dataset.load()