Coverage for src / local_deep_research / embeddings / embeddings_config.py: 60%
49 statements
« prev ^ index » next coverage.py v7.12.0, created at 2026-01-11 00:51 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2026-01-11 00:51 +0000
1"""
2Central configuration for embedding providers.
4This module provides the main get_embeddings() function and availability checks
5for different embedding providers, similar to llm_config.py.
6"""
8from typing import Any, Dict, Optional, Type
10from langchain_core.embeddings import Embeddings
11from loguru import logger
13from ..config.thread_settings import get_setting_from_snapshot
14from .providers.base import BaseEmbeddingProvider
16# Valid embedding provider options
17VALID_EMBEDDING_PROVIDERS = [
18 "sentence_transformers",
19 "ollama",
20 "openai",
21]
23# Lazy-loaded provider classes dict
24_PROVIDER_CLASSES: Optional[Dict[str, Type[BaseEmbeddingProvider]]] = None
27def _get_provider_classes() -> Dict[str, Type[BaseEmbeddingProvider]]:
28 """Lazy load provider classes to avoid circular imports."""
29 global _PROVIDER_CLASSES
30 if _PROVIDER_CLASSES is None: 30 ↛ 42line 30 didn't jump to line 42 because the condition on line 30 was always true
31 from .providers.implementations.sentence_transformers import (
32 SentenceTransformersProvider,
33 )
34 from .providers.implementations.ollama import OllamaEmbeddingsProvider
35 from .providers.implementations.openai import OpenAIEmbeddingsProvider
37 _PROVIDER_CLASSES = {
38 "sentence_transformers": SentenceTransformersProvider,
39 "ollama": OllamaEmbeddingsProvider,
40 "openai": OpenAIEmbeddingsProvider,
41 }
42 return _PROVIDER_CLASSES
45def is_sentence_transformers_available() -> bool:
46 """Check if Sentence Transformers is available."""
47 provider_classes = _get_provider_classes()
48 return provider_classes["sentence_transformers"].is_available()
51def is_ollama_embeddings_available(
52 settings_snapshot: Optional[Dict[str, Any]] = None,
53) -> bool:
54 """Check if Ollama embeddings are available."""
55 provider_classes = _get_provider_classes()
56 return provider_classes["ollama"].is_available(settings_snapshot)
59def is_openai_embeddings_available(
60 settings_snapshot: Optional[Dict[str, Any]] = None,
61) -> bool:
62 """Check if OpenAI embeddings are available."""
63 provider_classes = _get_provider_classes()
64 return provider_classes["openai"].is_available(settings_snapshot)
67def get_available_embedding_providers(
68 settings_snapshot: Optional[Dict[str, Any]] = None,
69) -> Dict[str, str]:
70 """
71 Return available embedding providers.
73 Args:
74 settings_snapshot: Optional settings snapshot
76 Returns:
77 Dict mapping provider keys to display names
78 """
79 providers = {}
81 if is_sentence_transformers_available():
82 providers["sentence_transformers"] = "Sentence Transformers (Local)"
84 if is_ollama_embeddings_available(settings_snapshot):
85 providers["ollama"] = "Ollama (Local)"
87 if is_openai_embeddings_available(settings_snapshot):
88 providers["openai"] = "OpenAI API"
90 return providers
93def get_embedding_function(
94 provider: Optional[str] = None,
95 model_name: Optional[str] = None,
96 settings_snapshot: Optional[Dict[str, Any]] = None,
97 **kwargs,
98):
99 """
100 Get a callable embedding function that can embed texts.
102 Args:
103 provider: Embedding provider to use
104 model_name: Model name to use
105 settings_snapshot: Optional settings snapshot
106 **kwargs: Additional provider-specific parameters
108 Returns:
109 A callable that takes a list of texts and returns embeddings
110 """
111 embeddings = get_embeddings(
112 provider=provider,
113 model=model_name,
114 settings_snapshot=settings_snapshot,
115 **kwargs,
116 )
117 return embeddings.embed_documents
120def get_embeddings(
121 provider: Optional[str] = None,
122 model: Optional[str] = None,
123 settings_snapshot: Optional[Dict[str, Any]] = None,
124 **kwargs,
125) -> Embeddings:
126 """
127 Get embeddings instance based on provider and model.
129 Args:
130 provider: Embedding provider to use (if None, uses settings)
131 model: Model name to use (if None, uses settings or provider default)
132 settings_snapshot: Optional settings snapshot for thread-safe access
133 **kwargs: Additional provider-specific parameters
135 Returns:
136 A LangChain Embeddings instance
138 Raises:
139 ValueError: If provider is invalid or not available
140 ImportError: If required dependencies are not installed
141 """
142 # Get provider from settings if not specified
143 if provider is None: 143 ↛ 144line 143 didn't jump to line 144 because the condition on line 143 was never true
144 provider = get_setting_from_snapshot(
145 "embeddings.provider",
146 default="sentence_transformers",
147 settings_snapshot=settings_snapshot,
148 )
150 # Clean and normalize provider
151 if provider: 151 ↛ 155line 151 didn't jump to line 155 because the condition on line 151 was always true
152 provider = provider.strip().strip("\"'").strip().lower()
154 # Validate provider
155 if provider not in VALID_EMBEDDING_PROVIDERS:
156 logger.error(f"Invalid embedding provider: {provider}")
157 raise ValueError(
158 f"Invalid embedding provider: {provider}. "
159 f"Must be one of: {VALID_EMBEDDING_PROVIDERS}"
160 )
162 logger.info(f"Getting embeddings with provider: {provider}, model: {model}")
164 # Get provider class and create embeddings
165 provider_classes = _get_provider_classes()
166 provider_class = provider_classes.get(provider)
168 if not provider_class: 168 ↛ 169line 168 didn't jump to line 169 because the condition on line 168 was never true
169 raise ValueError(f"Unsupported embedding provider: {provider}")
171 return provider_class.create_embeddings(
172 model=model, settings_snapshot=settings_snapshot, **kwargs
173 )