Coverage for src / local_deep_research / advanced_search_system / query_generation / adaptive_query_generator.py: 92%

156 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-25 01:07 +0000

1""" 

2Adaptive query generation system for improved search performance. 

3""" 

4 

5from collections import defaultdict 

6from dataclasses import dataclass 

7from typing import Dict, List, Optional, Set, Tuple 

8 

9from langchain_core.language_models import BaseChatModel 

10 

11from ...utilities.search_utilities import remove_think_tags 

12from ..constraints.base_constraint import Constraint, ConstraintType 

13 

14 

15@dataclass 

16class QueryPattern: 

17 """Represents a successful query pattern.""" 

18 

19 template: str 

20 constraint_types: List[ConstraintType] 

21 success_rate: float 

22 example_queries: List[str] 

23 discovered_entities: Set[str] 

24 

25 

26class AdaptiveQueryGenerator: 

27 """ 

28 Generates search queries that adapt based on past performance. 

29 

30 Features: 

31 1. Pattern learning from successful queries 

32 2. Semantic expansion for broader coverage 

33 3. Constraint combination optimization 

34 4. Failure recovery strategies 

35 """ 

36 

37 def __init__(self, model: BaseChatModel): 

38 """Initialize the adaptive query generator.""" 

39 self.model = model 

40 self.successful_patterns: List[QueryPattern] = [] 

41 self.failed_queries: Set[str] = set() 

42 self.semantic_expansions: Dict[str, List[str]] = {} 

43 self.constraint_combinations: Dict[ 

44 Tuple[ConstraintType, ...], float 

45 ] = defaultdict(float) 

46 

47 # Initialize default patterns 

48 self._initialize_default_patterns() 

49 

50 def _initialize_default_patterns(self): 

51 """Initialize with proven query patterns.""" 

52 default_patterns = [ 

53 QueryPattern( 

54 template='"{entity}" {property} {location}', 

55 constraint_types=[ 

56 ConstraintType.NAME_PATTERN, 

57 ConstraintType.PROPERTY, 

58 ConstraintType.LOCATION, 

59 ], 

60 success_rate=0.7, 

61 example_queries=['"mountain" formed ice age Colorado'], 

62 discovered_entities=set(), 

63 ), 

64 QueryPattern( 

65 template="{event} {temporal} {statistic}", 

66 constraint_types=[ 

67 ConstraintType.EVENT, 

68 ConstraintType.TEMPORAL, 

69 ConstraintType.STATISTIC, 

70 ], 

71 success_rate=0.6, 

72 example_queries=["accident 2000-2021 statistics"], 

73 discovered_entities=set(), 

74 ), 

75 QueryPattern( 

76 template='"{name_pattern}" AND {comparison} {value}', 

77 constraint_types=[ 

78 ConstraintType.NAME_PATTERN, 

79 ConstraintType.COMPARISON, 

80 ], 

81 success_rate=0.65, 

82 example_queries=['"body part" AND "84.5 times" ratio'], 

83 discovered_entities=set(), 

84 ), 

85 ] 

86 self.successful_patterns.extend(default_patterns) 

87 

88 def generate_query( 

89 self, constraints: List[Constraint], context: Optional[Dict] = None 

90 ) -> str: 

91 """Generate an adaptive query based on constraints and context.""" 

92 # Try pattern-based generation first 

93 pattern_query = self._generate_from_patterns(constraints) 

94 if pattern_query and pattern_query not in self.failed_queries: 

95 return pattern_query 

96 

97 # Try semantic expansion 

98 expanded_query = self._generate_with_expansion(constraints) 

99 if expanded_query and expanded_query not in self.failed_queries: 

100 return expanded_query 

101 

102 # Fall back to LLM-based generation 

103 return self._generate_with_llm(constraints, context) 

104 

105 def _generate_from_patterns( 

106 self, constraints: List[Constraint] 

107 ) -> Optional[str]: 

108 """Generate query using learned patterns.""" 

109 constraint_types = [c.type for c in constraints] 

110 

111 # Find matching patterns 

112 matching_patterns = [] 

113 for pattern in self.successful_patterns: 

114 if all(t in constraint_types for t in pattern.constraint_types): 

115 matching_patterns.append(pattern) 

116 

117 if not matching_patterns: 

118 return None 

119 

120 # Use the highest success rate pattern 

121 best_pattern = max(matching_patterns, key=lambda p: p.success_rate) 

122 

123 # Fill in the template 

124 template_vars = {} 

125 for constraint in constraints: 

126 if constraint.type == ConstraintType.NAME_PATTERN: 

127 template_vars["name_pattern"] = constraint.value 

128 template_vars["entity"] = constraint.value 

129 elif constraint.type == ConstraintType.PROPERTY: 

130 template_vars["property"] = constraint.value 

131 elif constraint.type == ConstraintType.LOCATION: 131 ↛ 133line 131 didn't jump to line 133 because the condition on line 131 was always true

132 template_vars["location"] = constraint.value 

133 elif constraint.type == ConstraintType.EVENT: 

134 template_vars["event"] = constraint.value 

135 elif constraint.type == ConstraintType.TEMPORAL: 

136 template_vars["temporal"] = constraint.value 

137 elif constraint.type == ConstraintType.STATISTIC: 

138 template_vars["statistic"] = constraint.value 

139 elif constraint.type == ConstraintType.COMPARISON: 

140 template_vars["comparison"] = f'"{constraint.value}"' 

141 template_vars["value"] = constraint.value 

142 

143 try: 

144 query = best_pattern.template.format(**template_vars) 

145 return query 

146 except KeyError: 

147 return None 

148 

149 def _generate_with_expansion( 

150 self, constraints: List[Constraint] 

151 ) -> Optional[str]: 

152 """Generate query with semantic expansion.""" 

153 expanded_terms = [] 

154 

155 for constraint in constraints: 

156 # Get expansions for this value 

157 if constraint.value not in self.semantic_expansions: 

158 self.semantic_expansions[constraint.value] = ( 

159 self._get_semantic_expansions( 

160 constraint.value, constraint.type 

161 ) 

162 ) 

163 

164 expansions = self.semantic_expansions[constraint.value] 

165 if expansions: 

166 # Use OR to include expansions 

167 expanded = ( 

168 f"({constraint.value} OR {' OR '.join(expansions[:2])})" 

169 ) 

170 expanded_terms.append(expanded) 

171 else: 

172 expanded_terms.append(f'"{constraint.value}"') 

173 

174 return " AND ".join(expanded_terms) 

175 

176 def _get_semantic_expansions( 

177 self, term: str, constraint_type: ConstraintType 

178 ) -> List[str]: 

179 """Get semantic expansions for a term.""" 

180 prompt = f""" 

181Generate 3 alternative phrases or related terms for "{term}" in the context of {constraint_type.value}. 

182 

183These should be: 

1841. Synonyms or near-synonyms 

1852. Related concepts 

1863. Alternative phrasings 

187 

188Return only the terms, one per line. 

189""" 

190 

191 response = self.model.invoke(prompt) 

192 expansions = [ 

193 line.strip() 

194 for line in remove_think_tags(response.content).strip().split("\n") 

195 if line.strip() 

196 ] 

197 

198 return [f'"{exp}"' for exp in expansions[:3]] 

199 

200 def _generate_with_llm( 

201 self, constraints: List[Constraint], context: Optional[Dict] = None 

202 ) -> str: 

203 """Generate query using LLM with context awareness.""" 

204 constraint_desc = self._format_constraints(constraints) 

205 

206 context_info = "" 

207 if context: 

208 if "failed_queries" in context: 

209 context_info += "\nFailed queries to avoid:\n" + "\n".join( 

210 context["failed_queries"][:3] 

211 ) 

212 if "successful_queries" in context: 

213 context_info += "\nSuccessful query patterns:\n" + "\n".join( 

214 context["successful_queries"][:3] 

215 ) 

216 

217 prompt = f""" 

218Create an effective search query for these constraints: 

219 

220{constraint_desc} 

221{context_info} 

222 

223Guidelines: 

2241. Focus on finding specific named entities 

2252. Use operators (AND, OR, quotes) effectively 

2263. Combine constraints strategically 

2274. Make the query neither too broad nor too narrow 

228 

229Return only the search query. 

230""" 

231 

232 response = self.model.invoke(prompt) 

233 return remove_think_tags(response.content).strip() 

234 

235 def update_patterns( 

236 self, 

237 query: str, 

238 constraints: List[Constraint], 

239 success: bool, 

240 entities_found: List[str], 

241 ): 

242 """Update patterns based on query performance.""" 

243 if success and entities_found: 

244 # Extract pattern from successful query 

245 pattern = self._extract_pattern(query, constraints) 

246 if pattern: 246 ↛ 265line 246 didn't jump to line 265 because the condition on line 246 was always true

247 # Update or add pattern 

248 existing = next( 

249 ( 

250 p 

251 for p in self.successful_patterns 

252 if p.template == pattern.template 

253 ), 

254 None, 

255 ) 

256 

257 if existing: 

258 existing.success_rate = (existing.success_rate + 1.0) / 2 

259 existing.example_queries.append(query) 

260 existing.discovered_entities.update(entities_found) 

261 else: 

262 self.successful_patterns.append(pattern) 

263 

264 # Update constraint combinations 

265 constraint_types = tuple(sorted(c.type for c in constraints)) 

266 self.constraint_combinations[constraint_types] += 1 

267 else: 

268 self.failed_queries.add(query) 

269 

270 def _extract_pattern( 

271 self, query: str, constraints: List[Constraint] 

272 ) -> Optional[QueryPattern]: 

273 """Extract a reusable pattern from a successful query.""" 

274 # Simple pattern extraction - could be made more sophisticated 

275 pattern = query 

276 

277 # Replace specific values with placeholders 

278 for constraint in constraints: 

279 if constraint.value in query: 

280 placeholder = f"{{{constraint.type.value}}}" 

281 pattern = pattern.replace(constraint.value, placeholder) 

282 

283 # Only create pattern if it has placeholders 

284 if "{" in pattern: 

285 return QueryPattern( 

286 template=pattern, 

287 constraint_types=[c.type for c in constraints], 

288 success_rate=1.0, 

289 example_queries=[query], 

290 discovered_entities=set(), 

291 ) 

292 

293 return None 

294 

295 def _format_constraints(self, constraints: List[Constraint]) -> str: 

296 """Format constraints for prompts.""" 

297 formatted = [] 

298 for c in constraints: 

299 formatted.append( 

300 f"- {c.type.value}: {c.description} [value: {c.value}]" 

301 ) 

302 return "\n".join(formatted) 

303 

304 def generate_fallback_queries( 

305 self, original_query: str, constraints: List[Constraint] 

306 ) -> List[str]: 

307 """Generate fallback queries when the original fails.""" 

308 fallback_queries = [] 

309 

310 # 1. Simplified query (fewer constraints) 

311 if len(constraints) > 2: 

312 priority_constraints = sorted( 

313 constraints, key=lambda c: c.weight, reverse=True 

314 )[:2] 

315 simplified = self.generate_query(priority_constraints) 

316 fallback_queries.append(simplified) 

317 

318 # 2. Broadened query (with OR instead of AND) 

319 terms = [f'"{c.value}"' for c in constraints] 

320 broadened = " OR ".join(terms) 

321 fallback_queries.append(broadened) 

322 

323 # 3. Decomposed queries (one constraint at a time) 

324 for constraint in constraints[:3]: 

325 single_query = self._generate_single_constraint_query(constraint) 

326 fallback_queries.append(single_query) 

327 

328 # 4. Alternative phrasing 

329 alt_prompt = f""" 

330The query "{original_query}" failed. Create 2 alternative queries with different phrasing. 

331 

332Constraints to satisfy: 

333{self._format_constraints(constraints)} 

334 

335Return only the queries, one per line. 

336""" 

337 

338 response = self.model.invoke(alt_prompt) 

339 alt_queries = [ 

340 line.strip() 

341 for line in remove_think_tags(response.content).strip().split("\n") 

342 if line.strip() 

343 ] 

344 fallback_queries.extend(alt_queries) 

345 

346 # Remove duplicates and failed queries 

347 unique_fallbacks = [] 

348 for q in fallback_queries: 

349 if q and q not in self.failed_queries and q not in unique_fallbacks: 

350 unique_fallbacks.append(q) 

351 

352 return unique_fallbacks[:5] 

353 

354 def _generate_single_constraint_query(self, constraint: Constraint) -> str: 

355 """Generate a query for a single constraint.""" 

356 type_specific_templates = { 

357 ConstraintType.NAME_PATTERN: '"{value}" names list', 

358 ConstraintType.LOCATION: '"{value}" places locations', 

359 ConstraintType.EVENT: '"{value}" incidents accidents', 

360 ConstraintType.PROPERTY: 'things with "{value}" property', 

361 ConstraintType.STATISTIC: '"{value}" statistics data', 

362 ConstraintType.TEMPORAL: "events in {value}", 

363 ConstraintType.COMPARISON: '"{value}" comparison ratio', 

364 ConstraintType.EXISTENCE: 'has "{value}" feature', 

365 } 

366 

367 template = type_specific_templates.get(constraint.type, '"{value}"') 

368 

369 return template.format(value=constraint.value) 

370 

371 def optimize_constraint_combinations( 

372 self, constraints: List[Constraint] 

373 ) -> List[List[Constraint]]: 

374 """Optimize constraint combinations based on past success.""" 

375 combinations = [] 

376 

377 # Sort constraint combinations by success rate 

378 sorted_combos = sorted( 

379 self.constraint_combinations.items(), 

380 key=lambda x: x[1], 

381 reverse=True, 

382 ) 

383 

384 # Try successful combinations first 

385 for combo_types, _ in sorted_combos: 

386 matching_constraints = [] 

387 for ctype in combo_types: 

388 matching = [c for c in constraints if c.type == ctype] 

389 if matching: 

390 matching_constraints.append(matching[0]) 

391 

392 if len(matching_constraints) == len(combo_types): 

393 combinations.append(matching_constraints) 

394 

395 # Add individual constraints 

396 combinations.extend([[c] for c in constraints]) 

397 

398 # Add pairs not yet tried 

399 for i in range(len(constraints)): 

400 for j in range(i + 1, len(constraints)): 

401 pair = [constraints[i], constraints[j]] 

402 if pair not in combinations: 

403 combinations.append(pair) 

404 

405 return combinations[:10] # Limit to top 10