Coverage for src / local_deep_research / advanced_search_system / query_generation / adaptive_query_generator.py: 97%
155 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:55 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:55 +0000
1"""
2Adaptive query generation system for improved search performance.
3"""
5from collections import defaultdict
6from dataclasses import dataclass
7from typing import Dict, List, Optional, Set, Tuple
9from langchain_core.language_models import BaseChatModel
11from ...utilities.search_utilities import remove_think_tags
12from ..constraints.base_constraint import Constraint, ConstraintType
15@dataclass
16class QueryPattern:
17 """Represents a successful query pattern."""
19 template: str
20 constraint_types: List[ConstraintType]
21 success_rate: float
22 example_queries: List[str]
23 discovered_entities: Set[str]
26class AdaptiveQueryGenerator:
27 """
28 Generates search queries that adapt based on past performance.
30 Features:
31 1. Pattern learning from successful queries
32 2. Semantic expansion for broader coverage
33 3. Constraint combination optimization
34 4. Failure recovery strategies
35 """
37 def __init__(self, model: BaseChatModel):
38 """Initialize the adaptive query generator."""
39 self.model = model
40 self.successful_patterns: List[QueryPattern] = []
41 self.failed_queries: Set[str] = set()
42 self.semantic_expansions: Dict[str, List[str]] = {}
43 self.constraint_combinations: Dict[
44 Tuple[ConstraintType, ...], float
45 ] = defaultdict(float)
47 # Initialize default patterns
48 self._initialize_default_patterns()
50 def _initialize_default_patterns(self):
51 """Initialize with proven query patterns."""
52 default_patterns = [
53 QueryPattern(
54 template='"{entity}" {property} {location}',
55 constraint_types=[
56 ConstraintType.NAME_PATTERN,
57 ConstraintType.PROPERTY,
58 ConstraintType.LOCATION,
59 ],
60 success_rate=0.7,
61 example_queries=['"mountain" formed ice age Colorado'],
62 discovered_entities=set(),
63 ),
64 QueryPattern(
65 template="{event} {temporal} {statistic}",
66 constraint_types=[
67 ConstraintType.EVENT,
68 ConstraintType.TEMPORAL,
69 ConstraintType.STATISTIC,
70 ],
71 success_rate=0.6,
72 example_queries=["accident 2000-2021 statistics"],
73 discovered_entities=set(),
74 ),
75 QueryPattern(
76 template='"{name_pattern}" AND {comparison} {value}',
77 constraint_types=[
78 ConstraintType.NAME_PATTERN,
79 ConstraintType.COMPARISON,
80 ],
81 success_rate=0.65,
82 example_queries=['"body part" AND "84.5 times" ratio'],
83 discovered_entities=set(),
84 ),
85 ]
86 self.successful_patterns.extend(default_patterns)
88 def generate_query(
89 self, constraints: List[Constraint], context: Optional[Dict] = None
90 ) -> str:
91 """Generate an adaptive query based on constraints and context."""
92 # Try pattern-based generation first
93 pattern_query = self._generate_from_patterns(constraints)
94 if pattern_query and pattern_query not in self.failed_queries:
95 return pattern_query
97 # Try semantic expansion
98 expanded_query = self._generate_with_expansion(constraints)
99 if expanded_query and expanded_query not in self.failed_queries:
100 return expanded_query
102 # Fall back to LLM-based generation
103 return self._generate_with_llm(constraints, context)
105 def _generate_from_patterns(
106 self, constraints: List[Constraint]
107 ) -> Optional[str]:
108 """Generate query using learned patterns."""
109 constraint_types = [c.type for c in constraints]
111 # Find matching patterns
112 matching_patterns = []
113 for pattern in self.successful_patterns:
114 if all(t in constraint_types for t in pattern.constraint_types):
115 matching_patterns.append(pattern)
117 if not matching_patterns:
118 return None
120 # Use the highest success rate pattern
121 best_pattern = max(matching_patterns, key=lambda p: p.success_rate)
123 # Fill in the template
124 template_vars = {}
125 for constraint in constraints:
126 if constraint.type == ConstraintType.NAME_PATTERN:
127 template_vars["name_pattern"] = constraint.value
128 template_vars["entity"] = constraint.value
129 elif constraint.type == ConstraintType.PROPERTY:
130 template_vars["property"] = constraint.value
131 elif constraint.type == ConstraintType.LOCATION:
132 template_vars["location"] = constraint.value
133 elif constraint.type == ConstraintType.EVENT: 133 ↛ 134line 133 didn't jump to line 134 because the condition on line 133 was never true
134 template_vars["event"] = constraint.value
135 elif constraint.type == ConstraintType.TEMPORAL: 135 ↛ 136line 135 didn't jump to line 136 because the condition on line 135 was never true
136 template_vars["temporal"] = constraint.value
137 elif constraint.type == ConstraintType.STATISTIC: 137 ↛ 138line 137 didn't jump to line 138 because the condition on line 137 was never true
138 template_vars["statistic"] = constraint.value
139 elif constraint.type == ConstraintType.COMPARISON: 139 ↛ 125line 139 didn't jump to line 125 because the condition on line 139 was always true
140 template_vars["comparison"] = f'"{constraint.value}"'
141 template_vars["value"] = constraint.value
143 try:
144 return best_pattern.template.format(**template_vars)
145 except KeyError:
146 return None
148 def _generate_with_expansion(
149 self, constraints: List[Constraint]
150 ) -> Optional[str]:
151 """Generate query with semantic expansion."""
152 expanded_terms = []
154 for constraint in constraints:
155 # Get expansions for this value
156 if constraint.value not in self.semantic_expansions:
157 self.semantic_expansions[constraint.value] = (
158 self._get_semantic_expansions(
159 constraint.value, constraint.type
160 )
161 )
163 expansions = self.semantic_expansions[constraint.value]
164 if expansions:
165 # Use OR to include expansions
166 expanded = (
167 f"({constraint.value} OR {' OR '.join(expansions[:2])})"
168 )
169 expanded_terms.append(expanded)
170 else:
171 expanded_terms.append(f'"{constraint.value}"')
173 return " AND ".join(expanded_terms)
175 def _get_semantic_expansions(
176 self, term: str, constraint_type: ConstraintType
177 ) -> List[str]:
178 """Get semantic expansions for a term."""
179 prompt = f"""
180Generate 3 alternative phrases or related terms for "{term}" in the context of {constraint_type.value}.
182These should be:
1831. Synonyms or near-synonyms
1842. Related concepts
1853. Alternative phrasings
187Return only the terms, one per line.
188"""
190 response = self.model.invoke(prompt)
191 expansions = [
192 line.strip()
193 for line in remove_think_tags(response.content).strip().split("\n")
194 if line.strip()
195 ]
197 return [f'"{exp}"' for exp in expansions[:3]]
199 def _generate_with_llm(
200 self, constraints: List[Constraint], context: Optional[Dict] = None
201 ) -> str:
202 """Generate query using LLM with context awareness."""
203 constraint_desc = self._format_constraints(constraints)
205 context_info = ""
206 if context:
207 if "failed_queries" in context:
208 context_info += "\nFailed queries to avoid:\n" + "\n".join(
209 context["failed_queries"][:3]
210 )
211 if "successful_queries" in context:
212 context_info += "\nSuccessful query patterns:\n" + "\n".join(
213 context["successful_queries"][:3]
214 )
216 prompt = f"""
217Create an effective search query for these constraints:
219{constraint_desc}
220{context_info}
222Guidelines:
2231. Focus on finding specific named entities
2242. Use operators (AND, OR, quotes) effectively
2253. Combine constraints strategically
2264. Make the query neither too broad nor too narrow
228Return only the search query.
229"""
231 response = self.model.invoke(prompt)
232 return remove_think_tags(response.content).strip()
234 def update_patterns(
235 self,
236 query: str,
237 constraints: List[Constraint],
238 success: bool,
239 entities_found: List[str],
240 ):
241 """Update patterns based on query performance."""
242 if success and entities_found:
243 # Extract pattern from successful query
244 pattern = self._extract_pattern(query, constraints)
245 if pattern: 245 ↛ 264line 245 didn't jump to line 264 because the condition on line 245 was always true
246 # Update or add pattern
247 existing = next(
248 (
249 p
250 for p in self.successful_patterns
251 if p.template == pattern.template
252 ),
253 None,
254 )
256 if existing:
257 existing.success_rate = (existing.success_rate + 1.0) / 2
258 existing.example_queries.append(query)
259 existing.discovered_entities.update(entities_found)
260 else:
261 self.successful_patterns.append(pattern)
263 # Update constraint combinations
264 constraint_types = tuple(sorted(c.type for c in constraints))
265 self.constraint_combinations[constraint_types] += 1
266 else:
267 self.failed_queries.add(query)
269 def _extract_pattern(
270 self, query: str, constraints: List[Constraint]
271 ) -> Optional[QueryPattern]:
272 """Extract a reusable pattern from a successful query."""
273 # Simple pattern extraction - could be made more sophisticated
274 pattern = query
276 # Replace specific values with placeholders
277 for constraint in constraints:
278 if constraint.value in query:
279 placeholder = f"{{{constraint.type.value}}}"
280 pattern = pattern.replace(constraint.value, placeholder)
282 # Only create pattern if it has placeholders
283 if "{" in pattern:
284 return QueryPattern(
285 template=pattern,
286 constraint_types=[c.type for c in constraints],
287 success_rate=1.0,
288 example_queries=[query],
289 discovered_entities=set(),
290 )
292 return None
294 def _format_constraints(self, constraints: List[Constraint]) -> str:
295 """Format constraints for prompts."""
296 formatted = []
297 for c in constraints:
298 formatted.append(
299 f"- {c.type.value}: {c.description} [value: {c.value}]"
300 )
301 return "\n".join(formatted)
303 def generate_fallback_queries(
304 self, original_query: str, constraints: List[Constraint]
305 ) -> List[str]:
306 """Generate fallback queries when the original fails."""
307 fallback_queries = []
309 # 1. Simplified query (fewer constraints)
310 if len(constraints) > 2:
311 priority_constraints = sorted(
312 constraints, key=lambda c: c.weight, reverse=True
313 )[:2]
314 simplified = self.generate_query(priority_constraints)
315 fallback_queries.append(simplified)
317 # 2. Broadened query (with OR instead of AND)
318 terms = [f'"{c.value}"' for c in constraints]
319 broadened = " OR ".join(terms)
320 fallback_queries.append(broadened)
322 # 3. Decomposed queries (one constraint at a time)
323 for constraint in constraints[:3]:
324 single_query = self._generate_single_constraint_query(constraint)
325 fallback_queries.append(single_query)
327 # 4. Alternative phrasing
328 alt_prompt = f"""
329The query "{original_query}" failed. Create 2 alternative queries with different phrasing.
331Constraints to satisfy:
332{self._format_constraints(constraints)}
334Return only the queries, one per line.
335"""
337 response = self.model.invoke(alt_prompt)
338 alt_queries = [
339 line.strip()
340 for line in remove_think_tags(response.content).strip().split("\n")
341 if line.strip()
342 ]
343 fallback_queries.extend(alt_queries)
345 # Remove duplicates and failed queries
346 unique_fallbacks = []
347 for q in fallback_queries:
348 if q and q not in self.failed_queries and q not in unique_fallbacks:
349 unique_fallbacks.append(q)
351 return unique_fallbacks[:5]
353 def _generate_single_constraint_query(self, constraint: Constraint) -> str:
354 """Generate a query for a single constraint."""
355 type_specific_templates = {
356 ConstraintType.NAME_PATTERN: '"{value}" names list',
357 ConstraintType.LOCATION: '"{value}" places locations',
358 ConstraintType.EVENT: '"{value}" incidents accidents',
359 ConstraintType.PROPERTY: 'things with "{value}" property',
360 ConstraintType.STATISTIC: '"{value}" statistics data',
361 ConstraintType.TEMPORAL: "events in {value}",
362 ConstraintType.COMPARISON: '"{value}" comparison ratio',
363 ConstraintType.EXISTENCE: 'has "{value}" feature',
364 }
366 template = type_specific_templates.get(constraint.type, '"{value}"')
368 return template.format(value=constraint.value)
370 def optimize_constraint_combinations(
371 self, constraints: List[Constraint]
372 ) -> List[List[Constraint]]:
373 """Optimize constraint combinations based on past success."""
374 combinations = []
376 # Sort constraint combinations by success rate
377 sorted_combos = sorted(
378 self.constraint_combinations.items(),
379 key=lambda x: x[1],
380 reverse=True,
381 )
383 # Try successful combinations first
384 for combo_types, _ in sorted_combos:
385 matching_constraints = []
386 for ctype in combo_types:
387 matching = [c for c in constraints if c.type == ctype]
388 if matching:
389 matching_constraints.append(matching[0])
391 if len(matching_constraints) == len(combo_types):
392 combinations.append(matching_constraints)
394 # Add individual constraints
395 combinations.extend([[c] for c in constraints])
397 # Add pairs not yet tried
398 for i in range(len(constraints)):
399 for j in range(i + 1, len(constraints)):
400 pair = [constraints[i], constraints[j]]
401 if pair not in combinations:
402 combinations.append(pair)
404 return combinations[:10] # Limit to top 10