Coverage for src / local_deep_research / llm / llm_registry.py: 100%
49 statements
« prev ^ index » next coverage.py v7.12.0, created at 2026-01-11 00:51 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2026-01-11 00:51 +0000
1"""Registry for custom LangChain LLMs.
3This module provides a global registry for registering and managing custom LangChain
4LLMs that can be used with Local Deep Research.
5"""
7import threading
8from typing import Callable, Dict, Optional, Union
10from langchain.chat_models.base import BaseChatModel
11from loguru import logger
14class LLMRegistry:
15 """Thread-safe registry for custom LangChain LLMs."""
17 def __init__(self):
18 self._llms: Dict[
19 str, Union[BaseChatModel, Callable[..., BaseChatModel]]
20 ] = {}
21 self._lock = threading.Lock()
23 def register(
24 self, name: str, llm: Union[BaseChatModel, Callable[..., BaseChatModel]]
25 ) -> None:
26 """Register a custom LLM.
28 Args:
29 name: Unique name for the LLM (case-insensitive)
30 llm: Either a BaseChatModel instance or a factory function that returns one
31 """
32 with self._lock:
33 # Normalize name to lowercase for case-insensitive storage
34 normalized_name = name.lower()
35 if normalized_name in self._llms:
36 logger.warning(f"Overwriting existing LLM: {name}")
37 self._llms[normalized_name] = llm
38 logger.info(
39 f"Registered custom LLM: {name} (normalized: {normalized_name})"
40 )
42 def unregister(self, name: str) -> None:
43 """Unregister a custom LLM.
45 Args:
46 name: Name of the LLM to unregister (case-insensitive)
47 """
48 with self._lock:
49 normalized_name = name.lower()
50 if normalized_name in self._llms:
51 del self._llms[normalized_name]
52 logger.info(f"Unregistered custom LLM: {name}")
54 def get(
55 self, name: str
56 ) -> Optional[Union[BaseChatModel, Callable[..., BaseChatModel]]]:
57 """Get a registered LLM.
59 Args:
60 name: Name of the LLM to retrieve (case-insensitive)
62 Returns:
63 The LLM instance/factory or None if not found
64 """
65 with self._lock:
66 normalized_name = name.lower()
67 return self._llms.get(normalized_name)
69 def is_registered(self, name: str) -> bool:
70 """Check if an LLM is registered.
72 Args:
73 name: Name to check (case-insensitive)
75 Returns:
76 True if registered, False otherwise
77 """
78 with self._lock:
79 normalized_name = name.lower()
80 return normalized_name in self._llms
82 def list_registered(self) -> list[str]:
83 """Get list of all registered LLM names.
85 Returns:
86 List of registered LLM names
87 """
88 with self._lock:
89 return list(self._llms.keys())
91 def clear(self) -> None:
92 """Clear all registered LLMs."""
93 with self._lock:
94 self._llms.clear()
95 logger.info("Cleared all registered custom LLMs")
98# Global registry instance
99_llm_registry = LLMRegistry()
102# Public API functions
103def register_llm(
104 name: str, llm: Union[BaseChatModel, Callable[..., BaseChatModel]]
105) -> None:
106 """Register a custom LLM in the global registry.
108 Args:
109 name: Unique name for the LLM
110 llm: Either a BaseChatModel instance or a factory function
111 """
112 _llm_registry.register(name, llm)
115def unregister_llm(name: str) -> None:
116 """Unregister a custom LLM from the global registry.
118 Args:
119 name: Name of the LLM to unregister
120 """
121 _llm_registry.unregister(name)
124def get_llm_from_registry(
125 name: str,
126) -> Optional[Union[BaseChatModel, Callable[..., BaseChatModel]]]:
127 """Get a registered LLM from the global registry.
129 Args:
130 name: Name of the LLM to retrieve
132 Returns:
133 The LLM instance/factory or None if not found
134 """
135 return _llm_registry.get(name)
138def is_llm_registered(name: str) -> bool:
139 """Check if an LLM is registered in the global registry.
141 Args:
142 name: Name to check
144 Returns:
145 True if registered, False otherwise
146 """
147 return _llm_registry.is_registered(name)
150def list_registered_llms() -> list[str]:
151 """Get list of all registered LLM names.
153 Returns:
154 List of registered LLM names
155 """
156 return _llm_registry.list_registered()
159def clear_llm_registry() -> None:
160 """Clear all registered LLMs from the global registry."""
161 _llm_registry.clear()