Coverage for src / local_deep_research / llm / providers / auto_discovery.py: 90%

107 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2026-01-11 00:51 +0000

1"""Auto-discovery system for OpenAI-compatible providers.""" 

2 

3import importlib 

4import inspect 

5from pathlib import Path 

6from typing import Dict, List, Optional 

7 

8from loguru import logger 

9 

10from .openai_base import OpenAICompatibleProvider 

11 

12 

13class ProviderInfo: 

14 """Information about a discovered provider.""" 

15 

16 def __init__(self, provider_class): 

17 self.provider_class = provider_class 

18 self.provider_key = getattr( 

19 provider_class, 

20 "provider_key", 

21 provider_class.__name__.replace("Provider", "").upper(), 

22 ) 

23 self.provider_name = provider_class.provider_name 

24 self.company_name = getattr( 

25 provider_class, "company_name", provider_class.provider_name 

26 ) 

27 self.region = getattr(provider_class, "region", "Unknown") 

28 self.country = getattr(provider_class, "country", "Unknown") 

29 self.gdpr_compliant = getattr(provider_class, "gdpr_compliant", False) 

30 self.data_location = getattr(provider_class, "data_location", "Unknown") 

31 self.is_cloud = getattr(provider_class, "is_cloud", True) 

32 # Handle providers that may not have requires_auth_for_models method 

33 if hasattr(provider_class, "requires_auth_for_models"): 

34 self.requires_auth_for_models = ( 

35 provider_class.requires_auth_for_models() 

36 ) 

37 else: 

38 # Default to True for providers without the method 

39 self.requires_auth_for_models = True 

40 

41 # Generate display name from attributes 

42 self.display_name = self._generate_display_name() 

43 

44 def _generate_display_name(self): 

45 """Generate a descriptive display name from provider attributes.""" 

46 # Start with the provider name 

47 name_parts = [self.provider_name] 

48 

49 # Add detailed location info 

50 location_parts = [] 

51 

52 # Add region 

53 if self.region != "Unknown": 

54 location_parts.append(self.region) 

55 

56 # Add specific data location if different from region 

57 if self.data_location != "Unknown": 

58 if self.data_location in ["Multiple", "Worldwide"]: 

59 location_parts.append("Data: Worldwide") 

60 elif self.data_location != self.country: 

61 location_parts.append(f"Data: {self.data_location}") 

62 

63 # Combine location info 

64 if location_parts: 

65 name_parts.append(f"({', '.join(location_parts)})") 

66 

67 # Only highlight GDPR compliance for EU-based providers as a special feature 

68 if self.gdpr_compliant and self.region == "EU": 

69 name_parts.append("🔒 GDPR") 

70 

71 # Add cloud/local indicator 

72 if self.is_cloud: 

73 name_parts.append("☁️ Cloud") 

74 else: 

75 name_parts.append("💻 Local") 

76 

77 return " ".join(name_parts) 

78 

79 def to_dict(self): 

80 """Convert to dictionary for API responses.""" 

81 return { 

82 "value": self.provider_key, 

83 "label": self.display_name, 

84 "is_cloud": self.is_cloud, 

85 "region": self.region, 

86 "country": self.country, 

87 "gdpr_compliant": self.gdpr_compliant, 

88 "data_location": self.data_location, 

89 } 

90 

91 

92class ProviderDiscovery: 

93 """Discovers and manages OpenAI-compatible providers.""" 

94 

95 _instance = None 

96 _providers: Dict[str, ProviderInfo] = {} 

97 

98 def __new__(cls): 

99 if cls._instance is None: 

100 cls._instance = super().__new__(cls) 

101 cls._instance._discovered = False 

102 return cls._instance 

103 

104 def discover_providers( 

105 self, force_refresh: bool = False 

106 ) -> Dict[str, ProviderInfo]: 

107 """Discover all providers in the providers directory. 

108 

109 Args: 

110 force_refresh: Force re-discovery even if already done 

111 

112 Returns: 

113 Dictionary mapping provider keys to ProviderInfo objects 

114 """ 

115 if self._discovered and not force_refresh: 

116 return self._providers 

117 

118 self._providers.clear() 

119 # Scan the implementations subdirectory for providers 

120 implementations_dir = Path(__file__).parent / "implementations" 

121 

122 if not implementations_dir.exists(): 122 ↛ 123line 122 didn't jump to line 123 because the condition on line 122 was never true

123 logger.warning( 

124 f"Implementations directory not found: {implementations_dir}" 

125 ) 

126 return self._providers 

127 

128 # Scan all Python files in the implementations directory 

129 logger.info(f"Scanning directory: {implementations_dir}") 

130 for file_path in implementations_dir.glob("*.py"): 

131 # Skip special files (like __init__.py) 

132 if file_path.name.startswith("_"): 132 ↛ 133line 132 didn't jump to line 133 because the condition on line 132 was never true

133 continue 

134 

135 module_name = file_path.stem 

136 logger.debug(f"Processing module: {module_name} from {file_path}") 

137 try: 

138 # Import the module from implementations subdirectory 

139 module = importlib.import_module( 

140 f".implementations.{module_name}", 

141 package="local_deep_research.llm.providers", 

142 ) 

143 

144 # Find all Provider classes (both OpenAICompatibleProvider and standalone) 

145 logger.debug( 

146 f"Inspecting module {module_name} for Provider classes" 

147 ) 

148 for name, obj in inspect.getmembers(module, inspect.isclass): 

149 if inspect.isclass(obj): 149 ↛ 154line 149 didn't jump to line 154 because the condition on line 149 was always true

150 logger.debug( 

151 f" Found class: {name}, bases: {obj.__bases__}" 

152 ) 

153 # Check if it's a Provider class (ends with "Provider" and has provider_name) 

154 if ( 

155 name.endswith("Provider") 

156 and hasattr(obj, "provider_name") 

157 and obj is not OpenAICompatibleProvider 

158 ): 

159 # Found a provider class 

160 provider_info = ProviderInfo(obj) 

161 self._providers[provider_info.provider_key] = ( 

162 provider_info 

163 ) 

164 

165 # Auto-register the provider 

166 register_func_name = f"register_{module_name}_provider" 

167 try: 

168 register_func = getattr(module, register_func_name) 

169 register_func() 

170 logger.info( 

171 f"Auto-registered provider: {provider_info.provider_key}" 

172 ) 

173 except AttributeError: 

174 logger.warning( 

175 f"Provider {provider_info.provider_key} from {module_name}.py " 

176 f"does not have a {register_func_name} function" 

177 ) 

178 

179 logger.info( 

180 f"Discovered provider: {provider_info.provider_key} from {module_name}.py" 

181 ) 

182 

183 except Exception as e: 

184 logger.exception( 

185 f"Error loading provider from {module_name}: {e}" 

186 ) 

187 

188 self._discovered = True 

189 logger.info(f"Discovered {len(self._providers)} providers") 

190 return self._providers 

191 

192 def get_provider_info(self, provider_key: str) -> Optional[ProviderInfo]: 

193 """Get information about a specific provider. 

194 

195 Args: 

196 provider_key: The provider key (e.g., 'IONOS', 'GOOGLE') 

197 

198 Returns: 

199 ProviderInfo object or None if not found 

200 """ 

201 if not self._discovered: 201 ↛ 202line 201 didn't jump to line 202 because the condition on line 201 was never true

202 self.discover_providers() 

203 return self._providers.get(provider_key.upper()) 

204 

205 def get_provider_options(self) -> List[Dict]: 

206 """Get list of provider options for UI dropdowns. 

207 

208 Returns: 

209 List of dictionaries with 'value' and 'label' keys 

210 """ 

211 if not self._discovered: 211 ↛ 212line 211 didn't jump to line 212 because the condition on line 211 was never true

212 self.discover_providers() 

213 

214 options = [] 

215 for provider_info in self._providers.values(): 

216 options.append(provider_info.to_dict()) 

217 

218 # Sort by label 

219 options.sort(key=lambda x: x["label"]) 

220 return options 

221 

222 def get_provider_class(self, provider_key: str): 

223 """Get the provider class for a given key. 

224 

225 Args: 

226 provider_key: The provider key (e.g., 'IONOS', 'GOOGLE') 

227 

228 Returns: 

229 Provider class or None if not found 

230 """ 

231 provider_info = self.get_provider_info(provider_key) 

232 return provider_info.provider_class if provider_info else None 

233 

234 

235# Global instance 

236provider_discovery = ProviderDiscovery() 

237 

238 

239def discover_providers(force_refresh: bool = False) -> Dict[str, ProviderInfo]: 

240 """Discover all available providers. 

241 

242 Args: 

243 force_refresh: Force re-discovery even if already done 

244 

245 Returns: 

246 Dictionary mapping provider keys to ProviderInfo objects 

247 """ 

248 return provider_discovery.discover_providers(force_refresh) 

249 

250 

251def get_discovered_provider_options() -> List[Dict]: 

252 """Get list of discovered provider options for UI dropdowns. 

253 

254 Returns: 

255 List of dictionaries with 'value' and 'label' keys 

256 """ 

257 return provider_discovery.get_provider_options() 

258 

259 

260def get_provider_class(provider_key: str): 

261 """Get the provider class for a given key. 

262 

263 Args: 

264 provider_key: The provider key (e.g., 'IONOS', 'GOOGLE') 

265 

266 Returns: 

267 Provider class or None if not found 

268 """ 

269 return provider_discovery.get_provider_class(provider_key)