Coverage for src / local_deep_research / research_library / services / library_rag_service.py: 95%
457 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:55 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:55 +0000
1"""
2Library RAG Service
4Handles indexing and searching library documents using RAG:
5- Index text documents into vector database
6- Chunk documents for semantic search
7- Generate embeddings using local models
8- Manage FAISS indices per research
9- Track RAG status in library
10"""
12from pathlib import Path
13from typing import Any, Dict, List, Optional, Tuple
15from langchain_core.documents import Document as LangchainDocument
16from loguru import logger
17from sqlalchemy import func
19from ...config.paths import get_cache_directory
20from ...database.models.library import (
21 Document,
22 DocumentChunk,
23 DocumentCollection,
24 Collection,
25 RAGIndex,
26 RagDocumentStatus,
27 EmbeddingProvider,
28)
29from ...database.session_context import get_user_db_session
30from ...utilities.type_utils import to_bool
31from ...embeddings.splitters import get_text_splitter
32from ...web_search_engines.engines.local_embedding_manager import (
33 LocalEmbeddingManager,
34)
35from ...security.file_integrity import FileIntegrityManager, FAISSIndexVerifier
36import hashlib
37from faiss import IndexFlatL2, IndexFlatIP, IndexHNSWFlat
38from langchain_community.vectorstores import FAISS
39from langchain_community.docstore.in_memory import InMemoryDocstore
42class LibraryRAGService:
43 """Service for managing RAG indexing of library documents."""
45 def __init__(
46 self,
47 username: str,
48 embedding_model: str = "all-MiniLM-L6-v2",
49 embedding_provider: str = "sentence_transformers",
50 chunk_size: int = 1000,
51 chunk_overlap: int = 200,
52 splitter_type: str = "recursive",
53 text_separators: Optional[list] = None,
54 distance_metric: str = "cosine",
55 normalize_vectors: bool = True,
56 index_type: str = "flat",
57 embedding_manager: Optional["LocalEmbeddingManager"] = None,
58 db_password: Optional[str] = None,
59 ):
60 """
61 Initialize library RAG service for a user.
63 Args:
64 username: Username for database access
65 embedding_model: Name of the embedding model to use
66 embedding_provider: Provider type ('sentence_transformers' or 'ollama')
67 chunk_size: Size of text chunks for splitting
68 chunk_overlap: Overlap between consecutive chunks
69 splitter_type: Type of splitter ('recursive', 'token', 'sentence', 'semantic')
70 text_separators: List of text separators for chunking (default: ["\n\n", "\n", ". ", " ", ""])
71 distance_metric: Distance metric ('cosine', 'l2', or 'dot_product')
72 normalize_vectors: Whether to normalize vectors with L2
73 index_type: FAISS index type ('flat', 'hnsw', or 'ivf')
74 embedding_manager: Optional pre-constructed LocalEmbeddingManager for testing/flexibility
75 db_password: Optional database password for background thread access
76 """
77 self.username = username
78 self._db_password = db_password # Can be used for thread access
79 # Initialize optional attributes to None before they're set below
80 # This allows the db_password setter to check them without hasattr
81 self.embedding_manager = None
82 self.integrity_manager = None
83 self.embedding_model = embedding_model
84 self.embedding_provider = embedding_provider
85 self.chunk_size = chunk_size
86 self.chunk_overlap = chunk_overlap
87 self.splitter_type = splitter_type
88 self.text_separators = (
89 text_separators
90 if text_separators is not None
91 else ["\n\n", "\n", ". ", " ", ""]
92 )
93 self.distance_metric = distance_metric
94 # Ensure normalize_vectors is always a proper boolean
95 self.normalize_vectors = to_bool(normalize_vectors, default=True)
96 self.index_type = index_type
98 # Use provided embedding manager or create a new one
99 # (Must be created before text splitter for semantic chunking)
100 if embedding_manager is not None:
101 self.embedding_manager = embedding_manager
102 else:
103 # Initialize embedding manager with library collection
104 # Load the complete user settings snapshot from database using the proper method
105 from ...settings.manager import SettingsManager
107 # Use proper database session for SettingsManager
108 # Note: using _db_password (backing field) directly here because the
109 # db_password property setter propagates to embedding_manager/integrity_manager,
110 # which are still None at this point in __init__.
111 with get_user_db_session(username, self._db_password) as session:
112 settings_manager = SettingsManager(session)
113 settings_snapshot = settings_manager.get_settings_snapshot()
115 # Add the specific settings needed for this RAG service
116 settings_snapshot.update(
117 {
118 "_username": username,
119 "embeddings.provider": embedding_provider,
120 f"embeddings.{embedding_provider}.model": embedding_model,
121 "local_search_chunk_size": chunk_size,
122 "local_search_chunk_overlap": chunk_overlap,
123 }
124 )
126 self.embedding_manager = LocalEmbeddingManager(
127 embedding_model=embedding_model,
128 embedding_model_type=embedding_provider,
129 settings_snapshot=settings_snapshot,
130 )
132 # Initialize text splitter based on type
133 # (Must be created AFTER embedding_manager for semantic chunking)
134 self.text_splitter = get_text_splitter(
135 splitter_type=self.splitter_type,
136 chunk_size=self.chunk_size,
137 chunk_overlap=self.chunk_overlap,
138 text_separators=self.text_separators,
139 embeddings=self.embedding_manager.embeddings
140 if self.splitter_type == "semantic"
141 else None,
142 )
144 # Initialize or load FAISS index for library collection
145 self.faiss_index = None
146 self.rag_index_record = None
148 # Initialize file integrity manager for FAISS indexes
149 self.integrity_manager = FileIntegrityManager(
150 username, password=self._db_password
151 )
152 self.integrity_manager.register_verifier(FAISSIndexVerifier())
154 self._closed = False
156 def close(self):
157 """Release embedding model and index resources."""
158 if self._closed:
159 return
160 self._closed = True
162 # Clear embedding manager resources
163 if self.embedding_manager is not None:
164 # Clear references to allow garbage collection
165 self.embedding_manager = None
167 # Clear FAISS index
168 if self.faiss_index is not None:
169 self.faiss_index = None
171 # Clear other resources
172 self.rag_index_record = None
173 self.integrity_manager = None
174 self.text_splitter = None
176 def __enter__(self):
177 """Enter context manager."""
178 return self
180 def __exit__(self, exc_type, exc_val, exc_tb):
181 """Exit context manager, ensuring cleanup."""
182 self.close()
183 return False
185 @property
186 def db_password(self):
187 """Get database password."""
188 return self._db_password
190 @db_password.setter
191 def db_password(self, value):
192 """Set database password and propagate to embedding manager and integrity manager."""
193 self._db_password = value
194 if self.embedding_manager:
195 self.embedding_manager.db_password = value
196 if self.integrity_manager:
197 self.integrity_manager.password = value
199 def _get_index_hash(
200 self,
201 collection_name: str,
202 embedding_model: str,
203 embedding_model_type: str,
204 ) -> str:
205 """Generate hash for index identification."""
206 hash_input = (
207 f"{collection_name}:{embedding_model}:{embedding_model_type}"
208 )
209 return hashlib.sha256(hash_input.encode()).hexdigest()
211 def _get_index_path(self, index_hash: str) -> Path:
212 """Get path for FAISS index file."""
213 # Store in centralized cache directory (respects LDR_DATA_DIR)
214 cache_dir = get_cache_directory() / "rag_indices"
215 cache_dir.mkdir(parents=True, exist_ok=True)
216 return cache_dir / f"{index_hash}.faiss"
218 @staticmethod
219 def _deduplicate_chunks(
220 chunks: List[LangchainDocument],
221 chunk_ids: List[str],
222 existing_ids: Optional[set] = None,
223 ) -> Tuple[List[LangchainDocument], List[str]]:
224 """Deduplicate chunks by ID within a batch, optionally excluding existing IDs."""
225 seen_ids: set = set()
226 new_chunks: List[LangchainDocument] = []
227 new_ids: List[str] = []
228 for chunk, chunk_id in zip(chunks, chunk_ids):
229 if chunk_id not in seen_ids and (
230 existing_ids is None or chunk_id not in existing_ids
231 ):
232 new_chunks.append(chunk)
233 new_ids.append(chunk_id)
234 seen_ids.add(chunk_id)
235 return new_chunks, new_ids
237 def _get_or_create_rag_index(self, collection_id: str) -> RAGIndex:
238 """Get or create RAGIndex record for the current configuration."""
239 with get_user_db_session(self.username, self.db_password) as session:
240 # Use collection_<uuid> format
241 collection_name = f"collection_{collection_id}"
242 index_hash = self._get_index_hash(
243 collection_name, self.embedding_model, self.embedding_provider
244 )
246 # Try to get existing index
247 rag_index = (
248 session.query(RAGIndex).filter_by(index_hash=index_hash).first()
249 )
251 if not rag_index:
252 # Create new index record
253 index_path = self._get_index_path(index_hash)
255 # Get embedding dimension by embedding a test string
256 test_embedding = self.embedding_manager.embeddings.embed_query(
257 "test"
258 )
259 embedding_dim = len(test_embedding)
261 rag_index = RAGIndex(
262 collection_name=collection_name,
263 embedding_model=self.embedding_model,
264 embedding_model_type=EmbeddingProvider(
265 self.embedding_provider
266 ),
267 embedding_dimension=embedding_dim,
268 index_path=str(index_path),
269 index_hash=index_hash,
270 chunk_size=self.chunk_size,
271 chunk_overlap=self.chunk_overlap,
272 splitter_type=self.splitter_type,
273 text_separators=self.text_separators,
274 distance_metric=self.distance_metric,
275 normalize_vectors=self.normalize_vectors,
276 index_type=self.index_type,
277 chunk_count=0,
278 total_documents=0,
279 status="active",
280 is_current=True,
281 )
282 session.add(rag_index)
283 session.commit()
284 session.refresh(rag_index)
285 logger.info(f"Created new RAG index: {index_hash}")
287 return rag_index
289 def load_or_create_faiss_index(self, collection_id: str) -> FAISS:
290 """
291 Load existing FAISS index or create new one.
293 Args:
294 collection_id: UUID of the collection
296 Returns:
297 FAISS vector store instance
298 """
299 rag_index = self._get_or_create_rag_index(collection_id)
300 self.rag_index_record = rag_index
302 index_path = Path(rag_index.index_path)
304 if index_path.exists():
305 # Verify integrity before loading
306 verified, reason = self.integrity_manager.verify_file(index_path)
307 if not verified:
308 logger.error(
309 f"Integrity verification failed for {index_path}: {reason}. "
310 f"Refusing to load. Creating new index."
311 )
312 # Remove corrupted index
313 try:
314 index_path.unlink()
315 logger.info(f"Removed corrupted index file: {index_path}")
316 except Exception:
317 logger.exception("Failed to remove corrupted index")
318 else:
319 try:
320 # Check for embedding dimension mismatch before loading
321 current_dim = len(
322 self.embedding_manager.embeddings.embed_query(
323 "dimension_check"
324 )
325 )
326 stored_dim = rag_index.embedding_dimension
328 if stored_dim and current_dim != stored_dim:
329 logger.warning(
330 f"Embedding dimension mismatch detected! "
331 f"Index created with dim={stored_dim}, "
332 f"current model returns dim={current_dim}. "
333 f"Deleting old index and rebuilding."
334 )
335 # Delete old index files
336 try:
337 index_path.unlink()
338 pkl_path = index_path.with_suffix(".pkl")
339 if pkl_path.exists(): 339 ↛ 341line 339 didn't jump to line 341 because the condition on line 339 was always true
340 pkl_path.unlink()
341 logger.info(
342 f"Deleted old FAISS index files at {index_path}"
343 )
344 except Exception:
345 logger.exception("Failed to delete old index files")
347 # Update RAGIndex with new dimension and reset counts
348 with get_user_db_session(
349 self.username, self.db_password
350 ) as session:
351 idx = (
352 session.query(RAGIndex)
353 .filter_by(id=rag_index.id)
354 .first()
355 )
356 if idx: 356 ↛ 366line 356 didn't jump to line 366 because the condition on line 356 was always true
357 idx.embedding_dimension = current_dim
358 idx.chunk_count = 0
359 idx.total_documents = 0
360 session.commit()
361 logger.info(
362 f"Updated RAGIndex dimension to {current_dim}"
363 )
365 # Clear rag_document_status for this index
366 session.query(RagDocumentStatus).filter_by(
367 rag_index_id=rag_index.id
368 ).delete()
369 session.commit()
370 logger.info(
371 "Cleared indexed status for documents in this "
372 "collection"
373 )
375 # Update local reference for index creation below
376 rag_index.embedding_dimension = current_dim
377 # Fall through to create new index below
378 else:
379 # Dimensions match (or no stored dimension), load index
380 faiss_index = FAISS.load_local(
381 str(index_path.parent),
382 self.embedding_manager.embeddings,
383 index_name=index_path.stem,
384 allow_dangerous_deserialization=True,
385 normalize_L2=True,
386 )
387 logger.info(
388 f"Loaded existing FAISS index from {index_path}"
389 )
390 return faiss_index
391 except Exception:
392 logger.warning(
393 "Failed to load FAISS index, creating new one"
394 )
396 # Create new FAISS index with configurable type and distance metric
397 logger.info(
398 f"Creating new FAISS index: type={self.index_type}, metric={self.distance_metric}, dimension={rag_index.embedding_dimension}"
399 )
401 # Create index based on type and distance metric
402 if self.index_type == "hnsw":
403 # HNSW: Fast approximate search, best for large collections
404 # M=32 is a good default for connections per layer
405 index = IndexHNSWFlat(rag_index.embedding_dimension, 32)
406 logger.info("Created HNSW index with M=32 connections")
407 elif self.index_type == "ivf":
408 # IVF requires training, for now fall back to flat
409 # TODO: Implement IVF with proper training
410 logger.warning(
411 "IVF index type not yet fully implemented, using Flat index"
412 )
413 if self.distance_metric in ("cosine", "dot_product"):
414 index = IndexFlatIP(rag_index.embedding_dimension)
415 else:
416 index = IndexFlatL2(rag_index.embedding_dimension)
417 else: # "flat" or default
418 # Flat index: Exact search
419 if self.distance_metric in ("cosine", "dot_product"):
420 # For cosine similarity, use inner product (IP) with normalized vectors
421 index = IndexFlatIP(rag_index.embedding_dimension)
422 logger.info(
423 "Created Flat index with Inner Product (for cosine similarity)"
424 )
425 else: # l2
426 index = IndexFlatL2(rag_index.embedding_dimension)
427 logger.info("Created Flat index with L2 distance")
429 faiss_index = FAISS(
430 self.embedding_manager.embeddings,
431 index=index,
432 docstore=InMemoryDocstore(), # Minimal - chunks in DB
433 index_to_docstore_id={},
434 normalize_L2=self.normalize_vectors, # Use configurable normalization
435 )
436 logger.info(
437 f"FAISS index created with normalization={self.normalize_vectors}"
438 )
439 return faiss_index
441 def get_current_index_info(
442 self, collection_id: Optional[str] = None
443 ) -> Optional[Dict[str, Any]]:
444 """
445 Get information about the current RAG index for a collection.
447 Args:
448 collection_id: UUID of collection (defaults to Library if None)
449 """
450 with get_user_db_session(self.username, self.db_password) as session:
451 # Get collection name in the format stored in RAGIndex (collection_<uuid>)
452 if collection_id:
453 collection = (
454 session.query(Collection)
455 .filter_by(id=collection_id)
456 .first()
457 )
458 collection_name = (
459 f"collection_{collection_id}" if collection else "unknown"
460 )
461 else:
462 # Default to Library collection
463 from ...database.library_init import get_default_library_id
465 collection_id = get_default_library_id(
466 self.username, self.db_password
467 )
468 collection_name = f"collection_{collection_id}"
470 rag_index = (
471 session.query(RAGIndex)
472 .filter_by(collection_name=collection_name, is_current=True)
473 .first()
474 )
476 if not rag_index:
477 # Debug: check all RAG indices for this collection
478 all_indices = session.query(RAGIndex).all()
479 logger.info(
480 f"No RAG index found for collection_name='{collection_name}'. All indices: {[(idx.collection_name, idx.is_current) for idx in all_indices]}"
481 )
482 return None
484 # Calculate actual counts from rag_document_status table
485 from ...database.models.library import RagDocumentStatus
487 actual_chunk_count = (
488 session.query(func.sum(RagDocumentStatus.chunk_count))
489 .filter_by(collection_id=collection_id)
490 .scalar()
491 or 0
492 )
494 actual_doc_count = (
495 session.query(RagDocumentStatus)
496 .filter_by(collection_id=collection_id)
497 .count()
498 )
500 return {
501 "embedding_model": rag_index.embedding_model,
502 "embedding_model_type": rag_index.embedding_model_type.value
503 if rag_index.embedding_model_type
504 else None,
505 "embedding_dimension": rag_index.embedding_dimension,
506 "chunk_size": rag_index.chunk_size,
507 "chunk_overlap": rag_index.chunk_overlap,
508 "chunk_count": actual_chunk_count,
509 "total_documents": actual_doc_count,
510 "created_at": rag_index.created_at.isoformat(),
511 "last_updated_at": rag_index.last_updated_at.isoformat(),
512 }
514 def index_document(
515 self, document_id: str, collection_id: str, force_reindex: bool = False
516 ) -> Dict[str, Any]:
517 """
518 Index a single document into RAG for a specific collection.
520 Args:
521 document_id: UUID of the Document to index
522 collection_id: UUID of the Collection to index for
523 force_reindex: Whether to force reindexing even if already indexed
525 Returns:
526 Dict with status, chunk_count, and any errors
527 """
528 with get_user_db_session(self.username, self.db_password) as session:
529 # Get the document
530 document = session.query(Document).filter_by(id=document_id).first()
532 if not document:
533 return {"status": "error", "error": "Document not found"}
535 # Get or create DocumentCollection entry
536 all_doc_collections = (
537 session.query(DocumentCollection)
538 .filter_by(document_id=document_id, collection_id=collection_id)
539 .all()
540 )
542 logger.info(
543 f"Found {len(all_doc_collections)} DocumentCollection entries for doc={document_id}, coll={collection_id}"
544 )
546 doc_collection = (
547 all_doc_collections[0] if all_doc_collections else None
548 )
550 if not doc_collection:
551 # Create new DocumentCollection entry
552 doc_collection = DocumentCollection(
553 document_id=document_id,
554 collection_id=collection_id,
555 indexed=False,
556 chunk_count=0,
557 )
558 session.add(doc_collection)
559 logger.info(
560 f"Created new DocumentCollection entry for doc={document_id}, coll={collection_id}"
561 )
563 # Check if already indexed for this collection
564 if doc_collection.indexed and not force_reindex:
565 return {
566 "status": "skipped",
567 "message": "Document already indexed for this collection",
568 "chunk_count": doc_collection.chunk_count,
569 }
571 # Validate text content
572 if not document.text_content:
573 return {
574 "status": "error",
575 "error": "Document has no text content",
576 }
578 try:
579 # Create LangChain Document from text
580 doc = LangchainDocument(
581 page_content=document.text_content,
582 metadata={
583 "source": document.original_url,
584 "document_id": document_id, # Add document ID for source linking
585 "collection_id": collection_id, # Add collection ID
586 "title": document.title
587 or document.filename
588 or "Untitled",
589 "document_title": document.title
590 or document.filename
591 or "Untitled", # Add for compatibility
592 "authors": document.authors,
593 "published_date": str(document.published_date)
594 if document.published_date
595 else None,
596 "doi": document.doi,
597 "arxiv_id": document.arxiv_id,
598 "pmid": document.pmid,
599 "pmcid": document.pmcid,
600 "extraction_method": document.extraction_method,
601 "word_count": document.word_count,
602 },
603 )
605 # Split into chunks
606 chunks = self.text_splitter.split_documents([doc])
607 logger.info(
608 f"Split document {document_id} into {len(chunks)} chunks"
609 )
611 # Get collection name for chunk storage
612 collection = (
613 session.query(Collection)
614 .filter_by(id=collection_id)
615 .first()
616 )
617 # Use collection_<uuid> format for internal storage
618 collection_name = (
619 f"collection_{collection_id}" if collection else "unknown"
620 )
622 # Store chunks in database using embedding manager
623 embedding_ids = self.embedding_manager._store_chunks_to_db(
624 chunks=chunks,
625 collection_name=collection_name,
626 source_type="document",
627 source_id=document_id,
628 )
630 # Load or create FAISS index
631 if self.faiss_index is None: 631 ↛ 632line 631 didn't jump to line 632 because the condition on line 631 was never true
632 self.faiss_index = self.load_or_create_faiss_index(
633 collection_id
634 )
636 # If force_reindex, remove old chunks from FAISS before adding new ones
637 if force_reindex: 637 ↛ 653line 637 didn't jump to line 653 because the condition on line 637 was always true
638 existing_ids = (
639 set(self.faiss_index.docstore._dict.keys())
640 if hasattr(self.faiss_index, "docstore")
641 else set()
642 )
643 old_chunk_ids = list(
644 {eid for eid in embedding_ids if eid in existing_ids}
645 )
646 if old_chunk_ids:
647 logger.info(
648 f"Force re-index: removing {len(old_chunk_ids)} existing chunks from FAISS"
649 )
650 self.faiss_index.delete(old_chunk_ids)
652 # Filter out chunks that already exist in FAISS (unless force_reindex)
653 if not force_reindex: 653 ↛ 654line 653 didn't jump to line 654 because the condition on line 653 was never true
654 existing_ids = (
655 set(self.faiss_index.docstore._dict.keys())
656 if hasattr(self.faiss_index, "docstore")
657 else set()
658 )
659 else:
660 existing_ids = None
662 unique_count = len(set(embedding_ids))
663 batch_dups = len(chunks) - unique_count
665 new_chunks, new_ids = self._deduplicate_chunks(
666 chunks, embedding_ids, existing_ids
667 )
669 # Add embeddings to FAISS index
670 if new_chunks: 670 ↛ 683line 670 didn't jump to line 683 because the condition on line 670 was always true
671 if force_reindex: 671 ↛ 676line 671 didn't jump to line 676 because the condition on line 671 was always true
672 logger.info(
673 f"Force re-index: adding {len(new_chunks)} chunks with updated metadata to FAISS index"
674 )
675 else:
676 already_exist = unique_count - len(new_chunks)
677 logger.info(
678 f"Adding {len(new_chunks)} new embeddings to FAISS index "
679 f"({already_exist} already exist, {batch_dups} batch duplicates removed)"
680 )
681 self.faiss_index.add_documents(new_chunks, ids=new_ids)
682 else:
683 logger.info(
684 f"All {len(chunks)} chunks already exist in FAISS index, skipping"
685 )
687 # Save FAISS index
688 index_path = Path(self.rag_index_record.index_path)
689 self.faiss_index.save_local(
690 str(index_path.parent), index_name=index_path.stem
691 )
692 # Record file integrity
693 self.integrity_manager.record_file(
694 index_path,
695 related_entity_type="rag_index",
696 related_entity_id=self.rag_index_record.id,
697 )
698 logger.info(
699 f"Saved FAISS index to {index_path} with integrity tracking"
700 )
702 from datetime import datetime, UTC
703 from sqlalchemy import text
705 # Check if document was already indexed (for stats update)
706 existing_status = (
707 session.query(RagDocumentStatus)
708 .filter_by(
709 document_id=document_id, collection_id=collection_id
710 )
711 .first()
712 )
713 was_already_indexed = existing_status is not None
715 # Mark document as indexed using rag_document_status table
716 # Row existence = indexed, simple and clean
717 timestamp = datetime.now(UTC)
719 # Create or update RagDocumentStatus using ORM merge (atomic upsert)
720 rag_status = RagDocumentStatus(
721 document_id=document_id,
722 collection_id=collection_id,
723 rag_index_id=self.rag_index_record.id,
724 chunk_count=len(chunks),
725 indexed_at=timestamp,
726 )
727 session.merge(rag_status)
729 logger.info(
730 f"Marked document as indexed in rag_document_status: doc_id={document_id}, coll_id={collection_id}, chunks={len(chunks)}"
731 )
733 # Also update DocumentCollection table for backward compatibility
734 session.query(DocumentCollection).filter_by(
735 document_id=document_id, collection_id=collection_id
736 ).update(
737 {
738 "indexed": True,
739 "chunk_count": len(chunks),
740 "last_indexed_at": timestamp,
741 }
742 )
744 logger.info(
745 "Also updated DocumentCollection.indexed for backward compatibility"
746 )
748 # Update RAGIndex statistics (only if not already indexed)
749 rag_index_obj = (
750 session.query(RAGIndex)
751 .filter_by(id=self.rag_index_record.id)
752 .first()
753 )
754 if rag_index_obj and not was_already_indexed:
755 rag_index_obj.chunk_count += len(chunks)
756 rag_index_obj.total_documents += 1
757 rag_index_obj.last_updated_at = datetime.now(UTC)
758 logger.info(
759 f"Updated RAGIndex stats: chunk_count +{len(chunks)}, total_documents +1"
760 )
762 # Flush ORM changes to database before commit
763 session.flush()
764 logger.info(f"Flushed ORM changes for document {document_id}")
766 # Commit the transaction
767 session.commit()
769 # WAL checkpoint after commit to ensure persistence
770 session.execute(text("PRAGMA wal_checkpoint(FULL)"))
772 logger.info(
773 f"Successfully indexed document {document_id} for collection {collection_id} "
774 f"with {len(chunks)} chunks"
775 )
777 return {
778 "status": "success",
779 "chunk_count": len(chunks),
780 "embedding_ids": embedding_ids,
781 }
783 except Exception as e:
784 logger.exception(
785 f"Error indexing document {document_id} for collection {collection_id}"
786 )
787 return {
788 "status": "error",
789 "error": f"Operation failed: {type(e).__name__}",
790 }
792 def index_all_documents(
793 self,
794 collection_id: str,
795 force_reindex: bool = False,
796 progress_callback=None,
797 ) -> Dict[str, Any]:
798 """
799 Index all documents in a collection into RAG.
801 Args:
802 collection_id: UUID of the collection to index
803 force_reindex: Whether to force reindexing already indexed documents
804 progress_callback: Optional callback function called after each document with (current, total, doc_title, status)
806 Returns:
807 Dict with counts of successful, skipped, and failed documents
808 """
809 with get_user_db_session(self.username, self.db_password) as session:
810 # Get all DocumentCollection entries for this collection
811 query = session.query(DocumentCollection).filter_by(
812 collection_id=collection_id
813 )
815 if not force_reindex:
816 # Only index documents that haven't been indexed yet
817 query = query.filter_by(indexed=False)
819 doc_collections = query.all()
821 if not doc_collections:
822 return {
823 "status": "info",
824 "message": "No documents to index",
825 "successful": 0,
826 "skipped": 0,
827 "failed": 0,
828 }
830 results = {"successful": 0, "skipped": 0, "failed": 0, "errors": []}
831 total = len(doc_collections)
833 for idx, doc_collection in enumerate(doc_collections, 1):
834 # Get the document for title info
835 document = (
836 session.query(Document)
837 .filter_by(id=doc_collection.document_id)
838 .first()
839 )
840 title = document.title if document else "Unknown"
842 result = self.index_document(
843 doc_collection.document_id, collection_id, force_reindex
844 )
846 if result["status"] == "success":
847 results["successful"] += 1
848 elif result["status"] == "skipped":
849 results["skipped"] += 1
850 else:
851 results["failed"] += 1
852 results["errors"].append(
853 {
854 "doc_id": doc_collection.document_id,
855 "title": title,
856 "error": result.get("error"),
857 }
858 )
860 # Call progress callback if provided
861 if progress_callback:
862 progress_callback(idx, total, title, result["status"])
864 logger.info(
865 f"Indexed collection {collection_id}: "
866 f"{results['successful']} successful, "
867 f"{results['skipped']} skipped, "
868 f"{results['failed']} failed"
869 )
871 return results
873 def remove_document_from_rag(
874 self, document_id: str, collection_id: str
875 ) -> Dict[str, Any]:
876 """
877 Remove a document's chunks from RAG for a specific collection.
879 Args:
880 document_id: UUID of the Document to remove
881 collection_id: UUID of the Collection to remove from
883 Returns:
884 Dict with status and count of removed chunks
885 """
886 with get_user_db_session(self.username, self.db_password) as session:
887 # Get the DocumentCollection entry
888 doc_collection = (
889 session.query(DocumentCollection)
890 .filter_by(document_id=document_id, collection_id=collection_id)
891 .first()
892 )
894 if not doc_collection:
895 return {
896 "status": "error",
897 "error": "Document not found in collection",
898 }
900 try:
901 # Get collection name in the format collection_<uuid>
902 collection = (
903 session.query(Collection)
904 .filter_by(id=collection_id)
905 .first()
906 )
907 # Use collection_<uuid> format for internal storage
908 collection_name = (
909 f"collection_{collection_id}" if collection else "unknown"
910 )
912 # Delete chunks from database
913 deleted_count = self.embedding_manager._delete_chunks_from_db(
914 collection_name=collection_name,
915 source_id=document_id,
916 )
918 # Update DocumentCollection RAG status
919 doc_collection.indexed = False
920 doc_collection.chunk_count = 0
921 doc_collection.last_indexed_at = None
922 session.commit()
924 logger.info(
925 f"Removed {deleted_count} chunks for document {document_id} from collection {collection_id}"
926 )
928 return {"status": "success", "deleted_count": deleted_count}
930 except Exception as e:
931 logger.exception(
932 f"Error removing document {document_id} from collection {collection_id}"
933 )
934 return {
935 "status": "error",
936 "error": f"Operation failed: {type(e).__name__}",
937 }
939 def index_documents_batch(
940 self,
941 doc_info: List[tuple],
942 collection_id: str,
943 force_reindex: bool = False,
944 ) -> Dict[str, Dict[str, Any]]:
945 """
946 Index multiple documents in a batch for a specific collection.
948 Args:
949 doc_info: List of (doc_id, title) tuples
950 collection_id: UUID of the collection to index for
951 force_reindex: Whether to force reindexing even if already indexed
953 Returns:
954 Dict mapping doc_id to individual result
955 """
956 results = {}
957 doc_ids = [doc_id for doc_id, _ in doc_info]
959 # Use single database session for querying
960 with get_user_db_session(self.username, self.db_password) as session:
961 # Pre-load all documents for this batch
962 documents = (
963 session.query(Document).filter(Document.id.in_(doc_ids)).all()
964 )
966 # Create lookup for quick access
967 doc_lookup = {doc.id: doc for doc in documents}
969 # Pre-load DocumentCollection entries
970 doc_collections = (
971 session.query(DocumentCollection)
972 .filter(
973 DocumentCollection.document_id.in_(doc_ids),
974 DocumentCollection.collection_id == collection_id,
975 )
976 .all()
977 )
978 doc_collection_lookup = {
979 dc.document_id: dc for dc in doc_collections
980 }
982 # Process each document in the batch
983 for doc_id, title in doc_info:
984 document = doc_lookup.get(doc_id)
986 if not document:
987 results[doc_id] = {
988 "status": "error",
989 "error": "Document not found",
990 }
991 continue
993 # Check if already indexed via DocumentCollection
994 doc_collection = doc_collection_lookup.get(doc_id)
995 if (
996 doc_collection
997 and doc_collection.indexed
998 and not force_reindex
999 ):
1000 results[doc_id] = {
1001 "status": "skipped",
1002 "message": "Document already indexed for this collection",
1003 "chunk_count": doc_collection.chunk_count,
1004 }
1005 continue
1007 # Validate text content
1008 if not document.text_content:
1009 results[doc_id] = {
1010 "status": "error",
1011 "error": "Document has no text content",
1012 }
1013 continue
1015 # Index the document
1016 try:
1017 result = self.index_document(
1018 doc_id, collection_id, force_reindex
1019 )
1020 results[doc_id] = result
1021 except Exception as e:
1022 logger.exception(
1023 f"Error indexing document {doc_id} in batch"
1024 )
1025 results[doc_id] = {
1026 "status": "error",
1027 "error": f"Indexing failed: {type(e).__name__}",
1028 }
1030 return results
1032 def get_rag_stats(
1033 self, collection_id: Optional[str] = None
1034 ) -> Dict[str, Any]:
1035 """
1036 Get RAG statistics for a collection.
1038 Args:
1039 collection_id: UUID of the collection (defaults to Library)
1041 Returns:
1042 Dict with counts and metadata about indexed documents
1043 """
1044 with get_user_db_session(self.username, self.db_password) as session:
1045 # Get collection ID (default to Library)
1046 if not collection_id: 1046 ↛ 1047line 1046 didn't jump to line 1047 because the condition on line 1046 was never true
1047 from ...database.library_init import get_default_library_id
1049 collection_id = get_default_library_id(
1050 self.username, self.db_password
1051 )
1053 # Count total documents in collection
1054 total_docs = (
1055 session.query(DocumentCollection)
1056 .filter_by(collection_id=collection_id)
1057 .count()
1058 )
1060 # Count indexed documents from rag_document_status table
1061 from ...database.models.library import RagDocumentStatus
1063 indexed_docs = (
1064 session.query(RagDocumentStatus)
1065 .filter_by(collection_id=collection_id)
1066 .count()
1067 )
1069 # Count total chunks from rag_document_status table
1070 total_chunks = (
1071 session.query(func.sum(RagDocumentStatus.chunk_count))
1072 .filter_by(collection_id=collection_id)
1073 .scalar()
1074 or 0
1075 )
1077 # Get collection name in the format stored in DocumentChunk (collection_<uuid>)
1078 collection = (
1079 session.query(Collection).filter_by(id=collection_id).first()
1080 )
1081 collection_name = (
1082 f"collection_{collection_id}" if collection else "library"
1083 )
1085 # Get embedding model info from chunks
1086 chunk_sample = (
1087 session.query(DocumentChunk)
1088 .filter_by(collection_name=collection_name)
1089 .first()
1090 )
1092 embedding_info = {}
1093 if chunk_sample:
1094 embedding_info = {
1095 "model": chunk_sample.embedding_model,
1096 "model_type": chunk_sample.embedding_model_type.value
1097 if chunk_sample.embedding_model_type
1098 else None,
1099 "dimension": chunk_sample.embedding_dimension,
1100 }
1102 return {
1103 "total_documents": total_docs,
1104 "indexed_documents": indexed_docs,
1105 "unindexed_documents": total_docs - indexed_docs,
1106 "total_chunks": total_chunks,
1107 "embedding_info": embedding_info,
1108 "chunk_size": self.chunk_size,
1109 "chunk_overlap": self.chunk_overlap,
1110 }
1112 def index_local_file(self, file_path: str) -> Dict[str, Any]:
1113 """
1114 Index a local file from the filesystem into RAG.
1116 Args:
1117 file_path: Path to the file to index
1119 Returns:
1120 Dict with status, chunk_count, and any errors
1121 """
1122 from pathlib import Path
1123 import mimetypes
1125 file_path = Path(file_path)
1127 if not file_path.exists():
1128 return {"status": "error", "error": f"File not found: {file_path}"}
1130 if not file_path.is_file():
1131 return {"status": "error", "error": f"Not a file: {file_path}"}
1133 # Determine file type
1134 mime_type, _ = mimetypes.guess_type(str(file_path))
1136 # Read file content based on type
1137 try:
1138 if file_path.suffix.lower() in [".txt", ".md", ".markdown"]:
1139 # Text files
1140 with open(file_path, "r", encoding="utf-8") as f:
1141 content = f.read()
1142 elif file_path.suffix.lower() in [".html", ".htm"]:
1143 # HTML files - strip tags
1144 from bs4 import BeautifulSoup
1146 with open(file_path, "r", encoding="utf-8") as f:
1147 soup = BeautifulSoup(f.read(), "html.parser")
1148 content = soup.get_text()
1149 elif file_path.suffix.lower() == ".pdf":
1150 # PDF files - extract text
1151 from pypdf import PdfReader
1153 content = ""
1154 with open(file_path, "rb") as f:
1155 pdf_reader = PdfReader(f)
1156 for page in pdf_reader.pages:
1157 content += page.extract_text()
1158 else:
1159 return {
1160 "status": "skipped",
1161 "error": f"Unsupported file type: {file_path.suffix}",
1162 }
1164 if not content or len(content.strip()) < 10:
1165 return {
1166 "status": "error",
1167 "error": "File has no extractable text content",
1168 }
1170 # Create LangChain Document from text
1171 doc = LangchainDocument(
1172 page_content=content,
1173 metadata={
1174 "source": str(file_path),
1175 "source_id": f"local_{file_path.stem}_{hash(str(file_path))}",
1176 "title": file_path.stem,
1177 "document_title": file_path.stem,
1178 "file_type": file_path.suffix.lower(),
1179 "file_size": file_path.stat().st_size,
1180 "source_type": "local_file",
1181 "collection": "local_library",
1182 },
1183 )
1185 # Split into chunks
1186 chunks = self.text_splitter.split_documents([doc])
1187 logger.info(
1188 f"Split local file {file_path} into {len(chunks)} chunks"
1189 )
1191 # Store chunks in database (returns UUID-based IDs)
1192 embedding_ids = self.embedding_manager._store_chunks_to_db(
1193 chunks=chunks,
1194 collection_name="local_library",
1195 source_type="local_file",
1196 source_id=str(file_path),
1197 )
1199 # Load or create FAISS index using default library collection
1200 if self.faiss_index is None:
1201 from ...database.library_init import get_default_library_id
1203 default_collection_id = get_default_library_id(
1204 self.username, self.db_password
1205 )
1206 self.faiss_index = self.load_or_create_faiss_index(
1207 default_collection_id
1208 )
1210 # Filter out chunks that already exist in FAISS and deduplicate
1211 if self.faiss_index is not None: 1211 ↛ 1218line 1211 didn't jump to line 1218 because the condition on line 1211 was always true
1212 existing_ids = (
1213 set(self.faiss_index.docstore._dict.keys())
1214 if hasattr(self.faiss_index, "docstore")
1215 else set()
1216 )
1217 else:
1218 existing_ids = None
1219 new_chunks, new_ids = self._deduplicate_chunks(
1220 chunks, embedding_ids, existing_ids
1221 )
1223 # Add embeddings to FAISS index
1224 if new_chunks: 1224 ↛ 1228line 1224 didn't jump to line 1228 because the condition on line 1224 was always true
1225 self.faiss_index.add_documents(new_chunks, ids=new_ids)
1227 # Save FAISS index
1228 index_path = (
1229 Path(self.rag_index_record.index_path)
1230 if self.rag_index_record
1231 else None
1232 )
1233 if index_path:
1234 self.faiss_index.save_local(
1235 str(index_path.parent), index_name=index_path.stem
1236 )
1237 # Record file integrity
1238 self.integrity_manager.record_file(
1239 index_path,
1240 related_entity_type="rag_index",
1241 related_entity_id=self.rag_index_record.id,
1242 )
1243 logger.info(
1244 f"Saved FAISS index to {index_path} with integrity tracking"
1245 )
1247 logger.info(
1248 f"Successfully indexed local file {file_path} with {len(new_chunks)} new chunks "
1249 f"({len(chunks) - len(new_chunks)} skipped)"
1250 )
1252 return {
1253 "status": "success",
1254 "chunk_count": len(new_chunks),
1255 "embedding_ids": new_ids,
1256 }
1258 except Exception as e:
1259 logger.exception(f"Error indexing local file {file_path}")
1260 return {
1261 "status": "error",
1262 "error": f"Operation failed: {type(e).__name__}",
1263 }
1265 def index_user_document(
1266 self, user_doc, collection_name: str, force_reindex: bool = False
1267 ) -> Dict[str, Any]:
1268 """
1269 Index a user-uploaded document into a specific collection.
1271 Args:
1272 user_doc: UserDocument object
1273 collection_name: Name of the collection (e.g., "collection_123")
1274 force_reindex: Whether to force reindexing
1276 Returns:
1277 Dict with status, chunk_count, and any errors
1278 """
1280 try:
1281 # Use the pre-extracted text content
1282 content = user_doc.text_content
1284 if not content or len(content.strip()) < 10:
1285 return {
1286 "status": "error",
1287 "error": "Document has no extractable text content",
1288 }
1290 # Create LangChain Document
1291 doc = LangchainDocument(
1292 page_content=content,
1293 metadata={
1294 "source": f"user_upload_{user_doc.id}",
1295 "source_id": user_doc.id,
1296 "title": user_doc.filename,
1297 "document_title": user_doc.filename,
1298 "file_type": user_doc.file_type,
1299 "file_size": user_doc.file_size,
1300 "collection": collection_name,
1301 },
1302 )
1304 # Split into chunks
1305 chunks = self.text_splitter.split_documents([doc])
1306 logger.info(
1307 f"Split user document {user_doc.filename} into {len(chunks)} chunks"
1308 )
1310 # Store chunks in database
1311 embedding_ids = self.embedding_manager._store_chunks_to_db(
1312 chunks=chunks,
1313 collection_name=collection_name,
1314 source_type="user_document",
1315 source_id=user_doc.id,
1316 )
1318 # Load or create FAISS index for this collection
1319 if self.faiss_index is None:
1320 # Extract collection_id from collection_name (format: "collection_<uuid>")
1321 collection_id = collection_name.removeprefix("collection_")
1322 self.faiss_index = self.load_or_create_faiss_index(
1323 collection_id
1324 )
1326 # If force_reindex, remove old chunks from FAISS before adding new ones
1327 if force_reindex:
1328 existing_ids = (
1329 set(self.faiss_index.docstore._dict.keys())
1330 if hasattr(self.faiss_index, "docstore")
1331 else set()
1332 )
1333 old_chunk_ids = list(
1334 {eid for eid in embedding_ids if eid in existing_ids}
1335 )
1336 if old_chunk_ids: 1336 ↛ 1343line 1336 didn't jump to line 1343 because the condition on line 1336 was always true
1337 logger.info(
1338 f"Force re-index: removing {len(old_chunk_ids)} existing chunks from FAISS"
1339 )
1340 self.faiss_index.delete(old_chunk_ids)
1342 # Filter out chunks that already exist in FAISS (unless force_reindex)
1343 if not force_reindex:
1344 existing_ids = (
1345 set(self.faiss_index.docstore._dict.keys())
1346 if hasattr(self.faiss_index, "docstore")
1347 else set()
1348 )
1349 else:
1350 existing_ids = None
1352 unique_count = len(set(embedding_ids))
1353 batch_dups = len(chunks) - unique_count
1355 new_chunks, new_ids = self._deduplicate_chunks(
1356 chunks, embedding_ids, existing_ids
1357 )
1359 # Add embeddings to FAISS index
1360 if new_chunks: 1360 ↛ 1373line 1360 didn't jump to line 1373 because the condition on line 1360 was always true
1361 if force_reindex:
1362 logger.info(
1363 f"Force re-index: adding {len(new_chunks)} chunks with updated metadata to FAISS index"
1364 )
1365 else:
1366 already_exist = unique_count - len(new_chunks)
1367 logger.info(
1368 f"Adding {len(new_chunks)} new chunks to FAISS index "
1369 f"({already_exist} already exist, {batch_dups} batch duplicates removed)"
1370 )
1371 self.faiss_index.add_documents(new_chunks, ids=new_ids)
1372 else:
1373 logger.info(
1374 f"All {len(chunks)} chunks already exist in FAISS index, skipping"
1375 )
1377 # Save FAISS index
1378 index_path = (
1379 Path(self.rag_index_record.index_path)
1380 if self.rag_index_record
1381 else None
1382 )
1383 if index_path:
1384 self.faiss_index.save_local(
1385 str(index_path.parent), index_name=index_path.stem
1386 )
1387 # Record file integrity
1388 self.integrity_manager.record_file(
1389 index_path,
1390 related_entity_type="rag_index",
1391 related_entity_id=self.rag_index_record.id,
1392 )
1394 logger.info(
1395 f"Successfully indexed user document {user_doc.filename} with {len(chunks)} chunks"
1396 )
1398 return {
1399 "status": "success",
1400 "chunk_count": len(chunks),
1401 "embedding_ids": embedding_ids,
1402 }
1404 except Exception as e:
1405 logger.exception(
1406 f"Error indexing user document {user_doc.filename}"
1407 )
1408 return {
1409 "status": "error",
1410 "error": f"Operation failed: {type(e).__name__}",
1411 }
1413 def remove_collection_from_index(
1414 self, collection_name: str
1415 ) -> Dict[str, Any]:
1416 """
1417 Remove all documents from a collection from the FAISS index.
1419 Args:
1420 collection_name: Name of the collection (e.g., "collection_123")
1422 Returns:
1423 Dict with status and count of removed chunks
1424 """
1425 from ...database.models import DocumentChunk
1426 from ...database.session_context import get_user_db_session
1428 try:
1429 with get_user_db_session(
1430 self.username, self.db_password
1431 ) as session:
1432 # Get all chunk IDs for this collection
1433 chunks = (
1434 session.query(DocumentChunk)
1435 .filter_by(collection_name=collection_name)
1436 .all()
1437 )
1439 if not chunks:
1440 return {"status": "success", "deleted_count": 0}
1442 chunk_ids = [
1443 f"{collection_name}_{chunk.id}" for chunk in chunks
1444 ]
1446 # Load FAISS index if not already loaded
1447 if self.faiss_index is None:
1448 # Extract collection_id from collection_name (format: "collection_<uuid>")
1449 collection_id = collection_name.removeprefix("collection_")
1450 self.faiss_index = self.load_or_create_faiss_index(
1451 collection_id
1452 )
1454 # Remove from FAISS index
1455 if hasattr(self.faiss_index, "delete"): 1455 ↛ 1479line 1455 didn't jump to line 1479 because the condition on line 1455 was always true
1456 try:
1457 self.faiss_index.delete(chunk_ids)
1459 # Save updated index
1460 index_path = (
1461 Path(self.rag_index_record.index_path)
1462 if self.rag_index_record
1463 else None
1464 )
1465 if index_path:
1466 self.faiss_index.save_local(
1467 str(index_path.parent),
1468 index_name=index_path.stem,
1469 )
1470 # Record file integrity
1471 self.integrity_manager.record_file(
1472 index_path,
1473 related_entity_type="rag_index",
1474 related_entity_id=self.rag_index_record.id,
1475 )
1476 except Exception:
1477 logger.warning("Could not delete chunks from FAISS")
1479 logger.info(
1480 f"Removed {len(chunk_ids)} chunks from collection {collection_name}"
1481 )
1483 return {"status": "success", "deleted_count": len(chunk_ids)}
1485 except Exception as e:
1486 logger.exception(
1487 f"Error removing collection {collection_name} from index"
1488 )
1489 return {
1490 "status": "error",
1491 "error": f"Operation failed: {type(e).__name__}",
1492 }