Coverage for src / local_deep_research / web_search_engines / rate_limiting / llm / wrapper.py: 93%

100 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 23:55 +0000

1""" 

2Rate-limited wrapper for LLM calls. 

3""" 

4 

5from typing import Optional 

6from urllib.parse import urlparse 

7 

8from loguru import logger 

9from tenacity import ( 

10 retry, 

11 stop_after_attempt, 

12 retry_if_exception, 

13) 

14from tenacity.wait import wait_base 

15 

16from ....llm.providers.base import normalize_provider 

17 

18from ..tracker import get_tracker 

19from ..exceptions import RateLimitError 

20from .detection import is_llm_rate_limit_error, extract_retry_after 

21 

22 

23class AdaptiveLLMWait(wait_base): 

24 """Adaptive wait strategy for LLM rate limiting.""" 

25 

26 def __init__(self, tracker, engine_type: str): 

27 self.tracker = tracker 

28 self.engine_type = engine_type 

29 self.last_error = None 

30 

31 def __call__(self, retry_state) -> float: 

32 # Store the error for potential retry-after extraction 

33 if retry_state.outcome and retry_state.outcome.failed: 

34 self.last_error = retry_state.outcome.exception() 

35 

36 # Get adaptive wait time from tracker 

37 wait_time: float = self.tracker.get_wait_time(self.engine_type) 

38 

39 # If we have a retry-after from the error, use it 

40 if self.last_error: 

41 retry_after = extract_retry_after(self.last_error) 

42 if retry_after > 0: 

43 wait_time = max(wait_time, float(retry_after)) 

44 

45 logger.info( 

46 f"LLM rate limit wait for {self.engine_type}: {wait_time:.2f}s" 

47 ) 

48 return wait_time 

49 

50 

51def create_rate_limited_llm_wrapper(base_llm, provider: Optional[str] = None): 

52 """ 

53 Create a rate-limited wrapper around an LLM instance. 

54 

55 Args: 

56 base_llm: The base LLM instance to wrap 

57 provider: Optional provider name (e.g., 'openai', 'anthropic') 

58 

59 Returns: 

60 A wrapped LLM instance with rate limiting capabilities 

61 """ 

62 

63 class RateLimitedLLMWrapper: 

64 """Wrapper that adds rate limiting to LLM calls.""" 

65 

66 def __init__(self, llm, provider_name: Optional[str] = None): 

67 self.base_llm = llm 

68 self.provider = provider_name 

69 self.rate_limiter = None 

70 

71 # Only setup rate limiting if enabled 

72 if self._should_rate_limit(): 72 ↛ 73line 72 didn't jump to line 73 because the condition on line 72 was never true

73 self.rate_limiter = get_tracker() 

74 logger.info( 

75 f"Rate limiting enabled for LLM provider: {self._get_rate_limit_key()}" 

76 ) 

77 

78 def _should_rate_limit(self) -> bool: 

79 """Check if rate limiting should be applied to this LLM.""" 

80 # Rate limiting for LLMs is currently disabled by default 

81 # TODO: Pass settings_snapshot to enable proper configuration 

82 return False 

83 

84 def _check_if_local_model(self) -> bool: 

85 """Check if the LLM is a local model that shouldn't be rate limited.""" 

86 # Don't rate limit local models 

87 local_providers = [ 

88 "ollama", 

89 "lmstudio", 

90 "llamacpp", 

91 "local", 

92 "none", 

93 ] 

94 if normalize_provider(self.provider) in local_providers: 

95 logger.debug( 

96 f"Skipping rate limiting for local provider: {self.provider}" 

97 ) 

98 return True 

99 

100 # Check if base URL indicates local model 

101 if hasattr(self.base_llm, "base_url"): 

102 base_url = str(self.base_llm.base_url) 

103 if any( 

104 local in base_url 

105 for local in ["localhost", "127.0.0.1", "0.0.0.0"] 

106 ): 

107 logger.debug( 

108 f"Skipping rate limiting for local URL: {base_url}" 

109 ) 

110 return True 

111 

112 return False 

113 

114 def _get_rate_limit_key(self) -> str: 

115 """Build composite key: provider-url-model""" 

116 provider = self.provider or "unknown" 

117 

118 # Extract URL 

119 url = "unknown" 

120 if hasattr(self.base_llm, "base_url"): 

121 url = str(self.base_llm.base_url) 

122 elif hasattr(self.base_llm, "_client") and hasattr( 

123 self.base_llm._client, "base_url" 

124 ): 

125 url = str(self.base_llm._client.base_url) 

126 

127 # Clean URL: remove protocol and trailing slashes 

128 if url != "unknown": 

129 parsed = urlparse(url) 

130 url = parsed.netloc or parsed.path 

131 url = url.rstrip("/") 

132 

133 # Extract model 

134 model = "unknown" 

135 if hasattr(self.base_llm, "model_name"): 

136 model = str(self.base_llm.model_name) 

137 elif hasattr(self.base_llm, "model"): 

138 model = str(self.base_llm.model) 

139 

140 # Clean model name 

141 model = model.replace("/", "-").replace(":", "-") 

142 

143 return f"{provider}-{url}-{model}" 

144 

145 def invoke(self, *args, **kwargs): 

146 """Invoke the LLM with rate limiting if enabled.""" 

147 if self.rate_limiter: 

148 rate_limit_key = self._get_rate_limit_key() 

149 

150 # Define retry logic 

151 @retry( 

152 wait=AdaptiveLLMWait(self.rate_limiter, rate_limit_key), 

153 stop=stop_after_attempt(3), 

154 retry=retry_if_exception(is_llm_rate_limit_error), 

155 ) 

156 def _invoke_with_retry(): 

157 return self._do_invoke(*args, **kwargs) 

158 

159 try: 

160 result = _invoke_with_retry() 

161 

162 # Record successful attempt 

163 self.rate_limiter.record_outcome( 

164 engine_type=rate_limit_key, 

165 wait_time=0, # First attempt had no wait 

166 success=True, 

167 retry_count=0, 

168 ) 

169 

170 return result 

171 

172 except Exception as e: 

173 # Only record rate limit failures, not general failures 

174 if is_llm_rate_limit_error(e): 174 ↛ 181line 174 didn't jump to line 181 because the condition on line 174 was always true

175 self.rate_limiter.record_outcome( 

176 engine_type=rate_limit_key, 

177 wait_time=0, 

178 success=False, 

179 retry_count=0, 

180 ) 

181 raise 

182 else: 

183 # No rate limiting, just invoke directly 

184 return self._do_invoke(*args, **kwargs) 

185 

186 def _do_invoke(self, *args, **kwargs): 

187 """Actually invoke the LLM.""" 

188 try: 

189 return self.base_llm.invoke(*args, **kwargs) 

190 except Exception as e: 

191 # Check if it's a rate limit error and wrap it 

192 if is_llm_rate_limit_error(e): 

193 logger.warning("LLM rate limit error detected") 

194 raise RateLimitError(f"LLM rate limit: {str(e)}") 

195 raise 

196 

197 # Pass through any other attributes to the base LLM 

198 def __getattr__(self, name): 

199 return getattr(self.base_llm, name) 

200 

201 def close(self): 

202 """Close underlying HTTP clients held by this LLM. Idempotent.""" 

203 try: 

204 from ....utilities.llm_utils import _close_base_llm 

205 

206 _close_base_llm(self.base_llm) 

207 except Exception: 

208 logger.debug( 

209 "best-effort cleanup of HTTP clients on shutdown", 

210 exc_info=True, 

211 ) 

212 

213 def __str__(self): 

214 return f"RateLimited({str(self.base_llm)})" 

215 

216 def __repr__(self): 

217 return f"RateLimitedLLMWrapper({repr(self.base_llm)})" 

218 

219 return RateLimitedLLMWrapper(base_llm, provider)