Coverage for src / local_deep_research / web_search_engines / rate_limiting / llm / wrapper.py: 42%
94 statements
« prev ^ index » next coverage.py v7.12.0, created at 2026-01-11 00:51 +0000
« prev ^ index » next coverage.py v7.12.0, created at 2026-01-11 00:51 +0000
1"""
2Rate-limited wrapper for LLM calls.
3"""
5from typing import Optional
6from urllib.parse import urlparse
8from loguru import logger
9from tenacity import (
10 retry,
11 stop_after_attempt,
12 retry_if_exception,
13)
14from tenacity.wait import wait_base
16from ..tracker import get_tracker
17from ..exceptions import RateLimitError
18from .detection import is_llm_rate_limit_error, extract_retry_after
21class AdaptiveLLMWait(wait_base):
22 """Adaptive wait strategy for LLM rate limiting."""
24 def __init__(self, tracker, engine_type: str):
25 self.tracker = tracker
26 self.engine_type = engine_type
27 self.last_error = None
29 def __call__(self, retry_state) -> float:
30 # Store the error for potential retry-after extraction
31 if retry_state.outcome and retry_state.outcome.failed:
32 self.last_error = retry_state.outcome.exception()
34 # Get adaptive wait time from tracker
35 wait_time = self.tracker.get_wait_time(self.engine_type)
37 # If we have a retry-after from the error, use it
38 if self.last_error:
39 retry_after = extract_retry_after(self.last_error)
40 if retry_after > 0:
41 wait_time = max(wait_time, retry_after)
43 logger.info(
44 f"LLM rate limit wait for {self.engine_type}: {wait_time:.2f}s"
45 )
46 return wait_time
49def create_rate_limited_llm_wrapper(base_llm, provider: Optional[str] = None):
50 """
51 Create a rate-limited wrapper around an LLM instance.
53 Args:
54 base_llm: The base LLM instance to wrap
55 provider: Optional provider name (e.g., 'openai', 'anthropic')
57 Returns:
58 A wrapped LLM instance with rate limiting capabilities
59 """
61 class RateLimitedLLMWrapper:
62 """Wrapper that adds rate limiting to LLM calls."""
64 def __init__(self, llm, provider_name: Optional[str] = None):
65 self.base_llm = llm
66 self.provider = provider_name
67 self.rate_limiter = None
69 # Only setup rate limiting if enabled
70 if self._should_rate_limit(): 70 ↛ 71line 70 didn't jump to line 71 because the condition on line 70 was never true
71 self.rate_limiter = get_tracker()
72 logger.info(
73 f"Rate limiting enabled for LLM provider: {self._get_rate_limit_key()}"
74 )
76 def _should_rate_limit(self) -> bool:
77 """Check if rate limiting should be applied to this LLM."""
78 # Rate limiting for LLMs is currently disabled by default
79 # TODO: Pass settings_snapshot to enable proper configuration
80 return False
82 def _check_if_local_model(self) -> bool:
83 """Check if the LLM is a local model that shouldn't be rate limited."""
84 # Don't rate limit local models
85 local_providers = [
86 "ollama",
87 "lmstudio",
88 "llamacpp",
89 "vllm",
90 "local",
91 "none",
92 ]
93 if self.provider and self.provider.lower() in local_providers:
94 logger.debug(
95 f"Skipping rate limiting for local provider: {self.provider}"
96 )
97 return True
99 # Check if base URL indicates local model
100 if hasattr(self.base_llm, "base_url"):
101 base_url = str(self.base_llm.base_url)
102 if any(
103 local in base_url
104 for local in ["localhost", "127.0.0.1", "0.0.0.0"]
105 ):
106 logger.debug(
107 f"Skipping rate limiting for local URL: {base_url}"
108 )
109 return True
111 return False
113 def _get_rate_limit_key(self) -> str:
114 """Build composite key: provider-url-model"""
115 provider = self.provider or "unknown"
117 # Extract URL
118 url = "unknown"
119 if hasattr(self.base_llm, "base_url"): 119 ↛ 121line 119 didn't jump to line 121 because the condition on line 119 was always true
120 url = str(self.base_llm.base_url)
121 elif hasattr(self.base_llm, "_client") and hasattr(
122 self.base_llm._client, "base_url"
123 ):
124 url = str(self.base_llm._client.base_url)
126 # Clean URL: remove protocol and trailing slashes
127 if url != "unknown": 127 ↛ 133line 127 didn't jump to line 133 because the condition on line 127 was always true
128 parsed = urlparse(url)
129 url = parsed.netloc or parsed.path
130 url = url.rstrip("/")
132 # Extract model
133 model = "unknown"
134 if hasattr(self.base_llm, "model_name"): 134 ↛ 136line 134 didn't jump to line 136 because the condition on line 134 was always true
135 model = str(self.base_llm.model_name)
136 elif hasattr(self.base_llm, "model"):
137 model = str(self.base_llm.model)
139 # Clean model name
140 model = model.replace("/", "-").replace(":", "-")
142 key = f"{provider}-{url}-{model}"
143 return key
145 def invoke(self, *args, **kwargs):
146 """Invoke the LLM with rate limiting if enabled."""
147 if self.rate_limiter: 147 ↛ 148line 147 didn't jump to line 148 because the condition on line 147 was never true
148 rate_limit_key = self._get_rate_limit_key()
150 # Define retry logic
151 @retry(
152 wait=AdaptiveLLMWait(self.rate_limiter, rate_limit_key),
153 stop=stop_after_attempt(3),
154 retry=retry_if_exception(is_llm_rate_limit_error),
155 )
156 def _invoke_with_retry():
157 return self._do_invoke(*args, **kwargs)
159 try:
160 result = _invoke_with_retry()
162 # Record successful attempt
163 self.rate_limiter.record_outcome(
164 engine_type=rate_limit_key,
165 wait_time=0, # First attempt had no wait
166 success=True,
167 retry_count=0,
168 )
170 return result
172 except Exception as e:
173 # Only record rate limit failures, not general failures
174 if is_llm_rate_limit_error(e):
175 self.rate_limiter.record_outcome(
176 engine_type=rate_limit_key,
177 wait_time=0,
178 success=False,
179 retry_count=0,
180 )
181 raise
182 else:
183 # No rate limiting, just invoke directly
184 return self._do_invoke(*args, **kwargs)
186 def _do_invoke(self, *args, **kwargs):
187 """Actually invoke the LLM."""
188 try:
189 return self.base_llm.invoke(*args, **kwargs)
190 except Exception as e:
191 # Check if it's a rate limit error and wrap it
192 if is_llm_rate_limit_error(e):
193 logger.warning(f"LLM rate limit error detected: {e}")
194 raise RateLimitError(f"LLM rate limit: {str(e)}")
195 raise
197 # Pass through any other attributes to the base LLM
198 def __getattr__(self, name):
199 return getattr(self.base_llm, name)
201 def __str__(self):
202 return f"RateLimited({str(self.base_llm)})"
204 def __repr__(self):
205 return f"RateLimitedLLMWrapper({repr(self.base_llm)})"
207 return RateLimitedLLMWrapper(base_llm, provider)