Coverage for src / local_deep_research / advanced_search_system / query_generation / adaptive_query_generator.py: 97%

155 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 23:55 +0000

1""" 

2Adaptive query generation system for improved search performance. 

3""" 

4 

5from collections import defaultdict 

6from dataclasses import dataclass 

7from typing import Dict, List, Optional, Set, Tuple 

8 

9from langchain_core.language_models import BaseChatModel 

10 

11from ...utilities.search_utilities import remove_think_tags 

12from ..constraints.base_constraint import Constraint, ConstraintType 

13 

14 

15@dataclass 

16class QueryPattern: 

17 """Represents a successful query pattern.""" 

18 

19 template: str 

20 constraint_types: List[ConstraintType] 

21 success_rate: float 

22 example_queries: List[str] 

23 discovered_entities: Set[str] 

24 

25 

26class AdaptiveQueryGenerator: 

27 """ 

28 Generates search queries that adapt based on past performance. 

29 

30 Features: 

31 1. Pattern learning from successful queries 

32 2. Semantic expansion for broader coverage 

33 3. Constraint combination optimization 

34 4. Failure recovery strategies 

35 """ 

36 

37 def __init__(self, model: BaseChatModel): 

38 """Initialize the adaptive query generator.""" 

39 self.model = model 

40 self.successful_patterns: List[QueryPattern] = [] 

41 self.failed_queries: Set[str] = set() 

42 self.semantic_expansions: Dict[str, List[str]] = {} 

43 self.constraint_combinations: Dict[ 

44 Tuple[ConstraintType, ...], float 

45 ] = defaultdict(float) 

46 

47 # Initialize default patterns 

48 self._initialize_default_patterns() 

49 

50 def _initialize_default_patterns(self): 

51 """Initialize with proven query patterns.""" 

52 default_patterns = [ 

53 QueryPattern( 

54 template='"{entity}" {property} {location}', 

55 constraint_types=[ 

56 ConstraintType.NAME_PATTERN, 

57 ConstraintType.PROPERTY, 

58 ConstraintType.LOCATION, 

59 ], 

60 success_rate=0.7, 

61 example_queries=['"mountain" formed ice age Colorado'], 

62 discovered_entities=set(), 

63 ), 

64 QueryPattern( 

65 template="{event} {temporal} {statistic}", 

66 constraint_types=[ 

67 ConstraintType.EVENT, 

68 ConstraintType.TEMPORAL, 

69 ConstraintType.STATISTIC, 

70 ], 

71 success_rate=0.6, 

72 example_queries=["accident 2000-2021 statistics"], 

73 discovered_entities=set(), 

74 ), 

75 QueryPattern( 

76 template='"{name_pattern}" AND {comparison} {value}', 

77 constraint_types=[ 

78 ConstraintType.NAME_PATTERN, 

79 ConstraintType.COMPARISON, 

80 ], 

81 success_rate=0.65, 

82 example_queries=['"body part" AND "84.5 times" ratio'], 

83 discovered_entities=set(), 

84 ), 

85 ] 

86 self.successful_patterns.extend(default_patterns) 

87 

88 def generate_query( 

89 self, constraints: List[Constraint], context: Optional[Dict] = None 

90 ) -> str: 

91 """Generate an adaptive query based on constraints and context.""" 

92 # Try pattern-based generation first 

93 pattern_query = self._generate_from_patterns(constraints) 

94 if pattern_query and pattern_query not in self.failed_queries: 

95 return pattern_query 

96 

97 # Try semantic expansion 

98 expanded_query = self._generate_with_expansion(constraints) 

99 if expanded_query and expanded_query not in self.failed_queries: 

100 return expanded_query 

101 

102 # Fall back to LLM-based generation 

103 return self._generate_with_llm(constraints, context) 

104 

105 def _generate_from_patterns( 

106 self, constraints: List[Constraint] 

107 ) -> Optional[str]: 

108 """Generate query using learned patterns.""" 

109 constraint_types = [c.type for c in constraints] 

110 

111 # Find matching patterns 

112 matching_patterns = [] 

113 for pattern in self.successful_patterns: 

114 if all(t in constraint_types for t in pattern.constraint_types): 

115 matching_patterns.append(pattern) 

116 

117 if not matching_patterns: 

118 return None 

119 

120 # Use the highest success rate pattern 

121 best_pattern = max(matching_patterns, key=lambda p: p.success_rate) 

122 

123 # Fill in the template 

124 template_vars = {} 

125 for constraint in constraints: 

126 if constraint.type == ConstraintType.NAME_PATTERN: 

127 template_vars["name_pattern"] = constraint.value 

128 template_vars["entity"] = constraint.value 

129 elif constraint.type == ConstraintType.PROPERTY: 

130 template_vars["property"] = constraint.value 

131 elif constraint.type == ConstraintType.LOCATION: 

132 template_vars["location"] = constraint.value 

133 elif constraint.type == ConstraintType.EVENT: 133 ↛ 134line 133 didn't jump to line 134 because the condition on line 133 was never true

134 template_vars["event"] = constraint.value 

135 elif constraint.type == ConstraintType.TEMPORAL: 135 ↛ 136line 135 didn't jump to line 136 because the condition on line 135 was never true

136 template_vars["temporal"] = constraint.value 

137 elif constraint.type == ConstraintType.STATISTIC: 137 ↛ 138line 137 didn't jump to line 138 because the condition on line 137 was never true

138 template_vars["statistic"] = constraint.value 

139 elif constraint.type == ConstraintType.COMPARISON: 139 ↛ 125line 139 didn't jump to line 125 because the condition on line 139 was always true

140 template_vars["comparison"] = f'"{constraint.value}"' 

141 template_vars["value"] = constraint.value 

142 

143 try: 

144 return best_pattern.template.format(**template_vars) 

145 except KeyError: 

146 return None 

147 

148 def _generate_with_expansion( 

149 self, constraints: List[Constraint] 

150 ) -> Optional[str]: 

151 """Generate query with semantic expansion.""" 

152 expanded_terms = [] 

153 

154 for constraint in constraints: 

155 # Get expansions for this value 

156 if constraint.value not in self.semantic_expansions: 

157 self.semantic_expansions[constraint.value] = ( 

158 self._get_semantic_expansions( 

159 constraint.value, constraint.type 

160 ) 

161 ) 

162 

163 expansions = self.semantic_expansions[constraint.value] 

164 if expansions: 

165 # Use OR to include expansions 

166 expanded = ( 

167 f"({constraint.value} OR {' OR '.join(expansions[:2])})" 

168 ) 

169 expanded_terms.append(expanded) 

170 else: 

171 expanded_terms.append(f'"{constraint.value}"') 

172 

173 return " AND ".join(expanded_terms) 

174 

175 def _get_semantic_expansions( 

176 self, term: str, constraint_type: ConstraintType 

177 ) -> List[str]: 

178 """Get semantic expansions for a term.""" 

179 prompt = f""" 

180Generate 3 alternative phrases or related terms for "{term}" in the context of {constraint_type.value}. 

181 

182These should be: 

1831. Synonyms or near-synonyms 

1842. Related concepts 

1853. Alternative phrasings 

186 

187Return only the terms, one per line. 

188""" 

189 

190 response = self.model.invoke(prompt) 

191 expansions = [ 

192 line.strip() 

193 for line in remove_think_tags(response.content).strip().split("\n") 

194 if line.strip() 

195 ] 

196 

197 return [f'"{exp}"' for exp in expansions[:3]] 

198 

199 def _generate_with_llm( 

200 self, constraints: List[Constraint], context: Optional[Dict] = None 

201 ) -> str: 

202 """Generate query using LLM with context awareness.""" 

203 constraint_desc = self._format_constraints(constraints) 

204 

205 context_info = "" 

206 if context: 

207 if "failed_queries" in context: 

208 context_info += "\nFailed queries to avoid:\n" + "\n".join( 

209 context["failed_queries"][:3] 

210 ) 

211 if "successful_queries" in context: 

212 context_info += "\nSuccessful query patterns:\n" + "\n".join( 

213 context["successful_queries"][:3] 

214 ) 

215 

216 prompt = f""" 

217Create an effective search query for these constraints: 

218 

219{constraint_desc} 

220{context_info} 

221 

222Guidelines: 

2231. Focus on finding specific named entities 

2242. Use operators (AND, OR, quotes) effectively 

2253. Combine constraints strategically 

2264. Make the query neither too broad nor too narrow 

227 

228Return only the search query. 

229""" 

230 

231 response = self.model.invoke(prompt) 

232 return remove_think_tags(response.content).strip() 

233 

234 def update_patterns( 

235 self, 

236 query: str, 

237 constraints: List[Constraint], 

238 success: bool, 

239 entities_found: List[str], 

240 ): 

241 """Update patterns based on query performance.""" 

242 if success and entities_found: 

243 # Extract pattern from successful query 

244 pattern = self._extract_pattern(query, constraints) 

245 if pattern: 245 ↛ 264line 245 didn't jump to line 264 because the condition on line 245 was always true

246 # Update or add pattern 

247 existing = next( 

248 ( 

249 p 

250 for p in self.successful_patterns 

251 if p.template == pattern.template 

252 ), 

253 None, 

254 ) 

255 

256 if existing: 

257 existing.success_rate = (existing.success_rate + 1.0) / 2 

258 existing.example_queries.append(query) 

259 existing.discovered_entities.update(entities_found) 

260 else: 

261 self.successful_patterns.append(pattern) 

262 

263 # Update constraint combinations 

264 constraint_types = tuple(sorted(c.type for c in constraints)) 

265 self.constraint_combinations[constraint_types] += 1 

266 else: 

267 self.failed_queries.add(query) 

268 

269 def _extract_pattern( 

270 self, query: str, constraints: List[Constraint] 

271 ) -> Optional[QueryPattern]: 

272 """Extract a reusable pattern from a successful query.""" 

273 # Simple pattern extraction - could be made more sophisticated 

274 pattern = query 

275 

276 # Replace specific values with placeholders 

277 for constraint in constraints: 

278 if constraint.value in query: 

279 placeholder = f"{{{constraint.type.value}}}" 

280 pattern = pattern.replace(constraint.value, placeholder) 

281 

282 # Only create pattern if it has placeholders 

283 if "{" in pattern: 

284 return QueryPattern( 

285 template=pattern, 

286 constraint_types=[c.type for c in constraints], 

287 success_rate=1.0, 

288 example_queries=[query], 

289 discovered_entities=set(), 

290 ) 

291 

292 return None 

293 

294 def _format_constraints(self, constraints: List[Constraint]) -> str: 

295 """Format constraints for prompts.""" 

296 formatted = [] 

297 for c in constraints: 

298 formatted.append( 

299 f"- {c.type.value}: {c.description} [value: {c.value}]" 

300 ) 

301 return "\n".join(formatted) 

302 

303 def generate_fallback_queries( 

304 self, original_query: str, constraints: List[Constraint] 

305 ) -> List[str]: 

306 """Generate fallback queries when the original fails.""" 

307 fallback_queries = [] 

308 

309 # 1. Simplified query (fewer constraints) 

310 if len(constraints) > 2: 

311 priority_constraints = sorted( 

312 constraints, key=lambda c: c.weight, reverse=True 

313 )[:2] 

314 simplified = self.generate_query(priority_constraints) 

315 fallback_queries.append(simplified) 

316 

317 # 2. Broadened query (with OR instead of AND) 

318 terms = [f'"{c.value}"' for c in constraints] 

319 broadened = " OR ".join(terms) 

320 fallback_queries.append(broadened) 

321 

322 # 3. Decomposed queries (one constraint at a time) 

323 for constraint in constraints[:3]: 

324 single_query = self._generate_single_constraint_query(constraint) 

325 fallback_queries.append(single_query) 

326 

327 # 4. Alternative phrasing 

328 alt_prompt = f""" 

329The query "{original_query}" failed. Create 2 alternative queries with different phrasing. 

330 

331Constraints to satisfy: 

332{self._format_constraints(constraints)} 

333 

334Return only the queries, one per line. 

335""" 

336 

337 response = self.model.invoke(alt_prompt) 

338 alt_queries = [ 

339 line.strip() 

340 for line in remove_think_tags(response.content).strip().split("\n") 

341 if line.strip() 

342 ] 

343 fallback_queries.extend(alt_queries) 

344 

345 # Remove duplicates and failed queries 

346 unique_fallbacks = [] 

347 for q in fallback_queries: 

348 if q and q not in self.failed_queries and q not in unique_fallbacks: 

349 unique_fallbacks.append(q) 

350 

351 return unique_fallbacks[:5] 

352 

353 def _generate_single_constraint_query(self, constraint: Constraint) -> str: 

354 """Generate a query for a single constraint.""" 

355 type_specific_templates = { 

356 ConstraintType.NAME_PATTERN: '"{value}" names list', 

357 ConstraintType.LOCATION: '"{value}" places locations', 

358 ConstraintType.EVENT: '"{value}" incidents accidents', 

359 ConstraintType.PROPERTY: 'things with "{value}" property', 

360 ConstraintType.STATISTIC: '"{value}" statistics data', 

361 ConstraintType.TEMPORAL: "events in {value}", 

362 ConstraintType.COMPARISON: '"{value}" comparison ratio', 

363 ConstraintType.EXISTENCE: 'has "{value}" feature', 

364 } 

365 

366 template = type_specific_templates.get(constraint.type, '"{value}"') 

367 

368 return template.format(value=constraint.value) 

369 

370 def optimize_constraint_combinations( 

371 self, constraints: List[Constraint] 

372 ) -> List[List[Constraint]]: 

373 """Optimize constraint combinations based on past success.""" 

374 combinations = [] 

375 

376 # Sort constraint combinations by success rate 

377 sorted_combos = sorted( 

378 self.constraint_combinations.items(), 

379 key=lambda x: x[1], 

380 reverse=True, 

381 ) 

382 

383 # Try successful combinations first 

384 for combo_types, _ in sorted_combos: 

385 matching_constraints = [] 

386 for ctype in combo_types: 

387 matching = [c for c in constraints if c.type == ctype] 

388 if matching: 

389 matching_constraints.append(matching[0]) 

390 

391 if len(matching_constraints) == len(combo_types): 

392 combinations.append(matching_constraints) 

393 

394 # Add individual constraints 

395 combinations.extend([[c] for c in constraints]) 

396 

397 # Add pairs not yet tried 

398 for i in range(len(constraints)): 

399 for j in range(i + 1, len(constraints)): 

400 pair = [constraints[i], constraints[j]] 

401 if pair not in combinations: 

402 combinations.append(pair) 

403 

404 return combinations[:10] # Limit to top 10