Coverage for src / local_deep_research / advanced_search_system / search_optimization / cross_constraint_manager.py: 82%

259 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-25 01:07 +0000

1""" 

2Cross-constraint search optimization manager. 

3""" 

4 

5import itertools 

6from collections import defaultdict 

7from dataclasses import dataclass, field 

8from typing import Dict, List, Optional, Set, Tuple 

9 

10from langchain_core.language_models import BaseChatModel 

11 

12from ...utilities.search_utilities import remove_think_tags 

13from ..candidates.base_candidate import Candidate 

14from ..constraints.base_constraint import Constraint 

15 

16 

17@dataclass 

18class ConstraintRelationship: 

19 """Represents a relationship between constraints.""" 

20 

21 constraint1_id: str 

22 constraint2_id: str 

23 relationship_type: str # 'complementary', 'dependent', 'exclusive' 

24 strength: float # 0.0 to 1.0 

25 evidence: List[str] = field(default_factory=list) 

26 

27 

28@dataclass 

29class ConstraintCluster: 

30 """Group of related constraints that should be searched together.""" 

31 

32 constraints: List[Constraint] 

33 cluster_type: str # 'temporal', 'spatial', 'causal', 'descriptive' 

34 coherence_score: float 

35 search_queries: List[str] = field(default_factory=list) 

36 

37 

38class CrossConstraintManager: 

39 """ 

40 Manages cross-constraint relationships and optimizes multi-constraint searches. 

41 

42 Key features: 

43 1. Identifies relationships between constraints 

44 2. Clusters related constraints for efficient searching 

45 3. Generates cross-constraint validation queries 

46 4. Tracks cross-constraint evidence patterns 

47 """ 

48 

49 def __init__(self, model: BaseChatModel): 

50 """Initialize the cross-constraint manager.""" 

51 self.model = model 

52 self.relationships: Dict[Tuple[str, str], ConstraintRelationship] = {} 

53 self.clusters: List[ConstraintCluster] = [] 

54 self.cross_validation_patterns: Dict[str, List[Dict]] = defaultdict( 

55 list 

56 ) 

57 self.constraint_graph: Dict[str, Set[str]] = defaultdict(set) 

58 

59 def analyze_constraint_relationships( 

60 self, constraints: List[Constraint] 

61 ) -> Dict[Tuple[str, str], ConstraintRelationship]: 

62 """Analyze relationships between constraints.""" 

63 relationships = {} 

64 

65 # Analyze each pair of constraints 

66 for c1, c2 in itertools.combinations(constraints, 2): 

67 relationship = self._analyze_pair(c1, c2) 

68 if ( 

69 relationship.strength > 0.3 

70 ): # Only keep meaningful relationships 

71 key = (c1.id, c2.id) 

72 relationships[key] = relationship 

73 

74 # Update constraint graph 

75 self.constraint_graph[c1.id].add(c2.id) 

76 self.constraint_graph[c2.id].add(c1.id) 

77 

78 self.relationships.update(relationships) 

79 return relationships 

80 

81 def _analyze_pair( 

82 self, c1: Constraint, c2: Constraint 

83 ) -> ConstraintRelationship: 

84 """Analyze the relationship between two constraints.""" 

85 prompt = f""" 

86Analyze the relationship between these two constraints: 

87 

88Constraint 1: {c1.description} (Type: {c1.type.value}) 

89Constraint 2: {c2.description} (Type: {c2.type.value}) 

90 

91Determine: 

921. Relationship type (complementary, dependent, exclusive, or none) 

932. Strength of relationship (0.0 to 1.0) 

943. Brief explanation 

95 

96Format: 

97Type: [relationship_type] 

98Strength: [0.0-1.0] 

99Evidence: [explanation] 

100""" 

101 

102 response = self.model.invoke(prompt) 

103 content = remove_think_tags(response.content) 

104 

105 # Parse response 

106 rel_type = "none" 

107 strength = 0.0 

108 evidence = [] 

109 

110 for line in content.strip().split("\n"): 

111 if line.startswith("Type:"): 

112 rel_type = line.split(":", 1)[1].strip().lower() 

113 elif line.startswith("Strength:"): 

114 try: 

115 strength = float(line.split(":", 1)[1].strip()) 

116 except ValueError: 

117 strength = 0.0 

118 elif line.startswith("Evidence:"): 

119 evidence.append(line.split(":", 1)[1].strip()) 

120 

121 return ConstraintRelationship( 

122 constraint1_id=c1.id, 

123 constraint2_id=c2.id, 

124 relationship_type=rel_type, 

125 strength=strength, 

126 evidence=evidence, 

127 ) 

128 

129 def create_constraint_clusters( 

130 self, constraints: List[Constraint] 

131 ) -> List[ConstraintCluster]: 

132 """Create clusters of related constraints.""" 

133 # First, analyze relationships if not done 

134 if not self.relationships: 134 ↛ 138line 134 didn't jump to line 138 because the condition on line 134 was always true

135 self.analyze_constraint_relationships(constraints) 

136 

137 # Create clusters using different strategies 

138 clusters = [] 

139 

140 # 1. Type-based clusters 

141 type_groups = defaultdict(list) 

142 for c in constraints: 

143 type_groups[c.type].append(c) 

144 

145 for ctype, group in type_groups.items(): 

146 if len(group) > 1: 

147 cluster = ConstraintCluster( 

148 constraints=group, 

149 cluster_type="type_based", 

150 coherence_score=0.7, 

151 ) 

152 clusters.append(cluster) 

153 

154 # 2. Relationship-based clusters 

155 strong_relationships = [ 

156 rel for rel in self.relationships.values() if rel.strength > 0.6 

157 ] 

158 

159 relationship_clusters = self._create_relationship_clusters( 

160 constraints, strong_relationships 

161 ) 

162 clusters.extend(relationship_clusters) 

163 

164 # 3. Semantic clusters 

165 semantic_clusters = self._create_semantic_clusters(constraints) 

166 clusters.extend(semantic_clusters) 

167 

168 # Remove duplicate clusters 

169 unique_clusters = self._deduplicate_clusters(clusters) 

170 

171 self.clusters = unique_clusters 

172 return unique_clusters 

173 

174 def _create_relationship_clusters( 

175 self, 

176 constraints: List[Constraint], 

177 relationships: List[ConstraintRelationship], 

178 ) -> List[ConstraintCluster]: 

179 """Create clusters based on strong relationships.""" 

180 clusters = [] 

181 processed = set() 

182 

183 # Build adjacency list 

184 adj_list = defaultdict(list) 

185 for rel in relationships: 

186 adj_list[rel.constraint1_id].append(rel.constraint2_id) 

187 adj_list[rel.constraint2_id].append(rel.constraint1_id) 

188 

189 # Find connected components 

190 for constraint in constraints: 

191 if constraint.id in processed: 

192 continue 

193 

194 # BFS to find connected component 

195 component = [] 

196 queue = [constraint.id] 

197 visited = {constraint.id} 

198 

199 while queue: 

200 current_id = queue.pop(0) 

201 current = next( 

202 (c for c in constraints if c.id == current_id), None 

203 ) 

204 if current: 204 ↛ 199line 204 didn't jump to line 199 because the condition on line 204 was always true

205 component.append(current) 

206 processed.add(current_id) 

207 

208 for neighbor_id in adj_list[current_id]: 

209 if neighbor_id not in visited: 

210 visited.add(neighbor_id) 

211 queue.append(neighbor_id) 

212 

213 if len(component) > 1: 

214 cluster = ConstraintCluster( 

215 constraints=component, 

216 cluster_type="relationship_based", 

217 coherence_score=self._calculate_cluster_coherence( 

218 component 

219 ), 

220 ) 

221 clusters.append(cluster) 

222 

223 return clusters 

224 

225 def _create_semantic_clusters( 

226 self, constraints: List[Constraint] 

227 ) -> List[ConstraintCluster]: 

228 """Create clusters based on semantic similarity.""" 

229 prompt = f""" 

230Group these constraints into semantic clusters based on their meaning and intent: 

231 

232{self._format_constraints_for_clustering(constraints)} 

233 

234For each cluster: 

2351. List the constraint IDs 

2362. Describe the cluster theme 

2373. Rate coherence (0.0-1.0) 

238 

239Format: 

240CLUSTER_1: 

241Constraints: [id1, id2, ...] 

242Theme: [description] 

243Coherence: [0.0-1.0] 

244""" 

245 

246 response = self.model.invoke(prompt) 

247 content = remove_think_tags(response.content) 

248 

249 clusters = [] 

250 current_cluster = {} 

251 

252 for line in content.strip().split("\n"): 

253 line = line.strip() 

254 

255 if line.startswith("CLUSTER_"): 255 ↛ 256line 255 didn't jump to line 256 because the condition on line 255 was never true

256 if current_cluster and "constraints" in current_cluster: 

257 # Create cluster from previous data 

258 constraint_ids = current_cluster["constraints"] 

259 cluster_constraints = [ 

260 c for c in constraints if c.id in constraint_ids 

261 ] 

262 

263 if len(cluster_constraints) > 1: 

264 cluster = ConstraintCluster( 

265 constraints=cluster_constraints, 

266 cluster_type="semantic", 

267 coherence_score=float( 

268 current_cluster.get("coherence", 0.5) 

269 ), 

270 ) 

271 clusters.append(cluster) 

272 

273 current_cluster = {} 

274 

275 elif line.startswith("Constraints:"): 275 ↛ 276line 275 didn't jump to line 276 because the condition on line 275 was never true

276 ids_str = line.split(":", 1)[1].strip() 

277 # Extract IDs from various formats 

278 import re 

279 

280 ids = re.findall(r"c\d+", ids_str) 

281 current_cluster["constraints"] = ids 

282 

283 elif line.startswith("Theme:"): 283 ↛ 284line 283 didn't jump to line 284 because the condition on line 283 was never true

284 current_cluster["theme"] = line.split(":", 1)[1].strip() 

285 

286 elif line.startswith("Coherence:"): 286 ↛ 287line 286 didn't jump to line 287 because the condition on line 286 was never true

287 try: 

288 current_cluster["coherence"] = float( 

289 line.split(":", 1)[1].strip() 

290 ) 

291 except ValueError: 

292 current_cluster["coherence"] = 0.5 

293 

294 # Don't forget the last cluster 

295 if current_cluster and "constraints" in current_cluster: 295 ↛ 296line 295 didn't jump to line 296 because the condition on line 295 was never true

296 constraint_ids = current_cluster["constraints"] 

297 cluster_constraints = [ 

298 c for c in constraints if c.id in constraint_ids 

299 ] 

300 

301 if len(cluster_constraints) > 1: 

302 cluster = ConstraintCluster( 

303 constraints=cluster_constraints, 

304 cluster_type="semantic", 

305 coherence_score=float( 

306 current_cluster.get("coherence", 0.5) 

307 ), 

308 ) 

309 clusters.append(cluster) 

310 

311 return clusters 

312 

313 def generate_cross_constraint_queries( 

314 self, cluster: ConstraintCluster 

315 ) -> List[str]: 

316 """Generate optimized queries for a constraint cluster.""" 

317 queries = [] 

318 

319 # 1. Combined query (all constraints) 

320 combined_query = self._generate_combined_query(cluster.constraints) 

321 queries.append(combined_query) 

322 

323 # 2. Progressive queries (build up constraints) 

324 progressive_queries = self._generate_progressive_queries( 

325 cluster.constraints 

326 ) 

327 queries.extend(progressive_queries) 

328 

329 # 3. Intersection queries (shared aspects) 

330 intersection_query = self._generate_intersection_query( 

331 cluster.constraints 

332 ) 

333 if intersection_query: 333 ↛ 337line 333 didn't jump to line 337 because the condition on line 333 was always true

334 queries.append(intersection_query) 

335 

336 # 4. Validation queries (cross-check) 

337 validation_queries = self._generate_validation_queries( 

338 cluster.constraints 

339 ) 

340 queries.extend(validation_queries) 

341 

342 # Store queries in cluster 

343 cluster.search_queries = queries 

344 

345 return queries 

346 

347 def _generate_combined_query(self, constraints: List[Constraint]) -> str: 

348 """Generate a query combining all constraints.""" 

349 prompt = f""" 

350Create a search query that finds entities satisfying ALL of these related constraints: 

351 

352{self._format_constraints_for_query(constraints)} 

353 

354The query should: 

3551. Efficiently combine all constraints 

3562. Use appropriate operators (AND, OR) 

3573. Focus on finding specific entities 

3584. Be neither too broad nor too narrow 

359 

360Return only the search query. 

361""" 

362 

363 response = self.model.invoke(prompt) 

364 return remove_think_tags(response.content).strip() 

365 

366 def _generate_progressive_queries( 

367 self, constraints: List[Constraint] 

368 ) -> List[str]: 

369 """Generate queries that progressively add constraints.""" 

370 queries = [] 

371 

372 # Sort by weight/importance 

373 sorted_constraints = sorted( 

374 constraints, key=lambda c: c.weight, reverse=True 

375 ) 

376 

377 # Build up constraints 

378 for i in range(2, min(len(sorted_constraints) + 1, 4)): 

379 subset = sorted_constraints[:i] 

380 query = self._generate_combined_query(subset) 

381 queries.append(query) 

382 

383 return queries 

384 

385 def _generate_intersection_query( 

386 self, constraints: List[Constraint] 

387 ) -> Optional[str]: 

388 """Generate a query focused on the intersection of constraints.""" 

389 if len(constraints) < 2: 

390 return None 

391 

392 prompt = f""" 

393Identify the common theme or intersection among these constraints: 

394 

395{self._format_constraints_for_query(constraints)} 

396 

397Create a search query that targets this common aspect. 

398Return only the search query, or 'NONE' if no clear intersection exists. 

399""" 

400 

401 response = self.model.invoke(prompt) 

402 query = remove_think_tags(response.content).strip() 

403 

404 if query.upper() == "NONE": 

405 return None 

406 

407 return query 

408 

409 def _generate_validation_queries( 

410 self, constraints: List[Constraint] 

411 ) -> List[str]: 

412 """Generate queries for cross-validation.""" 

413 queries = [] 

414 

415 # Pairwise validation queries 

416 for c1, c2 in itertools.combinations(constraints[:3], 2): 

417 prompt = f""" 

418Create a validation query that checks if an entity satisfies both: 

419- {c1.description} 

420- {c2.description} 

421 

422Return only the search query. 

423""" 

424 

425 response = self.model.invoke(prompt) 

426 query = remove_think_tags(response.content).strip() 

427 queries.append(query) 

428 

429 return queries[:2] # Limit to 2 validation queries 

430 

431 def validate_candidate_across_constraints( 

432 self, candidate: Candidate, constraints: List[Constraint] 

433 ) -> Dict[str, float]: 

434 """Validate a candidate across multiple constraints simultaneously.""" 

435 validation_scores = {} 

436 

437 # Find relevant clusters for these constraints 

438 relevant_clusters = [ 

439 cluster 

440 for cluster in self.clusters 

441 if any(c in cluster.constraints for c in constraints) 

442 ] 

443 

444 for cluster in relevant_clusters: 

445 # Use cluster-specific queries for validation 

446 cluster_score = self._validate_with_cluster(candidate, cluster) 

447 

448 # Update individual constraint scores 

449 for constraint in cluster.constraints: 

450 if constraint in constraints: 450 ↛ 449line 450 didn't jump to line 449 because the condition on line 450 was always true

451 validation_scores[constraint.id] = max( 

452 validation_scores.get(constraint.id, 0.0), cluster_score 

453 ) 

454 

455 # Additional pairwise validation 

456 for c1, c2 in itertools.combinations(constraints, 2): 

457 if (c1.id, c2.id) in self.relationships: 457 ↛ 458line 457 didn't jump to line 458 because the condition on line 457 was never true

458 rel = self.relationships[(c1.id, c2.id)] 

459 if rel.relationship_type == "complementary": 

460 # Boost scores for complementary constraints 

461 pair_score = self._validate_pair(candidate, c1, c2) 

462 validation_scores[c1.id] = max( 

463 validation_scores.get(c1.id, 0.0), pair_score 

464 ) 

465 validation_scores[c2.id] = max( 

466 validation_scores.get(c2.id, 0.0), pair_score 

467 ) 

468 

469 return validation_scores 

470 

471 def _validate_with_cluster( 

472 self, candidate: Candidate, cluster: ConstraintCluster 

473 ) -> float: 

474 """Validate candidate using cluster-based approach.""" 

475 if not cluster.search_queries: 475 ↛ 476line 475 didn't jump to line 476 because the condition on line 475 was never true

476 cluster.search_queries = self.generate_cross_constraint_queries( 

477 cluster 

478 ) 

479 

480 # Use the most comprehensive query 

481 validation_query = cluster.search_queries[0] 

482 

483 prompt = f""" 

484Does "{candidate.name}" satisfy this multi-constraint query: 

485Query: {validation_query} 

486 

487Constraints being checked: 

488{self._format_constraints_for_query(cluster.constraints)} 

489 

490Provide a confidence score (0.0-1.0) based on how well the candidate matches. 

491 

492Format: 

493Score: [0.0-1.0] 

494Explanation: [brief explanation] 

495""" 

496 

497 response = self.model.invoke(prompt) 

498 content = remove_think_tags(response.content) 

499 

500 # Parse score 

501 score = 0.0 

502 for line in content.strip().split("\n"): 502 ↛ 510line 502 didn't jump to line 510 because the loop on line 502 didn't complete

503 if line.startswith("Score:"): 503 ↛ 502line 503 didn't jump to line 502 because the condition on line 503 was always true

504 try: 

505 score = float(line.split(":", 1)[1].strip()) 

506 except ValueError: 

507 score = 0.0 

508 break 

509 

510 return score 

511 

512 def _validate_pair( 

513 self, candidate: Candidate, c1: Constraint, c2: Constraint 

514 ) -> float: 

515 """Validate candidate against a pair of constraints.""" 

516 prompt = f""" 

517Evaluate if "{candidate.name}" satisfies BOTH constraints: 

518 

5191. {c1.description} (Type: {c1.type.value}) 

5202. {c2.description} (Type: {c2.type.value}) 

521 

522Consider how these constraints relate to each other and whether the candidate satisfies both. 

523 

524Provide a confidence score (0.0-1.0). 

525 

526Format: 

527Score: [0.0-1.0] 

528""" 

529 

530 response = self.model.invoke(prompt) 

531 content = remove_think_tags(response.content) 

532 

533 # Parse score 

534 score = 0.0 

535 for line in content.strip().split("\n"): 

536 if line.startswith("Score:"): 

537 try: 

538 score = float(line.split(":", 1)[1].strip()) 

539 except ValueError: 

540 score = 0.0 

541 break 

542 

543 return score 

544 

545 def _calculate_cluster_coherence( 

546 self, constraints: List[Constraint] 

547 ) -> float: 

548 """Calculate coherence score for a constraint cluster.""" 

549 if len(constraints) < 2: 

550 return 0.0 

551 

552 # Calculate based on relationship strengths 

553 total_strength = 0.0 

554 pair_count = 0 

555 

556 for c1, c2 in itertools.combinations(constraints, 2): 

557 key = (c1.id, c2.id) 

558 if key in self.relationships: 

559 total_strength += self.relationships[key].strength 

560 pair_count += 1 

561 

562 if pair_count == 0: 

563 return 0.5 # Default coherence 

564 

565 average_strength = total_strength / pair_count 

566 

567 # Adjust for cluster size (larger clusters with high average strength are better) 

568 size_factor = min(len(constraints) / 5.0, 1.0) 

569 

570 return average_strength * (0.7 + 0.3 * size_factor) 

571 

572 def _deduplicate_clusters( 

573 self, clusters: List[ConstraintCluster] 

574 ) -> List[ConstraintCluster]: 

575 """Remove duplicate clusters.""" 

576 unique_clusters = [] 

577 seen_sets = [] 

578 

579 for cluster in clusters: 

580 constraint_set = {c.id for c in cluster.constraints} 

581 

582 # Check if we've seen this set 

583 is_duplicate = False 

584 for seen_set in seen_sets: 

585 if constraint_set == seen_set: 

586 is_duplicate = True 

587 break 

588 

589 if not is_duplicate: 

590 unique_clusters.append(cluster) 

591 seen_sets.append(constraint_set) 

592 

593 return unique_clusters 

594 

595 def _format_constraints_for_clustering( 

596 self, constraints: List[Constraint] 

597 ) -> str: 

598 """Format constraints for clustering prompt.""" 

599 formatted = [] 

600 for c in constraints: 

601 formatted.append( 

602 f"{c.id}: {c.description} (Type: {c.type.value}, Weight: {c.weight})" 

603 ) 

604 return "\n".join(formatted) 

605 

606 def _format_constraints_for_query( 

607 self, constraints: List[Constraint] 

608 ) -> str: 

609 """Format constraints for query generation.""" 

610 formatted = [] 

611 for c in constraints: 

612 formatted.append(f"- {c.description} [{c.type.value}]") 

613 return "\n".join(formatted) 

614 

615 def optimize_search_order( 

616 self, clusters: List[ConstraintCluster] 

617 ) -> List[ConstraintCluster]: 

618 """Optimize the order in which clusters should be searched.""" 

619 # Sort by coherence and cluster size 

620 return sorted( 

621 clusters, 

622 key=lambda c: (c.coherence_score * len(c.constraints)), 

623 reverse=True, 

624 )