Coverage for src / local_deep_research / llm / providers / auto_discovery.py: 96%

103 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 23:55 +0000

1"""Auto-discovery system for OpenAI-compatible providers.""" 

2 

3import importlib 

4import inspect 

5from pathlib import Path 

6from typing import Dict, List, Optional 

7 

8from loguru import logger 

9 

10from .base import BaseLLMProvider, normalize_provider 

11from .openai_base import OpenAICompatibleProvider 

12from ..llm_registry import register_llm 

13 

14 

15class ProviderInfo: 

16 """Information about a discovered provider.""" 

17 

18 def __init__(self, provider_class): 

19 self.provider_class = provider_class 

20 self.provider_key = getattr( 

21 provider_class, 

22 "provider_key", 

23 provider_class.__name__.replace("Provider", "").upper(), 

24 ) 

25 self.provider_name = provider_class.provider_name 

26 self.company_name = getattr( 

27 provider_class, "company_name", provider_class.provider_name 

28 ) 

29 self.is_cloud = getattr(provider_class, "is_cloud", True) 

30 self.requires_auth_for_models = ( 

31 provider_class.requires_auth_for_models() 

32 ) 

33 

34 # Generate display name from attributes 

35 self.display_name = self._generate_display_name() 

36 

37 def _generate_display_name(self): 

38 """Generate a descriptive display name from provider attributes.""" 

39 # Start with the provider name 

40 name_parts = [self.provider_name] 

41 

42 # Add cloud/local indicator 

43 if self.is_cloud is True: 

44 name_parts.append("☁️ Cloud") 

45 elif self.is_cloud is False: 

46 name_parts.append("💻 Local") 

47 

48 return " ".join(name_parts) 

49 

50 def to_dict(self): 

51 """Convert to dictionary for API responses.""" 

52 return { 

53 "value": self.provider_key, 

54 "label": self.display_name, 

55 "is_cloud": self.is_cloud, 

56 } 

57 

58 

59class ProviderDiscovery: 

60 """Discovers and manages OpenAI-compatible providers.""" 

61 

62 _instance = None 

63 _providers: Dict[str, ProviderInfo] = {} 

64 _discovered: bool = False 

65 

66 def __new__(cls): 

67 if cls._instance is None: 

68 cls._instance = super().__new__(cls) 

69 cls._instance._discovered = False 

70 return cls._instance 

71 

72 def discover_providers( 

73 self, force_refresh: bool = False 

74 ) -> Dict[str, ProviderInfo]: 

75 """Discover all providers in the providers directory. 

76 

77 Args: 

78 force_refresh: Force re-discovery even if already done 

79 

80 Returns: 

81 Dictionary mapping provider keys to ProviderInfo objects 

82 """ 

83 if self._discovered and not force_refresh: 

84 return self._providers 

85 

86 self._providers.clear() 

87 # Scan the implementations subdirectory for providers 

88 implementations_dir = Path(__file__).parent / "implementations" 

89 

90 if not implementations_dir.exists(): 

91 logger.warning( 

92 f"Implementations directory not found: {implementations_dir}" 

93 ) 

94 return self._providers 

95 

96 # Scan all Python files in the implementations directory 

97 logger.info(f"Scanning directory: {implementations_dir}") 

98 for file_path in implementations_dir.glob("*.py"): 

99 # Skip special files (like __init__.py) 

100 if file_path.name.startswith("_"): 

101 continue 

102 

103 module_name = file_path.stem 

104 logger.debug(f"Processing module: {module_name} from {file_path}") 

105 try: 

106 # Import the module from implementations subdirectory 

107 module = importlib.import_module( 

108 f".implementations.{module_name}", 

109 package="local_deep_research.llm.providers", 

110 ) 

111 

112 # Find all Provider classes (both OpenAICompatibleProvider and standalone) 

113 logger.debug( 

114 f"Inspecting module {module_name} for Provider classes" 

115 ) 

116 for name, obj in inspect.getmembers(module, inspect.isclass): 

117 if inspect.isclass(obj): 117 ↛ 122line 117 didn't jump to line 122 because the condition on line 117 was always true

118 logger.debug( 

119 f" Found class: {name}, bases: {obj.__bases__}" 

120 ) 

121 # Check if it's a Provider class (ends with "Provider" and has provider_name) 

122 if ( 

123 name.endswith("Provider") 

124 and hasattr(obj, "provider_name") 

125 and issubclass(obj, BaseLLMProvider) 

126 and obj is not OpenAICompatibleProvider 

127 and obj is not BaseLLMProvider 

128 ): 

129 # Found a provider class 

130 provider_info = ProviderInfo(obj) 

131 self._providers[provider_info.provider_key] = ( 

132 provider_info 

133 ) 

134 

135 # Auto-register the provider directly using the class 

136 register_llm( 

137 normalize_provider(provider_info.provider_key), 

138 obj.create_llm, 

139 ) 

140 logger.info( 

141 f"Auto-registered provider: {provider_info.provider_key}" 

142 ) 

143 

144 logger.info( 

145 f"Discovered provider: {provider_info.provider_key} from {module_name}.py" 

146 ) 

147 

148 except Exception: 

149 logger.exception(f"Error loading provider from {module_name}") 

150 

151 self._discovered = True 

152 logger.info(f"Discovered {len(self._providers)} providers") 

153 return self._providers 

154 

155 def get_provider_info(self, provider_key: str) -> Optional[ProviderInfo]: 

156 """Get information about a specific provider. 

157 

158 Args: 

159 provider_key: The provider key (e.g., 'IONOS', 'GOOGLE') 

160 

161 Returns: 

162 ProviderInfo object or None if not found 

163 """ 

164 if not self._discovered: 

165 self.discover_providers() 

166 return self._providers.get(provider_key.upper()) 

167 

168 def get_provider_options(self) -> List[Dict]: 

169 """Get list of provider options for UI dropdowns. 

170 

171 Returns: 

172 List of dictionaries with 'value' and 'label' keys 

173 """ 

174 if not self._discovered: 174 ↛ 175line 174 didn't jump to line 175 because the condition on line 174 was never true

175 self.discover_providers() 

176 

177 options = [] 

178 for provider_info in self._providers.values(): 

179 options.append(provider_info.to_dict()) 

180 

181 # Sort by label 

182 options.sort(key=lambda x: x["label"]) 

183 return options 

184 

185 def get_available_provider_options( 

186 self, settings_snapshot=None 

187 ) -> List[Dict]: 

188 """Get list of available provider options, filtered by availability. 

189 

190 Filters out providers that are not available (e.g., missing API keys). 

191 Useful for contexts where only usable providers should be shown 

192 (e.g., starting a research). For settings/configuration UIs, prefer 

193 get_provider_options() so users can discover and configure new providers. 

194 

195 Args: 

196 settings_snapshot: Settings snapshot for checking provider availability. 

197 Should be provided to correctly check cloud provider API keys. 

198 

199 Returns: 

200 List of dictionaries with 'value' and 'label' keys 

201 """ 

202 if not self._discovered: 202 ↛ 203line 202 didn't jump to line 203 because the condition on line 202 was never true

203 self.discover_providers() 

204 

205 options = [] 

206 for provider_info in self._providers.values(): 

207 if not provider_info.provider_class.is_available( 

208 settings_snapshot=settings_snapshot 

209 ): 

210 logger.debug( 

211 f"Provider {provider_info.provider_key} filtered out " 

212 f"(not available)" 

213 ) 

214 continue 

215 options.append(provider_info.to_dict()) 

216 

217 if not options: 

218 logger.warning( 

219 "No auto-discovered providers passed availability filter. " 

220 "Check that API keys are configured for cloud providers." 

221 ) 

222 

223 # Sort by label 

224 options.sort(key=lambda x: x["label"]) 

225 return options 

226 

227 def get_provider_class(self, provider_key: str): 

228 """Get the provider class for a given key. 

229 

230 Args: 

231 provider_key: The provider key (e.g., 'IONOS', 'GOOGLE') 

232 

233 Returns: 

234 Provider class or None if not found 

235 """ 

236 provider_info = self.get_provider_info(provider_key) 

237 return provider_info.provider_class if provider_info else None 

238 

239 

240# Global instance 

241provider_discovery = ProviderDiscovery() 

242 

243 

244def discover_providers(force_refresh: bool = False) -> Dict[str, ProviderInfo]: 

245 """Discover all available providers. 

246 

247 Args: 

248 force_refresh: Force re-discovery even if already done 

249 

250 Returns: 

251 Dictionary mapping provider keys to ProviderInfo objects 

252 """ 

253 return provider_discovery.discover_providers(force_refresh) 

254 

255 

256def get_discovered_provider_options() -> List[Dict]: 

257 """Get list of discovered provider options for UI dropdowns. 

258 

259 Returns: 

260 List of dictionaries with 'value' and 'label' keys 

261 """ 

262 return provider_discovery.get_provider_options() 

263 

264 

265def get_available_discovered_provider_options( 

266 settings_snapshot=None, 

267) -> List[Dict]: 

268 """Get list of available provider options, filtered by availability. 

269 

270 Only returns providers that pass is_available() check. Useful for 

271 contexts where only usable providers matter (e.g., starting a research). 

272 For settings/configuration UIs, use get_discovered_provider_options() 

273 instead so users can discover and configure new providers. 

274 

275 Args: 

276 settings_snapshot: Settings snapshot for checking provider availability 

277 

278 Returns: 

279 List of dictionaries with 'value' and 'label' keys 

280 """ 

281 return provider_discovery.get_available_provider_options(settings_snapshot) 

282 

283 

284def get_provider_class(provider_key: str): 

285 """Get the provider class for a given key. 

286 

287 Args: 

288 provider_key: The provider key (e.g., 'IONOS', 'GOOGLE') 

289 

290 Returns: 

291 Provider class or None if not found 

292 """ 

293 return provider_discovery.get_provider_class(provider_key)