Coverage for src / local_deep_research / web_search_engines / engines / local_embedding_manager.py: 98%

106 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 23:55 +0000

1import hashlib 

2import threading 

3import uuid 

4from datetime import UTC, datetime 

5from typing import Any, Dict, List, Optional 

6 

7from langchain_community.embeddings import ( 

8 HuggingFaceEmbeddings, 

9) 

10from langchain_core.documents import Document 

11from loguru import logger 

12 

13from ...database.models.library import DocumentChunk 

14from ...database.session_context import get_user_db_session 

15from ...utilities.url_utils import normalize_url 

16 

17 

18class LocalEmbeddingManager: 

19 """Handles embedding generation and storage for local document search""" 

20 

21 def __init__( 

22 self, 

23 embedding_model: str = "all-MiniLM-L6-v2", 

24 embedding_device: str = "cpu", 

25 embedding_model_type: str = "sentence_transformers", # or 'ollama' 

26 ollama_base_url: Optional[str] = None, 

27 settings_snapshot: Optional[Dict[str, Any]] = None, 

28 ): 

29 """ 

30 Initialize the embedding manager for local document search. 

31 

32 Args: 

33 embedding_model: Name of the embedding model to use 

34 embedding_device: Device to run embeddings on ('cpu' or 'cuda') 

35 embedding_model_type: Type of embedding model ('sentence_transformers' or 'ollama') 

36 ollama_base_url: Base URL for Ollama API if using ollama embeddings 

37 settings_snapshot: Optional settings snapshot for background threads 

38 """ 

39 

40 self.embedding_model = embedding_model 

41 self.embedding_device = embedding_device 

42 self.embedding_model_type = embedding_model_type 

43 self.ollama_base_url = ollama_base_url 

44 self.settings_snapshot = settings_snapshot or {} 

45 

46 # Username for database access (extracted from settings if available) 

47 self.username = ( 

48 settings_snapshot.get("_username") if settings_snapshot else None 

49 ) 

50 # Password for encrypted database access (can be set later) 

51 self.db_password = None 

52 

53 # Initialize the embedding model (with lock for thread-safe lazy init) 

54 self._embeddings = None 

55 self._embedding_lock = threading.Lock() 

56 

57 # Vector store cache 

58 self.vector_stores: dict[str, Any] = {} 

59 

60 # Track if this manager has been closed 

61 self._closed = False 

62 

63 def close(self): 

64 """Release embedding model resources.""" 

65 if self._closed: 

66 return 

67 self._closed = True 

68 # Clear embedding model reference to allow garbage collection 

69 self._embeddings = None 

70 # Clear vector store cache 

71 self.vector_stores.clear() 

72 logger.debug("LocalEmbeddingManager closed") 

73 

74 def __enter__(self): 

75 """Context manager entry.""" 

76 return self 

77 

78 def __exit__(self, exc_type, exc_val, exc_tb): 

79 """Context manager exit - ensures resources are released.""" 

80 self.close() 

81 return False 

82 

83 @property 

84 def embeddings(self): 

85 """ 

86 Lazily initialize embeddings when first accessed. 

87 This allows the LocalEmbeddingManager to be created without 

88 immediately loading models, which is helpful when no local search is performed. 

89 

90 Uses double-checked locking to ensure thread-safe initialization. 

91 Concurrent SentenceTransformer model loading causes meta tensor errors 

92 in PyTorch when multiple threads call model.to(device) simultaneously. 

93 """ 

94 if self._embeddings is None: 

95 with self._embedding_lock: 

96 if self._embeddings is None: 

97 logger.info("Initializing embeddings on first use") 

98 self._embeddings = self._initialize_embeddings() 

99 return self._embeddings 

100 

101 def _initialize_embeddings(self): 

102 """Initialize the embedding model based on configuration""" 

103 try: 

104 # Use the new unified embedding system 

105 from ...embeddings import get_embeddings 

106 

107 # Prepare kwargs for provider-specific parameters 

108 kwargs = {} 

109 

110 # Add device for sentence transformers 

111 if self.embedding_model_type == "sentence_transformers": 

112 kwargs["device"] = self.embedding_device 

113 

114 # Add base_url for ollama if specified 

115 if self.embedding_model_type == "ollama" and self.ollama_base_url: 

116 kwargs["base_url"] = normalize_url(self.ollama_base_url) 

117 

118 logger.info( 

119 f"Initializing embeddings with provider={self.embedding_model_type}, model={self.embedding_model}" 

120 ) 

121 

122 return get_embeddings( 

123 provider=self.embedding_model_type, 

124 model=self.embedding_model, 

125 settings_snapshot=self.settings_snapshot, 

126 **kwargs, 

127 ) 

128 except Exception: 

129 logger.exception("Error initializing embeddings") 

130 logger.warning( 

131 "Falling back to HuggingFaceEmbeddings with all-MiniLM-L6-v2" 

132 ) 

133 return HuggingFaceEmbeddings( 

134 model_name="sentence-transformers/all-MiniLM-L6-v2" 

135 ) 

136 

137 def _store_chunks_to_db( 

138 self, 

139 chunks: List[Document], 

140 collection_name: str, 

141 source_path: Optional[str] = None, 

142 source_id: Optional[int] = None, 

143 source_type: str = "local_file", 

144 ) -> List[str]: 

145 """ 

146 Store document chunks in the database. 

147 

148 Args: 

149 chunks: List of LangChain Document chunks 

150 collection_name: Name of the collection (e.g., 'personal_notes', 'library') 

151 source_path: Path to source file (for local files) 

152 source_id: ID of source document (for library documents) 

153 source_type: Type of source ('local_file' or 'library') 

154 

155 Returns: 

156 List of chunk embedding IDs (UUIDs) for FAISS mapping 

157 """ 

158 if not self.username: 

159 logger.warning( 

160 "No username available, cannot store chunks in database" 

161 ) 

162 return [] 

163 

164 chunk_ids = [] 

165 

166 try: 

167 with get_user_db_session( 

168 self.username, self.db_password 

169 ) as session: 

170 for idx, chunk in enumerate(chunks): 

171 # Generate unique hash for chunk 

172 chunk_text = chunk.page_content 

173 chunk_hash = hashlib.sha256(chunk_text.encode()).hexdigest() 

174 

175 # Generate unique embedding ID 

176 embedding_id = uuid.uuid4().hex 

177 

178 # Extract metadata 

179 metadata = chunk.metadata or {} 

180 document_title = metadata.get( 

181 "filename", metadata.get("title", "Unknown") 

182 ) 

183 

184 # Calculate word count 

185 word_count = len(chunk_text.split()) 

186 

187 # Get character positions from metadata if available 

188 start_char = metadata.get("start_char", 0) 

189 end_char = metadata.get("end_char", len(chunk_text)) 

190 

191 # Check if chunk already exists 

192 existing_chunk = ( 

193 session.query(DocumentChunk) 

194 .filter_by(chunk_hash=chunk_hash) 

195 .first() 

196 ) 

197 

198 if existing_chunk: 

199 # Update existing chunk 

200 existing_chunk.last_accessed = datetime.now(UTC) 

201 chunk_ids.append(existing_chunk.embedding_id) 

202 logger.debug( 

203 f"Chunk already exists, reusing: {existing_chunk.embedding_id}" 

204 ) 

205 else: 

206 # Create new chunk 

207 db_chunk = DocumentChunk( 

208 chunk_hash=chunk_hash, 

209 source_type=source_type, 

210 source_id=source_id, 

211 source_path=str(source_path) 

212 if source_path 

213 else None, 

214 collection_name=collection_name, 

215 chunk_text=chunk_text, 

216 chunk_index=idx, 

217 start_char=start_char, 

218 end_char=end_char, 

219 word_count=word_count, 

220 embedding_id=embedding_id, 

221 embedding_model=self.embedding_model, 

222 embedding_model_type=self.embedding_model_type, 

223 document_title=document_title, 

224 document_metadata=metadata, 

225 ) 

226 session.add(db_chunk) 

227 chunk_ids.append(embedding_id) 

228 

229 session.commit() 

230 logger.info( 

231 f"Stored {len(chunk_ids)} chunks to database for collection '{collection_name}'" 

232 ) 

233 

234 except Exception: 

235 logger.exception( 

236 f"Error storing chunks to database for collection '{collection_name}'" 

237 ) 

238 return [] 

239 

240 return chunk_ids 

241 

242 def _delete_chunks_from_db( 

243 self, 

244 collection_name: str, 

245 source_path: Optional[str] = None, 

246 source_id: Optional[int] = None, 

247 ) -> int: 

248 """ 

249 Delete chunks from database. 

250 

251 Args: 

252 collection_name: Name of the collection 

253 source_path: Path to source file (for local files) 

254 source_id: ID of source document (for library documents) 

255 

256 Returns: 

257 Number of chunks deleted 

258 """ 

259 if not self.username: 

260 logger.warning( 

261 "No username available, cannot delete chunks from database" 

262 ) 

263 return 0 

264 

265 try: 

266 with get_user_db_session( 

267 self.username, self.db_password 

268 ) as session: 

269 query = session.query(DocumentChunk).filter_by( 

270 collection_name=collection_name 

271 ) 

272 

273 if source_path: 

274 query = query.filter_by(source_path=str(source_path)) 

275 if source_id: 275 ↛ 276line 275 didn't jump to line 276 because the condition on line 275 was never true

276 query = query.filter_by(source_id=source_id) 

277 

278 count = int(query.delete()) 

279 session.commit() 

280 

281 logger.info( 

282 f"Deleted {count} chunks from database for collection '{collection_name}'" 

283 ) 

284 return count 

285 

286 except Exception: 

287 logger.exception( 

288 f"Error deleting chunks from database for collection '{collection_name}'" 

289 ) 

290 return 0