Coverage for src / local_deep_research / web_search_engines / rate_limiting / llm / wrapper.py: 42%

94 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2026-01-11 00:51 +0000

1""" 

2Rate-limited wrapper for LLM calls. 

3""" 

4 

5from typing import Optional 

6from urllib.parse import urlparse 

7 

8from loguru import logger 

9from tenacity import ( 

10 retry, 

11 stop_after_attempt, 

12 retry_if_exception, 

13) 

14from tenacity.wait import wait_base 

15 

16from ..tracker import get_tracker 

17from ..exceptions import RateLimitError 

18from .detection import is_llm_rate_limit_error, extract_retry_after 

19 

20 

21class AdaptiveLLMWait(wait_base): 

22 """Adaptive wait strategy for LLM rate limiting.""" 

23 

24 def __init__(self, tracker, engine_type: str): 

25 self.tracker = tracker 

26 self.engine_type = engine_type 

27 self.last_error = None 

28 

29 def __call__(self, retry_state) -> float: 

30 # Store the error for potential retry-after extraction 

31 if retry_state.outcome and retry_state.outcome.failed: 

32 self.last_error = retry_state.outcome.exception() 

33 

34 # Get adaptive wait time from tracker 

35 wait_time = self.tracker.get_wait_time(self.engine_type) 

36 

37 # If we have a retry-after from the error, use it 

38 if self.last_error: 

39 retry_after = extract_retry_after(self.last_error) 

40 if retry_after > 0: 

41 wait_time = max(wait_time, retry_after) 

42 

43 logger.info( 

44 f"LLM rate limit wait for {self.engine_type}: {wait_time:.2f}s" 

45 ) 

46 return wait_time 

47 

48 

49def create_rate_limited_llm_wrapper(base_llm, provider: Optional[str] = None): 

50 """ 

51 Create a rate-limited wrapper around an LLM instance. 

52 

53 Args: 

54 base_llm: The base LLM instance to wrap 

55 provider: Optional provider name (e.g., 'openai', 'anthropic') 

56 

57 Returns: 

58 A wrapped LLM instance with rate limiting capabilities 

59 """ 

60 

61 class RateLimitedLLMWrapper: 

62 """Wrapper that adds rate limiting to LLM calls.""" 

63 

64 def __init__(self, llm, provider_name: Optional[str] = None): 

65 self.base_llm = llm 

66 self.provider = provider_name 

67 self.rate_limiter = None 

68 

69 # Only setup rate limiting if enabled 

70 if self._should_rate_limit(): 70 ↛ 71line 70 didn't jump to line 71 because the condition on line 70 was never true

71 self.rate_limiter = get_tracker() 

72 logger.info( 

73 f"Rate limiting enabled for LLM provider: {self._get_rate_limit_key()}" 

74 ) 

75 

76 def _should_rate_limit(self) -> bool: 

77 """Check if rate limiting should be applied to this LLM.""" 

78 # Rate limiting for LLMs is currently disabled by default 

79 # TODO: Pass settings_snapshot to enable proper configuration 

80 return False 

81 

82 def _check_if_local_model(self) -> bool: 

83 """Check if the LLM is a local model that shouldn't be rate limited.""" 

84 # Don't rate limit local models 

85 local_providers = [ 

86 "ollama", 

87 "lmstudio", 

88 "llamacpp", 

89 "vllm", 

90 "local", 

91 "none", 

92 ] 

93 if self.provider and self.provider.lower() in local_providers: 

94 logger.debug( 

95 f"Skipping rate limiting for local provider: {self.provider}" 

96 ) 

97 return True 

98 

99 # Check if base URL indicates local model 

100 if hasattr(self.base_llm, "base_url"): 

101 base_url = str(self.base_llm.base_url) 

102 if any( 

103 local in base_url 

104 for local in ["localhost", "127.0.0.1", "0.0.0.0"] 

105 ): 

106 logger.debug( 

107 f"Skipping rate limiting for local URL: {base_url}" 

108 ) 

109 return True 

110 

111 return False 

112 

113 def _get_rate_limit_key(self) -> str: 

114 """Build composite key: provider-url-model""" 

115 provider = self.provider or "unknown" 

116 

117 # Extract URL 

118 url = "unknown" 

119 if hasattr(self.base_llm, "base_url"): 119 ↛ 121line 119 didn't jump to line 121 because the condition on line 119 was always true

120 url = str(self.base_llm.base_url) 

121 elif hasattr(self.base_llm, "_client") and hasattr( 

122 self.base_llm._client, "base_url" 

123 ): 

124 url = str(self.base_llm._client.base_url) 

125 

126 # Clean URL: remove protocol and trailing slashes 

127 if url != "unknown": 127 ↛ 133line 127 didn't jump to line 133 because the condition on line 127 was always true

128 parsed = urlparse(url) 

129 url = parsed.netloc or parsed.path 

130 url = url.rstrip("/") 

131 

132 # Extract model 

133 model = "unknown" 

134 if hasattr(self.base_llm, "model_name"): 134 ↛ 136line 134 didn't jump to line 136 because the condition on line 134 was always true

135 model = str(self.base_llm.model_name) 

136 elif hasattr(self.base_llm, "model"): 

137 model = str(self.base_llm.model) 

138 

139 # Clean model name 

140 model = model.replace("/", "-").replace(":", "-") 

141 

142 key = f"{provider}-{url}-{model}" 

143 return key 

144 

145 def invoke(self, *args, **kwargs): 

146 """Invoke the LLM with rate limiting if enabled.""" 

147 if self.rate_limiter: 147 ↛ 148line 147 didn't jump to line 148 because the condition on line 147 was never true

148 rate_limit_key = self._get_rate_limit_key() 

149 

150 # Define retry logic 

151 @retry( 

152 wait=AdaptiveLLMWait(self.rate_limiter, rate_limit_key), 

153 stop=stop_after_attempt(3), 

154 retry=retry_if_exception(is_llm_rate_limit_error), 

155 ) 

156 def _invoke_with_retry(): 

157 return self._do_invoke(*args, **kwargs) 

158 

159 try: 

160 result = _invoke_with_retry() 

161 

162 # Record successful attempt 

163 self.rate_limiter.record_outcome( 

164 engine_type=rate_limit_key, 

165 wait_time=0, # First attempt had no wait 

166 success=True, 

167 retry_count=0, 

168 ) 

169 

170 return result 

171 

172 except Exception as e: 

173 # Only record rate limit failures, not general failures 

174 if is_llm_rate_limit_error(e): 

175 self.rate_limiter.record_outcome( 

176 engine_type=rate_limit_key, 

177 wait_time=0, 

178 success=False, 

179 retry_count=0, 

180 ) 

181 raise 

182 else: 

183 # No rate limiting, just invoke directly 

184 return self._do_invoke(*args, **kwargs) 

185 

186 def _do_invoke(self, *args, **kwargs): 

187 """Actually invoke the LLM.""" 

188 try: 

189 return self.base_llm.invoke(*args, **kwargs) 

190 except Exception as e: 

191 # Check if it's a rate limit error and wrap it 

192 if is_llm_rate_limit_error(e): 

193 logger.warning(f"LLM rate limit error detected: {e}") 

194 raise RateLimitError(f"LLM rate limit: {str(e)}") 

195 raise 

196 

197 # Pass through any other attributes to the base LLM 

198 def __getattr__(self, name): 

199 return getattr(self.base_llm, name) 

200 

201 def __str__(self): 

202 return f"RateLimited({str(self.base_llm)})" 

203 

204 def __repr__(self): 

205 return f"RateLimitedLLMWrapper({repr(self.base_llm)})" 

206 

207 return RateLimitedLLMWrapper(base_llm, provider)