Coverage for src / local_deep_research / embeddings / embeddings_config.py: 60%

49 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2026-01-11 00:51 +0000

1""" 

2Central configuration for embedding providers. 

3 

4This module provides the main get_embeddings() function and availability checks 

5for different embedding providers, similar to llm_config.py. 

6""" 

7 

8from typing import Any, Dict, Optional, Type 

9 

10from langchain_core.embeddings import Embeddings 

11from loguru import logger 

12 

13from ..config.thread_settings import get_setting_from_snapshot 

14from .providers.base import BaseEmbeddingProvider 

15 

16# Valid embedding provider options 

17VALID_EMBEDDING_PROVIDERS = [ 

18 "sentence_transformers", 

19 "ollama", 

20 "openai", 

21] 

22 

23# Lazy-loaded provider classes dict 

24_PROVIDER_CLASSES: Optional[Dict[str, Type[BaseEmbeddingProvider]]] = None 

25 

26 

27def _get_provider_classes() -> Dict[str, Type[BaseEmbeddingProvider]]: 

28 """Lazy load provider classes to avoid circular imports.""" 

29 global _PROVIDER_CLASSES 

30 if _PROVIDER_CLASSES is None: 30 ↛ 42line 30 didn't jump to line 42 because the condition on line 30 was always true

31 from .providers.implementations.sentence_transformers import ( 

32 SentenceTransformersProvider, 

33 ) 

34 from .providers.implementations.ollama import OllamaEmbeddingsProvider 

35 from .providers.implementations.openai import OpenAIEmbeddingsProvider 

36 

37 _PROVIDER_CLASSES = { 

38 "sentence_transformers": SentenceTransformersProvider, 

39 "ollama": OllamaEmbeddingsProvider, 

40 "openai": OpenAIEmbeddingsProvider, 

41 } 

42 return _PROVIDER_CLASSES 

43 

44 

45def is_sentence_transformers_available() -> bool: 

46 """Check if Sentence Transformers is available.""" 

47 provider_classes = _get_provider_classes() 

48 return provider_classes["sentence_transformers"].is_available() 

49 

50 

51def is_ollama_embeddings_available( 

52 settings_snapshot: Optional[Dict[str, Any]] = None, 

53) -> bool: 

54 """Check if Ollama embeddings are available.""" 

55 provider_classes = _get_provider_classes() 

56 return provider_classes["ollama"].is_available(settings_snapshot) 

57 

58 

59def is_openai_embeddings_available( 

60 settings_snapshot: Optional[Dict[str, Any]] = None, 

61) -> bool: 

62 """Check if OpenAI embeddings are available.""" 

63 provider_classes = _get_provider_classes() 

64 return provider_classes["openai"].is_available(settings_snapshot) 

65 

66 

67def get_available_embedding_providers( 

68 settings_snapshot: Optional[Dict[str, Any]] = None, 

69) -> Dict[str, str]: 

70 """ 

71 Return available embedding providers. 

72 

73 Args: 

74 settings_snapshot: Optional settings snapshot 

75 

76 Returns: 

77 Dict mapping provider keys to display names 

78 """ 

79 providers = {} 

80 

81 if is_sentence_transformers_available(): 

82 providers["sentence_transformers"] = "Sentence Transformers (Local)" 

83 

84 if is_ollama_embeddings_available(settings_snapshot): 

85 providers["ollama"] = "Ollama (Local)" 

86 

87 if is_openai_embeddings_available(settings_snapshot): 

88 providers["openai"] = "OpenAI API" 

89 

90 return providers 

91 

92 

93def get_embedding_function( 

94 provider: Optional[str] = None, 

95 model_name: Optional[str] = None, 

96 settings_snapshot: Optional[Dict[str, Any]] = None, 

97 **kwargs, 

98): 

99 """ 

100 Get a callable embedding function that can embed texts. 

101 

102 Args: 

103 provider: Embedding provider to use 

104 model_name: Model name to use 

105 settings_snapshot: Optional settings snapshot 

106 **kwargs: Additional provider-specific parameters 

107 

108 Returns: 

109 A callable that takes a list of texts and returns embeddings 

110 """ 

111 embeddings = get_embeddings( 

112 provider=provider, 

113 model=model_name, 

114 settings_snapshot=settings_snapshot, 

115 **kwargs, 

116 ) 

117 return embeddings.embed_documents 

118 

119 

120def get_embeddings( 

121 provider: Optional[str] = None, 

122 model: Optional[str] = None, 

123 settings_snapshot: Optional[Dict[str, Any]] = None, 

124 **kwargs, 

125) -> Embeddings: 

126 """ 

127 Get embeddings instance based on provider and model. 

128 

129 Args: 

130 provider: Embedding provider to use (if None, uses settings) 

131 model: Model name to use (if None, uses settings or provider default) 

132 settings_snapshot: Optional settings snapshot for thread-safe access 

133 **kwargs: Additional provider-specific parameters 

134 

135 Returns: 

136 A LangChain Embeddings instance 

137 

138 Raises: 

139 ValueError: If provider is invalid or not available 

140 ImportError: If required dependencies are not installed 

141 """ 

142 # Get provider from settings if not specified 

143 if provider is None: 143 ↛ 144line 143 didn't jump to line 144 because the condition on line 143 was never true

144 provider = get_setting_from_snapshot( 

145 "embeddings.provider", 

146 default="sentence_transformers", 

147 settings_snapshot=settings_snapshot, 

148 ) 

149 

150 # Clean and normalize provider 

151 if provider: 151 ↛ 155line 151 didn't jump to line 155 because the condition on line 151 was always true

152 provider = provider.strip().strip("\"'").strip().lower() 

153 

154 # Validate provider 

155 if provider not in VALID_EMBEDDING_PROVIDERS: 

156 logger.error(f"Invalid embedding provider: {provider}") 

157 raise ValueError( 

158 f"Invalid embedding provider: {provider}. " 

159 f"Must be one of: {VALID_EMBEDDING_PROVIDERS}" 

160 ) 

161 

162 logger.info(f"Getting embeddings with provider: {provider}, model: {model}") 

163 

164 # Get provider class and create embeddings 

165 provider_classes = _get_provider_classes() 

166 provider_class = provider_classes.get(provider) 

167 

168 if not provider_class: 168 ↛ 169line 168 didn't jump to line 169 because the condition on line 168 was never true

169 raise ValueError(f"Unsupported embedding provider: {provider}") 

170 

171 return provider_class.create_embeddings( 

172 model=model, settings_snapshot=settings_snapshot, **kwargs 

173 )