Coverage for src / local_deep_research / embeddings / providers / base.py: 92%

34 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-25 01:07 +0000

1"""Base class for embedding providers.""" 

2 

3from abc import ABC, abstractmethod 

4from typing import Any, Dict, List, Optional 

5 

6from langchain_core.embeddings import Embeddings 

7 

8 

9class BaseEmbeddingProvider(ABC): 

10 """ 

11 Abstract base class for embedding providers. 

12 

13 All embedding providers should inherit from this class and implement 

14 the required methods. This provides a consistent interface similar to 

15 the LLM provider system. 

16 """ 

17 

18 # Override these in subclasses 

19 provider_name = "base" # Display name for logs/UI 

20 provider_key = "BASE" # Unique identifier (uppercase) 

21 requires_api_key = False # Whether this provider requires an API key 

22 supports_local = False # Whether this runs locally 

23 default_model = None # Default embedding model 

24 

25 @classmethod 

26 @abstractmethod 

27 def create_embeddings( 

28 cls, 

29 model: Optional[str] = None, 

30 settings_snapshot: Optional[Dict[str, Any]] = None, 

31 **kwargs, 

32 ) -> Embeddings: 

33 """ 

34 Create an embeddings instance for this provider. 

35 

36 Args: 

37 model: Name of the embedding model to use 

38 settings_snapshot: Optional settings snapshot for thread-safe access 

39 **kwargs: Additional provider-specific parameters 

40 

41 Returns: 

42 A LangChain Embeddings instance 

43 

44 Raises: 

45 ValueError: If required configuration is missing 

46 ImportError: If required dependencies are not installed 

47 """ 

48 pass 

49 

50 @classmethod 

51 @abstractmethod 

52 def is_available( 

53 cls, settings_snapshot: Optional[Dict[str, Any]] = None 

54 ) -> bool: 

55 """ 

56 Check if this embedding provider is available and properly configured. 

57 

58 Args: 

59 settings_snapshot: Optional settings snapshot for thread-safe access 

60 

61 Returns: 

62 True if the provider can be used, False otherwise 

63 """ 

64 pass 

65 

66 @classmethod 

67 def get_available_models( 

68 cls, settings_snapshot: Optional[Dict[str, Any]] = None 

69 ) -> List[Dict[str, str]]: 

70 """ 

71 Get list of available embedding models for this provider. 

72 

73 Implementations should return only models that support embeddings, 

74 filtering out chat/completion-only models where applicable. 

75 

76 Args: 

77 settings_snapshot: Optional settings snapshot 

78 

79 Returns: 

80 List of dicts with 'value' and 'label' keys for each model. 

81 May include optional 'is_embedding' (bool) key. 

82 """ 

83 return [] 

84 

85 @classmethod 

86 def is_embedding_model( 

87 cls, 

88 model: str, 

89 settings_snapshot: Optional[Dict[str, Any]] = None, 

90 ) -> Optional[bool]: 

91 """ 

92 Check whether a specific model supports embeddings. 

93 

94 Providers that can distinguish embedding models from chat/LLM models 

95 should override this method. 

96 

97 Args: 

98 model: Model identifier 

99 settings_snapshot: Optional settings snapshot 

100 

101 Returns: 

102 True if the model supports embeddings, False if it does not, 

103 None if the provider cannot determine this. 

104 """ 

105 return None 

106 

107 @classmethod 

108 def get_model_info(cls, model: str) -> Optional[Dict[str, Any]]: 

109 """ 

110 Get information about a specific model. 

111 

112 Args: 

113 model: Model identifier 

114 

115 Returns: 

116 Dict with model metadata (dimensions, description, etc.) or None 

117 """ 

118 return None 

119 

120 @classmethod 

121 def validate_config( 

122 cls, settings_snapshot: Optional[Dict[str, Any]] = None 

123 ) -> tuple[bool, Optional[str]]: 

124 """ 

125 Validate the provider configuration. 

126 

127 Args: 

128 settings_snapshot: Optional settings snapshot 

129 

130 Returns: 

131 Tuple of (is_valid, error_message) 

132 """ 

133 if not cls.is_available(settings_snapshot): 

134 return ( 

135 False, 

136 f"{cls.provider_name} is not available or not configured", 

137 ) 

138 return True, None 

139 

140 @classmethod 

141 def get_provider_info(cls) -> Dict[str, Any]: 

142 """ 

143 Get metadata about this provider. 

144 

145 Returns: 

146 Dict with provider information 

147 """ 

148 return { 

149 "name": cls.provider_name, 

150 "key": cls.provider_key, 

151 "requires_api_key": cls.requires_api_key, 

152 "supports_local": cls.supports_local, 

153 "default_model": cls.default_model, 

154 }