Coverage for src / local_deep_research / web_search_engines / rate_limiting / llm / wrapper.py: 93%
100 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:55 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:55 +0000
1"""
2Rate-limited wrapper for LLM calls.
3"""
5from typing import Optional
6from urllib.parse import urlparse
8from loguru import logger
9from tenacity import (
10 retry,
11 stop_after_attempt,
12 retry_if_exception,
13)
14from tenacity.wait import wait_base
16from ....llm.providers.base import normalize_provider
18from ..tracker import get_tracker
19from ..exceptions import RateLimitError
20from .detection import is_llm_rate_limit_error, extract_retry_after
23class AdaptiveLLMWait(wait_base):
24 """Adaptive wait strategy for LLM rate limiting."""
26 def __init__(self, tracker, engine_type: str):
27 self.tracker = tracker
28 self.engine_type = engine_type
29 self.last_error = None
31 def __call__(self, retry_state) -> float:
32 # Store the error for potential retry-after extraction
33 if retry_state.outcome and retry_state.outcome.failed:
34 self.last_error = retry_state.outcome.exception()
36 # Get adaptive wait time from tracker
37 wait_time: float = self.tracker.get_wait_time(self.engine_type)
39 # If we have a retry-after from the error, use it
40 if self.last_error:
41 retry_after = extract_retry_after(self.last_error)
42 if retry_after > 0:
43 wait_time = max(wait_time, float(retry_after))
45 logger.info(
46 f"LLM rate limit wait for {self.engine_type}: {wait_time:.2f}s"
47 )
48 return wait_time
51def create_rate_limited_llm_wrapper(base_llm, provider: Optional[str] = None):
52 """
53 Create a rate-limited wrapper around an LLM instance.
55 Args:
56 base_llm: The base LLM instance to wrap
57 provider: Optional provider name (e.g., 'openai', 'anthropic')
59 Returns:
60 A wrapped LLM instance with rate limiting capabilities
61 """
63 class RateLimitedLLMWrapper:
64 """Wrapper that adds rate limiting to LLM calls."""
66 def __init__(self, llm, provider_name: Optional[str] = None):
67 self.base_llm = llm
68 self.provider = provider_name
69 self.rate_limiter = None
71 # Only setup rate limiting if enabled
72 if self._should_rate_limit(): 72 ↛ 73line 72 didn't jump to line 73 because the condition on line 72 was never true
73 self.rate_limiter = get_tracker()
74 logger.info(
75 f"Rate limiting enabled for LLM provider: {self._get_rate_limit_key()}"
76 )
78 def _should_rate_limit(self) -> bool:
79 """Check if rate limiting should be applied to this LLM."""
80 # Rate limiting for LLMs is currently disabled by default
81 # TODO: Pass settings_snapshot to enable proper configuration
82 return False
84 def _check_if_local_model(self) -> bool:
85 """Check if the LLM is a local model that shouldn't be rate limited."""
86 # Don't rate limit local models
87 local_providers = [
88 "ollama",
89 "lmstudio",
90 "llamacpp",
91 "local",
92 "none",
93 ]
94 if normalize_provider(self.provider) in local_providers:
95 logger.debug(
96 f"Skipping rate limiting for local provider: {self.provider}"
97 )
98 return True
100 # Check if base URL indicates local model
101 if hasattr(self.base_llm, "base_url"):
102 base_url = str(self.base_llm.base_url)
103 if any(
104 local in base_url
105 for local in ["localhost", "127.0.0.1", "0.0.0.0"]
106 ):
107 logger.debug(
108 f"Skipping rate limiting for local URL: {base_url}"
109 )
110 return True
112 return False
114 def _get_rate_limit_key(self) -> str:
115 """Build composite key: provider-url-model"""
116 provider = self.provider or "unknown"
118 # Extract URL
119 url = "unknown"
120 if hasattr(self.base_llm, "base_url"):
121 url = str(self.base_llm.base_url)
122 elif hasattr(self.base_llm, "_client") and hasattr(
123 self.base_llm._client, "base_url"
124 ):
125 url = str(self.base_llm._client.base_url)
127 # Clean URL: remove protocol and trailing slashes
128 if url != "unknown":
129 parsed = urlparse(url)
130 url = parsed.netloc or parsed.path
131 url = url.rstrip("/")
133 # Extract model
134 model = "unknown"
135 if hasattr(self.base_llm, "model_name"):
136 model = str(self.base_llm.model_name)
137 elif hasattr(self.base_llm, "model"):
138 model = str(self.base_llm.model)
140 # Clean model name
141 model = model.replace("/", "-").replace(":", "-")
143 return f"{provider}-{url}-{model}"
145 def invoke(self, *args, **kwargs):
146 """Invoke the LLM with rate limiting if enabled."""
147 if self.rate_limiter:
148 rate_limit_key = self._get_rate_limit_key()
150 # Define retry logic
151 @retry(
152 wait=AdaptiveLLMWait(self.rate_limiter, rate_limit_key),
153 stop=stop_after_attempt(3),
154 retry=retry_if_exception(is_llm_rate_limit_error),
155 )
156 def _invoke_with_retry():
157 return self._do_invoke(*args, **kwargs)
159 try:
160 result = _invoke_with_retry()
162 # Record successful attempt
163 self.rate_limiter.record_outcome(
164 engine_type=rate_limit_key,
165 wait_time=0, # First attempt had no wait
166 success=True,
167 retry_count=0,
168 )
170 return result
172 except Exception as e:
173 # Only record rate limit failures, not general failures
174 if is_llm_rate_limit_error(e): 174 ↛ 181line 174 didn't jump to line 181 because the condition on line 174 was always true
175 self.rate_limiter.record_outcome(
176 engine_type=rate_limit_key,
177 wait_time=0,
178 success=False,
179 retry_count=0,
180 )
181 raise
182 else:
183 # No rate limiting, just invoke directly
184 return self._do_invoke(*args, **kwargs)
186 def _do_invoke(self, *args, **kwargs):
187 """Actually invoke the LLM."""
188 try:
189 return self.base_llm.invoke(*args, **kwargs)
190 except Exception as e:
191 # Check if it's a rate limit error and wrap it
192 if is_llm_rate_limit_error(e):
193 logger.warning("LLM rate limit error detected")
194 raise RateLimitError(f"LLM rate limit: {str(e)}")
195 raise
197 # Pass through any other attributes to the base LLM
198 def __getattr__(self, name):
199 return getattr(self.base_llm, name)
201 def close(self):
202 """Close underlying HTTP clients held by this LLM. Idempotent."""
203 try:
204 from ....utilities.llm_utils import _close_base_llm
206 _close_base_llm(self.base_llm)
207 except Exception:
208 logger.debug(
209 "best-effort cleanup of HTTP clients on shutdown",
210 exc_info=True,
211 )
213 def __str__(self):
214 return f"RateLimited({str(self.base_llm)})"
216 def __repr__(self):
217 return f"RateLimitedLLMWrapper({repr(self.base_llm)})"
219 return RateLimitedLLMWrapper(base_llm, provider)