Coverage for src / local_deep_research / benchmarks / datasets / base.py: 37%

101 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-25 01:07 +0000

1""" 

2Base classes for benchmark datasets. 

3 

4This module provides base classes for loading, processing, and working 

5with benchmark datasets in a maintainable, extensible way. 

6""" 

7 

8from loguru import logger 

9import random 

10from abc import ABC, abstractmethod 

11from typing import Any, Dict, List, Optional 

12 

13import pandas as pd 

14 

15 

16class BenchmarkDataset(ABC): 

17 """Base class for all benchmark datasets. 

18 

19 This abstract base class defines the interface that all benchmark 

20 datasets must implement, providing a consistent way to load and 

21 process benchmark data. 

22 """ 

23 

24 def __init__( 

25 self, 

26 dataset_path: Optional[str] = None, 

27 num_examples: Optional[int] = None, 

28 seed: int = 42, 

29 ): 

30 """Initialize the dataset. 

31 

32 Args: 

33 dataset_path: Optional path to the dataset file. If None, uses the default. 

34 num_examples: Optional number of examples to sample from the dataset. 

35 seed: Random seed for reproducible sampling. 

36 """ 

37 self.dataset_path = dataset_path or self.get_default_dataset_path() 

38 self.num_examples = num_examples 

39 self.seed = seed 

40 self.examples = [] 

41 self._is_loaded = False 

42 

43 @classmethod 

44 @abstractmethod 

45 def get_dataset_info(cls) -> Dict[str, str]: 

46 """Get basic information about the dataset. 

47 

48 Returns: 

49 Dictionary with dataset metadata (id, name, description, etc.). 

50 """ 

51 pass 

52 

53 @classmethod 

54 @abstractmethod 

55 def get_default_dataset_path(cls) -> str: 

56 """Get the default path or URL for the dataset. 

57 

58 Returns: 

59 String path or URL to the default dataset source. 

60 """ 

61 pass 

62 

63 @abstractmethod 

64 def process_example(self, example: Dict[str, Any]) -> Dict[str, Any]: 

65 """Process a single example from the dataset. 

66 

67 This method is called for each example during loading. It can be used 

68 to decrypt data, transform fields, or any other necessary processing. 

69 

70 Args: 

71 example: Raw example from the dataset. 

72 

73 Returns: 

74 Processed example ready for use. 

75 """ 

76 pass 

77 

78 def load(self) -> List[Dict[str, Any]]: 

79 """Load and process the dataset. 

80 

81 This method loads the dataset, processes each example, and optionally 

82 samples a subset of examples. 

83 

84 Returns: 

85 List of processed examples. 

86 """ 

87 if self._is_loaded: 

88 return self.examples 

89 

90 logger.info(f"Loading dataset from {self.dataset_path}") 

91 

92 try: 

93 # Load raw data 

94 if self.dataset_path.endswith(".csv"): 

95 df = pd.read_csv(self.dataset_path) 

96 raw_examples = [row.to_dict() for _, row in df.iterrows()] 

97 elif self.dataset_path.endswith( 

98 ".json" 

99 ) or self.dataset_path.endswith(".jsonl"): 

100 import json 

101 

102 with open(self.dataset_path, "r") as f: 

103 if self.dataset_path.endswith(".jsonl"): 

104 raw_examples = [ 

105 json.loads(line) for line in f if line.strip() 

106 ] 

107 else: 

108 raw_examples = json.load(f) 

109 else: 

110 raise ValueError( 

111 f"Unsupported file format: {self.dataset_path}" 

112 ) 

113 

114 # Process each example 

115 processed_examples = [] 

116 failure_count = 0 

117 for i, example in enumerate(raw_examples): 

118 try: 

119 processed = self.process_example(example) 

120 processed_examples.append(processed) 

121 except Exception: 

122 failure_count += 1 

123 logger.exception(f"Error processing example {i}") 

124 

125 if failure_count > 0: 

126 logger.warning( 

127 f"Dataset processing: {failure_count}/{len(raw_examples)} " 

128 f"examples failed, {len(processed_examples)} succeeded" 

129 ) 

130 

131 # Sample if needed 

132 if self.num_examples and self.num_examples < len( 

133 processed_examples 

134 ): 

135 # Security: seeded random for reproducible benchmark sampling, not security-sensitive 

136 random.seed(self.seed) 

137 sampled_examples = random.sample( 

138 processed_examples, self.num_examples 

139 ) 

140 logger.info( 

141 f"Sampled {self.num_examples} examples out of {len(processed_examples)}" 

142 ) 

143 self.examples = sampled_examples 

144 else: 

145 logger.info(f"Loaded {len(processed_examples)} examples") 

146 self.examples = processed_examples 

147 

148 self._is_loaded = True 

149 return self.examples 

150 

151 except Exception: 

152 logger.exception("Error loading dataset") 

153 raise 

154 

155 def get_examples(self) -> List[Dict[str, Any]]: 

156 """Get the loaded examples, loading the dataset if needed. 

157 

158 Returns: 

159 List of processed examples. 

160 """ 

161 if not self._is_loaded: 

162 return self.load() 

163 return self.examples 

164 

165 def get_example(self, index: int) -> Dict[str, Any]: 

166 """Get a specific example by index. 

167 

168 Args: 

169 index: Index of the example to retrieve. 

170 

171 Returns: 

172 The specified example. 

173 

174 Raises: 

175 IndexError: If the index is out of range. 

176 """ 

177 examples = self.get_examples() 

178 if index < 0 or index >= len(examples): 

179 raise IndexError( 

180 f"Example index {index} out of range (0-{len(examples) - 1})" 

181 ) 

182 return examples[index] 

183 

184 def get_question(self, example: Dict[str, Any]) -> str: 

185 """Extract the question from an example. 

186 

187 This method may be overridden by subclasses to customize question extraction. 

188 

189 Args: 

190 example: The example to extract the question from. 

191 

192 Returns: 

193 The question string. 

194 """ 

195 return example.get("problem", "") 

196 

197 def get_answer(self, example: Dict[str, Any]) -> str: 

198 """Extract the answer from an example. 

199 

200 This method may be overridden by subclasses to customize answer extraction. 

201 

202 Args: 

203 example: The example to extract the answer from. 

204 

205 Returns: 

206 The answer string. 

207 """ 

208 # Try the correct_answer field first, then fall back to answer 

209 return example.get("correct_answer", example.get("answer", "")) 

210 

211 

212class DatasetRegistry: 

213 """Registry for all available benchmark datasets. 

214 

215 This class serves as a central registry for all datasets, allowing 

216 them to be discovered and instantiated by name. 

217 """ 

218 

219 _registry = {} 

220 

221 @classmethod 

222 def register(cls, dataset_class): 

223 """Register a dataset class. 

224 

225 Args: 

226 dataset_class: A class inheriting from BenchmarkDataset. 

227 

228 Returns: 

229 The dataset class (to allow use as a decorator). 

230 """ 

231 dataset_info = dataset_class.get_dataset_info() 

232 dataset_id = dataset_info.get("id") 

233 if not dataset_id: 233 ↛ 234line 233 didn't jump to line 234 because the condition on line 233 was never true

234 raise ValueError("Dataset must have an ID") 

235 

236 cls._registry[dataset_id] = dataset_class 

237 logger.debug(f"Registered dataset: {dataset_id}") 

238 return dataset_class 

239 

240 @classmethod 

241 def get_dataset_class(cls, dataset_id: str): 

242 """Get a dataset class by ID. 

243 

244 Args: 

245 dataset_id: ID of the dataset to retrieve. 

246 

247 Returns: 

248 The dataset class. 

249 

250 Raises: 

251 ValueError: If no dataset with the given ID is registered. 

252 """ 

253 if dataset_id not in cls._registry: 

254 raise ValueError(f"Unknown dataset: {dataset_id}") 

255 return cls._registry[dataset_id] 

256 

257 @classmethod 

258 def create_dataset( 

259 cls, 

260 dataset_id: str, 

261 dataset_path: Optional[str] = None, 

262 num_examples: Optional[int] = None, 

263 seed: int = 42, 

264 ) -> BenchmarkDataset: 

265 """Create a dataset instance by ID. 

266 

267 Args: 

268 dataset_id: ID of the dataset to create. 

269 dataset_path: Optional path to the dataset file. 

270 num_examples: Optional number of examples to sample. 

271 seed: Random seed for sampling. 

272 

273 Returns: 

274 A dataset instance. 

275 """ 

276 dataset_class = cls.get_dataset_class(dataset_id) 

277 return dataset_class( 

278 dataset_path=dataset_path, num_examples=num_examples, seed=seed 

279 ) 

280 

281 @classmethod 

282 def get_available_datasets(cls) -> List[Dict[str, str]]: 

283 """Get information about all registered datasets. 

284 

285 Returns: 

286 List of dictionaries with dataset information. 

287 """ 

288 return [ 

289 cls.get_dataset_class(dataset_id).get_dataset_info() 

290 for dataset_id in cls._registry 

291 ] 

292 

293 @classmethod 

294 def load_dataset( 

295 cls, 

296 dataset_id: str, 

297 dataset_path: Optional[str] = None, 

298 num_examples: Optional[int] = None, 

299 seed: int = 42, 

300 ) -> List[Dict[str, Any]]: 

301 """Load a dataset by ID. 

302 

303 This is a convenience method that creates a dataset instance 

304 and loads its examples in one call. 

305 

306 Args: 

307 dataset_id: ID of the dataset to load. 

308 dataset_path: Optional path to the dataset file. 

309 num_examples: Optional number of examples to sample. 

310 seed: Random seed for sampling. 

311 

312 Returns: 

313 List of processed examples. 

314 """ 

315 dataset = cls.create_dataset( 

316 dataset_id=dataset_id, 

317 dataset_path=dataset_path, 

318 num_examples=num_examples, 

319 seed=seed, 

320 ) 

321 return dataset.load()