Coverage for src / local_deep_research / benchmarks / datasets / base.py: 33%

97 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2026-01-11 00:51 +0000

1""" 

2Base classes for benchmark datasets. 

3 

4This module provides base classes for loading, processing, and working 

5with benchmark datasets in a maintainable, extensible way. 

6""" 

7 

8from loguru import logger 

9import random 

10from abc import ABC, abstractmethod 

11from typing import Any, Dict, List, Optional 

12 

13import pandas as pd 

14 

15 

16class BenchmarkDataset(ABC): 

17 """Base class for all benchmark datasets. 

18 

19 This abstract base class defines the interface that all benchmark 

20 datasets must implement, providing a consistent way to load and 

21 process benchmark data. 

22 """ 

23 

24 def __init__( 

25 self, 

26 dataset_path: Optional[str] = None, 

27 num_examples: Optional[int] = None, 

28 seed: int = 42, 

29 ): 

30 """Initialize the dataset. 

31 

32 Args: 

33 dataset_path: Optional path to the dataset file. If None, uses the default. 

34 num_examples: Optional number of examples to sample from the dataset. 

35 seed: Random seed for reproducible sampling. 

36 """ 

37 self.dataset_path = dataset_path or self.get_default_dataset_path() 

38 self.num_examples = num_examples 

39 self.seed = seed 

40 self.examples = [] 

41 self._is_loaded = False 

42 

43 @classmethod 

44 @abstractmethod 

45 def get_dataset_info(cls) -> Dict[str, str]: 

46 """Get basic information about the dataset. 

47 

48 Returns: 

49 Dictionary with dataset metadata (id, name, description, etc.). 

50 """ 

51 pass 

52 

53 @classmethod 

54 @abstractmethod 

55 def get_default_dataset_path(cls) -> str: 

56 """Get the default path or URL for the dataset. 

57 

58 Returns: 

59 String path or URL to the default dataset source. 

60 """ 

61 pass 

62 

63 @abstractmethod 

64 def process_example(self, example: Dict[str, Any]) -> Dict[str, Any]: 

65 """Process a single example from the dataset. 

66 

67 This method is called for each example during loading. It can be used 

68 to decrypt data, transform fields, or any other necessary processing. 

69 

70 Args: 

71 example: Raw example from the dataset. 

72 

73 Returns: 

74 Processed example ready for use. 

75 """ 

76 pass 

77 

78 def load(self) -> List[Dict[str, Any]]: 

79 """Load and process the dataset. 

80 

81 This method loads the dataset, processes each example, and optionally 

82 samples a subset of examples. 

83 

84 Returns: 

85 List of processed examples. 

86 """ 

87 if self._is_loaded: 

88 return self.examples 

89 

90 logger.info(f"Loading dataset from {self.dataset_path}") 

91 

92 try: 

93 # Load raw data 

94 if self.dataset_path.endswith(".csv"): 

95 df = pd.read_csv(self.dataset_path) 

96 raw_examples = [row.to_dict() for _, row in df.iterrows()] 

97 elif self.dataset_path.endswith( 

98 ".json" 

99 ) or self.dataset_path.endswith(".jsonl"): 

100 import json 

101 

102 with open(self.dataset_path, "r") as f: 

103 if self.dataset_path.endswith(".jsonl"): 

104 raw_examples = [ 

105 json.loads(line) for line in f if line.strip() 

106 ] 

107 else: 

108 raw_examples = json.load(f) 

109 else: 

110 raise ValueError( 

111 f"Unsupported file format: {self.dataset_path}" 

112 ) 

113 

114 # Process each example 

115 processed_examples = [] 

116 for i, example in enumerate(raw_examples): 

117 try: 

118 processed = self.process_example(example) 

119 processed_examples.append(processed) 

120 except Exception: 

121 logger.exception(f"Error processing example {i}") 

122 

123 # Sample if needed 

124 if self.num_examples and self.num_examples < len( 

125 processed_examples 

126 ): 

127 random.seed(self.seed) 

128 sampled_examples = random.sample( 

129 processed_examples, self.num_examples 

130 ) 

131 logger.info( 

132 f"Sampled {self.num_examples} examples out of {len(processed_examples)}" 

133 ) 

134 self.examples = sampled_examples 

135 else: 

136 logger.info(f"Loaded {len(processed_examples)} examples") 

137 self.examples = processed_examples 

138 

139 self._is_loaded = True 

140 return self.examples 

141 

142 except Exception: 

143 logger.exception("Error loading dataset") 

144 raise 

145 

146 def get_examples(self) -> List[Dict[str, Any]]: 

147 """Get the loaded examples, loading the dataset if needed. 

148 

149 Returns: 

150 List of processed examples. 

151 """ 

152 if not self._is_loaded: 

153 return self.load() 

154 return self.examples 

155 

156 def get_example(self, index: int) -> Dict[str, Any]: 

157 """Get a specific example by index. 

158 

159 Args: 

160 index: Index of the example to retrieve. 

161 

162 Returns: 

163 The specified example. 

164 

165 Raises: 

166 IndexError: If the index is out of range. 

167 """ 

168 examples = self.get_examples() 

169 if index < 0 or index >= len(examples): 

170 raise IndexError( 

171 f"Example index {index} out of range (0-{len(examples) - 1})" 

172 ) 

173 return examples[index] 

174 

175 def get_question(self, example: Dict[str, Any]) -> str: 

176 """Extract the question from an example. 

177 

178 This method may be overridden by subclasses to customize question extraction. 

179 

180 Args: 

181 example: The example to extract the question from. 

182 

183 Returns: 

184 The question string. 

185 """ 

186 return example.get("problem", "") 

187 

188 def get_answer(self, example: Dict[str, Any]) -> str: 

189 """Extract the answer from an example. 

190 

191 This method may be overridden by subclasses to customize answer extraction. 

192 

193 Args: 

194 example: The example to extract the answer from. 

195 

196 Returns: 

197 The answer string. 

198 """ 

199 # Try the correct_answer field first, then fall back to answer 

200 return example.get("correct_answer", example.get("answer", "")) 

201 

202 

203class DatasetRegistry: 

204 """Registry for all available benchmark datasets. 

205 

206 This class serves as a central registry for all datasets, allowing 

207 them to be discovered and instantiated by name. 

208 """ 

209 

210 _registry = {} 

211 

212 @classmethod 

213 def register(cls, dataset_class): 

214 """Register a dataset class. 

215 

216 Args: 

217 dataset_class: A class inheriting from BenchmarkDataset. 

218 

219 Returns: 

220 The dataset class (to allow use as a decorator). 

221 """ 

222 dataset_info = dataset_class.get_dataset_info() 

223 dataset_id = dataset_info.get("id") 

224 if not dataset_id: 224 ↛ 225line 224 didn't jump to line 225 because the condition on line 224 was never true

225 raise ValueError("Dataset must have an ID") 

226 

227 cls._registry[dataset_id] = dataset_class 

228 logger.debug(f"Registered dataset: {dataset_id}") 

229 return dataset_class 

230 

231 @classmethod 

232 def get_dataset_class(cls, dataset_id: str): 

233 """Get a dataset class by ID. 

234 

235 Args: 

236 dataset_id: ID of the dataset to retrieve. 

237 

238 Returns: 

239 The dataset class. 

240 

241 Raises: 

242 ValueError: If no dataset with the given ID is registered. 

243 """ 

244 if dataset_id not in cls._registry: 

245 raise ValueError(f"Unknown dataset: {dataset_id}") 

246 return cls._registry[dataset_id] 

247 

248 @classmethod 

249 def create_dataset( 

250 cls, 

251 dataset_id: str, 

252 dataset_path: Optional[str] = None, 

253 num_examples: Optional[int] = None, 

254 seed: int = 42, 

255 ) -> BenchmarkDataset: 

256 """Create a dataset instance by ID. 

257 

258 Args: 

259 dataset_id: ID of the dataset to create. 

260 dataset_path: Optional path to the dataset file. 

261 num_examples: Optional number of examples to sample. 

262 seed: Random seed for sampling. 

263 

264 Returns: 

265 A dataset instance. 

266 """ 

267 dataset_class = cls.get_dataset_class(dataset_id) 

268 return dataset_class( 

269 dataset_path=dataset_path, num_examples=num_examples, seed=seed 

270 ) 

271 

272 @classmethod 

273 def get_available_datasets(cls) -> List[Dict[str, str]]: 

274 """Get information about all registered datasets. 

275 

276 Returns: 

277 List of dictionaries with dataset information. 

278 """ 

279 return [ 

280 cls.get_dataset_class(dataset_id).get_dataset_info() 

281 for dataset_id in cls._registry 

282 ] 

283 

284 @classmethod 

285 def load_dataset( 

286 cls, 

287 dataset_id: str, 

288 dataset_path: Optional[str] = None, 

289 num_examples: Optional[int] = None, 

290 seed: int = 42, 

291 ) -> List[Dict[str, Any]]: 

292 """Load a dataset by ID. 

293 

294 This is a convenience method that creates a dataset instance 

295 and loads its examples in one call. 

296 

297 Args: 

298 dataset_id: ID of the dataset to load. 

299 dataset_path: Optional path to the dataset file. 

300 num_examples: Optional number of examples to sample. 

301 seed: Random seed for sampling. 

302 

303 Returns: 

304 List of processed examples. 

305 """ 

306 dataset = cls.create_dataset( 

307 dataset_id=dataset_id, 

308 dataset_path=dataset_path, 

309 num_examples=num_examples, 

310 seed=seed, 

311 ) 

312 return dataset.load()