Coverage for src/local_deep_research/embeddings/providers/base.py: 92%
34 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-03 23:15 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-03 23:15 +0000
1"""Base class for embedding providers."""
3from abc import ABC, abstractmethod
4from typing import Any, Dict, List, Optional
6from langchain_core.embeddings import Embeddings
9class BaseEmbeddingProvider(ABC):
10 """
11 Abstract base class for embedding providers.
13 All embedding providers should inherit from this class and implement
14 the required methods. This provides a consistent interface similar to
15 the LLM provider system.
16 """
18 # Override these in subclasses
19 provider_name = "base" # Display name for logs/UI
20 provider_key = "BASE" # Unique identifier (uppercase)
21 requires_api_key = False # Whether this provider requires an API key
22 supports_local = False # Whether this runs locally
23 default_model = None # Default embedding model
25 @classmethod
26 @abstractmethod
27 def create_embeddings(
28 cls,
29 model: Optional[str] = None,
30 settings_snapshot: Optional[Dict[str, Any]] = None,
31 **kwargs,
32 ) -> Embeddings:
33 """
34 Create an embeddings instance for this provider.
36 Args:
37 model: Name of the embedding model to use
38 settings_snapshot: Optional settings snapshot for thread-safe access
39 **kwargs: Additional provider-specific parameters
41 Returns:
42 A LangChain Embeddings instance
44 Raises:
45 ValueError: If required configuration is missing
46 ImportError: If required dependencies are not installed
47 """
48 pass
50 @classmethod
51 @abstractmethod
52 def is_available(
53 cls, settings_snapshot: Optional[Dict[str, Any]] = None
54 ) -> bool:
55 """
56 Check if this embedding provider is available and properly configured.
58 Args:
59 settings_snapshot: Optional settings snapshot for thread-safe access
61 Returns:
62 True if the provider can be used, False otherwise
63 """
64 pass
66 @classmethod
67 def get_available_models(
68 cls, settings_snapshot: Optional[Dict[str, Any]] = None
69 ) -> List[Dict[str, Any]]:
70 """
71 Get list of available embedding models for this provider.
73 Implementations should return every model the backend reports.
74 Filtering by name is unreliable — users may load custom or
75 renamed embedding models — so leave the choice to the user and
76 only tag entries when a real capability signal is available.
78 Args:
79 settings_snapshot: Optional settings snapshot
81 Returns:
82 List of dicts with ``value`` and ``label`` string keys for
83 each model. May include an optional ``is_embedding`` (bool)
84 key when the provider can detect embedding capability from
85 the backend (e.g. Ollama's ``/api/show`` capabilities).
86 """
87 return []
89 @classmethod
90 def is_embedding_model(
91 cls,
92 model: str,
93 settings_snapshot: Optional[Dict[str, Any]] = None,
94 ) -> Optional[bool]:
95 """
96 Check whether a specific model supports embeddings.
98 Providers that can distinguish embedding models from chat/LLM models
99 should override this method.
101 Args:
102 model: Model identifier
103 settings_snapshot: Optional settings snapshot
105 Returns:
106 True if the model supports embeddings, False if it does not,
107 None if the provider cannot determine this.
108 """
109 return None
111 @classmethod
112 def get_model_info(cls, model: str) -> Optional[Dict[str, Any]]:
113 """
114 Get information about a specific model.
116 Args:
117 model: Model identifier
119 Returns:
120 Dict with model metadata (dimensions, description, etc.) or None
121 """
122 return None
124 @classmethod
125 def validate_config(
126 cls, settings_snapshot: Optional[Dict[str, Any]] = None
127 ) -> tuple[bool, Optional[str]]:
128 """
129 Validate the provider configuration.
131 Args:
132 settings_snapshot: Optional settings snapshot
134 Returns:
135 Tuple of (is_valid, error_message)
136 """
137 if not cls.is_available(settings_snapshot):
138 return (
139 False,
140 f"{cls.provider_name} is not available or not configured",
141 )
142 return True, None
144 @classmethod
145 def get_provider_info(cls) -> Dict[str, Any]:
146 """
147 Get metadata about this provider.
149 Returns:
150 Dict with provider information
151 """
152 return {
153 "name": cls.provider_name,
154 "key": cls.provider_key,
155 "requires_api_key": cls.requires_api_key,
156 "supports_local": cls.supports_local,
157 "default_model": cls.default_model,
158 }