Coverage for src / local_deep_research / mcp / client.py: 71%

185 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 23:55 +0000

1""" 

2MCP Client utilities for connecting to MCP servers. 

3 

4This module provides a wrapper around the MCP client SDK for connecting 

5to and calling tools on MCP servers. 

6""" 

7 

8import asyncio 

9from contextlib import AsyncExitStack, asynccontextmanager 

10from pathlib import Path 

11from typing import Any, Dict, List, Optional 

12 

13from tenacity import ( 

14 AsyncRetrying, 

15 retry_if_exception_message, 

16 retry_if_exception_type, 

17 stop_after_attempt, 

18 wait_exponential, 

19) 

20 

21from loguru import logger 

22 

23# Allowed commands for MCP server execution (security whitelist) 

24ALLOWED_COMMANDS = {"node", "npx", "python", "python3", "uv", "uvx", "docker"} 

25 

26# Retry configuration for subprocess startup race condition 

27INIT_RETRY_ATTEMPTS = 5 

28INIT_RETRY_BASE_DELAY = 0.1 # seconds 

29INIT_RETRY_MAX_DELAY = 2.0 # seconds 

30INIT_RETRY_BACKOFF_FACTOR = 2.0 

31 

32# The MCP SDK is an optional dependency (installed via the [mcp] extras group), 

33# so we guard the import and set a flag for runtime availability checks. 

34try: 

35 from mcp import ClientSession 

36 from mcp.client.stdio import StdioServerParameters, stdio_client 

37 

38 MCP_AVAILABLE = True 

39except ImportError: 

40 MCP_AVAILABLE = False 

41 logger.warning( 

42 "MCP client not available. Install with: pip install mcp[cli]" 

43 ) 

44 

45 

46class MCPClientError(Exception): 

47 """Error during MCP client operations.""" 

48 

49 pass 

50 

51 

52class MCPClient: 

53 """ 

54 Client for connecting to and calling tools on MCP servers. 

55 

56 Usage: 

57 async with MCPClient(config) as client: 

58 tools = await client.list_tools() 

59 result = await client.call_tool("tool_name", {"arg": "value"}) 

60 """ 

61 

62 def __init__( 

63 self, 

64 server_config: Dict[str, Any], 

65 timeout: float = 60.0, 

66 ): 

67 """ 

68 Initialize MCP client. 

69 

70 Args: 

71 server_config: Server configuration with keys: 

72 - command: str - Command to run (e.g., "python") 

73 - args: List[str] - Command arguments 

74 - env: Dict[str, str] - Environment variables (optional) 

75 - name: str - Server name for logging (optional) 

76 timeout: Timeout in seconds for tool calls 

77 """ 

78 if not MCP_AVAILABLE: 

79 raise MCPClientError( 

80 "MCP client not available. Install with: pip install mcp[cli]" 

81 ) 

82 

83 # Validate server configuration for security 

84 self._validate_server_config(server_config) 

85 

86 self.config = server_config 

87 self.timeout = timeout 

88 self.name = server_config.get("name", "unnamed") 

89 self._session: Optional[ClientSession] = None 

90 self._read_stream = None 

91 self._write_stream = None 

92 self._connected = False 

93 

94 def _validate_server_config(self, config: Dict[str, Any]) -> None: 

95 """ 

96 Validate server configuration for security. 

97 

98 Args: 

99 config: Server configuration dictionary 

100 

101 Raises: 

102 MCPClientError: If configuration is invalid or insecure 

103 """ 

104 # Check command exists 

105 command = config.get("command", "") 

106 command = command.strip() 

107 if not command: 

108 raise MCPClientError("Server config missing 'command'") 

109 

110 if " " in command or "\t" in command: 

111 raise MCPClientError( 

112 "'command' must be a single executable name without spaces; " 

113 "pass arguments via the 'args' field" 

114 ) 

115 

116 # Write stripped command back for use by connect() 

117 config["command"] = command 

118 

119 # Extract base command (handle paths like /usr/bin/node) 

120 base_cmd = Path(command).name 

121 if base_cmd not in ALLOWED_COMMANDS: 

122 raise MCPClientError( 

123 f"Command '{base_cmd}' not in allowed list: {ALLOWED_COMMANDS}. " 

124 f"Add to ALLOWED_COMMANDS if this is a trusted command." 

125 ) 

126 

127 # Validate args are a list of strings 

128 args = config.get("args", []) 

129 if not isinstance(args, list): 

130 raise MCPClientError("Server 'args' must be a list") 

131 if not all(isinstance(a, str) for a in args): 

132 raise MCPClientError("Server 'args' must contain only strings") 

133 

134 # Validate env is a dict of strings if provided 

135 env = config.get("env") 

136 if env is not None: 

137 if not isinstance(env, dict): 

138 raise MCPClientError("Server 'env' must be a dictionary") 

139 if not all( 139 ↛ exitline 139 didn't return from function '_validate_server_config' because the condition on line 139 was always true

140 isinstance(k, str) and isinstance(v, str) 

141 for k, v in env.items() 

142 ): 

143 raise MCPClientError( 

144 "Server 'env' must contain only string keys and values" 

145 ) 

146 

147 @asynccontextmanager 

148 async def connect(self): 

149 """ 

150 Connect to the MCP server as an async context manager. 

151 

152 Yields: 

153 self: The connected client instance 

154 """ 

155 server_params = StdioServerParameters( 

156 command=self.config["command"], 

157 args=self.config.get("args", []), 

158 env=self.config.get("env"), 

159 ) 

160 

161 logger.info(f"Connecting to MCP server '{self.name}'...") 

162 

163 try: 

164 async with stdio_client(server_params) as ( 

165 read_stream, 

166 write_stream, 

167 ): 

168 self._read_stream = read_stream 

169 self._write_stream = write_stream 

170 

171 # ClientSession must be used as async context manager 

172 async with ClientSession(read_stream, write_stream) as session: 

173 self._session = session 

174 

175 await self._initialize_with_retry() 

176 

177 self._connected = True 

178 logger.info(f"Connected to MCP server '{self.name}'") 

179 

180 yield self 

181 

182 except Exception as e: 

183 logger.exception(f"Failed to connect to MCP server '{self.name}'") 

184 raise MCPClientError(f"Connection failed: {e}") from e 

185 finally: 

186 self._connected = False 

187 self._session = None 

188 self._read_stream = None 

189 self._write_stream = None 

190 

191 async def _initialize_with_retry(self) -> None: 

192 """Initialize the MCP session with retry logic for subprocess startup race. 

193 

194 Retries on TimeoutError and "before initialization" errors using 

195 exponential backoff, then wraps final failures in MCPClientError. 

196 """ 

197 try: 

198 async for attempt in AsyncRetrying( 

199 stop=stop_after_attempt(INIT_RETRY_ATTEMPTS), 

200 wait=wait_exponential( 

201 multiplier=INIT_RETRY_BASE_DELAY, 

202 max=INIT_RETRY_MAX_DELAY, 

203 exp_base=INIT_RETRY_BACKOFF_FACTOR, 

204 ), 

205 retry=( 

206 retry_if_exception_type(asyncio.TimeoutError) 

207 | retry_if_exception_message( 

208 match="(?i)before initialization" 

209 ) 

210 ), 

211 before_sleep=lambda rs: logger.debug( 

212 f"Server '{self.name}' not ready " 

213 f"(attempt {rs.attempt_number}/{INIT_RETRY_ATTEMPTS}), retrying..." 

214 ), 

215 reraise=True, 

216 ): 

217 with attempt: 

218 if self._session is None: 

219 raise MCPClientError("Session not initialized") # noqa: TRY301 — re-raised by except MCPClientError 

220 await asyncio.wait_for( 

221 self._session.initialize(), 

222 timeout=min(5.0, self.timeout), 

223 ) 

224 except asyncio.TimeoutError as e: 

225 raise MCPClientError( 

226 f"Timeout initializing connection to '{self.name}' " 

227 f"after {INIT_RETRY_ATTEMPTS} attempts" 

228 ) from e 

229 except MCPClientError: 

230 raise 

231 except Exception as e: 

232 if "before initialization" in str(e).lower(): 

233 raise MCPClientError( 

234 f"Failed to initialize connection to '{self.name}' " 

235 f"after {INIT_RETRY_ATTEMPTS} attempts: {e}" 

236 ) from e 

237 raise 

238 

239 async def list_tools(self) -> List[Dict[str, Any]]: 

240 """ 

241 List available tools on the connected server. 

242 

243 Returns: 

244 List of tool definitions with name, description, and input schema 

245 """ 

246 if not self._connected or not self._session: 

247 raise MCPClientError("Not connected to server") 

248 

249 try: 

250 result = await asyncio.wait_for( 

251 self._session.list_tools(), 

252 timeout=self.timeout, 

253 ) 

254 tools = [] 

255 for tool in result.tools: 

256 tools.append( 

257 { 

258 "name": tool.name, 

259 "description": tool.description or "", 

260 "input_schema": ( 

261 tool.inputSchema 

262 if hasattr(tool, "inputSchema") 

263 else {} 

264 ), 

265 } 

266 ) 

267 return tools 

268 except asyncio.TimeoutError: 

269 logger.warning( 

270 f"Timeout listing tools from '{self.name}' after {self.timeout}s" 

271 ) 

272 raise MCPClientError(f"Timeout listing tools after {self.timeout}s") 

273 except Exception as e: 

274 logger.exception(f"Failed to list tools from '{self.name}'") 

275 raise MCPClientError(f"Failed to list tools: {e}") from e 

276 

277 async def call_tool( 

278 self, 

279 name: str, 

280 arguments: Optional[Dict[str, Any]] = None, 

281 ) -> Dict[str, Any]: 

282 """ 

283 Call a tool on the connected server. 

284 

285 Args: 

286 name: Name of the tool to call 

287 arguments: Arguments to pass to the tool 

288 

289 Returns: 

290 Tool result as a dictionary 

291 """ 

292 if not self._connected or not self._session: 

293 raise MCPClientError("Not connected to server") 

294 

295 try: 

296 logger.debug( 

297 f"Calling tool '{name}' on '{self.name}' with args: {arguments}" 

298 ) 

299 

300 result = await asyncio.wait_for( 

301 self._session.call_tool(name, arguments or {}), 

302 timeout=self.timeout, 

303 ) 

304 

305 # Parse the result 

306 if result.isError: 

307 return { 

308 "status": "error", 

309 "error": str(result.content), 

310 } 

311 

312 # Extract content from the result 

313 content = [] 

314 # Check if result.content is iterable (not None or other non-iterable types) 

315 if result.content is not None and hasattr( 

316 result.content, "__iter__" 

317 ): 

318 for item in result.content: 

319 if hasattr(item, "text"): 

320 content.append(item.text) 

321 elif hasattr(item, "data"): 321 ↛ 324line 321 didn't jump to line 324 because the condition on line 321 was always true

322 content.append(str(item.data)) 

323 else: 

324 content.append(str(item)) 

325 elif result.content is not None: 325 ↛ 327line 325 didn't jump to line 327 because the condition on line 325 was never true

326 # Handle non-iterable content by converting to string 

327 content.append(str(result.content)) 

328 

329 return { 

330 "status": "success", 

331 "content": "\n".join(content) if content else "", 

332 "raw": result.content, 

333 } 

334 

335 except asyncio.TimeoutError: 

336 logger.exception( 

337 f"Tool call '{name}' timed out after {self.timeout}s" 

338 ) 

339 raise MCPClientError(f"Tool call timed out after {self.timeout}s") 

340 except Exception as e: 

341 logger.exception(f"Failed to call tool '{name}' on '{self.name}'") 

342 raise MCPClientError(f"Tool call failed: {e}") from e 

343 

344 

345class MCPClientManager: 

346 """ 

347 Manager for multiple MCP client connections. 

348 

349 Handles connecting to multiple MCP servers and aggregating their tools. 

350 """ 

351 

352 def __init__(self, server_configs: List[Dict[str, Any]]): 

353 """ 

354 Initialize the manager with server configurations. 

355 

356 Args: 

357 server_configs: List of server configurations 

358 """ 

359 self.server_configs = server_configs 

360 self._clients: Dict[str, MCPClient] = {} 

361 

362 @asynccontextmanager 

363 async def connect_all(self): 

364 """ 

365 Connect to all configured MCP servers. 

366 

367 Yields: 

368 self: The manager with all clients connected 

369 """ 

370 # Create clients for each server 

371 clients = [MCPClient(config) for config in self.server_configs] 

372 

373 # Connect to all servers 

374 # Connect sequentially to avoid overwhelming the system. 

375 # AsyncExitStack manages cleanup of successfully-entered contexts. 

376 async with AsyncExitStack() as stack: 

377 for client in clients: 377 ↛ 378line 377 didn't jump to line 378 because the loop on line 377 never started

378 ctx = None 

379 try: 

380 ctx = client.connect() 

381 connected = await stack.enter_async_context(ctx) 

382 self._clients[client.name] = connected 

383 except Exception: 

384 logger.warning( 

385 f"Failed to connect to server '{client.name}'. Skipping." 

386 ) 

387 # If ctx was created but __aenter__ failed, clean up 

388 # manually — AsyncExitStack only tracks contexts that 

389 # were successfully entered, but the MCP subprocess 

390 # may already be running. 

391 if ctx is not None: 

392 try: 

393 await ctx.__aexit__(None, None, None) 

394 except Exception: 

395 logger.debug( 

396 "best-effort cleanup of partially-entered async context", 

397 exc_info=True, 

398 ) 

399 

400 try: 

401 yield self 

402 finally: 

403 self._clients.clear() 

404 

405 async def list_all_tools(self) -> Dict[str, List[Dict[str, Any]]]: 

406 """ 

407 List tools from all connected servers. 

408 

409 Returns: 

410 Dictionary mapping server name to list of tools 

411 """ 

412 all_tools = {} 

413 for name, client in self._clients.items(): 

414 try: 

415 tools = await client.list_tools() 

416 all_tools[name] = tools 

417 except MCPClientError: 

418 logger.warning(f"Failed to list tools from '{name}'") 

419 all_tools[name] = [] 

420 return all_tools 

421 

422 async def call_tool( 

423 self, 

424 server_name: str, 

425 tool_name: str, 

426 arguments: Optional[Dict[str, Any]] = None, 

427 ) -> Dict[str, Any]: 

428 """ 

429 Call a tool on a specific server. 

430 

431 Args: 

432 server_name: Name of the server 

433 tool_name: Name of the tool to call 

434 arguments: Arguments to pass to the tool 

435 

436 Returns: 

437 Tool result 

438 """ 

439 if server_name not in self._clients: 439 ↛ 442line 439 didn't jump to line 442 because the condition on line 439 was always true

440 raise MCPClientError(f"Server '{server_name}' not connected") 

441 

442 return await self._clients[server_name].call_tool(tool_name, arguments) 

443 

444 def get_connected_servers(self) -> List[str]: 

445 """Get list of connected server names.""" 

446 return list(self._clients.keys()) 

447 

448 

449def run_async(coro, timeout: float = 300.0): 

450 """ 

451 Run an async coroutine synchronously. 

452 

453 Helper for running async MCP operations from sync code. 

454 

455 Args: 

456 coro: The coroutine to run 

457 timeout: Maximum time to wait in seconds (default 5 minutes) 

458 

459 Returns: 

460 The result of the coroutine 

461 

462 Raises: 

463 MCPClientError: If the operation times out 

464 """ 

465 import concurrent.futures 

466 

467 try: 

468 # Check if we're already in an async context 

469 asyncio.get_running_loop() 

470 # We are in an async context - use thread pool to avoid nesting 

471 with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: 

472 future = executor.submit( 

473 asyncio.run, asyncio.wait_for(coro, timeout=timeout) 

474 ) 

475 try: 

476 return future.result(timeout=timeout + 5.0) 

477 except TimeoutError: 

478 raise MCPClientError( 

479 f"Async operation timed out after {timeout}s" 

480 ) 

481 except RuntimeError: 

482 # No running event loop - safe to use asyncio.run() 

483 try: 

484 return asyncio.run(asyncio.wait_for(coro, timeout=timeout)) 

485 except TimeoutError: 

486 raise MCPClientError(f"Async operation timed out after {timeout}s")