Coverage for src / local_deep_research / web_search_engines / engines / search_engine_semantic_scholar.py: 98%

265 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 23:55 +0000

1import re 

2from typing import Any, Dict, List, Optional, Tuple 

3 

4import requests 

5from langchain_core.language_models import BaseLLM 

6from loguru import logger 

7from requests.adapters import HTTPAdapter 

8from urllib3.util import Retry 

9 

10from ...constants import SNIPPET_LENGTH_SHORT 

11from ..rate_limiting import RateLimitError 

12from ..search_engine_base import BaseSearchEngine 

13from ...security import SafeSession 

14 

15 

16class SemanticScholarSearchEngine(BaseSearchEngine): 

17 """ 

18 Semantic Scholar search engine implementation with two-phase approach. 

19 Provides efficient access to scientific literature across all fields. 

20 """ 

21 

22 # Mark as public search engine 

23 is_public = True 

24 # Scientific/academic search engine 

25 is_scientific = True 

26 is_lexical = True 

27 needs_llm_relevance_filter = True 

28 

29 def __init__( 

30 self, 

31 max_results: int = 10, 

32 api_key: Optional[str] = None, 

33 year_range: Optional[Tuple[int, int]] = None, 

34 get_abstracts: bool = True, 

35 get_references: bool = False, 

36 get_citations: bool = False, 

37 get_embeddings: bool = False, 

38 get_tldr: bool = True, 

39 citation_limit: int = 10, 

40 reference_limit: int = 10, 

41 llm: Optional[BaseLLM] = None, 

42 max_filtered_results: Optional[int] = None, 

43 optimize_queries: bool = True, 

44 max_retries: int = 5, 

45 retry_backoff_factor: float = 1.0, 

46 fields_of_study: Optional[List[str]] = None, 

47 publication_types: Optional[List[str]] = None, 

48 settings_snapshot: Optional[Dict[str, Any]] = None, 

49 **kwargs, 

50 ): 

51 """ 

52 Initialize the Semantic Scholar search engine. 

53 

54 Args: 

55 max_results: Maximum number of search results 

56 api_key: Semantic Scholar API key for higher rate limits (optional) 

57 year_range: Optional tuple of (start_year, end_year) to filter results 

58 get_abstracts: Whether to fetch abstracts for all results 

59 get_references: Whether to fetch references for papers 

60 get_citations: Whether to fetch citations for papers 

61 get_embeddings: Whether to fetch SPECTER embeddings for papers 

62 get_tldr: Whether to fetch TLDR summaries for papers 

63 citation_limit: Maximum number of citations to fetch per paper 

64 reference_limit: Maximum number of references to fetch per paper 

65 llm: Language model for relevance filtering 

66 max_filtered_results: Maximum number of results to keep after filtering 

67 optimize_queries: Whether to optimize natural language queries 

68 max_retries: Maximum number of retries for API requests 

69 retry_backoff_factor: Backoff factor for retries 

70 fields_of_study: List of fields of study to filter results 

71 publication_types: List of publication types to filter results 

72 settings_snapshot: Settings snapshot for configuration 

73 **kwargs: Additional parameters to pass to parent class 

74 """ 

75 # Initialize the BaseSearchEngine with LLM, max_filtered_results, and max_results 

76 super().__init__( 

77 llm=llm, 

78 max_filtered_results=max_filtered_results, 

79 max_results=max_results, 

80 settings_snapshot=settings_snapshot, 

81 **kwargs, 

82 ) 

83 

84 # Get API key from settings if not provided 

85 if not api_key and settings_snapshot: 

86 from ...config.search_config import get_setting_from_snapshot 

87 

88 try: 

89 api_key = get_setting_from_snapshot( 

90 "search.engine.web.semantic_scholar.api_key", 

91 settings_snapshot=settings_snapshot, 

92 ) 

93 except Exception: 

94 logger.debug( 

95 "Failed to read semantic_scholar.api_key from settings snapshot", 

96 exc_info=True, 

97 ) 

98 

99 self.api_key = api_key 

100 self.year_range = year_range 

101 self.get_abstracts = get_abstracts 

102 self.get_references = get_references 

103 self.get_citations = get_citations 

104 self.get_embeddings = get_embeddings 

105 self.get_tldr = get_tldr 

106 self.citation_limit = citation_limit 

107 self.reference_limit = reference_limit 

108 self.optimize_queries = optimize_queries 

109 self.max_retries = max_retries 

110 self.retry_backoff_factor = retry_backoff_factor 

111 self.fields_of_study = ( 

112 self._ensure_list(fields_of_study) 

113 if fields_of_study is not None 

114 else None 

115 ) 

116 self.publication_types = ( 

117 self._ensure_list(publication_types) 

118 if publication_types is not None 

119 else None 

120 ) 

121 

122 # Base API URLs 

123 self.base_url = "https://api.semanticscholar.org/graph/v1" 

124 self.paper_search_url = f"{self.base_url}/paper/search" 

125 self.paper_details_url = f"{self.base_url}/paper" 

126 

127 # Create a session with retry capabilities 

128 self.session: SafeSession | None = self._create_session() 

129 

130 # Log API key status 

131 if self.api_key: 

132 logger.info( 

133 "Using Semantic Scholar with API key (higher rate limits)" 

134 ) 

135 else: 

136 logger.info( 

137 "Using Semantic Scholar without API key (lower rate limits)" 

138 ) 

139 

140 def _create_session(self) -> SafeSession: 

141 """Create and configure a requests session with retry capabilities""" 

142 session = SafeSession() 

143 

144 # Configure automatic retries with exponential backoff 

145 retry_strategy = Retry( 

146 total=self.max_retries, 

147 backoff_factor=self.retry_backoff_factor, 

148 status_forcelist=[429, 500, 502, 503, 504], 

149 allowed_methods={"HEAD", "GET", "POST", "OPTIONS"}, 

150 ) 

151 

152 adapter = HTTPAdapter(max_retries=retry_strategy) 

153 session.mount("https://", adapter) 

154 

155 # Set up headers 

156 headers = {"Accept": "application/json"} 

157 if self.api_key: 

158 headers["x-api-key"] = self.api_key 

159 

160 session.headers.update(headers) 

161 

162 return session 

163 

164 def close(self): 

165 """ 

166 Close the HTTP session and clean up resources. 

167 

168 Call this method when done using the search engine to prevent 

169 connection/file descriptor leaks. 

170 """ 

171 if hasattr(self, "session") and self.session: 

172 try: 

173 self.session.close() 

174 except Exception: 

175 logger.exception("Error closing SemanticScholar session") 

176 finally: 

177 self.session = None 

178 

179 def __del__(self): 

180 """Destructor to ensure session is closed.""" 

181 self.close() 

182 

183 def __enter__(self): 

184 """Context manager entry.""" 

185 return self 

186 

187 def __exit__(self, exc_type, exc_val, exc_tb): 

188 """Context manager exit - ensures session cleanup.""" 

189 self.close() 

190 return False 

191 

192 def _respect_rate_limit(self): 

193 """Apply rate limiting between requests""" 

194 # Apply rate limiting before request 

195 self._last_wait_time = self.rate_tracker.apply_rate_limit( 

196 self.engine_type 

197 ) 

198 logger.debug(f"Applied rate limit wait: {self._last_wait_time:.2f}s") 

199 

200 def _make_request( 

201 self, 

202 url: str, 

203 params: Optional[Dict] = None, 

204 data: Optional[Dict] = None, 

205 method: str = "GET", 

206 ) -> Dict: 

207 """ 

208 Make a request to the Semantic Scholar API. 

209 

210 Args: 

211 url: API endpoint URL 

212 params: Query parameters 

213 data: JSON data for POST requests 

214 method: HTTP method (GET or POST) 

215 

216 Returns: 

217 API response as dictionary 

218 """ 

219 self._respect_rate_limit() 

220 

221 try: 

222 if self.session is None: 222 ↛ 223line 222 didn't jump to line 223 because the condition on line 222 was never true

223 raise RuntimeError("Session is not initialized") 

224 if method.upper() == "GET": 

225 response = self.session.get(url, params=params, timeout=30) 

226 elif method.upper() == "POST": 

227 response = self.session.post( 

228 url, params=params, json=data, timeout=30 

229 ) 

230 else: 

231 raise ValueError(f"Unsupported HTTP method: {method}") 

232 

233 # Handle rate limiting 

234 if response.status_code == 429: 

235 logger.warning("Semantic Scholar rate limit exceeded") 

236 raise RateLimitError("Semantic Scholar rate limit exceeded") 

237 

238 response.raise_for_status() 

239 return response.json() # type: ignore[no-any-return] 

240 except requests.RequestException: 

241 logger.exception("API request failed") 

242 return {} 

243 

244 def _optimize_query(self, query: str) -> str: 

245 """ 

246 Optimize a natural language query for Semantic Scholar search. 

247 If LLM is available, uses it to extract key terms and concepts. 

248 

249 Args: 

250 query: Natural language query 

251 

252 Returns: 

253 Optimized query string 

254 """ 

255 if not self.llm or not self.optimize_queries: 

256 return query 

257 

258 try: 

259 prompt = f"""Transform this natural language question into an optimized academic search query. 

260 

261Original query: "{query}" 

262 

263INSTRUCTIONS: 

2641. Extract key academic concepts, technical terms, and proper nouns 

2652. Remove generic words, filler words, and non-technical terms 

2663. Add quotation marks around specific phrases that should be kept together 

2674. Return ONLY the optimized search query with no explanation 

2685. Keep it under 100 characters if possible 

269 

270EXAMPLE TRANSFORMATIONS: 

271"What are the latest findings about mRNA vaccines and COVID-19?" → "mRNA vaccines COVID-19 recent findings" 

272"How does machine learning impact climate change prediction?" → "machine learning "climate change" prediction" 

273"Tell me about quantum computing approaches for encryption" → "quantum computing encryption" 

274 

275Return ONLY the optimized search query with no explanation. 

276""" 

277 

278 response = self.llm.invoke(prompt) 

279 optimized_query = ( 

280 str(response.content) 

281 if hasattr(response, "content") 

282 else str(response) 

283 ).strip() 

284 

285 # Clean up the query - remove any explanations 

286 lines = optimized_query.split("\n") 

287 optimized_query = lines[0].strip() 

288 

289 # Safety check - if query looks too much like an explanation, use original 

290 if len(optimized_query.split()) > 15 or ":" in optimized_query: 

291 logger.warning( 

292 "Query optimization result looks too verbose, using original" 

293 ) 

294 return query 

295 

296 logger.info(f"Original query: '{query}'") 

297 logger.info(f"Optimized for search: '{optimized_query}'") 

298 

299 return optimized_query 

300 except Exception: 

301 logger.exception("Error optimizing query") 

302 return query # Fall back to original query on error 

303 

304 def _direct_search(self, query: str) -> List[Dict[str, Any]]: 

305 """ 

306 Make a direct search request to the Semantic Scholar API. 

307 

308 Args: 

309 query: The search query 

310 

311 Returns: 

312 List of paper dictionaries 

313 """ 

314 try: 

315 # Configure fields to retrieve 

316 fields = [ 

317 "paperId", 

318 "externalIds", 

319 "url", 

320 "title", 

321 "abstract", 

322 "venue", 

323 "year", 

324 "authors", 

325 "citationCount", # Add citation count for ranking 

326 "openAccessPdf", # PDF URL for open access papers 

327 ] 

328 

329 if self.get_tldr: 

330 fields.append("tldr") 

331 

332 params = { 

333 "query": query, 

334 "limit": min( 

335 self.max_results, 100 

336 ), # API limit is 100 per request 

337 "fields": ",".join(fields), 

338 } 

339 

340 # Add year filter if specified 

341 if self.year_range: 

342 start_year, end_year = self.year_range 

343 params["year"] = f"{start_year}-{end_year}" 

344 

345 # Add fields of study filter if specified 

346 if self.fields_of_study: 

347 params["fieldsOfStudy"] = ",".join(self.fields_of_study) 

348 

349 # Add publication types filter if specified 

350 if self.publication_types: 

351 params["publicationTypes"] = ",".join(self.publication_types) 

352 

353 response = self._make_request(self.paper_search_url, params) 

354 

355 if "data" in response: 

356 papers = response["data"] 

357 logger.info( 

358 f"Found {len(papers)} papers with direct search for query: '{query}'" 

359 ) 

360 return papers # type: ignore[no-any-return] 

361 logger.warning( 

362 f"No data in response for direct search query: '{query}'" 

363 ) 

364 return [] 

365 

366 except Exception: 

367 logger.exception("Error in direct search") 

368 return [] 

369 

370 def _adaptive_search(self, query: str) -> Tuple[List[Dict[str, Any]], str]: 

371 """ 

372 Perform an adaptive search that adjusts based on result volume. 

373 Uses LLM to generate better fallback queries when available. 

374 

375 Args: 

376 query: The search query 

377 

378 Returns: 

379 Tuple of (list of paper results, search strategy used) 

380 """ 

381 # Start with a standard search 

382 papers = self._direct_search(query) 

383 strategy = "standard" 

384 

385 # If no results, try different variations 

386 if not papers: 

387 # Try removing quotes to broaden search 

388 if '"' in query: 

389 unquoted_query = query.replace('"', "") 

390 logger.info( 

391 "No results with quoted terms, trying without quotes: {}", 

392 unquoted_query, 

393 ) 

394 papers = self._direct_search(unquoted_query) 

395 

396 if papers: 

397 strategy = "unquoted" 

398 return papers, strategy 

399 

400 # If LLM is available, use it to generate better fallback queries 

401 if self.llm: 

402 try: 

403 # Generate alternate search queries focusing on core concepts 

404 prompt = f"""You are helping refine a search query that returned no results. 

405 

406Original query: "{query}" 

407 

408The query might be too specific or use natural language phrasing that doesn't match academic paper keywords. 

409 

410Please provide THREE alternative search queries that: 

4111. Focus on the core academic concepts 

4122. Use precise terminology commonly found in academic papers 

4133. Break down complex queries into more searchable components 

4144. Format each as a concise keyword-focused search term (not a natural language question) 

415 

416Format each query on a new line with no numbering or explanation. Keep each query under 8 words and very focused. 

417""" 

418 # Get the LLM's response 

419 response = self.llm.invoke(prompt) 

420 

421 # Extract the alternative queries 

422 alt_queries = [] 

423 if hasattr( 

424 response, "content" 

425 ): # Handle various LLM response formats 

426 content = response.content 

427 alt_queries = [ 

428 q.strip() 

429 for q in content.strip().split("\n") 

430 if q.strip() 

431 ] 

432 elif isinstance(response, str): 432 ↛ 440line 432 didn't jump to line 440 because the condition on line 432 was always true

433 alt_queries = [ 

434 q.strip() 

435 for q in response.strip().split("\n") 

436 if q.strip() 

437 ] 

438 

439 # Try each alternative query 

440 for alt_query in alt_queries[ 

441 :3 

442 ]: # Limit to first 3 alternatives 

443 logger.info("Trying LLM-suggested query: {}", alt_query) 

444 alt_papers = self._direct_search(alt_query) 

445 

446 if alt_papers: 

447 logger.info( 

448 "Found {} papers using LLM-suggested query: {}", 

449 len(alt_papers), 

450 alt_query, 

451 ) 

452 strategy = "llm_alternative" 

453 return alt_papers, strategy 

454 except Exception: 

455 logger.exception("Error using LLM for query refinement") 

456 # Fall through to simpler strategies 

457 

458 # Fallback: Try with the longest words (likely specific terms) 

459 words = re.findall(r"\w+", query) 

460 longer_words = [word for word in words if len(word) > 6] 

461 if longer_words: 

462 # Use up to 3 of the longest words 

463 longer_words = sorted(longer_words, key=len, reverse=True)[:3] 

464 key_terms_query = " ".join(longer_words) 

465 logger.info("Trying with key terms: {}", key_terms_query) 

466 papers = self._direct_search(key_terms_query) 

467 

468 if papers: 

469 strategy = "key_terms" 

470 return papers, strategy 

471 

472 # Final fallback: Try with just the longest word 

473 if words: 473 ↛ 483line 473 didn't jump to line 483 because the condition on line 473 was always true

474 longest_word = max(words, key=len) 

475 if len(longest_word) > 5: # Only use if it's reasonably long 

476 logger.info("Trying with single key term: {}", longest_word) 

477 papers = self._direct_search(longest_word) 

478 

479 if papers: 

480 strategy = "single_term" 

481 return papers, strategy 

482 

483 return papers, strategy 

484 

485 def _get_paper_details(self, paper_id: str) -> Dict[str, Any]: 

486 """ 

487 Get detailed information about a specific paper. 

488 

489 Args: 

490 paper_id: Semantic Scholar Paper ID 

491 

492 Returns: 

493 Dictionary with paper details 

494 """ 

495 try: 

496 # Construct fields parameter 

497 fields = [ 

498 "paperId", 

499 "externalIds", 

500 "corpusId", 

501 "url", 

502 "title", 

503 "abstract", 

504 "venue", 

505 "year", 

506 "authors", 

507 "fieldsOfStudy", 

508 "citationCount", # Add citation count 

509 ] 

510 

511 if self.get_tldr: 

512 fields.append("tldr") 

513 

514 if self.get_embeddings: 

515 fields.append("embedding") 

516 

517 # Add citation and reference fields if requested 

518 if self.get_citations: 

519 fields.append(f"citations.limit({self.citation_limit})") 

520 

521 if self.get_references: 

522 fields.append(f"references.limit({self.reference_limit})") 

523 

524 # Make the request 

525 url = f"{self.paper_details_url}/{paper_id}" 

526 params = {"fields": ",".join(fields)} 

527 

528 return self._make_request(url, params) 

529 

530 except Exception: 

531 logger.exception("Error getting paper details for paper") 

532 return {} 

533 

534 def _get_previews(self, query: str) -> List[Dict[str, Any]]: 

535 """ 

536 Get preview information for Semantic Scholar papers. 

537 

538 Args: 

539 query: The search query 

540 

541 Returns: 

542 List of preview dictionaries 

543 """ 

544 logger.info(f"Getting Semantic Scholar previews for query: {query}") 

545 

546 # Optimize the query if LLM is available 

547 optimized_query = self._optimize_query(query) 

548 

549 # Use the adaptive search approach 

550 papers, strategy = self._adaptive_search(optimized_query) 

551 

552 if not papers: 

553 logger.warning("No Semantic Scholar results found") 

554 return [] 

555 

556 # Format as previews 

557 previews = [] 

558 for paper in papers: 

559 try: 

560 # Format authors - ensure we have a valid list with string values 

561 authors = [] 

562 if paper.get("authors"): 

563 authors = [ 

564 author.get("name", "") 

565 for author in paper["authors"] 

566 if author and author.get("name") 

567 ] 

568 

569 # Ensure we have valid strings for all fields 

570 paper_id = paper.get("paperId", "") 

571 title = paper.get("title", "") 

572 url = paper.get("url", "") 

573 

574 # Handle abstract safely, ensuring we always have a string 

575 abstract = paper.get("abstract") 

576 snippet = "" 

577 if abstract: 

578 snippet = ( 

579 abstract[:SNIPPET_LENGTH_SHORT] + "..." 

580 if len(abstract) > SNIPPET_LENGTH_SHORT 

581 else abstract 

582 ) 

583 

584 venue = paper.get("venue", "") 

585 year = paper.get("year") 

586 external_ids = paper.get("externalIds", {}) 

587 

588 # Handle TLDR safely 

589 tldr_text = "" 

590 if paper.get("tldr") and isinstance(paper.get("tldr"), dict): 

591 tldr_text = paper.get("tldr", {}).get("text", "") 

592 

593 # Create preview with basic information, ensuring no None values 

594 preview = { 

595 "id": paper_id if paper_id else "", 

596 "title": title if title else "", 

597 "link": url if url else "", 

598 "snippet": snippet, 

599 "authors": authors, 

600 "venue": venue if venue else "", 

601 "year": year, 

602 "external_ids": external_ids if external_ids else {}, 

603 "source": "Semantic Scholar", 

604 "_paper_id": paper_id if paper_id else "", 

605 "_search_strategy": strategy, 

606 "tldr": tldr_text, 

607 } 

608 

609 # Store the full paper object for later reference 

610 preview["_full_paper"] = paper 

611 

612 previews.append(preview) 

613 except Exception: 

614 logger.exception("Error processing paper preview") 

615 # Continue with the next paper 

616 

617 # Sort by year (newer first) if available 

618 def _year_key(p: dict[str, Any]) -> int: 

619 year = p.get("year") 

620 try: 

621 return int(year) if year is not None else 0 

622 except (TypeError, ValueError): 

623 return 0 

624 

625 previews = sorted(previews, key=_year_key, reverse=True) 

626 

627 logger.info( 

628 f"Found {len(previews)} Semantic Scholar previews using strategy: {strategy}" 

629 ) 

630 return previews 

631 

632 def _get_full_content( 

633 self, relevant_items: List[Dict[str, Any]] 

634 ) -> List[Dict[str, Any]]: 

635 """ 

636 Get full content for the relevant Semantic Scholar papers. 

637 Gets additional details like citations, references, and full metadata. 

638 

639 Args: 

640 relevant_items: List of relevant preview dictionaries 

641 

642 Returns: 

643 List of result dictionaries with full content 

644 """ 

645 # For Semantic Scholar, we already have most content from the preview 

646 # Additional API calls are only needed for citations/references 

647 

648 logger.info( 

649 f"Getting content for {len(relevant_items)} Semantic Scholar papers" 

650 ) 

651 

652 results = [] 

653 for item in relevant_items: 

654 result = item.copy() 

655 paper_id = item.get("_paper_id", "") 

656 

657 # Skip if no paper ID 

658 if not paper_id: 

659 results.append(result) 

660 continue 

661 

662 # Get paper details if citations or references are requested 

663 if self.get_citations or self.get_references or self.get_embeddings: 

664 paper_details = self._get_paper_details(paper_id) 

665 

666 if paper_details: 

667 # Add citation information 

668 if self.get_citations and "citations" in paper_details: 

669 result["citations"] = paper_details["citations"] 

670 

671 # Add reference information 

672 if self.get_references and "references" in paper_details: 

673 result["references"] = paper_details["references"] 

674 

675 # Add embedding if available 

676 if self.get_embeddings and "embedding" in paper_details: 

677 result["embedding"] = paper_details["embedding"] 

678 

679 # Add fields of study 

680 if "fieldsOfStudy" in paper_details: 

681 result["fields_of_study"] = paper_details[ 

682 "fieldsOfStudy" 

683 ] 

684 

685 # Remove temporary fields 

686 if "_paper_id" in result: 686 ↛ 688line 686 didn't jump to line 688 because the condition on line 686 was always true

687 del result["_paper_id"] 

688 if "_search_strategy" in result: 688 ↛ 690line 688 didn't jump to line 690 because the condition on line 688 was always true

689 del result["_search_strategy"] 

690 if "_full_paper" in result: 690 ↛ 693line 690 didn't jump to line 693 because the condition on line 690 was always true

691 del result["_full_paper"] 

692 

693 results.append(result) 

694 

695 return results