Coverage for src / local_deep_research / web_search_engines / engines / search_engine_elasticsearch.py: 100%
116 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-25 01:07 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-25 01:07 +0000
1from loguru import logger
2from typing import Any, Dict, List, Optional
4from elasticsearch import Elasticsearch
5from langchain_core.language_models import BaseLLM
7from ...config import search_config
8from ...constants import SNIPPET_LENGTH_SHORT
9from ..search_engine_base import BaseSearchEngine
12class ElasticsearchSearchEngine(BaseSearchEngine):
13 """Elasticsearch search engine implementation with two-phase approach"""
15 def __init__(
16 self,
17 hosts: List[str] = ["http://localhost:9200"],
18 index_name: str = "documents",
19 username: Optional[str] = None,
20 password: Optional[str] = None,
21 api_key: Optional[str] = None,
22 cloud_id: Optional[str] = None,
23 max_results: int = 10,
24 highlight_fields: List[str] = ["content", "title"],
25 search_fields: List[str] = ["content", "title"],
26 filter_query: Optional[Dict[str, Any]] = None,
27 llm: Optional[BaseLLM] = None,
28 max_filtered_results: Optional[int] = None,
29 ):
30 """
31 Initialize the Elasticsearch search engine.
33 Args:
34 hosts: List of Elasticsearch hosts
35 index_name: Name of the index to search
36 username: Optional username for authentication
37 password: Optional password for authentication
38 api_key: Optional API key for authentication
39 cloud_id: Optional Elastic Cloud ID
40 max_results: Maximum number of search results
41 highlight_fields: Fields to highlight in search results
42 search_fields: Fields to search in
43 filter_query: Optional filter query in Elasticsearch DSL format
44 llm: Language model for relevance filtering
45 max_filtered_results: Maximum number of results to keep after filtering
46 """
47 # Initialize the BaseSearchEngine with LLM, max_filtered_results, and max_results
48 super().__init__(
49 llm=llm,
50 max_filtered_results=max_filtered_results,
51 max_results=max_results,
52 )
54 self.index_name = index_name
55 self.highlight_fields = self._ensure_list(
56 highlight_fields, default=["content", "title"]
57 )
58 self.search_fields = self._ensure_list(
59 search_fields, default=["content", "title"]
60 )
61 self.filter_query = filter_query or {}
63 # Normalize hosts – may arrive as a JSON-encoded string from settings
64 hosts = self._ensure_list(hosts, default=["http://localhost:9200"])
66 # Initialize the Elasticsearch client
67 es_args = {}
69 # Basic authentication
70 if username and password:
71 es_args["basic_auth"] = (username, password)
73 # API key authentication
74 if api_key:
75 es_args["api_key"] = api_key
77 # Cloud ID for Elastic Cloud
78 if cloud_id:
79 es_args["cloud_id"] = cloud_id
81 # Connect to Elasticsearch
82 self.client = Elasticsearch(hosts, **es_args)
84 # Verify connection
85 try:
86 info = self.client.info()
87 logger.info(
88 f"Connected to Elasticsearch cluster: {info.get('cluster_name')}"
89 )
90 logger.info(
91 f"Elasticsearch version: {info.get('version', {}).get('number')}"
92 )
93 except Exception as e:
94 logger.exception("Failed to connect to Elasticsearch")
95 raise ConnectionError(f"Could not connect to Elasticsearch: {e!s}")
97 def _get_previews(self, query: str) -> List[Dict[str, Any]]:
98 """
99 Get preview information for Elasticsearch documents.
101 Args:
102 query: The search query
104 Returns:
105 List of preview dictionaries
106 """
107 logger.info(
108 f"Getting document previews from Elasticsearch with query: {query}"
109 )
111 try:
112 # Build the search query
113 search_query = {
114 "query": {
115 "multi_match": {
116 "query": query,
117 "fields": self.search_fields,
118 "type": "best_fields",
119 "tie_breaker": 0.3,
120 }
121 },
122 "highlight": {
123 "fields": {field: {} for field in self.highlight_fields},
124 "pre_tags": ["<em>"],
125 "post_tags": ["</em>"],
126 },
127 "size": self.max_results,
128 }
130 # Add filter if provided
131 if self.filter_query:
132 search_query["query"] = {
133 "bool": {
134 "must": search_query["query"],
135 "filter": self.filter_query,
136 }
137 }
139 # Execute the search
140 response = self.client.search(
141 index=self.index_name,
142 body=search_query,
143 )
145 # Process the search results
146 hits = response.get("hits", {}).get("hits", [])
148 # Format results as previews with basic information
149 previews = []
150 for hit in hits:
151 source = hit.get("_source", {})
152 highlight = hit.get("highlight", {})
154 # Extract highlighted snippets or fall back to original content
155 snippet = ""
156 for field in self.highlight_fields:
157 if highlight.get(field):
158 # Join all highlights for this field
159 field_snippets = " ... ".join(highlight[field])
160 snippet += field_snippets + " "
162 # If no highlights, use a portion of the content
163 if not snippet and "content" in source:
164 content = source.get("content", "")
165 snippet = (
166 content[:SNIPPET_LENGTH_SHORT] + "..."
167 if len(content) > SNIPPET_LENGTH_SHORT
168 else content
169 )
171 # Create preview object
172 preview = {
173 "id": hit.get("_id", ""),
174 "title": source.get("title", "Untitled Document"),
175 "link": source.get("url", "")
176 or f"elasticsearch://{self.index_name}/{hit.get('_id', '')}",
177 "snippet": snippet.strip(),
178 "score": hit.get("_score", 0),
179 "_index": hit.get("_index", self.index_name),
180 }
182 previews.append(preview)
184 logger.info(
185 f"Found {len(previews)} preview results from Elasticsearch"
186 )
187 return previews
189 except Exception:
190 logger.exception("Error getting Elasticsearch previews")
191 return []
193 def _get_full_content(
194 self, relevant_items: List[Dict[str, Any]]
195 ) -> List[Dict[str, Any]]:
196 """
197 Get full content for the relevant Elasticsearch documents.
199 Args:
200 relevant_items: List of relevant preview dictionaries
202 Returns:
203 List of result dictionaries with full content
204 """
205 # Check if we should get full content
206 if (
207 hasattr(search_config, "SEARCH_SNIPPETS_ONLY")
208 and search_config.SEARCH_SNIPPETS_ONLY
209 ):
210 logger.info("Snippet-only mode, skipping full content retrieval")
211 return relevant_items
213 logger.info("Getting full content for relevant Elasticsearch documents")
215 results = []
216 for item in relevant_items:
217 # Start with the preview data
218 result = item.copy()
220 # Get the document ID
221 doc_id = item.get("id")
222 if not doc_id:
223 # Skip items without ID
224 logger.warning(f"Skipping item without ID: {item}")
225 results.append(result)
226 continue
228 try:
229 # Fetch the full document
230 doc_response = self.client.get(
231 index=self.index_name,
232 id=doc_id,
233 )
235 # Get the source document
236 source = doc_response.get("_source", {})
238 # Add full content to the result
239 result["content"] = source.get(
240 "content", result.get("snippet", "")
241 )
242 result["full_content"] = source.get("content", "")
244 # Add metadata from source
245 for key, value in source.items():
246 if key not in result and key not in ["content"]:
247 result[key] = value
249 except Exception:
250 logger.exception(
251 f"Error fetching full content for document {doc_id}"
252 )
253 # Keep the preview data if we can't get the full content
255 results.append(result)
257 return results
259 def search_by_query_string(self, query_string: str) -> List[Dict[str, Any]]:
260 """
261 Perform a search using Elasticsearch Query String syntax.
263 Args:
264 query_string: The query in Elasticsearch Query String syntax
266 Returns:
267 List of search results
268 """
269 try:
270 # Build the search query
271 search_query = {
272 "query": {
273 "query_string": {
274 "query": query_string,
275 "fields": self.search_fields,
276 }
277 },
278 "highlight": {
279 "fields": {field: {} for field in self.highlight_fields},
280 "pre_tags": ["<em>"],
281 "post_tags": ["</em>"],
282 },
283 "size": self.max_results,
284 }
286 # Execute the search
287 response = self.client.search(
288 index=self.index_name,
289 body=search_query,
290 )
292 # Process and return the results
293 previews = self._process_es_response(response)
294 return self._get_full_content(previews)
296 except Exception:
297 logger.exception("Error in query_string search")
298 return []
300 def search_by_dsl(self, query_dsl: Dict[str, Any]) -> List[Dict[str, Any]]:
301 """
302 Perform a search using Elasticsearch DSL (Query Domain Specific Language).
304 Args:
305 query_dsl: The query in Elasticsearch DSL format
307 Returns:
308 List of search results
309 """
310 try:
311 # Execute the search with the provided DSL
312 response = self.client.search(
313 index=self.index_name,
314 body=query_dsl,
315 )
317 # Process and return the results
318 previews = self._process_es_response(response)
319 return self._get_full_content(previews)
321 except Exception:
322 logger.exception("Error in DSL search")
323 return []
325 def _process_es_response(
326 self, response: Dict[str, Any]
327 ) -> List[Dict[str, Any]]:
328 """
329 Process Elasticsearch response into preview dictionaries.
331 Args:
332 response: Elasticsearch response dictionary
334 Returns:
335 List of preview dictionaries
336 """
337 hits = response.get("hits", {}).get("hits", [])
339 # Format results as previews
340 previews = []
341 for hit in hits:
342 source = hit.get("_source", {})
343 highlight = hit.get("highlight", {})
345 # Extract highlighted snippets or fall back to original content
346 snippet = ""
347 for field in self.highlight_fields:
348 if highlight.get(field):
349 field_snippets = " ... ".join(highlight[field])
350 snippet += field_snippets + " "
352 # If no highlights, use a portion of the content
353 if not snippet and "content" in source:
354 content = source.get("content", "")
355 snippet = (
356 content[:SNIPPET_LENGTH_SHORT] + "..."
357 if len(content) > SNIPPET_LENGTH_SHORT
358 else content
359 )
361 # Create preview object
362 preview = {
363 "id": hit.get("_id", ""),
364 "title": source.get("title", "Untitled Document"),
365 "link": source.get("url", "")
366 or f"elasticsearch://{self.index_name}/{hit.get('_id', '')}",
367 "snippet": snippet.strip(),
368 "score": hit.get("_score", 0),
369 "_index": hit.get("_index", self.index_name),
370 }
372 previews.append(preview)
374 return previews