Coverage for src/local_deep_research/embeddings/embeddings_config.py: 95%

49 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-03 23:15 +0000

1""" 

2Central configuration for embedding providers. 

3 

4This module provides the main get_embeddings() function and availability checks 

5for different embedding providers, similar to llm_config.py. 

6""" 

7 

8from typing import Any, Dict, Optional, Type 

9 

10from langchain_core.embeddings import Embeddings 

11from loguru import logger 

12 

13from ..config.thread_settings import get_setting_from_snapshot 

14from .providers.base import BaseEmbeddingProvider 

15 

16# Internal: list of provider strings accepted by get_embeddings(). 

17# Not re-exported from embeddings/__init__.py — call sites should not import it. 

18# Kept module-level so the validation error message at L~163 can list options. 

19VALID_EMBEDDING_PROVIDERS = [ 

20 "sentence_transformers", 

21 "ollama", 

22 "openai", 

23] 

24 

25# Lazy-loaded provider classes dict 

26_PROVIDER_CLASSES: Optional[Dict[str, Type[BaseEmbeddingProvider]]] = None 

27 

28 

29def _get_provider_classes() -> Dict[str, Type[BaseEmbeddingProvider]]: 

30 """Lazy load provider classes to avoid circular imports.""" 

31 global _PROVIDER_CLASSES 

32 if _PROVIDER_CLASSES is None: 

33 from .providers.implementations.sentence_transformers import ( 

34 SentenceTransformersProvider, 

35 ) 

36 from .providers.implementations.ollama import OllamaEmbeddingsProvider 

37 from .providers.implementations.openai import OpenAIEmbeddingsProvider 

38 

39 _PROVIDER_CLASSES = { 

40 "sentence_transformers": SentenceTransformersProvider, 

41 "ollama": OllamaEmbeddingsProvider, 

42 "openai": OpenAIEmbeddingsProvider, 

43 } 

44 return _PROVIDER_CLASSES 

45 

46 

47def is_sentence_transformers_available() -> bool: 

48 """Check if Sentence Transformers is available.""" 

49 provider_classes = _get_provider_classes() 

50 return provider_classes["sentence_transformers"].is_available() 

51 

52 

53def is_ollama_embeddings_available( 

54 settings_snapshot: Optional[Dict[str, Any]] = None, 

55) -> bool: 

56 """Check if Ollama embeddings are available.""" 

57 provider_classes = _get_provider_classes() 

58 return provider_classes["ollama"].is_available(settings_snapshot) 

59 

60 

61def is_openai_embeddings_available( 

62 settings_snapshot: Optional[Dict[str, Any]] = None, 

63) -> bool: 

64 """Check if OpenAI embeddings are available.""" 

65 provider_classes = _get_provider_classes() 

66 return provider_classes["openai"].is_available(settings_snapshot) 

67 

68 

69def get_available_embedding_providers( 

70 settings_snapshot: Optional[Dict[str, Any]] = None, 

71) -> Dict[str, str]: 

72 """ 

73 Return available embedding providers. 

74 

75 Args: 

76 settings_snapshot: Optional settings snapshot 

77 

78 Returns: 

79 Dict mapping provider keys to display names 

80 """ 

81 providers = {} 

82 

83 if is_sentence_transformers_available(): 

84 providers["sentence_transformers"] = "Sentence Transformers (Local)" 

85 

86 if is_ollama_embeddings_available(settings_snapshot): 

87 providers["ollama"] = "Ollama (Local)" 

88 

89 if is_openai_embeddings_available(settings_snapshot): 

90 # Single entry covers the OpenAI cloud API and any 

91 # OpenAI-compatible endpoint (LM Studio, vLLM, llama.cpp); 

92 # the provider class branches on 

93 # ``embeddings.openai.base_url`` at runtime. 

94 providers["openai"] = "OpenAI / OpenAI-Compatible Endpoint" 

95 

96 return providers 

97 

98 

99def get_embedding_function( 

100 provider: Optional[str] = None, 

101 model_name: Optional[str] = None, 

102 settings_snapshot: Optional[Dict[str, Any]] = None, 

103 **kwargs, 

104): 

105 """ 

106 Get a callable embedding function that can embed texts. 

107 

108 Args: 

109 provider: Embedding provider to use 

110 model_name: Model name to use 

111 settings_snapshot: Optional settings snapshot 

112 **kwargs: Additional provider-specific parameters 

113 

114 Returns: 

115 A callable that takes a list of texts and returns embeddings 

116 """ 

117 embeddings = get_embeddings( 

118 provider=provider, 

119 model=model_name, 

120 settings_snapshot=settings_snapshot, 

121 **kwargs, 

122 ) 

123 return embeddings.embed_documents 

124 

125 

126def get_embeddings( 

127 provider: Optional[str] = None, 

128 model: Optional[str] = None, 

129 settings_snapshot: Optional[Dict[str, Any]] = None, 

130 **kwargs, 

131) -> Embeddings: 

132 """ 

133 Get embeddings instance based on provider and model. 

134 

135 Args: 

136 provider: Embedding provider to use (if None, uses settings) 

137 model: Model name to use (if None, uses settings or provider default) 

138 settings_snapshot: Optional settings snapshot for thread-safe access 

139 **kwargs: Additional provider-specific parameters 

140 

141 Returns: 

142 A LangChain Embeddings instance 

143 

144 Raises: 

145 ValueError: If provider is invalid or not available 

146 ImportError: If required dependencies are not installed 

147 """ 

148 # Get provider from settings if not specified 

149 if provider is None: 

150 provider = get_setting_from_snapshot( 

151 "embeddings.provider", 

152 default="sentence_transformers", 

153 settings_snapshot=settings_snapshot, 

154 ) 

155 

156 # Clean and normalize provider 

157 if provider: 157 ↛ 161line 157 didn't jump to line 161 because the condition on line 157 was always true

158 provider = provider.strip().strip("\"'").strip().lower() 

159 

160 # Validate provider 

161 if provider not in VALID_EMBEDDING_PROVIDERS: 

162 logger.error(f"Invalid embedding provider: {provider}") 

163 raise ValueError( 

164 f"Invalid embedding provider: {provider}. " 

165 f"Must be one of: {VALID_EMBEDDING_PROVIDERS}" 

166 ) 

167 

168 logger.info(f"Getting embeddings with provider: {provider}, model: {model}") 

169 

170 # Get provider class and create embeddings 

171 provider_classes = _get_provider_classes() 

172 provider_class = provider_classes.get(provider) 

173 

174 if not provider_class: 174 ↛ 175line 174 didn't jump to line 175 because the condition on line 174 was never true

175 raise ValueError(f"Unsupported embedding provider: {provider}") 

176 

177 return provider_class.create_embeddings( 

178 model=model, settings_snapshot=settings_snapshot, **kwargs 

179 )