Coverage for src / local_deep_research / llm / llm_registry.py: 100%

49 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 23:55 +0000

1"""Registry for custom LangChain LLMs. 

2 

3This module provides a global registry for registering and managing custom LangChain 

4LLMs that can be used with Local Deep Research. 

5""" 

6 

7import threading 

8from typing import Callable, Dict, Optional, Union 

9 

10from langchain.chat_models.base import BaseChatModel 

11from loguru import logger 

12 

13 

14class LLMRegistry: 

15 """Thread-safe registry for custom LangChain LLMs.""" 

16 

17 def __init__(self): 

18 self._llms: Dict[ 

19 str, Union[BaseChatModel, Callable[..., BaseChatModel]] 

20 ] = {} 

21 self._lock = threading.Lock() 

22 

23 def register( 

24 self, name: str, llm: Union[BaseChatModel, Callable[..., BaseChatModel]] 

25 ) -> None: 

26 """Register a custom LLM. 

27 

28 Args: 

29 name: Unique name for the LLM (case-insensitive) 

30 llm: Either a BaseChatModel instance or a factory function that returns one 

31 """ 

32 with self._lock: 

33 # Normalize name to lowercase for case-insensitive storage 

34 normalized_name = name.lower() 

35 if normalized_name in self._llms: 

36 logger.warning(f"Overwriting existing LLM: {name}") 

37 self._llms[normalized_name] = llm 

38 logger.info( 

39 f"Registered custom LLM: {name} (normalized: {normalized_name})" 

40 ) 

41 

42 # Completes the CRUD API surface for the registry. 

43 # Used in tests to verify cleanup behavior. 

44 def unregister(self, name: str) -> None: 

45 """Unregister a custom LLM. 

46 

47 Args: 

48 name: Name of the LLM to unregister (case-insensitive) 

49 """ 

50 with self._lock: 

51 normalized_name = name.lower() 

52 if normalized_name in self._llms: 

53 del self._llms[normalized_name] 

54 logger.info(f"Unregistered custom LLM: {name}") 

55 

56 def get( 

57 self, name: str 

58 ) -> Optional[Union[BaseChatModel, Callable[..., BaseChatModel]]]: 

59 """Get a registered LLM. 

60 

61 Args: 

62 name: Name of the LLM to retrieve (case-insensitive) 

63 

64 Returns: 

65 The LLM instance/factory or None if not found 

66 """ 

67 with self._lock: 

68 normalized_name = name.lower() 

69 return self._llms.get(normalized_name) 

70 

71 def is_registered(self, name: str) -> bool: 

72 """Check if an LLM is registered. 

73 

74 Args: 

75 name: Name to check (case-insensitive) 

76 

77 Returns: 

78 True if registered, False otherwise 

79 """ 

80 with self._lock: 

81 normalized_name = name.lower() 

82 return normalized_name in self._llms 

83 

84 # Used in test assertions to verify registry state; 

85 # part of public API for plugin authors. 

86 def list_registered(self) -> list[str]: 

87 """Get list of all registered LLM names. 

88 

89 Returns: 

90 List of registered LLM names 

91 """ 

92 with self._lock: 

93 return list(self._llms.keys()) 

94 

95 # Used in 7+ test files' autouse fixtures for test isolation 

96 # (64+ tests depend on this to reset global state between runs). 

97 def clear(self) -> None: 

98 """Clear all registered LLMs.""" 

99 with self._lock: 

100 self._llms.clear() 

101 logger.info("Cleared all registered custom LLMs") 

102 

103 

104# Global registry instance 

105_llm_registry = LLMRegistry() 

106 

107 

108# Public API functions 

109def register_llm( 

110 name: str, llm: Union[BaseChatModel, Callable[..., BaseChatModel]] 

111) -> None: 

112 """Register a custom LLM in the global registry. 

113 

114 Args: 

115 name: Unique name for the LLM 

116 llm: Either a BaseChatModel instance or a factory function 

117 """ 

118 _llm_registry.register(name, llm) 

119 

120 

121def unregister_llm(name: str) -> None: 

122 """Unregister a custom LLM from the global registry. 

123 

124 Args: 

125 name: Name of the LLM to unregister 

126 """ 

127 _llm_registry.unregister(name) 

128 

129 

130def get_llm_from_registry( 

131 name: str, 

132) -> Optional[Union[BaseChatModel, Callable[..., BaseChatModel]]]: 

133 """Get a registered LLM from the global registry. 

134 

135 Args: 

136 name: Name of the LLM to retrieve 

137 

138 Returns: 

139 The LLM instance/factory or None if not found 

140 """ 

141 return _llm_registry.get(name) 

142 

143 

144def is_llm_registered(name: str) -> bool: 

145 """Check if an LLM is registered in the global registry. 

146 

147 Args: 

148 name: Name to check 

149 

150 Returns: 

151 True if registered, False otherwise 

152 """ 

153 return _llm_registry.is_registered(name) 

154 

155 

156def list_registered_llms() -> list[str]: 

157 """Get list of all registered LLM names. 

158 

159 Returns: 

160 List of registered LLM names 

161 """ 

162 return _llm_registry.list_registered() 

163 

164 

165def clear_llm_registry() -> None: 

166 """Clear all registered LLMs from the global registry.""" 

167 _llm_registry.clear()