Coverage for src / local_deep_research / utilities / search_cache.py: 82%

221 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2026-01-11 00:51 +0000

1""" 

2Search Cache Utility 

3Provides intelligent caching for search results to avoid repeated queries. 

4Includes TTL, LRU eviction, and query normalization. 

5""" 

6 

7import hashlib 

8import threading 

9import time 

10from functools import lru_cache 

11from pathlib import Path 

12from typing import Any, Dict, List, Optional 

13 

14from loguru import logger 

15from sqlalchemy import create_engine 

16from sqlalchemy.orm import sessionmaker 

17 

18from ..config.paths import get_cache_directory 

19from ..database.models import Base, SearchCache as SearchCacheModel 

20 

21 

22class SearchCache: 

23 """ 

24 Persistent cache for search results with TTL and LRU eviction. 

25 Stores results in SQLite for persistence across sessions. 

26 """ 

27 

28 def __init__( 

29 self, 

30 cache_dir: str = None, 

31 max_memory_items: int = 1000, 

32 default_ttl: int = 3600, 

33 ): 

34 """ 

35 Initialize search cache. 

36 

37 Args: 

38 cache_dir: Directory for cache database. Defaults to data/__CACHE_DIR__ 

39 max_memory_items: Maximum items in memory cache 

40 default_ttl: Default time-to-live in seconds (1 hour default) 

41 """ 

42 self.max_memory_items = max_memory_items 

43 self.default_ttl = default_ttl 

44 

45 # Setup cache directory 

46 if cache_dir is None: 

47 cache_dir = get_cache_directory() / "search_cache" 

48 else: 

49 cache_dir = Path(cache_dir) 

50 

51 cache_dir.mkdir(parents=True, exist_ok=True) 

52 self.db_path = cache_dir / "search_cache.db" 

53 

54 # Initialize database 

55 self._init_db() 

56 

57 # In-memory cache for frequently accessed items 

58 self._memory_cache = {} 

59 self._access_times = {} 

60 

61 # Stampede protection: events and locks for each query being fetched 

62 self._fetch_events = {} # query_hash -> threading.Event (signals completion) 

63 self._fetch_locks = {} # query_hash -> threading.Lock (prevents concurrent fetch) 

64 self._fetch_locks_lock = ( 

65 threading.Lock() 

66 ) # Protects the fetch dictionaries 

67 self._fetch_results = {} # query_hash -> results (temporary storage during fetch) 

68 

69 def _init_db(self): 

70 """Initialize SQLite database for persistent cache using SQLAlchemy.""" 

71 try: 

72 # Create engine and session 

73 self.engine = create_engine(f"sqlite:///{self.db_path}") 

74 Base.metadata.create_all( 

75 self.engine, tables=[SearchCacheModel.__table__] 

76 ) 

77 self.Session = sessionmaker(bind=self.engine) 

78 except Exception: 

79 logger.exception("Failed to initialize search cache database") 

80 

81 def _normalize_query(self, query: str) -> str: 

82 """Normalize query for consistent caching.""" 

83 # Convert to lowercase and remove extra whitespace 

84 normalized = " ".join(query.lower().strip().split()) 

85 

86 # Remove common punctuation that doesn't affect search 

87 normalized = normalized.replace('"', "").replace("'", "") 

88 

89 return normalized 

90 

91 def _get_query_hash( 

92 self, query: str, search_engine: str = "default" 

93 ) -> str: 

94 """Generate hash for query + search engine combination.""" 

95 normalized_query = self._normalize_query(query) 

96 cache_key = f"{search_engine}:{normalized_query}" 

97 return hashlib.md5( # DevSkim: ignore DS126858 

98 cache_key.encode(), usedforsecurity=False 

99 ).hexdigest() 

100 

101 def _cleanup_expired(self): 

102 """Remove expired entries from database.""" 

103 try: 

104 current_time = int(time.time()) 

105 with self.Session() as session: 

106 deleted = ( 

107 session.query(SearchCacheModel) 

108 .filter(SearchCacheModel.expires_at < current_time) 

109 .delete() 

110 ) 

111 session.commit() 

112 if deleted > 0: 

113 logger.debug(f"Cleaned up {deleted} expired cache entries") 

114 except Exception: 

115 logger.exception("Failed to cleanup expired cache entries") 

116 

117 def _evict_lru_memory(self): 

118 """Evict least recently used items from memory cache.""" 

119 if len(self._memory_cache) <= self.max_memory_items: 

120 return 

121 

122 # Sort by access time and remove oldest 

123 sorted_items = sorted(self._access_times.items(), key=lambda x: x[1]) 

124 items_to_remove = ( 

125 len(self._memory_cache) - self.max_memory_items + 100 

126 ) # Remove extra for efficiency 

127 

128 for query_hash, _ in sorted_items[:items_to_remove]: 

129 self._memory_cache.pop(query_hash, None) 

130 self._access_times.pop(query_hash, None) 

131 

132 def get( 

133 self, query: str, search_engine: str = "default" 

134 ) -> Optional[List[Dict[str, Any]]]: 

135 """ 

136 Get cached search results for a query. 

137 

138 Args: 

139 query: Search query 

140 search_engine: Search engine identifier for cache partitioning 

141 

142 Returns: 

143 Cached results or None if not found/expired 

144 """ 

145 query_hash = self._get_query_hash(query, search_engine) 

146 current_time = int(time.time()) 

147 

148 # Check memory cache first 

149 if query_hash in self._memory_cache: 

150 entry = self._memory_cache[query_hash] 

151 if entry["expires_at"] > current_time: 

152 self._access_times[query_hash] = current_time 

153 logger.debug(f"Cache hit (memory) for query: {query[:50]}...") 

154 return entry["results"] 

155 else: 

156 # Expired, remove from memory 

157 self._memory_cache.pop(query_hash, None) 

158 self._access_times.pop(query_hash, None) 

159 

160 # Check database cache 

161 try: 

162 with self.Session() as session: 

163 cache_entry = ( 

164 session.query(SearchCacheModel) 

165 .filter( 

166 SearchCacheModel.query_hash == query_hash, 

167 SearchCacheModel.expires_at > current_time, 

168 ) 

169 .first() 

170 ) 

171 

172 if cache_entry: 172 ↛ 173line 172 didn't jump to line 173 because the condition on line 172 was never true

173 results = cache_entry.results 

174 

175 # Update access statistics 

176 cache_entry.access_count += 1 

177 cache_entry.last_accessed = current_time 

178 session.commit() 

179 

180 # Add to memory cache 

181 self._memory_cache[query_hash] = { 

182 "results": results, 

183 "expires_at": cache_entry.expires_at, 

184 } 

185 self._access_times[query_hash] = current_time 

186 self._evict_lru_memory() 

187 

188 logger.debug( 

189 f"Cache hit (database) for query: {query[:50]}..." 

190 ) 

191 return results 

192 

193 except Exception: 

194 logger.exception("Failed to retrieve from search cache") 

195 

196 logger.debug(f"Cache miss for query: {query[:50]}...") 

197 return None 

198 

199 def put( 

200 self, 

201 query: str, 

202 results: List[Dict[str, Any]], 

203 search_engine: str = "default", 

204 ttl: Optional[int] = None, 

205 ) -> bool: 

206 """ 

207 Store search results in cache. 

208 

209 Args: 

210 query: Search query 

211 results: Search results to cache 

212 search_engine: Search engine identifier 

213 ttl: Time-to-live in seconds (uses default if None) 

214 

215 Returns: 

216 True if successfully cached 

217 """ 

218 if not results: # Don't cache empty results 

219 return False 

220 

221 query_hash = self._get_query_hash(query, search_engine) 

222 current_time = int(time.time()) 

223 expires_at = current_time + (ttl or self.default_ttl) 

224 

225 try: 

226 # Store in database 

227 with self.Session() as session: 

228 # Check if entry exists 

229 existing = ( 

230 session.query(SearchCacheModel) 

231 .filter_by(query_hash=query_hash) 

232 .first() 

233 ) 

234 

235 if existing: 

236 # Update existing entry 

237 existing.query_text = self._normalize_query(query) 

238 existing.results = results 

239 existing.created_at = current_time 

240 existing.expires_at = expires_at 

241 existing.access_count = 1 

242 existing.last_accessed = current_time 

243 else: 

244 # Create new entry 

245 cache_entry = SearchCacheModel( 

246 query_hash=query_hash, 

247 query_text=self._normalize_query(query), 

248 results=results, 

249 created_at=current_time, 

250 expires_at=expires_at, 

251 access_count=1, 

252 last_accessed=current_time, 

253 ) 

254 session.add(cache_entry) 

255 

256 session.commit() 

257 

258 # Store in memory cache 

259 self._memory_cache[query_hash] = { 

260 "results": results, 

261 "expires_at": expires_at, 

262 } 

263 self._access_times[query_hash] = current_time 

264 self._evict_lru_memory() 

265 

266 logger.debug(f"Cached results for query: {query[:50]}...") 

267 return True 

268 

269 except Exception: 

270 logger.exception("Failed to store in search cache") 

271 return False 

272 

273 def get_or_fetch( 

274 self, 

275 query: str, 

276 fetch_func, 

277 search_engine: str = "default", 

278 ttl: Optional[int] = None, 

279 ) -> Optional[List[Dict[str, Any]]]: 

280 """ 

281 Get cached results or fetch with stampede protection. 

282 

283 This is the recommended way to use the cache. It ensures only one thread 

284 fetches data for a given query, preventing cache stampedes. 

285 

286 Args: 

287 query: Search query 

288 fetch_func: Function to call if cache miss (should return results list) 

289 search_engine: Search engine identifier 

290 ttl: Time-to-live for cached results 

291 

292 Returns: 

293 Search results (from cache or freshly fetched) 

294 """ 

295 query_hash = self._get_query_hash(query, search_engine) 

296 

297 # Try to get from cache first 

298 results = self.get(query, search_engine) 

299 if results is not None: 

300 return results 

301 

302 # Acquire lock for this query to prevent stampede 

303 with self._fetch_locks_lock: 

304 # Double-check after acquiring lock 

305 results = self.get(query, search_engine) 

306 if results is not None: 306 ↛ 307line 306 didn't jump to line 307 because the condition on line 306 was never true

307 return results 

308 

309 # Check if another thread started fetching while we waited 

310 if query_hash in self._fetch_events: 

311 existing_event = self._fetch_events[query_hash] 

312 # Check if this is a stale event (already set means fetch completed) 

313 if existing_event.is_set(): 

314 # Previous fetch completed, clean up and start fresh 

315 del self._fetch_events[query_hash] 

316 del self._fetch_locks[query_hash] 

317 if query_hash in self._fetch_results: 317 ↛ 320line 317 didn't jump to line 320 because the condition on line 317 was always true

318 del self._fetch_results[query_hash] 

319 # Create new event/lock for this fetch 

320 event = threading.Event() 

321 self._fetch_events[query_hash] = event 

322 self._fetch_locks[query_hash] = threading.Lock() 

323 event = None # Signal we should fetch 

324 else: 

325 # Another thread is actively fetching 

326 event = existing_event 

327 else: 

328 # We are the first thread to fetch this query 

329 event = threading.Event() 

330 self._fetch_events[query_hash] = event 

331 self._fetch_locks[query_hash] = threading.Lock() 

332 event = None # Signal we should fetch 

333 

334 # If another thread is fetching, wait for it 

335 if event is not None: 

336 event.wait(timeout=30) 

337 if query_hash in self._fetch_results: 337 ↛ 342line 337 didn't jump to line 342 because the condition on line 337 was always true

338 result = self._fetch_results.get(query_hash) 

339 if result is not None: 339 ↛ 342line 339 didn't jump to line 342 because the condition on line 339 was always true

340 return result 

341 # Re-check cache, and if still miss, return None (fetch failed) 

342 return self.get(query, search_engine) 

343 

344 # We are the thread that should fetch 

345 fetch_lock = self._fetch_locks[query_hash] 

346 fetch_event = self._fetch_events[query_hash] 

347 

348 with fetch_lock: 

349 # Triple-check (another thread might have fetched while we waited for lock) 

350 results = self.get(query, search_engine) 

351 if results is not None: 351 ↛ 352line 351 didn't jump to line 352 because the condition on line 351 was never true

352 fetch_event.set() # Signal completion 

353 return results 

354 

355 logger.debug( 

356 f"Fetching results for query: {query[:50]}... (stampede protected)" 

357 ) 

358 

359 try: 

360 # Fetch the results 

361 results = fetch_func() 

362 

363 if results: 363 ↛ 370line 363 didn't jump to line 370 because the condition on line 363 was always true

364 # Store in cache 

365 self.put(query, results, search_engine, ttl) 

366 

367 # Store temporarily for other waiting threads 

368 self._fetch_results[query_hash] = results 

369 

370 return results 

371 

372 except Exception: 

373 logger.exception( 

374 f"Failed to fetch results for query: {query[:50]}" 

375 ) 

376 # Store None to indicate fetch failed 

377 self._fetch_results[query_hash] = None 

378 return None 

379 

380 finally: 

381 # Signal completion to waiting threads 

382 fetch_event.set() 

383 

384 # Clean up after a delay 

385 def cleanup(): 

386 time.sleep(2) # Give waiting threads time to get results 

387 with self._fetch_locks_lock: 

388 self._fetch_locks.pop(query_hash, None) 

389 self._fetch_events.pop(query_hash, None) 

390 self._fetch_results.pop(query_hash, None) 

391 

392 # Run cleanup in background 

393 threading.Thread(target=cleanup, daemon=True).start() 

394 

395 def invalidate(self, query: str, search_engine: str = "default") -> bool: 

396 """Invalidate cached results for a specific query.""" 

397 query_hash = self._get_query_hash(query, search_engine) 

398 

399 try: 

400 # Remove from memory 

401 self._memory_cache.pop(query_hash, None) 

402 self._access_times.pop(query_hash, None) 

403 

404 # Remove from database 

405 with self.Session() as session: 

406 deleted = ( 

407 session.query(SearchCacheModel) 

408 .filter_by(query_hash=query_hash) 

409 .delete() 

410 ) 

411 session.commit() 

412 

413 logger.debug(f"Invalidated cache for query: {query[:50]}...") 

414 return deleted > 0 

415 

416 except Exception: 

417 logger.exception("Failed to invalidate cache") 

418 return False 

419 

420 def clear_all(self) -> bool: 

421 """Clear all cached results.""" 

422 try: 

423 self._memory_cache.clear() 

424 self._access_times.clear() 

425 

426 with self.Session() as session: 

427 session.query(SearchCacheModel).delete() 

428 session.commit() 

429 

430 logger.info("Cleared all search cache") 

431 return True 

432 

433 except Exception: 

434 logger.exception("Failed to clear search cache") 

435 return False 

436 

437 def get_stats(self) -> Dict[str, Any]: 

438 """Get cache statistics.""" 

439 try: 

440 current_time = int(time.time()) 

441 with self.Session() as session: 

442 # Total entries 

443 total_entries = ( 

444 session.query(SearchCacheModel) 

445 .filter(SearchCacheModel.expires_at > current_time) 

446 .count() 

447 ) 

448 

449 # Total expired entries 

450 expired_entries = ( 

451 session.query(SearchCacheModel) 

452 .filter(SearchCacheModel.expires_at <= current_time) 

453 .count() 

454 ) 

455 

456 # Average access count 

457 from sqlalchemy import func 

458 

459 avg_access_result = ( 

460 session.query(func.avg(SearchCacheModel.access_count)) 

461 .filter(SearchCacheModel.expires_at > current_time) 

462 .scalar() 

463 ) 

464 avg_access = avg_access_result or 0 

465 

466 return { 

467 "total_valid_entries": total_entries, 

468 "expired_entries": expired_entries, 

469 "memory_cache_size": len(self._memory_cache), 

470 "average_access_count": round(avg_access, 2), 

471 "cache_hit_potential": ( 

472 f"{(total_entries / (total_entries + 1)) * 100:.1f}%" 

473 if total_entries > 0 

474 else "0%" 

475 ), 

476 } 

477 

478 except Exception: 

479 logger.exception("Failed to get cache stats") 

480 return {"error": "Cache stats unavailable"} 

481 

482 

483# Global cache instance 

484_global_cache = None 

485 

486 

487def get_search_cache() -> SearchCache: 

488 """Get global search cache instance.""" 

489 global _global_cache 

490 if _global_cache is None: 

491 _global_cache = SearchCache() 

492 return _global_cache 

493 

494 

495@lru_cache(maxsize=100) 

496def normalize_entity_query(entity: str, constraint: str) -> str: 

497 """ 

498 Normalize entity + constraint combination for consistent caching. 

499 Uses LRU cache for frequent normalizations. 

500 """ 

501 # Remove quotes and normalize whitespace 

502 entity_clean = " ".join(entity.strip().lower().split()) 

503 constraint_clean = " ".join(constraint.strip().lower().split()) 

504 

505 # Create canonical form 

506 return f"{entity_clean} {constraint_clean}"