Coverage for src/local_deep_research/embeddings/embeddings_config.py: 95%
49 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-03 23:15 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-03 23:15 +0000
1"""
2Central configuration for embedding providers.
4This module provides the main get_embeddings() function and availability checks
5for different embedding providers, similar to llm_config.py.
6"""
8from typing import Any, Dict, Optional, Type
10from langchain_core.embeddings import Embeddings
11from loguru import logger
13from ..config.thread_settings import get_setting_from_snapshot
14from .providers.base import BaseEmbeddingProvider
16# Internal: list of provider strings accepted by get_embeddings().
17# Not re-exported from embeddings/__init__.py — call sites should not import it.
18# Kept module-level so the validation error message at L~163 can list options.
19VALID_EMBEDDING_PROVIDERS = [
20 "sentence_transformers",
21 "ollama",
22 "openai",
23]
25# Lazy-loaded provider classes dict
26_PROVIDER_CLASSES: Optional[Dict[str, Type[BaseEmbeddingProvider]]] = None
29def _get_provider_classes() -> Dict[str, Type[BaseEmbeddingProvider]]:
30 """Lazy load provider classes to avoid circular imports."""
31 global _PROVIDER_CLASSES
32 if _PROVIDER_CLASSES is None:
33 from .providers.implementations.sentence_transformers import (
34 SentenceTransformersProvider,
35 )
36 from .providers.implementations.ollama import OllamaEmbeddingsProvider
37 from .providers.implementations.openai import OpenAIEmbeddingsProvider
39 _PROVIDER_CLASSES = {
40 "sentence_transformers": SentenceTransformersProvider,
41 "ollama": OllamaEmbeddingsProvider,
42 "openai": OpenAIEmbeddingsProvider,
43 }
44 return _PROVIDER_CLASSES
47def is_sentence_transformers_available() -> bool:
48 """Check if Sentence Transformers is available."""
49 provider_classes = _get_provider_classes()
50 return provider_classes["sentence_transformers"].is_available()
53def is_ollama_embeddings_available(
54 settings_snapshot: Optional[Dict[str, Any]] = None,
55) -> bool:
56 """Check if Ollama embeddings are available."""
57 provider_classes = _get_provider_classes()
58 return provider_classes["ollama"].is_available(settings_snapshot)
61def is_openai_embeddings_available(
62 settings_snapshot: Optional[Dict[str, Any]] = None,
63) -> bool:
64 """Check if OpenAI embeddings are available."""
65 provider_classes = _get_provider_classes()
66 return provider_classes["openai"].is_available(settings_snapshot)
69def get_available_embedding_providers(
70 settings_snapshot: Optional[Dict[str, Any]] = None,
71) -> Dict[str, str]:
72 """
73 Return available embedding providers.
75 Args:
76 settings_snapshot: Optional settings snapshot
78 Returns:
79 Dict mapping provider keys to display names
80 """
81 providers = {}
83 if is_sentence_transformers_available():
84 providers["sentence_transformers"] = "Sentence Transformers (Local)"
86 if is_ollama_embeddings_available(settings_snapshot):
87 providers["ollama"] = "Ollama (Local)"
89 if is_openai_embeddings_available(settings_snapshot):
90 # Single entry covers the OpenAI cloud API and any
91 # OpenAI-compatible endpoint (LM Studio, vLLM, llama.cpp);
92 # the provider class branches on
93 # ``embeddings.openai.base_url`` at runtime.
94 providers["openai"] = "OpenAI / OpenAI-Compatible Endpoint"
96 return providers
99def get_embedding_function(
100 provider: Optional[str] = None,
101 model_name: Optional[str] = None,
102 settings_snapshot: Optional[Dict[str, Any]] = None,
103 **kwargs,
104):
105 """
106 Get a callable embedding function that can embed texts.
108 Args:
109 provider: Embedding provider to use
110 model_name: Model name to use
111 settings_snapshot: Optional settings snapshot
112 **kwargs: Additional provider-specific parameters
114 Returns:
115 A callable that takes a list of texts and returns embeddings
116 """
117 embeddings = get_embeddings(
118 provider=provider,
119 model=model_name,
120 settings_snapshot=settings_snapshot,
121 **kwargs,
122 )
123 return embeddings.embed_documents
126def get_embeddings(
127 provider: Optional[str] = None,
128 model: Optional[str] = None,
129 settings_snapshot: Optional[Dict[str, Any]] = None,
130 **kwargs,
131) -> Embeddings:
132 """
133 Get embeddings instance based on provider and model.
135 Args:
136 provider: Embedding provider to use (if None, uses settings)
137 model: Model name to use (if None, uses settings or provider default)
138 settings_snapshot: Optional settings snapshot for thread-safe access
139 **kwargs: Additional provider-specific parameters
141 Returns:
142 A LangChain Embeddings instance
144 Raises:
145 ValueError: If provider is invalid or not available
146 ImportError: If required dependencies are not installed
147 """
148 # Get provider from settings if not specified
149 if provider is None:
150 provider = get_setting_from_snapshot(
151 "embeddings.provider",
152 default="sentence_transformers",
153 settings_snapshot=settings_snapshot,
154 )
156 # Clean and normalize provider
157 if provider: 157 ↛ 161line 157 didn't jump to line 161 because the condition on line 157 was always true
158 provider = provider.strip().strip("\"'").strip().lower()
160 # Validate provider
161 if provider not in VALID_EMBEDDING_PROVIDERS:
162 logger.error(f"Invalid embedding provider: {provider}")
163 raise ValueError(
164 f"Invalid embedding provider: {provider}. "
165 f"Must be one of: {VALID_EMBEDDING_PROVIDERS}"
166 )
168 logger.info(f"Getting embeddings with provider: {provider}, model: {model}")
170 # Get provider class and create embeddings
171 provider_classes = _get_provider_classes()
172 provider_class = provider_classes.get(provider)
174 if not provider_class: 174 ↛ 175line 174 didn't jump to line 175 because the condition on line 174 was never true
175 raise ValueError(f"Unsupported embedding provider: {provider}")
177 return provider_class.create_embeddings(
178 model=model, settings_snapshot=settings_snapshot, **kwargs
179 )