Coverage for src / local_deep_research / embeddings / providers / implementations / sentence_transformers.py: 81%

28 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2026-01-11 00:51 +0000

1"""Sentence Transformers embedding provider.""" 

2 

3from typing import Any, Dict, List, Optional 

4 

5from langchain_core.embeddings import Embeddings 

6from loguru import logger 

7 

8from ....config.thread_settings import get_setting_from_snapshot 

9from ..base import BaseEmbeddingProvider 

10 

11 

12class SentenceTransformersProvider(BaseEmbeddingProvider): 

13 """ 

14 Sentence Transformers embedding provider. 

15 

16 Uses HuggingFace sentence-transformers models for local embeddings. 

17 No API key required, runs entirely locally. 

18 """ 

19 

20 provider_name = "Sentence Transformers" 

21 provider_key = "SENTENCE_TRANSFORMERS" 

22 requires_api_key = False 

23 supports_local = True 

24 default_model = "all-MiniLM-L6-v2" 

25 

26 # Available models with metadata 

27 AVAILABLE_MODELS = { 

28 "all-MiniLM-L6-v2": { 

29 "dimensions": 384, 

30 "description": "Fast, lightweight model. Good for general use.", 

31 "max_seq_length": 256, 

32 }, 

33 "all-mpnet-base-v2": { 

34 "dimensions": 768, 

35 "description": "Higher quality, slower. Best accuracy.", 

36 "max_seq_length": 384, 

37 }, 

38 "multi-qa-MiniLM-L6-cos-v1": { 

39 "dimensions": 384, 

40 "description": "Optimized for question-answering tasks.", 

41 "max_seq_length": 512, 

42 }, 

43 "paraphrase-multilingual-MiniLM-L12-v2": { 

44 "dimensions": 384, 

45 "description": "Supports multiple languages.", 

46 "max_seq_length": 128, 

47 }, 

48 } 

49 

50 @classmethod 

51 def create_embeddings( 

52 cls, 

53 model: Optional[str] = None, 

54 settings_snapshot: Optional[Dict[str, Any]] = None, 

55 **kwargs, 

56 ) -> Embeddings: 

57 """ 

58 Create Sentence Transformers embeddings instance. 

59 

60 Args: 

61 model: Model name (defaults to all-MiniLM-L6-v2) 

62 settings_snapshot: Optional settings snapshot 

63 **kwargs: Additional parameters (device, etc.) 

64 

65 Returns: 

66 SentenceTransformerEmbeddings instance 

67 """ 

68 from langchain_community.embeddings import ( 

69 SentenceTransformerEmbeddings, 

70 ) 

71 

72 # Get model from settings if not specified 

73 if model is None: 73 ↛ 74line 73 didn't jump to line 74 because the condition on line 73 was never true

74 model = get_setting_from_snapshot( 

75 "embeddings.sentence_transformers.model", 

76 default=cls.default_model, 

77 settings_snapshot=settings_snapshot, 

78 ) 

79 

80 # Get device setting (cpu or cuda) 

81 device = kwargs.get("device") 

82 if device is None: 82 ↛ 83line 82 didn't jump to line 83 because the condition on line 82 was never true

83 device = get_setting_from_snapshot( 

84 "embeddings.sentence_transformers.device", 

85 default="cpu", 

86 settings_snapshot=settings_snapshot, 

87 ) 

88 

89 logger.info( 

90 f"Creating SentenceTransformerEmbeddings with model={model}, device={device}" 

91 ) 

92 

93 return SentenceTransformerEmbeddings( 

94 model_name=model, 

95 model_kwargs={"device": device}, 

96 ) 

97 

98 @classmethod 

99 def is_available( 

100 cls, settings_snapshot: Optional[Dict[str, Any]] = None 

101 ) -> bool: 

102 """ 

103 Check if Sentence Transformers is available. 

104 

105 Since sentence-transformers is a required dependency, this always returns True. 

106 This method exists for API consistency with other providers. 

107 """ 

108 return True 

109 

110 @classmethod 

111 def get_available_models( 

112 cls, settings_snapshot: Optional[Dict[str, Any]] = None 

113 ) -> List[Dict[str, str]]: 

114 """ 

115 Get list of available Sentence Transformer models. 

116 

117 Note: Since there's no centralized API for Sentence Transformers, 

118 we return a curated list of commonly used models. Users can also 

119 specify any model name from HuggingFace directly in settings. 

120 """ 

121 return [ 

122 { 

123 "value": model, 

124 "label": f"{model} ({info['dimensions']}d) - {info['description']}", 

125 } 

126 for model, info in cls.AVAILABLE_MODELS.items() 

127 ]