Coverage for src / local_deep_research / advanced_search_system / query_generation / adaptive_query_generator.py: 92%
156 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-25 01:07 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-25 01:07 +0000
1"""
2Adaptive query generation system for improved search performance.
3"""
5from collections import defaultdict
6from dataclasses import dataclass
7from typing import Dict, List, Optional, Set, Tuple
9from langchain_core.language_models import BaseChatModel
11from ...utilities.search_utilities import remove_think_tags
12from ..constraints.base_constraint import Constraint, ConstraintType
15@dataclass
16class QueryPattern:
17 """Represents a successful query pattern."""
19 template: str
20 constraint_types: List[ConstraintType]
21 success_rate: float
22 example_queries: List[str]
23 discovered_entities: Set[str]
26class AdaptiveQueryGenerator:
27 """
28 Generates search queries that adapt based on past performance.
30 Features:
31 1. Pattern learning from successful queries
32 2. Semantic expansion for broader coverage
33 3. Constraint combination optimization
34 4. Failure recovery strategies
35 """
37 def __init__(self, model: BaseChatModel):
38 """Initialize the adaptive query generator."""
39 self.model = model
40 self.successful_patterns: List[QueryPattern] = []
41 self.failed_queries: Set[str] = set()
42 self.semantic_expansions: Dict[str, List[str]] = {}
43 self.constraint_combinations: Dict[
44 Tuple[ConstraintType, ...], float
45 ] = defaultdict(float)
47 # Initialize default patterns
48 self._initialize_default_patterns()
50 def _initialize_default_patterns(self):
51 """Initialize with proven query patterns."""
52 default_patterns = [
53 QueryPattern(
54 template='"{entity}" {property} {location}',
55 constraint_types=[
56 ConstraintType.NAME_PATTERN,
57 ConstraintType.PROPERTY,
58 ConstraintType.LOCATION,
59 ],
60 success_rate=0.7,
61 example_queries=['"mountain" formed ice age Colorado'],
62 discovered_entities=set(),
63 ),
64 QueryPattern(
65 template="{event} {temporal} {statistic}",
66 constraint_types=[
67 ConstraintType.EVENT,
68 ConstraintType.TEMPORAL,
69 ConstraintType.STATISTIC,
70 ],
71 success_rate=0.6,
72 example_queries=["accident 2000-2021 statistics"],
73 discovered_entities=set(),
74 ),
75 QueryPattern(
76 template='"{name_pattern}" AND {comparison} {value}',
77 constraint_types=[
78 ConstraintType.NAME_PATTERN,
79 ConstraintType.COMPARISON,
80 ],
81 success_rate=0.65,
82 example_queries=['"body part" AND "84.5 times" ratio'],
83 discovered_entities=set(),
84 ),
85 ]
86 self.successful_patterns.extend(default_patterns)
88 def generate_query(
89 self, constraints: List[Constraint], context: Optional[Dict] = None
90 ) -> str:
91 """Generate an adaptive query based on constraints and context."""
92 # Try pattern-based generation first
93 pattern_query = self._generate_from_patterns(constraints)
94 if pattern_query and pattern_query not in self.failed_queries:
95 return pattern_query
97 # Try semantic expansion
98 expanded_query = self._generate_with_expansion(constraints)
99 if expanded_query and expanded_query not in self.failed_queries:
100 return expanded_query
102 # Fall back to LLM-based generation
103 return self._generate_with_llm(constraints, context)
105 def _generate_from_patterns(
106 self, constraints: List[Constraint]
107 ) -> Optional[str]:
108 """Generate query using learned patterns."""
109 constraint_types = [c.type for c in constraints]
111 # Find matching patterns
112 matching_patterns = []
113 for pattern in self.successful_patterns:
114 if all(t in constraint_types for t in pattern.constraint_types):
115 matching_patterns.append(pattern)
117 if not matching_patterns:
118 return None
120 # Use the highest success rate pattern
121 best_pattern = max(matching_patterns, key=lambda p: p.success_rate)
123 # Fill in the template
124 template_vars = {}
125 for constraint in constraints:
126 if constraint.type == ConstraintType.NAME_PATTERN:
127 template_vars["name_pattern"] = constraint.value
128 template_vars["entity"] = constraint.value
129 elif constraint.type == ConstraintType.PROPERTY:
130 template_vars["property"] = constraint.value
131 elif constraint.type == ConstraintType.LOCATION: 131 ↛ 133line 131 didn't jump to line 133 because the condition on line 131 was always true
132 template_vars["location"] = constraint.value
133 elif constraint.type == ConstraintType.EVENT:
134 template_vars["event"] = constraint.value
135 elif constraint.type == ConstraintType.TEMPORAL:
136 template_vars["temporal"] = constraint.value
137 elif constraint.type == ConstraintType.STATISTIC:
138 template_vars["statistic"] = constraint.value
139 elif constraint.type == ConstraintType.COMPARISON:
140 template_vars["comparison"] = f'"{constraint.value}"'
141 template_vars["value"] = constraint.value
143 try:
144 query = best_pattern.template.format(**template_vars)
145 return query
146 except KeyError:
147 return None
149 def _generate_with_expansion(
150 self, constraints: List[Constraint]
151 ) -> Optional[str]:
152 """Generate query with semantic expansion."""
153 expanded_terms = []
155 for constraint in constraints:
156 # Get expansions for this value
157 if constraint.value not in self.semantic_expansions:
158 self.semantic_expansions[constraint.value] = (
159 self._get_semantic_expansions(
160 constraint.value, constraint.type
161 )
162 )
164 expansions = self.semantic_expansions[constraint.value]
165 if expansions:
166 # Use OR to include expansions
167 expanded = (
168 f"({constraint.value} OR {' OR '.join(expansions[:2])})"
169 )
170 expanded_terms.append(expanded)
171 else:
172 expanded_terms.append(f'"{constraint.value}"')
174 return " AND ".join(expanded_terms)
176 def _get_semantic_expansions(
177 self, term: str, constraint_type: ConstraintType
178 ) -> List[str]:
179 """Get semantic expansions for a term."""
180 prompt = f"""
181Generate 3 alternative phrases or related terms for "{term}" in the context of {constraint_type.value}.
183These should be:
1841. Synonyms or near-synonyms
1852. Related concepts
1863. Alternative phrasings
188Return only the terms, one per line.
189"""
191 response = self.model.invoke(prompt)
192 expansions = [
193 line.strip()
194 for line in remove_think_tags(response.content).strip().split("\n")
195 if line.strip()
196 ]
198 return [f'"{exp}"' for exp in expansions[:3]]
200 def _generate_with_llm(
201 self, constraints: List[Constraint], context: Optional[Dict] = None
202 ) -> str:
203 """Generate query using LLM with context awareness."""
204 constraint_desc = self._format_constraints(constraints)
206 context_info = ""
207 if context:
208 if "failed_queries" in context:
209 context_info += "\nFailed queries to avoid:\n" + "\n".join(
210 context["failed_queries"][:3]
211 )
212 if "successful_queries" in context:
213 context_info += "\nSuccessful query patterns:\n" + "\n".join(
214 context["successful_queries"][:3]
215 )
217 prompt = f"""
218Create an effective search query for these constraints:
220{constraint_desc}
221{context_info}
223Guidelines:
2241. Focus on finding specific named entities
2252. Use operators (AND, OR, quotes) effectively
2263. Combine constraints strategically
2274. Make the query neither too broad nor too narrow
229Return only the search query.
230"""
232 response = self.model.invoke(prompt)
233 return remove_think_tags(response.content).strip()
235 def update_patterns(
236 self,
237 query: str,
238 constraints: List[Constraint],
239 success: bool,
240 entities_found: List[str],
241 ):
242 """Update patterns based on query performance."""
243 if success and entities_found:
244 # Extract pattern from successful query
245 pattern = self._extract_pattern(query, constraints)
246 if pattern: 246 ↛ 265line 246 didn't jump to line 265 because the condition on line 246 was always true
247 # Update or add pattern
248 existing = next(
249 (
250 p
251 for p in self.successful_patterns
252 if p.template == pattern.template
253 ),
254 None,
255 )
257 if existing:
258 existing.success_rate = (existing.success_rate + 1.0) / 2
259 existing.example_queries.append(query)
260 existing.discovered_entities.update(entities_found)
261 else:
262 self.successful_patterns.append(pattern)
264 # Update constraint combinations
265 constraint_types = tuple(sorted(c.type for c in constraints))
266 self.constraint_combinations[constraint_types] += 1
267 else:
268 self.failed_queries.add(query)
270 def _extract_pattern(
271 self, query: str, constraints: List[Constraint]
272 ) -> Optional[QueryPattern]:
273 """Extract a reusable pattern from a successful query."""
274 # Simple pattern extraction - could be made more sophisticated
275 pattern = query
277 # Replace specific values with placeholders
278 for constraint in constraints:
279 if constraint.value in query:
280 placeholder = f"{{{constraint.type.value}}}"
281 pattern = pattern.replace(constraint.value, placeholder)
283 # Only create pattern if it has placeholders
284 if "{" in pattern:
285 return QueryPattern(
286 template=pattern,
287 constraint_types=[c.type for c in constraints],
288 success_rate=1.0,
289 example_queries=[query],
290 discovered_entities=set(),
291 )
293 return None
295 def _format_constraints(self, constraints: List[Constraint]) -> str:
296 """Format constraints for prompts."""
297 formatted = []
298 for c in constraints:
299 formatted.append(
300 f"- {c.type.value}: {c.description} [value: {c.value}]"
301 )
302 return "\n".join(formatted)
304 def generate_fallback_queries(
305 self, original_query: str, constraints: List[Constraint]
306 ) -> List[str]:
307 """Generate fallback queries when the original fails."""
308 fallback_queries = []
310 # 1. Simplified query (fewer constraints)
311 if len(constraints) > 2:
312 priority_constraints = sorted(
313 constraints, key=lambda c: c.weight, reverse=True
314 )[:2]
315 simplified = self.generate_query(priority_constraints)
316 fallback_queries.append(simplified)
318 # 2. Broadened query (with OR instead of AND)
319 terms = [f'"{c.value}"' for c in constraints]
320 broadened = " OR ".join(terms)
321 fallback_queries.append(broadened)
323 # 3. Decomposed queries (one constraint at a time)
324 for constraint in constraints[:3]:
325 single_query = self._generate_single_constraint_query(constraint)
326 fallback_queries.append(single_query)
328 # 4. Alternative phrasing
329 alt_prompt = f"""
330The query "{original_query}" failed. Create 2 alternative queries with different phrasing.
332Constraints to satisfy:
333{self._format_constraints(constraints)}
335Return only the queries, one per line.
336"""
338 response = self.model.invoke(alt_prompt)
339 alt_queries = [
340 line.strip()
341 for line in remove_think_tags(response.content).strip().split("\n")
342 if line.strip()
343 ]
344 fallback_queries.extend(alt_queries)
346 # Remove duplicates and failed queries
347 unique_fallbacks = []
348 for q in fallback_queries:
349 if q and q not in self.failed_queries and q not in unique_fallbacks:
350 unique_fallbacks.append(q)
352 return unique_fallbacks[:5]
354 def _generate_single_constraint_query(self, constraint: Constraint) -> str:
355 """Generate a query for a single constraint."""
356 type_specific_templates = {
357 ConstraintType.NAME_PATTERN: '"{value}" names list',
358 ConstraintType.LOCATION: '"{value}" places locations',
359 ConstraintType.EVENT: '"{value}" incidents accidents',
360 ConstraintType.PROPERTY: 'things with "{value}" property',
361 ConstraintType.STATISTIC: '"{value}" statistics data',
362 ConstraintType.TEMPORAL: "events in {value}",
363 ConstraintType.COMPARISON: '"{value}" comparison ratio',
364 ConstraintType.EXISTENCE: 'has "{value}" feature',
365 }
367 template = type_specific_templates.get(constraint.type, '"{value}"')
369 return template.format(value=constraint.value)
371 def optimize_constraint_combinations(
372 self, constraints: List[Constraint]
373 ) -> List[List[Constraint]]:
374 """Optimize constraint combinations based on past success."""
375 combinations = []
377 # Sort constraint combinations by success rate
378 sorted_combos = sorted(
379 self.constraint_combinations.items(),
380 key=lambda x: x[1],
381 reverse=True,
382 )
384 # Try successful combinations first
385 for combo_types, _ in sorted_combos:
386 matching_constraints = []
387 for ctype in combo_types:
388 matching = [c for c in constraints if c.type == ctype]
389 if matching:
390 matching_constraints.append(matching[0])
392 if len(matching_constraints) == len(combo_types):
393 combinations.append(matching_constraints)
395 # Add individual constraints
396 combinations.extend([[c] for c in constraints])
398 # Add pairs not yet tried
399 for i in range(len(constraints)):
400 for j in range(i + 1, len(constraints)):
401 pair = [constraints[i], constraints[j]]
402 if pair not in combinations:
403 combinations.append(pair)
405 return combinations[:10] # Limit to top 10