Coverage for src / local_deep_research / utilities / search_cache.py: 97%

247 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 23:55 +0000

1""" 

2Search Cache Utility 

3Provides intelligent caching for search results to avoid repeated queries. 

4Includes TTL, LRU eviction, and query normalization. 

5""" 

6 

7import atexit 

8import hashlib 

9import threading 

10import time 

11from functools import lru_cache 

12from pathlib import Path 

13from typing import Any, Dict, List, Optional 

14 

15from loguru import logger 

16from sqlalchemy import create_engine 

17from sqlalchemy.engine import Engine 

18from sqlalchemy.orm import sessionmaker 

19 

20from ..config.paths import get_cache_directory 

21from ..database.models import Base, SearchCache as SearchCacheModel 

22from .thread_context import get_search_context 

23 

24 

25class SearchCache: 

26 """ 

27 Persistent cache for search results with TTL and LRU eviction. 

28 Stores results in SQLite for persistence across sessions. 

29 """ 

30 

31 def __init__( 

32 self, 

33 cache_dir: Optional[str] = None, 

34 max_memory_items: int = 1000, 

35 default_ttl: int = 3600, 

36 ): 

37 """ 

38 Initialize search cache. 

39 

40 Args: 

41 cache_dir: Directory for cache database. Defaults to data/__CACHE_DIR__ 

42 max_memory_items: Maximum items in memory cache 

43 default_ttl: Default time-to-live in seconds (1 hour default) 

44 """ 

45 self.max_memory_items = max_memory_items 

46 self.default_ttl = default_ttl 

47 

48 # Setup cache directory 

49 if cache_dir is None: 

50 cache_dir_path: Path = get_cache_directory() / "search_cache" 

51 else: 

52 cache_dir_path = Path(cache_dir) 

53 

54 cache_dir_path.mkdir(parents=True, exist_ok=True) 

55 self.db_path = cache_dir_path / "search_cache.db" 

56 # Initialize database 

57 self.engine: Optional[Engine] = None 

58 self._init_db() 

59 

60 # In-memory cache for frequently accessed items 

61 self._memory_cache: Dict[str, Dict[str, Any]] = {} 

62 self._access_times: Dict[str, int] = {} 

63 

64 # Stampede protection: events and locks for each query being fetched 

65 self._fetch_events: Dict[ 

66 str, threading.Event 

67 ] = {} # query_hash -> threading.Event (signals completion) 

68 self._fetch_locks: Dict[ 

69 str, threading.Lock 

70 ] = {} # query_hash -> threading.Lock (prevents concurrent fetch) 

71 self._fetch_locks_lock = ( 

72 threading.Lock() 

73 ) # Protects the fetch dictionaries 

74 self._fetch_results: Dict[ 

75 str, Optional[List[Dict[str, Any]]] 

76 ] = {} # query_hash -> results (temporary storage during fetch) 

77 

78 def _init_db(self): 

79 """Initialize SQLite database for persistent cache using SQLAlchemy.""" 

80 try: 

81 # Create engine and session 

82 self.engine = create_engine(f"sqlite:///{self.db_path}") 

83 Base.metadata.create_all( 

84 self.engine, tables=[SearchCacheModel.__table__] 

85 ) 

86 self.Session = sessionmaker(bind=self.engine) 

87 except Exception: 

88 logger.exception("Failed to initialize search cache database") 

89 

90 def _normalize_query(self, query: str) -> str: 

91 """Normalize query for consistent caching.""" 

92 # Convert to lowercase and remove extra whitespace 

93 normalized = " ".join(query.lower().strip().split()) 

94 

95 # Remove common punctuation that doesn't affect search 

96 return normalized.replace('"', "").replace("'", "") 

97 

98 def _get_query_hash( 

99 self, query: str, search_engine: str = "default" 

100 ) -> str: 

101 """Generate hash for query + search engine + username combination. 

102 

103 Incorporates the current user's username into the hash so that 

104 different users' cached results are isolated from each other. 

105 """ 

106 normalized_query = self._normalize_query(query) 

107 # Get username from thread context for per-user cache isolation 

108 username = "" 

109 context = get_search_context() 

110 if context: 

111 username = context.get("username", "") or "" 

112 cache_key = f"{username}:{search_engine}:{normalized_query}" 

113 return hashlib.sha256(cache_key.encode()).hexdigest() 

114 

115 def _cleanup_expired(self): 

116 """Remove expired entries from database.""" 

117 try: 

118 current_time = int(time.time()) 

119 with self.Session() as session: 

120 deleted = ( 

121 session.query(SearchCacheModel) 

122 .filter(SearchCacheModel.expires_at < current_time) 

123 .delete() 

124 ) 

125 session.commit() 

126 if deleted > 0: 126 ↛ exitline 126 didn't jump to the function exit

127 logger.debug(f"Cleaned up {deleted} expired cache entries") 

128 except Exception: 

129 logger.exception("Failed to cleanup expired cache entries") 

130 

131 def _evict_lru_memory(self): 

132 """Evict least recently used items from memory cache.""" 

133 if len(self._memory_cache) <= self.max_memory_items: 

134 return 

135 

136 # Sort by access time and remove oldest 

137 sorted_items = sorted(self._access_times.items(), key=lambda x: x[1]) 

138 items_to_remove = ( 

139 len(self._memory_cache) - self.max_memory_items + 100 

140 ) # Remove extra for efficiency 

141 

142 for query_hash, _ in sorted_items[:items_to_remove]: 

143 self._memory_cache.pop(query_hash, None) 

144 self._access_times.pop(query_hash, None) 

145 

146 def get( 

147 self, query: str, search_engine: str = "default" 

148 ) -> Optional[List[Dict[str, Any]]]: 

149 """ 

150 Get cached search results for a query. 

151 

152 Args: 

153 query: Search query 

154 search_engine: Search engine identifier for cache partitioning 

155 

156 Returns: 

157 Cached results or None if not found/expired 

158 """ 

159 query_hash = self._get_query_hash(query, search_engine) 

160 current_time = int(time.time()) 

161 

162 # Check memory cache first 

163 if query_hash in self._memory_cache: 

164 entry = self._memory_cache[query_hash] 

165 if entry["expires_at"] > current_time: 

166 self._access_times[query_hash] = current_time 

167 logger.debug(f"Cache hit (memory) for query: {query[:50]}...") 

168 return entry["results"] 

169 # Expired, remove from memory 

170 self._memory_cache.pop(query_hash, None) 

171 self._access_times.pop(query_hash, None) 

172 

173 # Check database cache 

174 try: 

175 with self.Session() as session: 

176 cache_entry = ( 

177 session.query(SearchCacheModel) 

178 .filter( 

179 SearchCacheModel.query_hash == query_hash, 

180 SearchCacheModel.expires_at > current_time, 

181 ) 

182 .first() 

183 ) 

184 

185 if cache_entry: 

186 results = cache_entry.results 

187 

188 # Update access statistics 

189 cache_entry.access_count += 1 

190 cache_entry.last_accessed = current_time 

191 session.commit() 

192 

193 # Add to memory cache 

194 self._memory_cache[query_hash] = { 

195 "results": results, 

196 "expires_at": cache_entry.expires_at, 

197 } 

198 self._access_times[query_hash] = current_time 

199 self._evict_lru_memory() 

200 

201 logger.debug( 

202 f"Cache hit (database) for query: {query[:50]}..." 

203 ) 

204 return results 

205 

206 except Exception: 

207 logger.exception("Failed to retrieve from search cache") 

208 

209 logger.debug(f"Cache miss for query: {query[:50]}...") 

210 return None 

211 

212 def put( 

213 self, 

214 query: str, 

215 results: List[Dict[str, Any]], 

216 search_engine: str = "default", 

217 ttl: Optional[int] = None, 

218 ) -> bool: 

219 """ 

220 Store search results in cache. 

221 

222 Args: 

223 query: Search query 

224 results: Search results to cache 

225 search_engine: Search engine identifier 

226 ttl: Time-to-live in seconds (uses default if None) 

227 

228 Returns: 

229 True if successfully cached 

230 """ 

231 if not results: # Don't cache empty results 

232 return False 

233 

234 query_hash = self._get_query_hash(query, search_engine) 

235 current_time = int(time.time()) 

236 expires_at = current_time + (ttl or self.default_ttl) 

237 

238 try: 

239 # Store in database 

240 with self.Session() as session: 

241 # Check if entry exists 

242 existing = ( 

243 session.query(SearchCacheModel) 

244 .filter_by(query_hash=query_hash) 

245 .first() 

246 ) 

247 

248 if existing: 

249 # Update existing entry 

250 existing.query_text = self._normalize_query(query) # type: ignore[assignment] 

251 existing.results = results # type: ignore[assignment] 

252 existing.created_at = current_time # type: ignore[assignment] 

253 existing.expires_at = expires_at # type: ignore[assignment] 

254 existing.access_count = 1 # type: ignore[assignment] 

255 existing.last_accessed = current_time # type: ignore[assignment] 

256 else: 

257 # Create new entry 

258 cache_entry = SearchCacheModel( 

259 query_hash=query_hash, 

260 query_text=self._normalize_query(query), 

261 results=results, 

262 created_at=current_time, 

263 expires_at=expires_at, 

264 access_count=1, 

265 last_accessed=current_time, 

266 ) 

267 session.add(cache_entry) 

268 

269 session.commit() 

270 

271 # Store in memory cache 

272 self._memory_cache[query_hash] = { 

273 "results": results, 

274 "expires_at": expires_at, 

275 } 

276 self._access_times[query_hash] = current_time 

277 self._evict_lru_memory() 

278 

279 logger.debug(f"Cached results for query: {query[:50]}...") 

280 return True 

281 

282 except Exception: 

283 logger.exception("Failed to store in search cache") 

284 return False 

285 

286 def get_or_fetch( 

287 self, 

288 query: str, 

289 fetch_func, 

290 search_engine: str = "default", 

291 ttl: Optional[int] = None, 

292 ) -> Optional[List[Dict[str, Any]]]: 

293 """ 

294 Get cached results or fetch with stampede protection. 

295 

296 This is the recommended way to use the cache. It ensures only one thread 

297 fetches data for a given query, preventing cache stampedes. 

298 

299 Args: 

300 query: Search query 

301 fetch_func: Function to call if cache miss (should return results list) 

302 search_engine: Search engine identifier 

303 ttl: Time-to-live for cached results 

304 

305 Returns: 

306 Search results (from cache or freshly fetched) 

307 """ 

308 query_hash = self._get_query_hash(query, search_engine) 

309 

310 # Try to get from cache first 

311 results = self.get(query, search_engine) 

312 if results is not None: 

313 return results 

314 

315 # Acquire lock for this query to prevent stampede 

316 with self._fetch_locks_lock: 

317 # Double-check after acquiring lock 

318 results = self.get(query, search_engine) 

319 if results is not None: 319 ↛ 320line 319 didn't jump to line 320 because the condition on line 319 was never true

320 return results 

321 

322 # Check if another thread started fetching while we waited 

323 event: Optional[threading.Event] = None 

324 if query_hash in self._fetch_events: 

325 existing_event = self._fetch_events[query_hash] 

326 # Check if this is a stale event (already set means fetch completed) 

327 if existing_event.is_set(): 

328 # Previous fetch completed, clean up and start fresh 

329 del self._fetch_events[query_hash] 

330 del self._fetch_locks[query_hash] 

331 if query_hash in self._fetch_results: 

332 del self._fetch_results[query_hash] 

333 # Create new event/lock for this fetch 

334 event = threading.Event() 

335 self._fetch_events[query_hash] = event 

336 self._fetch_locks[query_hash] = threading.Lock() 

337 event = None # Signal we should fetch 

338 else: 

339 # Another thread is actively fetching 

340 event = existing_event 

341 else: 

342 # We are the first thread to fetch this query 

343 event = threading.Event() 

344 self._fetch_events[query_hash] = event 

345 self._fetch_locks[query_hash] = threading.Lock() 

346 event = None # Signal we should fetch 

347 

348 # If another thread is fetching, wait for it 

349 if event is not None: 

350 event.wait(timeout=30) 

351 if query_hash in self._fetch_results: 351 ↛ 356line 351 didn't jump to line 356 because the condition on line 351 was always true

352 result = self._fetch_results.get(query_hash) 

353 if result is not None: 

354 return result 

355 # Re-check cache, and if still miss, return None (fetch failed) 

356 return self.get(query, search_engine) 

357 

358 # We are the thread that should fetch 

359 fetch_lock = self._fetch_locks[query_hash] 

360 fetch_event = self._fetch_events[query_hash] 

361 

362 with fetch_lock: 

363 # Triple-check (another thread might have fetched while we waited for lock) 

364 results = self.get(query, search_engine) 

365 if results is not None: 365 ↛ 366line 365 didn't jump to line 366 because the condition on line 365 was never true

366 fetch_event.set() # Signal completion 

367 return results 

368 

369 logger.debug( 

370 f"Fetching results for query: {query[:50]}... (stampede protected)" 

371 ) 

372 

373 try: 

374 # Fetch the results 

375 results = fetch_func() 

376 

377 if results: 

378 # Store in cache 

379 self.put(query, results, search_engine, ttl) 

380 

381 # Store temporarily for other waiting threads 

382 self._fetch_results[query_hash] = results 

383 

384 return results 

385 

386 except Exception: 

387 logger.exception( 

388 f"Failed to fetch results for query: {query[:50]}" 

389 ) 

390 # Store None to indicate fetch failed 

391 self._fetch_results[query_hash] = None 

392 return None 

393 

394 finally: 

395 # Signal completion to waiting threads 

396 fetch_event.set() 

397 

398 # Clean up after a delay 

399 def cleanup(): 

400 time.sleep(2) # Give waiting threads time to get results 

401 with self._fetch_locks_lock: 

402 self._fetch_locks.pop(query_hash, None) 

403 self._fetch_events.pop(query_hash, None) 

404 self._fetch_results.pop(query_hash, None) 

405 

406 # Run cleanup in background 

407 threading.Thread(target=cleanup, daemon=True).start() 

408 

409 def invalidate(self, query: str, search_engine: str = "default") -> bool: 

410 """Invalidate cached results for a specific query.""" 

411 query_hash = self._get_query_hash(query, search_engine) 

412 

413 try: 

414 # Remove from memory 

415 self._memory_cache.pop(query_hash, None) 

416 self._access_times.pop(query_hash, None) 

417 

418 # Remove from database 

419 with self.Session() as session: 

420 deleted = ( 

421 session.query(SearchCacheModel) 

422 .filter_by(query_hash=query_hash) 

423 .delete() 

424 ) 

425 session.commit() 

426 

427 logger.debug(f"Invalidated cache for query: {query[:50]}...") 

428 return deleted > 0 

429 

430 except Exception: 

431 logger.exception("Failed to invalidate cache") 

432 return False 

433 

434 def clear_all(self) -> bool: 

435 """Clear all cached results.""" 

436 try: 

437 self._memory_cache.clear() 

438 self._access_times.clear() 

439 

440 with self.Session() as session: 

441 session.query(SearchCacheModel).delete() 

442 session.commit() 

443 

444 logger.info("Cleared all search cache") 

445 return True 

446 

447 except Exception: 

448 logger.exception("Failed to clear search cache") 

449 return False 

450 

451 def dispose(self): 

452 """ 

453 Dispose of the database engine and clean up resources. 

454 

455 Call this method during application shutdown to prevent file descriptor leaks. 

456 After calling dispose(), this cache instance should no longer be used. 

457 """ 

458 if hasattr(self, "engine") and self.engine: 

459 try: 

460 self.engine.dispose() 

461 logger.debug("SearchCache engine disposed") 

462 except Exception: 

463 logger.exception("Error disposing SearchCache engine") 

464 finally: 

465 self.engine = None 

466 

467 def __del__(self): 

468 """Destructor to ensure engine is disposed.""" 

469 self.dispose() 

470 

471 def get_stats(self) -> Dict[str, Any]: 

472 """Get cache statistics.""" 

473 try: 

474 current_time = int(time.time()) 

475 with self.Session() as session: 

476 # Total entries 

477 total_entries = ( 

478 session.query(SearchCacheModel) 

479 .filter(SearchCacheModel.expires_at > current_time) 

480 .count() 

481 ) 

482 

483 # Total expired entries 

484 expired_entries = ( 

485 session.query(SearchCacheModel) 

486 .filter(SearchCacheModel.expires_at <= current_time) 

487 .count() 

488 ) 

489 

490 # Average access count 

491 from sqlalchemy import func 

492 

493 avg_access_result = ( 

494 session.query(func.avg(SearchCacheModel.access_count)) 

495 .filter(SearchCacheModel.expires_at > current_time) 

496 .scalar() 

497 ) 

498 avg_access = avg_access_result or 0 

499 

500 return { 

501 "total_valid_entries": total_entries, 

502 "expired_entries": expired_entries, 

503 "memory_cache_size": len(self._memory_cache), 

504 "average_access_count": round(avg_access, 2), 

505 "cache_hit_potential": ( 

506 f"{(total_entries / (total_entries + 1)) * 100:.1f}%" 

507 if total_entries > 0 

508 else "0%" 

509 ), 

510 } 

511 

512 except Exception: 

513 logger.exception("Failed to get cache stats") 

514 return {"error": "Cache stats unavailable"} 

515 

516 

517# Global cache instance 

518_global_cache = None 

519_global_cache_lock = threading.Lock() 

520 

521 

522def get_search_cache() -> SearchCache: 

523 """Get global search cache instance.""" 

524 global _global_cache 

525 if _global_cache is None: 

526 with _global_cache_lock: 

527 if _global_cache is None: 527 ↛ 529line 527 didn't jump to line 529

528 _global_cache = SearchCache() 

529 return _global_cache 

530 

531 

532def _dispose_global_cache(): 

533 global _global_cache 

534 if _global_cache is not None: 

535 _global_cache.dispose() 

536 _global_cache = None 

537 

538 

539atexit.register(_dispose_global_cache) 

540 

541 

542@lru_cache(maxsize=100) 

543def normalize_entity_query(entity: str, constraint: str) -> str: 

544 """ 

545 Normalize entity + constraint combination for consistent caching. 

546 Uses LRU cache for frequent normalizations. 

547 """ 

548 # Remove quotes and normalize whitespace 

549 entity_clean = " ".join(entity.strip().lower().split()) 

550 constraint_clean = " ".join(constraint.strip().lower().split()) 

551 

552 # Create canonical form 

553 return f"{entity_clean} {constraint_clean}"