Coverage for src / local_deep_research / web_search_engines / engines / local_embedding_manager.py: 98%
106 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:55 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:55 +0000
1import hashlib
2import threading
3import uuid
4from datetime import UTC, datetime
5from typing import Any, Dict, List, Optional
7from langchain_community.embeddings import (
8 HuggingFaceEmbeddings,
9)
10from langchain_core.documents import Document
11from loguru import logger
13from ...database.models.library import DocumentChunk
14from ...database.session_context import get_user_db_session
15from ...utilities.url_utils import normalize_url
18class LocalEmbeddingManager:
19 """Handles embedding generation and storage for local document search"""
21 def __init__(
22 self,
23 embedding_model: str = "all-MiniLM-L6-v2",
24 embedding_device: str = "cpu",
25 embedding_model_type: str = "sentence_transformers", # or 'ollama'
26 ollama_base_url: Optional[str] = None,
27 settings_snapshot: Optional[Dict[str, Any]] = None,
28 ):
29 """
30 Initialize the embedding manager for local document search.
32 Args:
33 embedding_model: Name of the embedding model to use
34 embedding_device: Device to run embeddings on ('cpu' or 'cuda')
35 embedding_model_type: Type of embedding model ('sentence_transformers' or 'ollama')
36 ollama_base_url: Base URL for Ollama API if using ollama embeddings
37 settings_snapshot: Optional settings snapshot for background threads
38 """
40 self.embedding_model = embedding_model
41 self.embedding_device = embedding_device
42 self.embedding_model_type = embedding_model_type
43 self.ollama_base_url = ollama_base_url
44 self.settings_snapshot = settings_snapshot or {}
46 # Username for database access (extracted from settings if available)
47 self.username = (
48 settings_snapshot.get("_username") if settings_snapshot else None
49 )
50 # Password for encrypted database access (can be set later)
51 self.db_password = None
53 # Initialize the embedding model (with lock for thread-safe lazy init)
54 self._embeddings = None
55 self._embedding_lock = threading.Lock()
57 # Vector store cache
58 self.vector_stores: dict[str, Any] = {}
60 # Track if this manager has been closed
61 self._closed = False
63 def close(self):
64 """Release embedding model resources."""
65 if self._closed:
66 return
67 self._closed = True
68 # Clear embedding model reference to allow garbage collection
69 self._embeddings = None
70 # Clear vector store cache
71 self.vector_stores.clear()
72 logger.debug("LocalEmbeddingManager closed")
74 def __enter__(self):
75 """Context manager entry."""
76 return self
78 def __exit__(self, exc_type, exc_val, exc_tb):
79 """Context manager exit - ensures resources are released."""
80 self.close()
81 return False
83 @property
84 def embeddings(self):
85 """
86 Lazily initialize embeddings when first accessed.
87 This allows the LocalEmbeddingManager to be created without
88 immediately loading models, which is helpful when no local search is performed.
90 Uses double-checked locking to ensure thread-safe initialization.
91 Concurrent SentenceTransformer model loading causes meta tensor errors
92 in PyTorch when multiple threads call model.to(device) simultaneously.
93 """
94 if self._embeddings is None:
95 with self._embedding_lock:
96 if self._embeddings is None:
97 logger.info("Initializing embeddings on first use")
98 self._embeddings = self._initialize_embeddings()
99 return self._embeddings
101 def _initialize_embeddings(self):
102 """Initialize the embedding model based on configuration"""
103 try:
104 # Use the new unified embedding system
105 from ...embeddings import get_embeddings
107 # Prepare kwargs for provider-specific parameters
108 kwargs = {}
110 # Add device for sentence transformers
111 if self.embedding_model_type == "sentence_transformers":
112 kwargs["device"] = self.embedding_device
114 # Add base_url for ollama if specified
115 if self.embedding_model_type == "ollama" and self.ollama_base_url:
116 kwargs["base_url"] = normalize_url(self.ollama_base_url)
118 logger.info(
119 f"Initializing embeddings with provider={self.embedding_model_type}, model={self.embedding_model}"
120 )
122 return get_embeddings(
123 provider=self.embedding_model_type,
124 model=self.embedding_model,
125 settings_snapshot=self.settings_snapshot,
126 **kwargs,
127 )
128 except Exception:
129 logger.exception("Error initializing embeddings")
130 logger.warning(
131 "Falling back to HuggingFaceEmbeddings with all-MiniLM-L6-v2"
132 )
133 return HuggingFaceEmbeddings(
134 model_name="sentence-transformers/all-MiniLM-L6-v2"
135 )
137 def _store_chunks_to_db(
138 self,
139 chunks: List[Document],
140 collection_name: str,
141 source_path: Optional[str] = None,
142 source_id: Optional[int] = None,
143 source_type: str = "local_file",
144 ) -> List[str]:
145 """
146 Store document chunks in the database.
148 Args:
149 chunks: List of LangChain Document chunks
150 collection_name: Name of the collection (e.g., 'personal_notes', 'library')
151 source_path: Path to source file (for local files)
152 source_id: ID of source document (for library documents)
153 source_type: Type of source ('local_file' or 'library')
155 Returns:
156 List of chunk embedding IDs (UUIDs) for FAISS mapping
157 """
158 if not self.username:
159 logger.warning(
160 "No username available, cannot store chunks in database"
161 )
162 return []
164 chunk_ids = []
166 try:
167 with get_user_db_session(
168 self.username, self.db_password
169 ) as session:
170 for idx, chunk in enumerate(chunks):
171 # Generate unique hash for chunk
172 chunk_text = chunk.page_content
173 chunk_hash = hashlib.sha256(chunk_text.encode()).hexdigest()
175 # Generate unique embedding ID
176 embedding_id = uuid.uuid4().hex
178 # Extract metadata
179 metadata = chunk.metadata or {}
180 document_title = metadata.get(
181 "filename", metadata.get("title", "Unknown")
182 )
184 # Calculate word count
185 word_count = len(chunk_text.split())
187 # Get character positions from metadata if available
188 start_char = metadata.get("start_char", 0)
189 end_char = metadata.get("end_char", len(chunk_text))
191 # Check if chunk already exists
192 existing_chunk = (
193 session.query(DocumentChunk)
194 .filter_by(chunk_hash=chunk_hash)
195 .first()
196 )
198 if existing_chunk:
199 # Update existing chunk
200 existing_chunk.last_accessed = datetime.now(UTC)
201 chunk_ids.append(existing_chunk.embedding_id)
202 logger.debug(
203 f"Chunk already exists, reusing: {existing_chunk.embedding_id}"
204 )
205 else:
206 # Create new chunk
207 db_chunk = DocumentChunk(
208 chunk_hash=chunk_hash,
209 source_type=source_type,
210 source_id=source_id,
211 source_path=str(source_path)
212 if source_path
213 else None,
214 collection_name=collection_name,
215 chunk_text=chunk_text,
216 chunk_index=idx,
217 start_char=start_char,
218 end_char=end_char,
219 word_count=word_count,
220 embedding_id=embedding_id,
221 embedding_model=self.embedding_model,
222 embedding_model_type=self.embedding_model_type,
223 document_title=document_title,
224 document_metadata=metadata,
225 )
226 session.add(db_chunk)
227 chunk_ids.append(embedding_id)
229 session.commit()
230 logger.info(
231 f"Stored {len(chunk_ids)} chunks to database for collection '{collection_name}'"
232 )
234 except Exception:
235 logger.exception(
236 f"Error storing chunks to database for collection '{collection_name}'"
237 )
238 return []
240 return chunk_ids
242 def _delete_chunks_from_db(
243 self,
244 collection_name: str,
245 source_path: Optional[str] = None,
246 source_id: Optional[int] = None,
247 ) -> int:
248 """
249 Delete chunks from database.
251 Args:
252 collection_name: Name of the collection
253 source_path: Path to source file (for local files)
254 source_id: ID of source document (for library documents)
256 Returns:
257 Number of chunks deleted
258 """
259 if not self.username:
260 logger.warning(
261 "No username available, cannot delete chunks from database"
262 )
263 return 0
265 try:
266 with get_user_db_session(
267 self.username, self.db_password
268 ) as session:
269 query = session.query(DocumentChunk).filter_by(
270 collection_name=collection_name
271 )
273 if source_path:
274 query = query.filter_by(source_path=str(source_path))
275 if source_id: 275 ↛ 276line 275 didn't jump to line 276 because the condition on line 275 was never true
276 query = query.filter_by(source_id=source_id)
278 count = int(query.delete())
279 session.commit()
281 logger.info(
282 f"Deleted {count} chunks from database for collection '{collection_name}'"
283 )
284 return count
286 except Exception:
287 logger.exception(
288 f"Error deleting chunks from database for collection '{collection_name}'"
289 )
290 return 0