Coverage for src / local_deep_research / llm / llm_registry.py: 100%

49 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2026-01-11 00:51 +0000

1"""Registry for custom LangChain LLMs. 

2 

3This module provides a global registry for registering and managing custom LangChain 

4LLMs that can be used with Local Deep Research. 

5""" 

6 

7import threading 

8from typing import Callable, Dict, Optional, Union 

9 

10from langchain.chat_models.base import BaseChatModel 

11from loguru import logger 

12 

13 

14class LLMRegistry: 

15 """Thread-safe registry for custom LangChain LLMs.""" 

16 

17 def __init__(self): 

18 self._llms: Dict[ 

19 str, Union[BaseChatModel, Callable[..., BaseChatModel]] 

20 ] = {} 

21 self._lock = threading.Lock() 

22 

23 def register( 

24 self, name: str, llm: Union[BaseChatModel, Callable[..., BaseChatModel]] 

25 ) -> None: 

26 """Register a custom LLM. 

27 

28 Args: 

29 name: Unique name for the LLM (case-insensitive) 

30 llm: Either a BaseChatModel instance or a factory function that returns one 

31 """ 

32 with self._lock: 

33 # Normalize name to lowercase for case-insensitive storage 

34 normalized_name = name.lower() 

35 if normalized_name in self._llms: 

36 logger.warning(f"Overwriting existing LLM: {name}") 

37 self._llms[normalized_name] = llm 

38 logger.info( 

39 f"Registered custom LLM: {name} (normalized: {normalized_name})" 

40 ) 

41 

42 def unregister(self, name: str) -> None: 

43 """Unregister a custom LLM. 

44 

45 Args: 

46 name: Name of the LLM to unregister (case-insensitive) 

47 """ 

48 with self._lock: 

49 normalized_name = name.lower() 

50 if normalized_name in self._llms: 

51 del self._llms[normalized_name] 

52 logger.info(f"Unregistered custom LLM: {name}") 

53 

54 def get( 

55 self, name: str 

56 ) -> Optional[Union[BaseChatModel, Callable[..., BaseChatModel]]]: 

57 """Get a registered LLM. 

58 

59 Args: 

60 name: Name of the LLM to retrieve (case-insensitive) 

61 

62 Returns: 

63 The LLM instance/factory or None if not found 

64 """ 

65 with self._lock: 

66 normalized_name = name.lower() 

67 return self._llms.get(normalized_name) 

68 

69 def is_registered(self, name: str) -> bool: 

70 """Check if an LLM is registered. 

71 

72 Args: 

73 name: Name to check (case-insensitive) 

74 

75 Returns: 

76 True if registered, False otherwise 

77 """ 

78 with self._lock: 

79 normalized_name = name.lower() 

80 return normalized_name in self._llms 

81 

82 def list_registered(self) -> list[str]: 

83 """Get list of all registered LLM names. 

84 

85 Returns: 

86 List of registered LLM names 

87 """ 

88 with self._lock: 

89 return list(self._llms.keys()) 

90 

91 def clear(self) -> None: 

92 """Clear all registered LLMs.""" 

93 with self._lock: 

94 self._llms.clear() 

95 logger.info("Cleared all registered custom LLMs") 

96 

97 

98# Global registry instance 

99_llm_registry = LLMRegistry() 

100 

101 

102# Public API functions 

103def register_llm( 

104 name: str, llm: Union[BaseChatModel, Callable[..., BaseChatModel]] 

105) -> None: 

106 """Register a custom LLM in the global registry. 

107 

108 Args: 

109 name: Unique name for the LLM 

110 llm: Either a BaseChatModel instance or a factory function 

111 """ 

112 _llm_registry.register(name, llm) 

113 

114 

115def unregister_llm(name: str) -> None: 

116 """Unregister a custom LLM from the global registry. 

117 

118 Args: 

119 name: Name of the LLM to unregister 

120 """ 

121 _llm_registry.unregister(name) 

122 

123 

124def get_llm_from_registry( 

125 name: str, 

126) -> Optional[Union[BaseChatModel, Callable[..., BaseChatModel]]]: 

127 """Get a registered LLM from the global registry. 

128 

129 Args: 

130 name: Name of the LLM to retrieve 

131 

132 Returns: 

133 The LLM instance/factory or None if not found 

134 """ 

135 return _llm_registry.get(name) 

136 

137 

138def is_llm_registered(name: str) -> bool: 

139 """Check if an LLM is registered in the global registry. 

140 

141 Args: 

142 name: Name to check 

143 

144 Returns: 

145 True if registered, False otherwise 

146 """ 

147 return _llm_registry.is_registered(name) 

148 

149 

150def list_registered_llms() -> list[str]: 

151 """Get list of all registered LLM names. 

152 

153 Returns: 

154 List of registered LLM names 

155 """ 

156 return _llm_registry.list_registered() 

157 

158 

159def clear_llm_registry() -> None: 

160 """Clear all registered LLMs from the global registry.""" 

161 _llm_registry.clear()