Coverage for src / local_deep_research / advanced_search_system / search_optimization / cross_constraint_manager.py: 82%
259 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-25 01:07 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-25 01:07 +0000
1"""
2Cross-constraint search optimization manager.
3"""
5import itertools
6from collections import defaultdict
7from dataclasses import dataclass, field
8from typing import Dict, List, Optional, Set, Tuple
10from langchain_core.language_models import BaseChatModel
12from ...utilities.search_utilities import remove_think_tags
13from ..candidates.base_candidate import Candidate
14from ..constraints.base_constraint import Constraint
17@dataclass
18class ConstraintRelationship:
19 """Represents a relationship between constraints."""
21 constraint1_id: str
22 constraint2_id: str
23 relationship_type: str # 'complementary', 'dependent', 'exclusive'
24 strength: float # 0.0 to 1.0
25 evidence: List[str] = field(default_factory=list)
28@dataclass
29class ConstraintCluster:
30 """Group of related constraints that should be searched together."""
32 constraints: List[Constraint]
33 cluster_type: str # 'temporal', 'spatial', 'causal', 'descriptive'
34 coherence_score: float
35 search_queries: List[str] = field(default_factory=list)
38class CrossConstraintManager:
39 """
40 Manages cross-constraint relationships and optimizes multi-constraint searches.
42 Key features:
43 1. Identifies relationships between constraints
44 2. Clusters related constraints for efficient searching
45 3. Generates cross-constraint validation queries
46 4. Tracks cross-constraint evidence patterns
47 """
49 def __init__(self, model: BaseChatModel):
50 """Initialize the cross-constraint manager."""
51 self.model = model
52 self.relationships: Dict[Tuple[str, str], ConstraintRelationship] = {}
53 self.clusters: List[ConstraintCluster] = []
54 self.cross_validation_patterns: Dict[str, List[Dict]] = defaultdict(
55 list
56 )
57 self.constraint_graph: Dict[str, Set[str]] = defaultdict(set)
59 def analyze_constraint_relationships(
60 self, constraints: List[Constraint]
61 ) -> Dict[Tuple[str, str], ConstraintRelationship]:
62 """Analyze relationships between constraints."""
63 relationships = {}
65 # Analyze each pair of constraints
66 for c1, c2 in itertools.combinations(constraints, 2):
67 relationship = self._analyze_pair(c1, c2)
68 if (
69 relationship.strength > 0.3
70 ): # Only keep meaningful relationships
71 key = (c1.id, c2.id)
72 relationships[key] = relationship
74 # Update constraint graph
75 self.constraint_graph[c1.id].add(c2.id)
76 self.constraint_graph[c2.id].add(c1.id)
78 self.relationships.update(relationships)
79 return relationships
81 def _analyze_pair(
82 self, c1: Constraint, c2: Constraint
83 ) -> ConstraintRelationship:
84 """Analyze the relationship between two constraints."""
85 prompt = f"""
86Analyze the relationship between these two constraints:
88Constraint 1: {c1.description} (Type: {c1.type.value})
89Constraint 2: {c2.description} (Type: {c2.type.value})
91Determine:
921. Relationship type (complementary, dependent, exclusive, or none)
932. Strength of relationship (0.0 to 1.0)
943. Brief explanation
96Format:
97Type: [relationship_type]
98Strength: [0.0-1.0]
99Evidence: [explanation]
100"""
102 response = self.model.invoke(prompt)
103 content = remove_think_tags(response.content)
105 # Parse response
106 rel_type = "none"
107 strength = 0.0
108 evidence = []
110 for line in content.strip().split("\n"):
111 if line.startswith("Type:"):
112 rel_type = line.split(":", 1)[1].strip().lower()
113 elif line.startswith("Strength:"):
114 try:
115 strength = float(line.split(":", 1)[1].strip())
116 except ValueError:
117 strength = 0.0
118 elif line.startswith("Evidence:"):
119 evidence.append(line.split(":", 1)[1].strip())
121 return ConstraintRelationship(
122 constraint1_id=c1.id,
123 constraint2_id=c2.id,
124 relationship_type=rel_type,
125 strength=strength,
126 evidence=evidence,
127 )
129 def create_constraint_clusters(
130 self, constraints: List[Constraint]
131 ) -> List[ConstraintCluster]:
132 """Create clusters of related constraints."""
133 # First, analyze relationships if not done
134 if not self.relationships: 134 ↛ 138line 134 didn't jump to line 138 because the condition on line 134 was always true
135 self.analyze_constraint_relationships(constraints)
137 # Create clusters using different strategies
138 clusters = []
140 # 1. Type-based clusters
141 type_groups = defaultdict(list)
142 for c in constraints:
143 type_groups[c.type].append(c)
145 for ctype, group in type_groups.items():
146 if len(group) > 1:
147 cluster = ConstraintCluster(
148 constraints=group,
149 cluster_type="type_based",
150 coherence_score=0.7,
151 )
152 clusters.append(cluster)
154 # 2. Relationship-based clusters
155 strong_relationships = [
156 rel for rel in self.relationships.values() if rel.strength > 0.6
157 ]
159 relationship_clusters = self._create_relationship_clusters(
160 constraints, strong_relationships
161 )
162 clusters.extend(relationship_clusters)
164 # 3. Semantic clusters
165 semantic_clusters = self._create_semantic_clusters(constraints)
166 clusters.extend(semantic_clusters)
168 # Remove duplicate clusters
169 unique_clusters = self._deduplicate_clusters(clusters)
171 self.clusters = unique_clusters
172 return unique_clusters
174 def _create_relationship_clusters(
175 self,
176 constraints: List[Constraint],
177 relationships: List[ConstraintRelationship],
178 ) -> List[ConstraintCluster]:
179 """Create clusters based on strong relationships."""
180 clusters = []
181 processed = set()
183 # Build adjacency list
184 adj_list = defaultdict(list)
185 for rel in relationships:
186 adj_list[rel.constraint1_id].append(rel.constraint2_id)
187 adj_list[rel.constraint2_id].append(rel.constraint1_id)
189 # Find connected components
190 for constraint in constraints:
191 if constraint.id in processed:
192 continue
194 # BFS to find connected component
195 component = []
196 queue = [constraint.id]
197 visited = {constraint.id}
199 while queue:
200 current_id = queue.pop(0)
201 current = next(
202 (c for c in constraints if c.id == current_id), None
203 )
204 if current: 204 ↛ 199line 204 didn't jump to line 199 because the condition on line 204 was always true
205 component.append(current)
206 processed.add(current_id)
208 for neighbor_id in adj_list[current_id]:
209 if neighbor_id not in visited:
210 visited.add(neighbor_id)
211 queue.append(neighbor_id)
213 if len(component) > 1:
214 cluster = ConstraintCluster(
215 constraints=component,
216 cluster_type="relationship_based",
217 coherence_score=self._calculate_cluster_coherence(
218 component
219 ),
220 )
221 clusters.append(cluster)
223 return clusters
225 def _create_semantic_clusters(
226 self, constraints: List[Constraint]
227 ) -> List[ConstraintCluster]:
228 """Create clusters based on semantic similarity."""
229 prompt = f"""
230Group these constraints into semantic clusters based on their meaning and intent:
232{self._format_constraints_for_clustering(constraints)}
234For each cluster:
2351. List the constraint IDs
2362. Describe the cluster theme
2373. Rate coherence (0.0-1.0)
239Format:
240CLUSTER_1:
241Constraints: [id1, id2, ...]
242Theme: [description]
243Coherence: [0.0-1.0]
244"""
246 response = self.model.invoke(prompt)
247 content = remove_think_tags(response.content)
249 clusters = []
250 current_cluster = {}
252 for line in content.strip().split("\n"):
253 line = line.strip()
255 if line.startswith("CLUSTER_"): 255 ↛ 256line 255 didn't jump to line 256 because the condition on line 255 was never true
256 if current_cluster and "constraints" in current_cluster:
257 # Create cluster from previous data
258 constraint_ids = current_cluster["constraints"]
259 cluster_constraints = [
260 c for c in constraints if c.id in constraint_ids
261 ]
263 if len(cluster_constraints) > 1:
264 cluster = ConstraintCluster(
265 constraints=cluster_constraints,
266 cluster_type="semantic",
267 coherence_score=float(
268 current_cluster.get("coherence", 0.5)
269 ),
270 )
271 clusters.append(cluster)
273 current_cluster = {}
275 elif line.startswith("Constraints:"): 275 ↛ 276line 275 didn't jump to line 276 because the condition on line 275 was never true
276 ids_str = line.split(":", 1)[1].strip()
277 # Extract IDs from various formats
278 import re
280 ids = re.findall(r"c\d+", ids_str)
281 current_cluster["constraints"] = ids
283 elif line.startswith("Theme:"): 283 ↛ 284line 283 didn't jump to line 284 because the condition on line 283 was never true
284 current_cluster["theme"] = line.split(":", 1)[1].strip()
286 elif line.startswith("Coherence:"): 286 ↛ 287line 286 didn't jump to line 287 because the condition on line 286 was never true
287 try:
288 current_cluster["coherence"] = float(
289 line.split(":", 1)[1].strip()
290 )
291 except ValueError:
292 current_cluster["coherence"] = 0.5
294 # Don't forget the last cluster
295 if current_cluster and "constraints" in current_cluster: 295 ↛ 296line 295 didn't jump to line 296 because the condition on line 295 was never true
296 constraint_ids = current_cluster["constraints"]
297 cluster_constraints = [
298 c for c in constraints if c.id in constraint_ids
299 ]
301 if len(cluster_constraints) > 1:
302 cluster = ConstraintCluster(
303 constraints=cluster_constraints,
304 cluster_type="semantic",
305 coherence_score=float(
306 current_cluster.get("coherence", 0.5)
307 ),
308 )
309 clusters.append(cluster)
311 return clusters
313 def generate_cross_constraint_queries(
314 self, cluster: ConstraintCluster
315 ) -> List[str]:
316 """Generate optimized queries for a constraint cluster."""
317 queries = []
319 # 1. Combined query (all constraints)
320 combined_query = self._generate_combined_query(cluster.constraints)
321 queries.append(combined_query)
323 # 2. Progressive queries (build up constraints)
324 progressive_queries = self._generate_progressive_queries(
325 cluster.constraints
326 )
327 queries.extend(progressive_queries)
329 # 3. Intersection queries (shared aspects)
330 intersection_query = self._generate_intersection_query(
331 cluster.constraints
332 )
333 if intersection_query: 333 ↛ 337line 333 didn't jump to line 337 because the condition on line 333 was always true
334 queries.append(intersection_query)
336 # 4. Validation queries (cross-check)
337 validation_queries = self._generate_validation_queries(
338 cluster.constraints
339 )
340 queries.extend(validation_queries)
342 # Store queries in cluster
343 cluster.search_queries = queries
345 return queries
347 def _generate_combined_query(self, constraints: List[Constraint]) -> str:
348 """Generate a query combining all constraints."""
349 prompt = f"""
350Create a search query that finds entities satisfying ALL of these related constraints:
352{self._format_constraints_for_query(constraints)}
354The query should:
3551. Efficiently combine all constraints
3562. Use appropriate operators (AND, OR)
3573. Focus on finding specific entities
3584. Be neither too broad nor too narrow
360Return only the search query.
361"""
363 response = self.model.invoke(prompt)
364 return remove_think_tags(response.content).strip()
366 def _generate_progressive_queries(
367 self, constraints: List[Constraint]
368 ) -> List[str]:
369 """Generate queries that progressively add constraints."""
370 queries = []
372 # Sort by weight/importance
373 sorted_constraints = sorted(
374 constraints, key=lambda c: c.weight, reverse=True
375 )
377 # Build up constraints
378 for i in range(2, min(len(sorted_constraints) + 1, 4)):
379 subset = sorted_constraints[:i]
380 query = self._generate_combined_query(subset)
381 queries.append(query)
383 return queries
385 def _generate_intersection_query(
386 self, constraints: List[Constraint]
387 ) -> Optional[str]:
388 """Generate a query focused on the intersection of constraints."""
389 if len(constraints) < 2:
390 return None
392 prompt = f"""
393Identify the common theme or intersection among these constraints:
395{self._format_constraints_for_query(constraints)}
397Create a search query that targets this common aspect.
398Return only the search query, or 'NONE' if no clear intersection exists.
399"""
401 response = self.model.invoke(prompt)
402 query = remove_think_tags(response.content).strip()
404 if query.upper() == "NONE":
405 return None
407 return query
409 def _generate_validation_queries(
410 self, constraints: List[Constraint]
411 ) -> List[str]:
412 """Generate queries for cross-validation."""
413 queries = []
415 # Pairwise validation queries
416 for c1, c2 in itertools.combinations(constraints[:3], 2):
417 prompt = f"""
418Create a validation query that checks if an entity satisfies both:
419- {c1.description}
420- {c2.description}
422Return only the search query.
423"""
425 response = self.model.invoke(prompt)
426 query = remove_think_tags(response.content).strip()
427 queries.append(query)
429 return queries[:2] # Limit to 2 validation queries
431 def validate_candidate_across_constraints(
432 self, candidate: Candidate, constraints: List[Constraint]
433 ) -> Dict[str, float]:
434 """Validate a candidate across multiple constraints simultaneously."""
435 validation_scores = {}
437 # Find relevant clusters for these constraints
438 relevant_clusters = [
439 cluster
440 for cluster in self.clusters
441 if any(c in cluster.constraints for c in constraints)
442 ]
444 for cluster in relevant_clusters:
445 # Use cluster-specific queries for validation
446 cluster_score = self._validate_with_cluster(candidate, cluster)
448 # Update individual constraint scores
449 for constraint in cluster.constraints:
450 if constraint in constraints: 450 ↛ 449line 450 didn't jump to line 449 because the condition on line 450 was always true
451 validation_scores[constraint.id] = max(
452 validation_scores.get(constraint.id, 0.0), cluster_score
453 )
455 # Additional pairwise validation
456 for c1, c2 in itertools.combinations(constraints, 2):
457 if (c1.id, c2.id) in self.relationships: 457 ↛ 458line 457 didn't jump to line 458 because the condition on line 457 was never true
458 rel = self.relationships[(c1.id, c2.id)]
459 if rel.relationship_type == "complementary":
460 # Boost scores for complementary constraints
461 pair_score = self._validate_pair(candidate, c1, c2)
462 validation_scores[c1.id] = max(
463 validation_scores.get(c1.id, 0.0), pair_score
464 )
465 validation_scores[c2.id] = max(
466 validation_scores.get(c2.id, 0.0), pair_score
467 )
469 return validation_scores
471 def _validate_with_cluster(
472 self, candidate: Candidate, cluster: ConstraintCluster
473 ) -> float:
474 """Validate candidate using cluster-based approach."""
475 if not cluster.search_queries: 475 ↛ 476line 475 didn't jump to line 476 because the condition on line 475 was never true
476 cluster.search_queries = self.generate_cross_constraint_queries(
477 cluster
478 )
480 # Use the most comprehensive query
481 validation_query = cluster.search_queries[0]
483 prompt = f"""
484Does "{candidate.name}" satisfy this multi-constraint query:
485Query: {validation_query}
487Constraints being checked:
488{self._format_constraints_for_query(cluster.constraints)}
490Provide a confidence score (0.0-1.0) based on how well the candidate matches.
492Format:
493Score: [0.0-1.0]
494Explanation: [brief explanation]
495"""
497 response = self.model.invoke(prompt)
498 content = remove_think_tags(response.content)
500 # Parse score
501 score = 0.0
502 for line in content.strip().split("\n"): 502 ↛ 510line 502 didn't jump to line 510 because the loop on line 502 didn't complete
503 if line.startswith("Score:"): 503 ↛ 502line 503 didn't jump to line 502 because the condition on line 503 was always true
504 try:
505 score = float(line.split(":", 1)[1].strip())
506 except ValueError:
507 score = 0.0
508 break
510 return score
512 def _validate_pair(
513 self, candidate: Candidate, c1: Constraint, c2: Constraint
514 ) -> float:
515 """Validate candidate against a pair of constraints."""
516 prompt = f"""
517Evaluate if "{candidate.name}" satisfies BOTH constraints:
5191. {c1.description} (Type: {c1.type.value})
5202. {c2.description} (Type: {c2.type.value})
522Consider how these constraints relate to each other and whether the candidate satisfies both.
524Provide a confidence score (0.0-1.0).
526Format:
527Score: [0.0-1.0]
528"""
530 response = self.model.invoke(prompt)
531 content = remove_think_tags(response.content)
533 # Parse score
534 score = 0.0
535 for line in content.strip().split("\n"):
536 if line.startswith("Score:"):
537 try:
538 score = float(line.split(":", 1)[1].strip())
539 except ValueError:
540 score = 0.0
541 break
543 return score
545 def _calculate_cluster_coherence(
546 self, constraints: List[Constraint]
547 ) -> float:
548 """Calculate coherence score for a constraint cluster."""
549 if len(constraints) < 2:
550 return 0.0
552 # Calculate based on relationship strengths
553 total_strength = 0.0
554 pair_count = 0
556 for c1, c2 in itertools.combinations(constraints, 2):
557 key = (c1.id, c2.id)
558 if key in self.relationships:
559 total_strength += self.relationships[key].strength
560 pair_count += 1
562 if pair_count == 0:
563 return 0.5 # Default coherence
565 average_strength = total_strength / pair_count
567 # Adjust for cluster size (larger clusters with high average strength are better)
568 size_factor = min(len(constraints) / 5.0, 1.0)
570 return average_strength * (0.7 + 0.3 * size_factor)
572 def _deduplicate_clusters(
573 self, clusters: List[ConstraintCluster]
574 ) -> List[ConstraintCluster]:
575 """Remove duplicate clusters."""
576 unique_clusters = []
577 seen_sets = []
579 for cluster in clusters:
580 constraint_set = {c.id for c in cluster.constraints}
582 # Check if we've seen this set
583 is_duplicate = False
584 for seen_set in seen_sets:
585 if constraint_set == seen_set:
586 is_duplicate = True
587 break
589 if not is_duplicate:
590 unique_clusters.append(cluster)
591 seen_sets.append(constraint_set)
593 return unique_clusters
595 def _format_constraints_for_clustering(
596 self, constraints: List[Constraint]
597 ) -> str:
598 """Format constraints for clustering prompt."""
599 formatted = []
600 for c in constraints:
601 formatted.append(
602 f"{c.id}: {c.description} (Type: {c.type.value}, Weight: {c.weight})"
603 )
604 return "\n".join(formatted)
606 def _format_constraints_for_query(
607 self, constraints: List[Constraint]
608 ) -> str:
609 """Format constraints for query generation."""
610 formatted = []
611 for c in constraints:
612 formatted.append(f"- {c.description} [{c.type.value}]")
613 return "\n".join(formatted)
615 def optimize_search_order(
616 self, clusters: List[ConstraintCluster]
617 ) -> List[ConstraintCluster]:
618 """Optimize the order in which clusters should be searched."""
619 # Sort by coherence and cluster size
620 return sorted(
621 clusters,
622 key=lambda c: (c.coherence_score * len(c.constraints)),
623 reverse=True,
624 )