Coverage for src / local_deep_research / benchmarks / datasets / base.py: 92%

101 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 23:55 +0000

1""" 

2Base classes for benchmark datasets. 

3 

4This module provides base classes for loading, processing, and working 

5with benchmark datasets in a maintainable, extensible way. 

6""" 

7 

8from loguru import logger 

9import random 

10from abc import ABC, abstractmethod 

11from typing import Any, Dict, List, Optional 

12 

13import pandas as pd 

14 

15 

16class BenchmarkDataset(ABC): 

17 """Base class for all benchmark datasets. 

18 

19 This abstract base class defines the interface that all benchmark 

20 datasets must implement, providing a consistent way to load and 

21 process benchmark data. 

22 """ 

23 

24 def __init__( 

25 self, 

26 dataset_path: Optional[str] = None, 

27 num_examples: Optional[int] = None, 

28 seed: int = 42, 

29 ): 

30 """Initialize the dataset. 

31 

32 Args: 

33 dataset_path: Optional path to the dataset file. If None, uses the default. 

34 num_examples: Optional number of examples to sample from the dataset. 

35 seed: Random seed for reproducible sampling. 

36 """ 

37 self.dataset_path = dataset_path or self.get_default_dataset_path() 

38 self.num_examples = num_examples 

39 self.seed = seed 

40 self.examples = [] 

41 self._is_loaded = False 

42 

43 @classmethod 

44 @abstractmethod 

45 def get_dataset_info(cls) -> Dict[str, str]: 

46 """Get basic information about the dataset. 

47 

48 Returns: 

49 Dictionary with dataset metadata (id, name, description, etc.). 

50 """ 

51 pass 

52 

53 @classmethod 

54 @abstractmethod 

55 def get_default_dataset_path(cls) -> str: 

56 """Get the default path or URL for the dataset. 

57 

58 Returns: 

59 String path or URL to the default dataset source. 

60 """ 

61 pass 

62 

63 @abstractmethod 

64 def process_example(self, example: Dict[str, Any]) -> Dict[str, Any]: 

65 """Process a single example from the dataset. 

66 

67 This method is called for each example during loading. It can be used 

68 to decrypt data, transform fields, or any other necessary processing. 

69 

70 Args: 

71 example: Raw example from the dataset. 

72 

73 Returns: 

74 Processed example ready for use. 

75 """ 

76 pass 

77 

78 def load(self) -> List[Dict[str, Any]]: 

79 """Load and process the dataset. 

80 

81 This method loads the dataset, processes each example, and optionally 

82 samples a subset of examples. 

83 

84 Returns: 

85 List of processed examples. 

86 """ 

87 if self._is_loaded: 

88 return self.examples 

89 

90 logger.info(f"Loading dataset from {self.dataset_path}") 

91 

92 # Validate file format before attempting to load 

93 if not ( 

94 self.dataset_path.endswith(".csv") 

95 or self.dataset_path.endswith(".json") 

96 or self.dataset_path.endswith(".jsonl") 

97 ): 

98 raise ValueError(f"Unsupported file format: {self.dataset_path}") 

99 

100 try: 

101 # Load raw data 

102 if self.dataset_path.endswith(".csv"): 

103 df = pd.read_csv(self.dataset_path) 

104 raw_examples = [row.to_dict() for _, row in df.iterrows()] 

105 else: 

106 import json 

107 

108 with open(self.dataset_path, "r") as f: 

109 if self.dataset_path.endswith(".jsonl"): 

110 raw_examples = [ 

111 json.loads(line) for line in f if line.strip() 

112 ] 

113 else: 

114 raw_examples = json.load(f) 

115 

116 # Process each example 

117 processed_examples = [] 

118 failure_count = 0 

119 for i, example in enumerate(raw_examples): 

120 try: 

121 processed = self.process_example(example) 

122 processed_examples.append(processed) 

123 except Exception: 

124 failure_count += 1 

125 logger.exception(f"Error processing example {i}") 

126 

127 if failure_count > 0: 

128 logger.warning( 

129 f"Dataset processing: {failure_count}/{len(raw_examples)} " 

130 f"examples failed, {len(processed_examples)} succeeded" 

131 ) 

132 

133 # Sample if needed 

134 if self.num_examples and self.num_examples < len( 

135 processed_examples 

136 ): 

137 # Security: seeded random for reproducible benchmark sampling, not security-sensitive 

138 random.seed(self.seed) 

139 sampled_examples = random.sample( 

140 processed_examples, self.num_examples 

141 ) 

142 logger.info( 

143 f"Sampled {self.num_examples} examples out of {len(processed_examples)}" 

144 ) 

145 self.examples = sampled_examples 

146 else: 

147 logger.info(f"Loaded {len(processed_examples)} examples") 

148 self.examples = processed_examples 

149 

150 self._is_loaded = True 

151 return self.examples 

152 

153 except Exception: 

154 logger.exception("Error loading dataset") 

155 raise 

156 

157 def get_examples(self) -> List[Dict[str, Any]]: 

158 """Get the loaded examples, loading the dataset if needed. 

159 

160 Returns: 

161 List of processed examples. 

162 """ 

163 if not self._is_loaded: 163 ↛ 165line 163 didn't jump to line 165 because the condition on line 163 was always true

164 return self.load() 

165 return self.examples 

166 

167 def get_example(self, index: int) -> Dict[str, Any]: 

168 """Get a specific example by index. 

169 

170 Args: 

171 index: Index of the example to retrieve. 

172 

173 Returns: 

174 The specified example. 

175 

176 Raises: 

177 IndexError: If the index is out of range. 

178 """ 

179 examples = self.get_examples() 

180 if index < 0 or index >= len(examples): 

181 raise IndexError( 

182 f"Example index {index} out of range (0-{len(examples) - 1})" 

183 ) 

184 return examples[index] 

185 

186 def get_question(self, example: Dict[str, Any]) -> str: 

187 """Extract the question from an example. 

188 

189 This method may be overridden by subclasses to customize question extraction. 

190 

191 Args: 

192 example: The example to extract the question from. 

193 

194 Returns: 

195 The question string. 

196 """ 

197 return example.get("problem", "") 

198 

199 def get_answer(self, example: Dict[str, Any]) -> str: 

200 """Extract the answer from an example. 

201 

202 This method may be overridden by subclasses to customize answer extraction. 

203 

204 Args: 

205 example: The example to extract the answer from. 

206 

207 Returns: 

208 The answer string. 

209 """ 

210 # Try the correct_answer field first, then fall back to answer 

211 return example.get("correct_answer", example.get("answer", "")) 

212 

213 

214class DatasetRegistry: 

215 """Registry for all available benchmark datasets. 

216 

217 This class serves as a central registry for all datasets, allowing 

218 them to be discovered and instantiated by name. 

219 """ 

220 

221 _registry = {} 

222 

223 @classmethod 

224 def register(cls, dataset_class): 

225 """Register a dataset class. 

226 

227 Args: 

228 dataset_class: A class inheriting from BenchmarkDataset. 

229 

230 Returns: 

231 The dataset class (to allow use as a decorator). 

232 """ 

233 dataset_info = dataset_class.get_dataset_info() 

234 dataset_id = dataset_info.get("id") 

235 if not dataset_id: 

236 raise ValueError("Dataset must have an ID") 

237 

238 cls._registry[dataset_id] = dataset_class 

239 logger.debug(f"Registered dataset: {dataset_id}") 

240 return dataset_class 

241 

242 @classmethod 

243 def get_dataset_class(cls, dataset_id: str): 

244 """Get a dataset class by ID. 

245 

246 Args: 

247 dataset_id: ID of the dataset to retrieve. 

248 

249 Returns: 

250 The dataset class. 

251 

252 Raises: 

253 ValueError: If no dataset with the given ID is registered. 

254 """ 

255 if dataset_id not in cls._registry: 

256 raise ValueError(f"Unknown dataset: {dataset_id}") 

257 return cls._registry[dataset_id] 

258 

259 @classmethod 

260 def create_dataset( 

261 cls, 

262 dataset_id: str, 

263 dataset_path: Optional[str] = None, 

264 num_examples: Optional[int] = None, 

265 seed: int = 42, 

266 ) -> BenchmarkDataset: 

267 """Create a dataset instance by ID. 

268 

269 Args: 

270 dataset_id: ID of the dataset to create. 

271 dataset_path: Optional path to the dataset file. 

272 num_examples: Optional number of examples to sample. 

273 seed: Random seed for sampling. 

274 

275 Returns: 

276 A dataset instance. 

277 """ 

278 dataset_class = cls.get_dataset_class(dataset_id) 

279 return dataset_class( 

280 dataset_path=dataset_path, num_examples=num_examples, seed=seed 

281 ) 

282 

283 @classmethod 

284 def get_available_datasets(cls) -> List[Dict[str, str]]: 

285 """Get information about all registered datasets. 

286 

287 Returns: 

288 List of dictionaries with dataset information. 

289 """ 

290 return [ 

291 cls.get_dataset_class(dataset_id).get_dataset_info() 

292 for dataset_id in cls._registry 

293 ] 

294 

295 @classmethod 

296 def load_dataset( 

297 cls, 

298 dataset_id: str, 

299 dataset_path: Optional[str] = None, 

300 num_examples: Optional[int] = None, 

301 seed: int = 42, 

302 ) -> List[Dict[str, Any]]: 

303 """Load a dataset by ID. 

304 

305 This is a convenience method that creates a dataset instance 

306 and loads its examples in one call. 

307 

308 Args: 

309 dataset_id: ID of the dataset to load. 

310 dataset_path: Optional path to the dataset file. 

311 num_examples: Optional number of examples to sample. 

312 seed: Random seed for sampling. 

313 

314 Returns: 

315 List of processed examples. 

316 """ 

317 dataset = cls.create_dataset( 

318 dataset_id=dataset_id, 

319 dataset_path=dataset_path, 

320 num_examples=num_examples, 

321 seed=seed, 

322 ) 

323 return dataset.load()