Coverage for src / local_deep_research / llm / llm_registry.py: 100%
49 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:55 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:55 +0000
1"""Registry for custom LangChain LLMs.
3This module provides a global registry for registering and managing custom LangChain
4LLMs that can be used with Local Deep Research.
5"""
7import threading
8from typing import Callable, Dict, Optional, Union
10from langchain.chat_models.base import BaseChatModel
11from loguru import logger
14class LLMRegistry:
15 """Thread-safe registry for custom LangChain LLMs."""
17 def __init__(self):
18 self._llms: Dict[
19 str, Union[BaseChatModel, Callable[..., BaseChatModel]]
20 ] = {}
21 self._lock = threading.Lock()
23 def register(
24 self, name: str, llm: Union[BaseChatModel, Callable[..., BaseChatModel]]
25 ) -> None:
26 """Register a custom LLM.
28 Args:
29 name: Unique name for the LLM (case-insensitive)
30 llm: Either a BaseChatModel instance or a factory function that returns one
31 """
32 with self._lock:
33 # Normalize name to lowercase for case-insensitive storage
34 normalized_name = name.lower()
35 if normalized_name in self._llms:
36 logger.warning(f"Overwriting existing LLM: {name}")
37 self._llms[normalized_name] = llm
38 logger.info(
39 f"Registered custom LLM: {name} (normalized: {normalized_name})"
40 )
42 # Completes the CRUD API surface for the registry.
43 # Used in tests to verify cleanup behavior.
44 def unregister(self, name: str) -> None:
45 """Unregister a custom LLM.
47 Args:
48 name: Name of the LLM to unregister (case-insensitive)
49 """
50 with self._lock:
51 normalized_name = name.lower()
52 if normalized_name in self._llms:
53 del self._llms[normalized_name]
54 logger.info(f"Unregistered custom LLM: {name}")
56 def get(
57 self, name: str
58 ) -> Optional[Union[BaseChatModel, Callable[..., BaseChatModel]]]:
59 """Get a registered LLM.
61 Args:
62 name: Name of the LLM to retrieve (case-insensitive)
64 Returns:
65 The LLM instance/factory or None if not found
66 """
67 with self._lock:
68 normalized_name = name.lower()
69 return self._llms.get(normalized_name)
71 def is_registered(self, name: str) -> bool:
72 """Check if an LLM is registered.
74 Args:
75 name: Name to check (case-insensitive)
77 Returns:
78 True if registered, False otherwise
79 """
80 with self._lock:
81 normalized_name = name.lower()
82 return normalized_name in self._llms
84 # Used in test assertions to verify registry state;
85 # part of public API for plugin authors.
86 def list_registered(self) -> list[str]:
87 """Get list of all registered LLM names.
89 Returns:
90 List of registered LLM names
91 """
92 with self._lock:
93 return list(self._llms.keys())
95 # Used in 7+ test files' autouse fixtures for test isolation
96 # (64+ tests depend on this to reset global state between runs).
97 def clear(self) -> None:
98 """Clear all registered LLMs."""
99 with self._lock:
100 self._llms.clear()
101 logger.info("Cleared all registered custom LLMs")
104# Global registry instance
105_llm_registry = LLMRegistry()
108# Public API functions
109def register_llm(
110 name: str, llm: Union[BaseChatModel, Callable[..., BaseChatModel]]
111) -> None:
112 """Register a custom LLM in the global registry.
114 Args:
115 name: Unique name for the LLM
116 llm: Either a BaseChatModel instance or a factory function
117 """
118 _llm_registry.register(name, llm)
121def unregister_llm(name: str) -> None:
122 """Unregister a custom LLM from the global registry.
124 Args:
125 name: Name of the LLM to unregister
126 """
127 _llm_registry.unregister(name)
130def get_llm_from_registry(
131 name: str,
132) -> Optional[Union[BaseChatModel, Callable[..., BaseChatModel]]]:
133 """Get a registered LLM from the global registry.
135 Args:
136 name: Name of the LLM to retrieve
138 Returns:
139 The LLM instance/factory or None if not found
140 """
141 return _llm_registry.get(name)
144def is_llm_registered(name: str) -> bool:
145 """Check if an LLM is registered in the global registry.
147 Args:
148 name: Name to check
150 Returns:
151 True if registered, False otherwise
152 """
153 return _llm_registry.is_registered(name)
156def list_registered_llms() -> list[str]:
157 """Get list of all registered LLM names.
159 Returns:
160 List of registered LLM names
161 """
162 return _llm_registry.list_registered()
165def clear_llm_registry() -> None:
166 """Clear all registered LLMs from the global registry."""
167 _llm_registry.clear()