Coverage for src / local_deep_research / mcp / client.py: 71%
185 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:55 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:55 +0000
1"""
2MCP Client utilities for connecting to MCP servers.
4This module provides a wrapper around the MCP client SDK for connecting
5to and calling tools on MCP servers.
6"""
8import asyncio
9from contextlib import AsyncExitStack, asynccontextmanager
10from pathlib import Path
11from typing import Any, Dict, List, Optional
13from tenacity import (
14 AsyncRetrying,
15 retry_if_exception_message,
16 retry_if_exception_type,
17 stop_after_attempt,
18 wait_exponential,
19)
21from loguru import logger
23# Allowed commands for MCP server execution (security whitelist)
24ALLOWED_COMMANDS = {"node", "npx", "python", "python3", "uv", "uvx", "docker"}
26# Retry configuration for subprocess startup race condition
27INIT_RETRY_ATTEMPTS = 5
28INIT_RETRY_BASE_DELAY = 0.1 # seconds
29INIT_RETRY_MAX_DELAY = 2.0 # seconds
30INIT_RETRY_BACKOFF_FACTOR = 2.0
32# The MCP SDK is an optional dependency (installed via the [mcp] extras group),
33# so we guard the import and set a flag for runtime availability checks.
34try:
35 from mcp import ClientSession
36 from mcp.client.stdio import StdioServerParameters, stdio_client
38 MCP_AVAILABLE = True
39except ImportError:
40 MCP_AVAILABLE = False
41 logger.warning(
42 "MCP client not available. Install with: pip install mcp[cli]"
43 )
46class MCPClientError(Exception):
47 """Error during MCP client operations."""
49 pass
52class MCPClient:
53 """
54 Client for connecting to and calling tools on MCP servers.
56 Usage:
57 async with MCPClient(config) as client:
58 tools = await client.list_tools()
59 result = await client.call_tool("tool_name", {"arg": "value"})
60 """
62 def __init__(
63 self,
64 server_config: Dict[str, Any],
65 timeout: float = 60.0,
66 ):
67 """
68 Initialize MCP client.
70 Args:
71 server_config: Server configuration with keys:
72 - command: str - Command to run (e.g., "python")
73 - args: List[str] - Command arguments
74 - env: Dict[str, str] - Environment variables (optional)
75 - name: str - Server name for logging (optional)
76 timeout: Timeout in seconds for tool calls
77 """
78 if not MCP_AVAILABLE:
79 raise MCPClientError(
80 "MCP client not available. Install with: pip install mcp[cli]"
81 )
83 # Validate server configuration for security
84 self._validate_server_config(server_config)
86 self.config = server_config
87 self.timeout = timeout
88 self.name = server_config.get("name", "unnamed")
89 self._session: Optional[ClientSession] = None
90 self._read_stream = None
91 self._write_stream = None
92 self._connected = False
94 def _validate_server_config(self, config: Dict[str, Any]) -> None:
95 """
96 Validate server configuration for security.
98 Args:
99 config: Server configuration dictionary
101 Raises:
102 MCPClientError: If configuration is invalid or insecure
103 """
104 # Check command exists
105 command = config.get("command", "")
106 command = command.strip()
107 if not command:
108 raise MCPClientError("Server config missing 'command'")
110 if " " in command or "\t" in command:
111 raise MCPClientError(
112 "'command' must be a single executable name without spaces; "
113 "pass arguments via the 'args' field"
114 )
116 # Write stripped command back for use by connect()
117 config["command"] = command
119 # Extract base command (handle paths like /usr/bin/node)
120 base_cmd = Path(command).name
121 if base_cmd not in ALLOWED_COMMANDS:
122 raise MCPClientError(
123 f"Command '{base_cmd}' not in allowed list: {ALLOWED_COMMANDS}. "
124 f"Add to ALLOWED_COMMANDS if this is a trusted command."
125 )
127 # Validate args are a list of strings
128 args = config.get("args", [])
129 if not isinstance(args, list):
130 raise MCPClientError("Server 'args' must be a list")
131 if not all(isinstance(a, str) for a in args):
132 raise MCPClientError("Server 'args' must contain only strings")
134 # Validate env is a dict of strings if provided
135 env = config.get("env")
136 if env is not None:
137 if not isinstance(env, dict):
138 raise MCPClientError("Server 'env' must be a dictionary")
139 if not all( 139 ↛ exitline 139 didn't return from function '_validate_server_config' because the condition on line 139 was always true
140 isinstance(k, str) and isinstance(v, str)
141 for k, v in env.items()
142 ):
143 raise MCPClientError(
144 "Server 'env' must contain only string keys and values"
145 )
147 @asynccontextmanager
148 async def connect(self):
149 """
150 Connect to the MCP server as an async context manager.
152 Yields:
153 self: The connected client instance
154 """
155 server_params = StdioServerParameters(
156 command=self.config["command"],
157 args=self.config.get("args", []),
158 env=self.config.get("env"),
159 )
161 logger.info(f"Connecting to MCP server '{self.name}'...")
163 try:
164 async with stdio_client(server_params) as (
165 read_stream,
166 write_stream,
167 ):
168 self._read_stream = read_stream
169 self._write_stream = write_stream
171 # ClientSession must be used as async context manager
172 async with ClientSession(read_stream, write_stream) as session:
173 self._session = session
175 await self._initialize_with_retry()
177 self._connected = True
178 logger.info(f"Connected to MCP server '{self.name}'")
180 yield self
182 except Exception as e:
183 logger.exception(f"Failed to connect to MCP server '{self.name}'")
184 raise MCPClientError(f"Connection failed: {e}") from e
185 finally:
186 self._connected = False
187 self._session = None
188 self._read_stream = None
189 self._write_stream = None
191 async def _initialize_with_retry(self) -> None:
192 """Initialize the MCP session with retry logic for subprocess startup race.
194 Retries on TimeoutError and "before initialization" errors using
195 exponential backoff, then wraps final failures in MCPClientError.
196 """
197 try:
198 async for attempt in AsyncRetrying(
199 stop=stop_after_attempt(INIT_RETRY_ATTEMPTS),
200 wait=wait_exponential(
201 multiplier=INIT_RETRY_BASE_DELAY,
202 max=INIT_RETRY_MAX_DELAY,
203 exp_base=INIT_RETRY_BACKOFF_FACTOR,
204 ),
205 retry=(
206 retry_if_exception_type(asyncio.TimeoutError)
207 | retry_if_exception_message(
208 match="(?i)before initialization"
209 )
210 ),
211 before_sleep=lambda rs: logger.debug(
212 f"Server '{self.name}' not ready "
213 f"(attempt {rs.attempt_number}/{INIT_RETRY_ATTEMPTS}), retrying..."
214 ),
215 reraise=True,
216 ):
217 with attempt:
218 if self._session is None:
219 raise MCPClientError("Session not initialized") # noqa: TRY301 — re-raised by except MCPClientError
220 await asyncio.wait_for(
221 self._session.initialize(),
222 timeout=min(5.0, self.timeout),
223 )
224 except asyncio.TimeoutError as e:
225 raise MCPClientError(
226 f"Timeout initializing connection to '{self.name}' "
227 f"after {INIT_RETRY_ATTEMPTS} attempts"
228 ) from e
229 except MCPClientError:
230 raise
231 except Exception as e:
232 if "before initialization" in str(e).lower():
233 raise MCPClientError(
234 f"Failed to initialize connection to '{self.name}' "
235 f"after {INIT_RETRY_ATTEMPTS} attempts: {e}"
236 ) from e
237 raise
239 async def list_tools(self) -> List[Dict[str, Any]]:
240 """
241 List available tools on the connected server.
243 Returns:
244 List of tool definitions with name, description, and input schema
245 """
246 if not self._connected or not self._session:
247 raise MCPClientError("Not connected to server")
249 try:
250 result = await asyncio.wait_for(
251 self._session.list_tools(),
252 timeout=self.timeout,
253 )
254 tools = []
255 for tool in result.tools:
256 tools.append(
257 {
258 "name": tool.name,
259 "description": tool.description or "",
260 "input_schema": (
261 tool.inputSchema
262 if hasattr(tool, "inputSchema")
263 else {}
264 ),
265 }
266 )
267 return tools
268 except asyncio.TimeoutError:
269 logger.warning(
270 f"Timeout listing tools from '{self.name}' after {self.timeout}s"
271 )
272 raise MCPClientError(f"Timeout listing tools after {self.timeout}s")
273 except Exception as e:
274 logger.exception(f"Failed to list tools from '{self.name}'")
275 raise MCPClientError(f"Failed to list tools: {e}") from e
277 async def call_tool(
278 self,
279 name: str,
280 arguments: Optional[Dict[str, Any]] = None,
281 ) -> Dict[str, Any]:
282 """
283 Call a tool on the connected server.
285 Args:
286 name: Name of the tool to call
287 arguments: Arguments to pass to the tool
289 Returns:
290 Tool result as a dictionary
291 """
292 if not self._connected or not self._session:
293 raise MCPClientError("Not connected to server")
295 try:
296 logger.debug(
297 f"Calling tool '{name}' on '{self.name}' with args: {arguments}"
298 )
300 result = await asyncio.wait_for(
301 self._session.call_tool(name, arguments or {}),
302 timeout=self.timeout,
303 )
305 # Parse the result
306 if result.isError:
307 return {
308 "status": "error",
309 "error": str(result.content),
310 }
312 # Extract content from the result
313 content = []
314 # Check if result.content is iterable (not None or other non-iterable types)
315 if result.content is not None and hasattr(
316 result.content, "__iter__"
317 ):
318 for item in result.content:
319 if hasattr(item, "text"):
320 content.append(item.text)
321 elif hasattr(item, "data"): 321 ↛ 324line 321 didn't jump to line 324 because the condition on line 321 was always true
322 content.append(str(item.data))
323 else:
324 content.append(str(item))
325 elif result.content is not None: 325 ↛ 327line 325 didn't jump to line 327 because the condition on line 325 was never true
326 # Handle non-iterable content by converting to string
327 content.append(str(result.content))
329 return {
330 "status": "success",
331 "content": "\n".join(content) if content else "",
332 "raw": result.content,
333 }
335 except asyncio.TimeoutError:
336 logger.exception(
337 f"Tool call '{name}' timed out after {self.timeout}s"
338 )
339 raise MCPClientError(f"Tool call timed out after {self.timeout}s")
340 except Exception as e:
341 logger.exception(f"Failed to call tool '{name}' on '{self.name}'")
342 raise MCPClientError(f"Tool call failed: {e}") from e
345class MCPClientManager:
346 """
347 Manager for multiple MCP client connections.
349 Handles connecting to multiple MCP servers and aggregating their tools.
350 """
352 def __init__(self, server_configs: List[Dict[str, Any]]):
353 """
354 Initialize the manager with server configurations.
356 Args:
357 server_configs: List of server configurations
358 """
359 self.server_configs = server_configs
360 self._clients: Dict[str, MCPClient] = {}
362 @asynccontextmanager
363 async def connect_all(self):
364 """
365 Connect to all configured MCP servers.
367 Yields:
368 self: The manager with all clients connected
369 """
370 # Create clients for each server
371 clients = [MCPClient(config) for config in self.server_configs]
373 # Connect to all servers
374 # Connect sequentially to avoid overwhelming the system.
375 # AsyncExitStack manages cleanup of successfully-entered contexts.
376 async with AsyncExitStack() as stack:
377 for client in clients: 377 ↛ 378line 377 didn't jump to line 378 because the loop on line 377 never started
378 ctx = None
379 try:
380 ctx = client.connect()
381 connected = await stack.enter_async_context(ctx)
382 self._clients[client.name] = connected
383 except Exception:
384 logger.warning(
385 f"Failed to connect to server '{client.name}'. Skipping."
386 )
387 # If ctx was created but __aenter__ failed, clean up
388 # manually — AsyncExitStack only tracks contexts that
389 # were successfully entered, but the MCP subprocess
390 # may already be running.
391 if ctx is not None:
392 try:
393 await ctx.__aexit__(None, None, None)
394 except Exception:
395 logger.debug(
396 "best-effort cleanup of partially-entered async context",
397 exc_info=True,
398 )
400 try:
401 yield self
402 finally:
403 self._clients.clear()
405 async def list_all_tools(self) -> Dict[str, List[Dict[str, Any]]]:
406 """
407 List tools from all connected servers.
409 Returns:
410 Dictionary mapping server name to list of tools
411 """
412 all_tools = {}
413 for name, client in self._clients.items():
414 try:
415 tools = await client.list_tools()
416 all_tools[name] = tools
417 except MCPClientError:
418 logger.warning(f"Failed to list tools from '{name}'")
419 all_tools[name] = []
420 return all_tools
422 async def call_tool(
423 self,
424 server_name: str,
425 tool_name: str,
426 arguments: Optional[Dict[str, Any]] = None,
427 ) -> Dict[str, Any]:
428 """
429 Call a tool on a specific server.
431 Args:
432 server_name: Name of the server
433 tool_name: Name of the tool to call
434 arguments: Arguments to pass to the tool
436 Returns:
437 Tool result
438 """
439 if server_name not in self._clients: 439 ↛ 442line 439 didn't jump to line 442 because the condition on line 439 was always true
440 raise MCPClientError(f"Server '{server_name}' not connected")
442 return await self._clients[server_name].call_tool(tool_name, arguments)
444 def get_connected_servers(self) -> List[str]:
445 """Get list of connected server names."""
446 return list(self._clients.keys())
449def run_async(coro, timeout: float = 300.0):
450 """
451 Run an async coroutine synchronously.
453 Helper for running async MCP operations from sync code.
455 Args:
456 coro: The coroutine to run
457 timeout: Maximum time to wait in seconds (default 5 minutes)
459 Returns:
460 The result of the coroutine
462 Raises:
463 MCPClientError: If the operation times out
464 """
465 import concurrent.futures
467 try:
468 # Check if we're already in an async context
469 asyncio.get_running_loop()
470 # We are in an async context - use thread pool to avoid nesting
471 with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
472 future = executor.submit(
473 asyncio.run, asyncio.wait_for(coro, timeout=timeout)
474 )
475 try:
476 return future.result(timeout=timeout + 5.0)
477 except TimeoutError:
478 raise MCPClientError(
479 f"Async operation timed out after {timeout}s"
480 )
481 except RuntimeError:
482 # No running event loop - safe to use asyncio.run()
483 try:
484 return asyncio.run(asyncio.wait_for(coro, timeout=timeout))
485 except TimeoutError:
486 raise MCPClientError(f"Async operation timed out after {timeout}s")