diff --git a/README.md b/README.md index 69c9f00..c29f413 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ - [2. SSE to stdio](#2-sse-to-stdio) - [2.1 Configuration](#21-configuration) - [2.2 Example usage](#22-example-usage) + - [Named Servers](#named-servers) - [Installation](#installation) - [Installing via Smithery](#installing-via-smithery) - [Installing via PyPI](#installing-via-pypi) @@ -23,6 +24,7 @@ - [Extending the container image](#extending-the-container-image) - [Docker Compose Setup](#docker-compose-setup) - [Command line arguments](#command-line-arguments) + - [Example config file](#example-config-file) - [Testing](#testing) ## About @@ -115,18 +117,20 @@ separator. Arguments -| Name | Required | Description | Example | -| ------------------------- | -------------------------- | --------------------------------------------------------------------------------------------- | --------------------- | -| `command_or_url` | Yes | The command to spawn the MCP stdio server | uvx mcp-server-fetch | -| `--port` | No, random available | The MCP server port to listen on | 8080 | -| `--host` | No, `127.0.0.1` by default | The host IP address that the MCP server will listen on | 0.0.0.0 | -| `--env` | No | Additional environment variables to pass to the MCP stdio server. Can be used multiple times. | FOO BAR | -| `--cwd` | No | The working directory to pass to the MCP stdio server process. | /tmp | -| `--pass-environment` | No | Pass through all environment variables when spawning the server | --no-pass-environment | -| `--allow-origin` | No | Allowed origins for the SSE server. Can be used multiple times. Default is no CORS allowed. | --allow-cors "\*" | -| `--stateless` | No | Enable stateless mode for streamable http transports. Default is False | --no-stateless | -| `--sse-port` (deprecated) | No, random available | The SSE server port to listen on | 8080 | -| `--sse-host` (deprecated) | No, `127.0.0.1` by default | The host IP address that the SSE server will listen on | 0.0.0.0 | +| Name | Required | Description | Example | +|--------------------------------------|----------------------------|---------------------------------------------------------------------------------------------|---------------------------------------------| +| `command_or_url` | Yes | The command to spawn the MCP stdio server | uvx mcp-server-fetch | +| `--port` | No, random available | The MCP server port to listen on | 8080 | +| `--host` | No, `127.0.0.1` by default | The host IP address that the MCP server will listen on | 0.0.0.0 | +| `--env` | No | Additional environment variables to pass to the MCP stdio server. Can be used multiple times. | FOO BAR | +| `--cwd` | No | The working directory to pass to the MCP stdio server process. | /tmp | +| `--pass-environment` | No | Pass through all environment variables when spawning the server | --no-pass-environment | +| `--allow-origin` | No | Allowed origins for the SSE server. Can be used multiple times. Default is no CORS allowed. | --allow-origin "\*" | +| `--stateless` | No | Enable stateless mode for streamable http transports. Default is False | --no-stateless | +| `--named-server NAME COMMAND_STRING` | No | Defines a named stdio server. | --named-server fetch 'uvx mcp-server-fetch' | +| `--named-server-config FILE_PATH` | No | Path to a JSON file defining named stdio servers. | --named-server-config /path/to/servers.json | +| `--sse-port` (deprecated) | No, random available | The SSE server port to listen on | 8080 | +| `--sse-host` (deprecated) | No, `127.0.0.1` by default | The host IP address that the SSE server will listen on | 0.0.0.0 | ### 2.2 Example usage @@ -148,10 +152,61 @@ mcp-proxy --host=0.0.0.0 --port=8080 uvx mcp-server-fetch # Note that the `--` separator is used to separate the `mcp-proxy` arguments from the `mcp-server-fetch` arguments # (deprecated) mcp-proxy --sse-port=8080 -- uvx mcp-server-fetch --user-agent=YourUserAgent mcp-proxy --port=8080 -- uvx mcp-server-fetch --user-agent=YourUserAgent + +# Start multiple named MCP servers behind the proxy +mcp-proxy --port=8080 --named-server fetch 'uvx mcp-server-fetch' --named-server fetch2 'uvx mcp-server-fetch' + +# Start multiple named MCP servers using a configuration file +mcp-proxy --port=8080 --named-server-config ./servers.json ``` -This will start an MCP server that can be connected to at `http://127.0.0.1:8080/sse` via SSE, or -`http://127.0.0.1:8080/mcp/` via StreamableHttp +## Named Servers + +- `NAME` is used in the URL path `/servers/NAME/`. +- `COMMAND_STRING` is the command to start the server (e.g., 'uvx mcp-server-fetch'). + - Can be used multiple times. + - This argument is ignored if `--named-server-config` is used. +- `FILE_PATH` - If provided, this is the exclusive source for named servers, and `--named-server` CLI arguments are ignored. + +If a default server is specified (the `command_or_url` argument without `--named-server` or `--named-server-config`), it will be accessible at the root paths (e.g., `http://127.0.0.1:8080/sse`). + +Named servers (whether defined by `--named-server` or `--named-server-config`) will be accessible under `/servers//` (e.g., `http://127.0.0.1:8080/servers/fetch1/sse`). +The `/status` endpoint provides global status. + +**JSON Configuration File Format for `--named-server-config`:** + +The JSON file should follow this structure: + +```json +{ + "mcpServers": { + "fetch": { + "disabled": false, + "timeout": 60, + "command": "uvx", + "args": [ + "mcp-server-fetch" + ], + "transportType": "stdio" + }, + "github": { + "timeout": 60, + "command": "npx", + "args": [ + "-y", + "@modelcontextprotocol/server-github" + ], + "transportType": "stdio" + } + } +} +``` + +- `mcpServers`: A dictionary where each key is the server name (used in the URL path, e.g., `/servers/fetch/`) and the value is an object defining the server. +- `command`: (Required) The command to execute for the stdio server. +- `args`: (Optional) A list of arguments for the command. Defaults to an empty list. +- `enabled`: (Optional) If `false`, this server definition will be skipped. Defaults to `true`. +- `timeout` and `transportType`: These fields are present in standard MCP client configurations but are currently **ignored** by `mcp-proxy` when loading named servers. The transport type is implicitly "stdio". ## Installation @@ -258,14 +313,20 @@ services: ```bash usage: mcp-proxy [-h] [-H KEY VALUE] [--transport {sse,streamablehttp}] [-e KEY VALUE] [--cwd CWD] [--pass-environment | --no-pass-environment] - [--debug | --no-debug] [--port PORT] [--host HOST] [--stateless | --no-stateless] [--sse-port SSE_PORT] [--sse-host SSE_HOST] - [--allow-origin ALLOW_ORIGIN [ALLOW_ORIGIN ...]] + [--debug | --no-debug] [--named-server NAME COMMAND_STRING] [--named-server-config FILE_PATH] [--port PORT] [--host HOST] + [--stateless | --no-stateless] [--sse-port SSE_PORT] [--sse-host SSE_HOST] [--allow-origin ALLOW_ORIGIN [ALLOW_ORIGIN ...]] [command_or_url] [args ...] -Start the MCP proxy in one of two possible modes: as a client or a server. +Start the MCP proxy. +It can run as an SSE client (connecting to a remote SSE server and exposing stdio). +Or, it can run as an SSE server (connecting to local stdio command(s) and exposing them over SSE). +When running as an SSE server, it can proxy a single default stdio command or multiple named stdio commands (defined via CLI or a config file). positional arguments: - command_or_url Command or URL to connect to. When a URL, will run an SSE/StreamableHTTP client, otherwise will run the given command and connect as a stdio client. See corresponding options for more details. + command_or_url Command or URL. + If URL (http/https): Runs in SSE/StreamableHTTP client mode. + If command string: Runs in SSE server mode, this is the default stdio server. + If --named-server or --named-server-config is used, this can be omitted if no default server is desired. options: -h, --help show this help message and exit @@ -277,12 +338,16 @@ SSE/StreamableHTTP client options: The transport to use for the client. Default is SSE. stdio client options: - args Any extra arguments to the command to spawn the server - -e, --env KEY VALUE Environment variables used when spawning the server. Can be used multiple times. - --cwd CWD The working directory to use when spawning the process. + args Any extra arguments to the command to spawn the default server. Ignored if only named servers are defined. + -e, --env KEY VALUE Environment variables used when spawning the default server. Can be used multiple times. For named servers, environment is inherited or passed via --pass-environment. + --cwd CWD The working directory to use when spawning the default server process. Named servers inherit the proxy's CWD. --pass-environment, --no-pass-environment - Pass through all environment variables when spawning the server. + Pass through all environment variables when spawning all server processes. --debug, --no-debug Enable debug mode with detailed logging output. + --named-server NAME COMMAND_STRING + Define a named stdio server. NAME is for the URL path /servers/NAME/. COMMAND_STRING is a single string with the command and its arguments (e.g., 'uvx mcp-server-fetch --timeout 10'). These servers inherit the proxy's CWD and environment from --pass-environment. Can be specified multiple times. Ignored if --named-server-config is used. + --named-server-config FILE_PATH + Path to a JSON configuration file for named stdio servers. If provided, this will be the exclusive source for named server definitions, and any --named-server CLI arguments will be ignored. SSE server options: --port PORT Port to expose an SSE server on. Default is a random port @@ -298,9 +363,39 @@ Examples: mcp-proxy http://localhost:8080/sse mcp-proxy --transport streamablehttp http://localhost:8080/mcp mcp-proxy --headers Authorization 'Bearer YOUR_TOKEN' http://localhost:8080/sse - mcp-proxy --port 8080 -- your-command --arg1 value1 --arg2 value2 - mcp-proxy your-command --port 8080 -e KEY VALUE -e ANOTHER_KEY ANOTHER_VALUE - mcp-proxy your-command --port 8080 --allow-origin='*' + mcp-proxy --port 8080 -- my-default-command --arg1 value1 + mcp-proxy --port 8080 --named-server fetch1 'uvx mcp-server-fetch' --named-server tool2 'my-custom-tool --verbose' + mcp-proxy --port 8080 --named-server-config /path/to/servers.json + mcp-proxy --port 8080 --named-server-config /path/to/servers.json -- my-default-command --arg1 + mcp-proxy --port 8080 -e KEY VALUE -e ANOTHER_KEY ANOTHER_VALUE -- my-default-command + mcp-proxy --port 8080 --allow-origin='*' -- my-default-command +``` + +### Example config file + +```json +{ + "mcpServers": { + "fetch": { + "enabled": true, + "timeout": 60, + "command": "uvx", + "args": [ + "mcp-server-fetch" + ], + "transportType": "stdio" + }, + "github": { + "timeout": 60, + "command": "npx", + "args": [ + "-y", + "@modelcontextprotocol/server-github" + ], + "transportType": "stdio" + } + } +} ``` ## Testing diff --git a/config_example.json b/config_example.json new file mode 100644 index 0000000..760f055 --- /dev/null +++ b/config_example.json @@ -0,0 +1,23 @@ +{ + "mcpServers": { + "fetch": { + "enabled": true, + "timeout": 60, + "command": "uvx", + "args": [ + "mcp-server-fetch" + ], + "transportType": "stdio" + }, + "github": { + "enabled": false, + "timeout": 60, + "command": "npx", + "args": [ + "-y", + "@modelcontextprotocol/server-github" + ], + "transportType": "stdio" + } + } +} diff --git a/src/mcp_proxy/__main__.py b/src/mcp_proxy/__main__.py index 76365a8..8295020 100644 --- a/src/mcp_proxy/__main__.py +++ b/src/mcp_proxy/__main__.py @@ -8,13 +8,16 @@ Two ways to run the application: import argparse import asyncio +import json import logging import os +import shlex import sys import typing as t from mcp.client.stdio import StdioServerParameters +from .config_loader import load_named_server_configs_from_file from .mcp_server import MCPServerSettings, run_mcp_server from .sse_client import run_sse_client from .streamablehttp_client import run_streamablehttp_client @@ -26,8 +29,8 @@ SSE_URL: t.Final[str | None] = os.getenv( ) -def main() -> None: - """Start the client using asyncio.""" +def _setup_argument_parser() -> argparse.ArgumentParser: + """Set up and return the argument parser for the MCP proxy.""" parser = argparse.ArgumentParser( description=("Start the MCP proxy in one of two possible modes: as a client or a server."), epilog=( @@ -36,19 +39,29 @@ def main() -> None: " mcp-proxy --transport streamablehttp http://localhost:8080/mcp\n" " mcp-proxy --headers Authorization 'Bearer YOUR_TOKEN' http://localhost:8080/sse\n" " mcp-proxy --port 8080 -- your-command --arg1 value1 --arg2 value2\n" + " mcp-proxy --named-server fetch 'uvx mcp-server-fetch' --port 8080\n" " mcp-proxy your-command --port 8080 -e KEY VALUE -e ANOTHER_KEY ANOTHER_VALUE\n" " mcp-proxy your-command --port 8080 --allow-origin='*'\n" ), formatter_class=argparse.RawTextHelpFormatter, ) + + _add_arguments_to_parser(parser) + return parser + + +def _add_arguments_to_parser(parser: argparse.ArgumentParser) -> None: + """Add all arguments to the argument parser.""" parser.add_argument( "command_or_url", help=( - "Command or URL to connect to. When a URL, will run an SSE/StreamableHTTP client, " - "otherwise will run the given command and connect as a stdio client. " + "Command or URL to connect to. When a URL, will run an SSE/StreamableHTTP client. " + "Otherwise, if --named-server is not used, this will be the command " + "for the default stdio client. If --named-server is used, this argument " + "is ignored for stdio mode unless no default server is desired. " "See corresponding options for more details." ), - nargs="?", # Required below to allow for coming form env var + nargs="?", default=SSE_URL, ) @@ -73,7 +86,10 @@ def main() -> None: stdio_client_options.add_argument( "args", nargs="*", - help="Any extra arguments to the command to spawn the server", + help=( + "Any extra arguments to the command to spawn the default server. " + "Ignored if only named servers are defined." + ), ) stdio_client_options.add_argument( "-e", @@ -81,18 +97,25 @@ def main() -> None: nargs=2, action="append", metavar=("KEY", "VALUE"), - help="Environment variables used when spawning the server. Can be used multiple times.", + help=( + "Environment variables used when spawning the default server. Can be " + "used multiple times. For named servers, environment is inherited or " + "passed via --pass-environment." + ), default=[], ) stdio_client_options.add_argument( "--cwd", default=None, - help="The working directory to use when spawning the process.", + help=( + "The working directory to use when spawning the default server process. " + "Named servers inherit the proxy's CWD." + ), ) stdio_client_options.add_argument( "--pass-environment", action=argparse.BooleanOptionalAction, - help="Pass through all environment variables when spawning the server.", + help="Pass through all environment variables when spawning all server processes.", default=False, ) stdio_client_options.add_argument( @@ -101,6 +124,31 @@ def main() -> None: help="Enable debug mode with detailed logging output.", default=False, ) + stdio_client_options.add_argument( + "--named-server", + action="append", + nargs=2, + metavar=("NAME", "COMMAND_STRING"), + help=( + "Define a named stdio server. NAME is for the URL path /servers/NAME/. " + "COMMAND_STRING is a single string with the command and its arguments " + "(e.g., 'uvx mcp-server-fetch --timeout 10'). " + "These servers inherit the proxy's CWD and environment from --pass-environment." + ), + default=[], + dest="named_server_definitions", + ) + stdio_client_options.add_argument( + "--named-server-config", + type=str, + default=None, + metavar="FILE_PATH", + help=( + "Path to a JSON configuration file for named stdio servers. " + "If provided, this will be the exclusive source for named server definitions, " + "and any --named-server CLI arguments will be ignored." + ), + ) mcp_server_group = parser.add_argument_group("SSE server options") mcp_server_group.add_argument( @@ -135,65 +183,222 @@ def main() -> None: "--allow-origin", nargs="+", default=[], - help="Allowed origins for the SSE server. " - "Can be used multiple times. Default is no CORS allowed.", + help=( + "Allowed origins for the SSE server. Can be used multiple times. " + "Default is no CORS allowed." + ), ) - args = parser.parse_args() - - if not args.command_or_url: - parser.print_help() - sys.exit(1) +def _setup_logging(*, debug: bool) -> logging.Logger: + """Set up logging configuration and return the logger.""" logging.basicConfig( - level=logging.DEBUG if args.debug else logging.INFO, + level=logging.DEBUG if debug else logging.INFO, format="[%(levelname)1.1s %(asctime)s.%(msecs).03d %(name)s] %(message)s", ) - logger = logging.getLogger(__name__) + return logging.getLogger(__name__) - if ( - SSE_URL - or args.command_or_url.startswith("http://") - or args.command_or_url.startswith("https://") + +def _handle_sse_client_mode( + args_parsed: argparse.Namespace, + logger: logging.Logger, +) -> None: + """Handle SSE/StreamableHTTP client mode operation.""" + if args_parsed.named_server_definitions: + logger.warning( + "--named-server arguments are ignored when command_or_url is an HTTP/HTTPS URL " + "(SSE/StreamableHTTP client mode).", + ) + # Start a client connected to the SSE server, and expose as a stdio server + logger.debug("Starting SSE/StreamableHTTP client and stdio server") + headers = dict(args_parsed.headers) + if api_access_token := os.getenv("API_ACCESS_TOKEN", None): + headers["Authorization"] = f"Bearer {api_access_token}" + + if args_parsed.transport == "streamablehttp": + asyncio.run(run_streamablehttp_client(args_parsed.command_or_url, headers=headers)) + else: + asyncio.run(run_sse_client(args_parsed.command_or_url, headers=headers)) + + +def _configure_default_server( + args_parsed: argparse.Namespace, + base_env: dict[str, str], + logger: logging.Logger, +) -> StdioServerParameters | None: + """Configure the default server if applicable.""" + if not ( + args_parsed.command_or_url + and not args_parsed.command_or_url.startswith(("http://", "https://")) ): - # Start a client connected to the SSE server, and expose as a stdio server - logger.debug("Starting SSE client and stdio server") - headers = dict(args.headers) - if api_access_token := os.getenv("API_ACCESS_TOKEN", None): - headers["Authorization"] = f"Bearer {api_access_token}" - if args.transport == "streamablehttp": - asyncio.run(run_streamablehttp_client(args.command_or_url, headers=headers)) - else: - asyncio.run(run_sse_client(args.command_or_url, headers=headers)) + return None + + default_server_env = base_env.copy() + default_server_env.update(dict(args_parsed.env)) # Specific env vars for default server + + default_stdio_params = StdioServerParameters( + command=args_parsed.command_or_url, + args=args_parsed.args, + env=default_server_env, + cwd=args_parsed.cwd if args_parsed.cwd else None, + ) + logger.info( + "Configured default server: %s %s", + args_parsed.command_or_url, + " ".join(args_parsed.args), + ) + return default_stdio_params + + +def _load_named_servers_from_config( + config_path: str, + base_env: dict[str, str], + logger: logging.Logger, +) -> dict[str, StdioServerParameters]: + """Load named server configurations from a file.""" + try: + return load_named_server_configs_from_file(config_path, base_env) + except (FileNotFoundError, json.JSONDecodeError, ValueError): + # Specific errors are already logged by the loader function + # We log a generic message here before exiting + logger.exception( + "Failed to load server configurations from %s. Exiting.", + config_path, + ) + sys.exit(1) + except Exception: # Catch any other unexpected errors from loader + logger.exception( + "An unexpected error occurred while loading server configurations from %s. Exiting.", + config_path, + ) + sys.exit(1) + + +def _configure_named_servers_from_cli( + named_server_definitions: list[tuple[str, str]], + base_env: dict[str, str], + logger: logging.Logger, +) -> dict[str, StdioServerParameters]: + """Configure named servers from CLI arguments.""" + named_stdio_params: dict[str, StdioServerParameters] = {} + + for name, command_string in named_server_definitions: + try: + command_parts = shlex.split(command_string) + if not command_parts: # Handle empty command_string + logger.error("Empty COMMAND_STRING for named server '%s'. Skipping.", name) + continue + + command = command_parts[0] + command_args = command_parts[1:] + # Named servers inherit base_env (which includes passed-through env) + # and use the proxy's CWD. + named_stdio_params[name] = StdioServerParameters( + command=command, + args=command_args, + env=base_env.copy(), # Each named server gets a copy of the base env + cwd=None, # Named servers run in the proxy's CWD + ) + logger.info("Configured named server '%s': %s", name, command_string) + except IndexError: # Should be caught by the check for empty command_parts + logger.exception( + "Invalid COMMAND_STRING for named server '%s': '%s'. Must include a command.", + name, + command_string, + ) + sys.exit(1) + except Exception: + logger.exception("Error parsing COMMAND_STRING for named server '%s'", name) + sys.exit(1) + + return named_stdio_params + + +def _create_mcp_settings(args_parsed: argparse.Namespace) -> MCPServerSettings: + """Create MCP server settings from parsed arguments.""" + return MCPServerSettings( + bind_host=args_parsed.host if args_parsed.host is not None else args_parsed.sse_host, + port=args_parsed.port if args_parsed.port is not None else args_parsed.sse_port, + stateless=args_parsed.stateless, + allow_origins=args_parsed.allow_origin if len(args_parsed.allow_origin) > 0 else None, + log_level="DEBUG" if args_parsed.debug else "INFO", + ) + + +def main() -> None: + """Start the client using asyncio.""" + parser = _setup_argument_parser() + args_parsed = parser.parse_args() + logger = _setup_logging(debug=args_parsed.debug) + + # Validate required arguments + if ( + not args_parsed.command_or_url + and not args_parsed.named_server_definitions + and not args_parsed.named_server_config + ): + parser.print_help() + logger.error( + "Either a command_or_url for a default server or at least one --named-server " + "(or --named-server-config) must be provided for stdio mode.", + ) + sys.exit(1) + + # Handle SSE client mode if URL is provided + if args_parsed.command_or_url and args_parsed.command_or_url.startswith( + ("http://", "https://"), + ): + _handle_sse_client_mode(args_parsed, logger) return - # Start a client connected to the given command, and expose as an SSE server - logger.debug("Starting stdio client and SSE server") + # Start stdio client(s) and expose as an SSE server + logger.debug("Configuring stdio client(s) and SSE server") - # The environment variables passed to the server process - env: dict[str, str] = {} - # Pass through current environment variables if configured - if args.pass_environment: - env.update(os.environ) - # Pass in and override any environment variables with those passed on the command line - env.update(dict(args.env)) + # Base environment for all spawned processes + base_env: dict[str, str] = {} + if args_parsed.pass_environment: + base_env.update(os.environ) - stdio_params = StdioServerParameters( - command=args.command_or_url, - args=args.args, - env=env, - cwd=args.cwd if args.cwd else None, + # Configure default server + default_stdio_params = _configure_default_server(args_parsed, base_env, logger) + + # Configure named servers + named_stdio_params: dict[str, StdioServerParameters] = {} + if args_parsed.named_server_config: + if args_parsed.named_server_definitions: + logger.warning( + "--named-server CLI arguments are ignored when --named-server-config is provided.", + ) + named_stdio_params = _load_named_servers_from_config( + args_parsed.named_server_config, + base_env, + logger, + ) + elif args_parsed.named_server_definitions: + named_stdio_params = _configure_named_servers_from_cli( + args_parsed.named_server_definitions, + base_env, + logger, + ) + + # Ensure at least one server is configured + if not default_stdio_params and not named_stdio_params: + parser.print_help() + logger.error( + "No stdio servers configured. Provide a default command or use --named-server.", + ) + sys.exit(1) + + # Create MCP server settings and run the server + mcp_settings = _create_mcp_settings(args_parsed) + asyncio.run( + run_mcp_server( + default_server_params=default_stdio_params, + named_server_params=named_stdio_params, + mcp_settings=mcp_settings, + ), ) - mcp_settings = MCPServerSettings( - bind_host=args.host if args.host is not None else args.sse_host, - port=args.port if args.port is not None else args.sse_port, - stateless=args.stateless, - allow_origins=args.allow_origin if len(args.allow_origin) > 0 else None, - log_level="DEBUG" if args.debug else "INFO", - ) - asyncio.run(run_mcp_server(stdio_params, mcp_settings)) - if __name__ == "__main__": main() diff --git a/src/mcp_proxy/config_loader.py b/src/mcp_proxy/config_loader.py new file mode 100644 index 0000000..ce12bb4 --- /dev/null +++ b/src/mcp_proxy/config_loader.py @@ -0,0 +1,99 @@ +"""Configuration loader for MCP proxy. + +This module provides functionality to load named server configurations from JSON files. +""" + +import json +import logging +from pathlib import Path + +from mcp.client.stdio import StdioServerParameters + +logger = logging.getLogger(__name__) + + +def load_named_server_configs_from_file( + config_file_path: str, + base_env: dict[str, str], +) -> dict[str, StdioServerParameters]: + """Loads named server configurations from a JSON file. + + Args: + config_file_path: Path to the JSON configuration file. + base_env: The base environment dictionary to be inherited by servers. + + Returns: + A dictionary of named server parameters. + + Raises: + FileNotFoundError: If the config file is not found. + json.JSONDecodeError: If the config file contains invalid JSON. + ValueError: If the config file format is invalid. + """ + named_stdio_params: dict[str, StdioServerParameters] = {} + logger.info("Loading named server configurations from: %s", config_file_path) + + try: + with Path(config_file_path).open() as f: + config_data = json.load(f) + except FileNotFoundError: + logger.exception("Configuration file not found: %s", config_file_path) + raise + except json.JSONDecodeError: + logger.exception("Error decoding JSON from configuration file: %s", config_file_path) + raise + except Exception as e: + logger.exception( + "Unexpected error opening or reading configuration file %s", + config_file_path, + ) + error_message = f"Could not read configuration file: {e}" + raise ValueError(error_message) from e + + if not isinstance(config_data, dict) or "mcpServers" not in config_data: + msg = f"Invalid config file format in {config_file_path}. Missing 'mcpServers' key." + logger.error(msg) + raise ValueError(msg) + + for name, server_config in config_data.get("mcpServers", {}).items(): + if not isinstance(server_config, dict): + logger.warning( + "Skipping invalid server config for '%s' in %s. Entry is not a dictionary.", + name, + config_file_path, + ) + continue + if not server_config.get("enabled", True): # Default to True if 'enabled' is not present + logger.info("Named server '%s' from config is not enabled. Skipping.", name) + continue + + command = server_config.get("command") + command_args = server_config.get("args", []) + + if not command: + logger.warning( + "Named server '%s' from config is missing 'command'. Skipping.", + name, + ) + continue + if not isinstance(command_args, list): + logger.warning( + "Named server '%s' from config has invalid 'args' (must be a list). Skipping.", + name, + ) + continue + + named_stdio_params[name] = StdioServerParameters( + command=command, + args=command_args, + env=base_env.copy(), + cwd=None, + ) + logger.info( + "Configured named server '%s' from config: %s %s", + name, + command, + " ".join(command_args), + ) + + return named_stdio_params diff --git a/src/mcp_proxy/mcp_server.py b/src/mcp_proxy/mcp_server.py index 9a3ddf2..0b31ff6 100644 --- a/src/mcp_proxy/mcp_server.py +++ b/src/mcp_proxy/mcp_server.py @@ -5,12 +5,12 @@ import logging from collections.abc import AsyncIterator from dataclasses import dataclass from datetime import datetime, timezone -from typing import Literal +from typing import Any, Literal import uvicorn from mcp.client.session import ClientSession from mcp.client.stdio import StdioServerParameters, stdio_client -from mcp.server import Server +from mcp.server import Server as MCPServerSDK # Renamed to avoid conflict from mcp.server.sse import SseServerTransport from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from starlette.applications import Starlette @@ -18,7 +18,7 @@ from starlette.middleware import Middleware from starlette.middleware.cors import CORSMiddleware from starlette.requests import Request from starlette.responses import JSONResponse, Response -from starlette.routing import Mount, Route +from starlette.routing import BaseRoute, Mount, Route from starlette.types import Receive, Scope, Send from .proxy_server import create_proxy_server @@ -37,123 +37,154 @@ class MCPServerSettings: log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO" -def create_starlette_app( - mcp_server: Server[object], +# To store last activity for multiple servers if needed, though status endpoint is global for now. +_global_status: dict[str, Any] = { + "api_last_activity": datetime.now(timezone.utc).isoformat(), + "server_instances": {}, # Could be used to store per-instance status later +} + + +def _update_global_activity() -> None: + _global_status["api_last_activity"] = datetime.now(timezone.utc).isoformat() + + +async def _handle_status(_: Request) -> Response: + """Global health check and service usage monitoring endpoint.""" + return JSONResponse(_global_status) + + +def create_single_instance_routes( + mcp_server_instance: MCPServerSDK[object], *, - stateless: bool = False, - allow_origins: list[str] | None = None, - debug: bool = False, -) -> Starlette: - """Create a Starlette application that can serve the mcp server with SSE or Streamable http.""" - logger.debug("Creating Starlette app with stateless: %s and debug: %s", stateless, debug) - # record the last activity of api - status = { - "api_last_activity": datetime.now(timezone.utc).isoformat(), - } + stateless_instance: bool, +) -> tuple[list[BaseRoute], StreamableHTTPSessionManager]: # Return the manager itself + """Create Starlette routes and the HTTP session manager for a single MCP server instance.""" + logger.debug( + "Creating routes for a single MCP server instance (stateless: %s)", + stateless_instance, + ) - def _update_mcp_activity() -> None: - status.update( - { - "api_last_activity": datetime.now(timezone.utc).isoformat(), - }, - ) + sse_transport = SseServerTransport("messages/") + http_session_manager = StreamableHTTPSessionManager( + app=mcp_server_instance, + event_store=None, + json_response=True, + stateless=stateless_instance, + ) - sse = SseServerTransport("/messages/") - - async def handle_sse(request: Request) -> None: - async with sse.connect_sse( + async def handle_sse_instance(request: Request) -> None: + async with sse_transport.connect_sse( request.scope, request.receive, request._send, # noqa: SLF001 ) as (read_stream, write_stream): - _update_mcp_activity() - - await mcp_server.run( + _update_global_activity() + await mcp_server_instance.run( read_stream, write_stream, - mcp_server.create_initialization_options(), + mcp_server_instance.create_initialization_options(), ) - # Refer: https://github.com/modelcontextprotocol/python-sdk/blob/v1.8.0/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py - http = StreamableHTTPSessionManager( - app=mcp_server, - event_store=None, - json_response=True, - stateless=stateless, - ) + async def handle_streamable_http_instance(scope: Scope, receive: Receive, send: Send) -> None: + _update_global_activity() + await http_session_manager.handle_request(scope, receive, send) - async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None: - _update_mcp_activity() - await http.handle_request(scope, receive, send) - - async def handle_status(_: Request) -> Response: - """Health check and service usage monitoring endpoint. - - Purpose of this handler: - - Provides a dedicated API endpoint for external health checks. - - Returns last API activity timestamp to monitor service usage patterns and uptime. - - Serves as basic infrastructure for potential future service metrics expansion. - """ - return JSONResponse(status) - - @contextlib.asynccontextmanager - async def lifespan(_: Starlette) -> AsyncIterator[None]: - """Context manager for session manager.""" - async with http.run(): - logger.info("Application started with StreamableHTTP session manager!") - try: - yield - finally: - logger.info("Application shutting down...") - - middleware: list[Middleware] = [] - if allow_origins is not None: - middleware.append( - Middleware( - CORSMiddleware, - allow_origins=allow_origins, - allow_methods=["*"], - allow_headers=["*"], - ), - ) - - return Starlette( - debug=debug, - middleware=middleware, - routes=[ - Route("/status", endpoint=handle_status), - Mount("/mcp", app=handle_streamable_http), - Route("/sse", endpoint=handle_sse), - Mount("/messages/", app=sse.handle_post_message), - ], - lifespan=lifespan, - ) + routes = [ + Mount("/mcp", app=handle_streamable_http_instance), + Route("/sse", endpoint=handle_sse_instance), + Mount("/messages/", app=sse_transport.handle_post_message), + ] + return routes, http_session_manager async def run_mcp_server( - stdio_params: StdioServerParameters, mcp_settings: MCPServerSettings, + default_server_params: StdioServerParameters | None = None, + named_server_params: dict[str, StdioServerParameters] | None = None, ) -> None: - """Run the stdio client and expose an MCP server. + """Run stdio client(s) and expose an MCP server with multiple possible backends.""" + if named_server_params is None: + named_server_params = {} - Args: - stdio_params: The parameters for the stdio client that spawns a stdio server. - mcp_settings: The settings for the MCP server that accepts incoming requests. + all_routes: list[BaseRoute] = [ + Route("/status", endpoint=_handle_status), # Global status endpoint + ] + # Use AsyncExitStack to manage lifecycles of multiple components + async with contextlib.AsyncExitStack() as stack: + # Manage lifespans of all StreamableHTTPSessionManagers + @contextlib.asynccontextmanager + async def combined_lifespan(_app: Starlette) -> AsyncIterator[None]: + logger.info("Main application lifespan starting...") + # All http_session_managers' .run() are already entered into the stack + yield + logger.info("Main application lifespan shutting down...") - """ - async with stdio_client(stdio_params) as streams, ClientSession(*streams) as session: - logger.debug("Starting proxy server...") - mcp_server = await create_proxy_server(session) + # Setup default server if configured + if default_server_params: + logger.info( + "Setting up default server: %s %s", + default_server_params.command, + " ".join(default_server_params.args), + ) + stdio_streams = await stack.enter_async_context(stdio_client(default_server_params)) + session = await stack.enter_async_context(ClientSession(*stdio_streams)) + proxy = await create_proxy_server(session) - # Bind request handling to MCP server - starlette_app = create_starlette_app( - mcp_server, - stateless=mcp_settings.stateless, - allow_origins=mcp_settings.allow_origins, + instance_routes, http_manager = create_single_instance_routes( + proxy, + stateless_instance=mcp_settings.stateless, + ) + await stack.enter_async_context(http_manager.run()) # Manage lifespan by calling run() + all_routes.extend(instance_routes) + _global_status["server_instances"]["default"] = "configured" + + # Setup named servers + for name, params in named_server_params.items(): + logger.info( + "Setting up named server '%s': %s %s", + name, + params.command, + " ".join(params.args), + ) + stdio_streams_named = await stack.enter_async_context(stdio_client(params)) + session_named = await stack.enter_async_context(ClientSession(*stdio_streams_named)) + proxy_named = await create_proxy_server(session_named) + + instance_routes_named, http_manager_named = create_single_instance_routes( + proxy_named, + stateless_instance=mcp_settings.stateless, + ) + await stack.enter_async_context( + http_manager_named.run(), + ) # Manage lifespan by calling run() + + # Mount these routes under /servers// + server_mount = Mount(f"/servers/{name}", routes=instance_routes_named) + all_routes.append(server_mount) + _global_status["server_instances"][name] = "configured" + + if not default_server_params and not named_server_params: + logger.error("No servers configured to run.") + return + + middleware: list[Middleware] = [] + if mcp_settings.allow_origins: + middleware.append( + Middleware( + CORSMiddleware, + allow_origins=mcp_settings.allow_origins, + allow_methods=["*"], + allow_headers=["*"], + ), + ) + + starlette_app = Starlette( debug=(mcp_settings.log_level == "DEBUG"), + routes=all_routes, + middleware=middleware, + lifespan=combined_lifespan, ) - # Configure HTTP server config = uvicorn.Config( starlette_app, host=mcp_settings.bind_host, @@ -161,8 +192,27 @@ async def run_mcp_server( log_level=mcp_settings.log_level.lower(), ) http_server = uvicorn.Server(config) + + # Print out the SSE URLs for all configured servers + base_url = f"http://{mcp_settings.bind_host}:{mcp_settings.port}" + sse_urls = [] + + # Add default server if configured + if default_server_params: + sse_urls.append(f"{base_url}/sse") + + # Add named servers + sse_urls.extend([f"{base_url}/servers/{name}/sse" for name in named_server_params]) + + # Display the SSE URLs prominently + if sse_urls: + # Using print directly for user visibility, with noqa to ignore linter warnings + logger.info("Serving MCP Servers via SSE:") + for url in sse_urls: + logger.info(" - %s", url) + logger.debug( - "Serving incoming requests on %s:%s", + "Serving incoming MCP requests on %s:%s", mcp_settings.bind_host, mcp_settings.port, ) diff --git a/src/mcp_proxy/proxy_server.py b/src/mcp_proxy/proxy_server.py index 5468b85..d0cc397 100644 --- a/src/mcp_proxy/proxy_server.py +++ b/src/mcp_proxy/proxy_server.py @@ -1,4 +1,4 @@ -"""Create an MCP server that proxies requests throgh an MCP client. +"""Create an MCP server that proxies requests through an MCP client. This server is created independent of any transport mechanism. """ diff --git a/tests/test_config_loader.py b/tests/test_config_loader.py new file mode 100644 index 0000000..0211a5b --- /dev/null +++ b/tests/test_config_loader.py @@ -0,0 +1,244 @@ +"""Tests for the configuration loader module.""" + +import json +import shutil +import tempfile +from collections.abc import Callable, Generator +from pathlib import Path +from unittest.mock import patch + +import pytest +from mcp.client.stdio import StdioServerParameters + +from mcp_proxy.config_loader import load_named_server_configs_from_file + + +@pytest.fixture +def create_temp_config_file() -> Generator[Callable[[dict], str], None, None]: + """Creates a temporary JSON config file and returns its path.""" + temp_files: list[str] = [] + + def _create_temp_config_file(config_content: dict) -> str: + with tempfile.NamedTemporaryFile( + mode="w", + delete=False, + suffix=".json", + ) as tmp_config: + json.dump(config_content, tmp_config) + temp_files.append(tmp_config.name) + return tmp_config.name + + yield _create_temp_config_file + + for f_path in temp_files: + path = Path(f_path) + if path.exists(): + path.unlink() + + +def test_load_valid_config(create_temp_config_file: Callable[[dict], str]) -> None: + """Test loading a valid configuration file.""" + config_content = { + "mcpServers": { + "server1": { + "command": "echo", + "args": ["hello"], + "enabled": True, + }, + "server2": { + "command": "cat", + "args": ["file.txt"], + }, + }, + } + tmp_config_path = create_temp_config_file(config_content) + base_env = {"PASSED": "env_value"} + + loaded_params = load_named_server_configs_from_file(tmp_config_path, base_env) + + assert "server1" in loaded_params + assert loaded_params["server1"].command == "echo" + assert loaded_params["server1"].args == ["hello"] + assert ( + loaded_params["server1"].env == base_env + ) # Env is a copy, check if it contains base_env items + + assert "server2" in loaded_params + assert loaded_params["server2"].command == "cat" + assert loaded_params["server2"].args == ["file.txt"] + assert loaded_params["server2"].env == base_env + + +def test_load_config_with_not_enabled_server( + create_temp_config_file: Callable[[dict], str], +) -> None: + """Test loading a configuration with disabled servers.""" + config_content = { + "mcpServers": { + "explicitly_enabled_server": {"command": "true_command", "enabled": True}, + # No 'enabled' flag, defaults to True + "implicitly_enabled_server": {"command": "another_true_command"}, + "not_enabled_server": {"command": "false_command", "enabled": False}, + }, + } + tmp_config_path = create_temp_config_file(config_content) + loaded_params = load_named_server_configs_from_file(tmp_config_path, {}) + + assert "explicitly_enabled_server" in loaded_params + assert loaded_params["explicitly_enabled_server"].command == "true_command" + assert "implicitly_enabled_server" in loaded_params + assert loaded_params["implicitly_enabled_server"].command == "another_true_command" + assert "not_enabled_server" not in loaded_params + + +def test_file_not_found() -> None: + """Test handling of non-existent configuration files.""" + with pytest.raises(FileNotFoundError): + load_named_server_configs_from_file("non_existent_file.json", {}) + + +def test_json_decode_error() -> None: + """Test handling of invalid JSON in configuration files.""" + # Create a file with invalid JSON content + with tempfile.NamedTemporaryFile( + mode="w", + delete=False, + suffix=".json", + ) as tmp_config: + tmp_config.write("this is not json {") + tmp_config_path = tmp_config.name + + # Use try/finally to ensure cleanup + try: + with pytest.raises(json.JSONDecodeError): + load_named_server_configs_from_file(tmp_config_path, {}) + finally: + path = Path(tmp_config_path) + if path.exists(): + path.unlink() + + +def test_load_example_fetch_config_if_uvx_exists() -> None: + """Test loading the example fetch configuration if uvx is available.""" + if not shutil.which("uvx"): + pytest.skip("uvx command not found in PATH, skipping test for example config.") + + # Assuming the test is run from the root of the repository + example_config_path = Path(__file__).parent.parent / "config_example.json" + + if not example_config_path.exists(): + pytest.fail( + f"Example config file not found at expected path: {example_config_path}", + ) + + base_env = {"EXAMPLE_ENV": "true"} + loaded_params = load_named_server_configs_from_file(example_config_path, base_env) + + assert "fetch" in loaded_params + fetch_param = loaded_params["fetch"] + assert isinstance(fetch_param, StdioServerParameters) + assert fetch_param.command == "uvx" + assert fetch_param.args == ["mcp-server-fetch"] + assert fetch_param.env == base_env + # The 'timeout' and 'transportType' fields from the config are currently ignored by the loader, + # so no need to assert them on StdioServerParameters. + + +def test_invalid_config_format_missing_mcpservers( + create_temp_config_file: Callable[[dict], str], +) -> None: + """Test handling of configuration files missing the mcpServers key.""" + config_content = {"some_other_key": "value"} + tmp_config_path = create_temp_config_file(config_content) + + with pytest.raises(ValueError, match="Missing 'mcpServers' key"): + load_named_server_configs_from_file(tmp_config_path, {}) + + +@patch("mcp_proxy.config_loader.logger") +def test_invalid_server_entry_not_dict( + mock_logger: object, + create_temp_config_file: Callable[[dict], str], +) -> None: + """Test handling of server entries that are not dictionaries.""" + config_content = {"mcpServers": {"server1": "not_a_dict"}} + tmp_config_path = create_temp_config_file(config_content) + + loaded_params = load_named_server_configs_from_file(tmp_config_path, {}) + assert len(loaded_params) == 0 # No servers should be loaded + mock_logger.warning.assert_called_with( + "Skipping invalid server config for '%s' in %s. Entry is not a dictionary.", + "server1", + tmp_config_path, + ) + + +@patch("mcp_proxy.config_loader.logger") +def test_server_entry_missing_command( + mock_logger: object, + create_temp_config_file: Callable[[dict], str], +) -> None: + """Test handling of server entries missing the command field.""" + config_content = {"mcpServers": {"server_no_command": {"args": ["arg1"]}}} + tmp_config_path = create_temp_config_file(config_content) + loaded_params = load_named_server_configs_from_file(tmp_config_path, {}) + assert "server_no_command" not in loaded_params + mock_logger.warning.assert_called_with( + "Named server '%s' from config is missing 'command'. Skipping.", + "server_no_command", + ) + + +@patch("mcp_proxy.config_loader.logger") +def test_server_entry_invalid_args_type( + mock_logger: object, + create_temp_config_file: Callable[[dict], str], +) -> None: + """Test handling of server entries with invalid args type.""" + config_content = { + "mcpServers": { + "server_invalid_args": {"command": "mycmd", "args": "not_a_list"}, + }, + } + tmp_config_path = create_temp_config_file(config_content) + loaded_params = load_named_server_configs_from_file(tmp_config_path, {}) + assert "server_invalid_args" not in loaded_params + mock_logger.warning.assert_called_with( + "Named server '%s' from config has invalid 'args' (must be a list). Skipping.", + "server_invalid_args", + ) + + +def test_empty_mcpservers_dict(create_temp_config_file: Callable[[dict], str]) -> None: + """Test handling of configuration files with empty mcpServers dictionary.""" + config_content = {"mcpServers": {}} + tmp_config_path = create_temp_config_file(config_content) + loaded_params = load_named_server_configs_from_file(tmp_config_path, {}) + assert len(loaded_params) == 0 + + +def test_config_file_is_empty_json_object(create_temp_config_file: Callable[[dict], str]) -> None: + """Test handling of configuration files with empty JSON objects.""" + config_content = {} # Empty JSON object + tmp_config_path = create_temp_config_file(config_content) + with pytest.raises(ValueError, match="Missing 'mcpServers' key"): + load_named_server_configs_from_file(tmp_config_path, {}) + + +def test_config_file_is_empty_string() -> None: + """Test handling of configuration files with empty content.""" + # Create a file with an empty string + with tempfile.NamedTemporaryFile( + mode="w", + delete=False, + suffix=".json", + ) as tmp_config: + tmp_config.write("") # Empty content + tmp_config_path = tmp_config.name + try: + with pytest.raises(json.JSONDecodeError): + load_named_server_configs_from_file(tmp_config_path, {}) + finally: + path = Path(tmp_config_path) + if path.exists(): + path.unlink() diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py index cadf80d..b568551 100644 --- a/tests/test_mcp_server.py +++ b/tests/test_mcp_server.py @@ -1,18 +1,68 @@ """Tests for the sse server.""" +# ruff: noqa: PLR2004 import asyncio import contextlib import typing as t +from unittest.mock import AsyncMock, MagicMock, patch import pytest import uvicorn from mcp.client.session import ClientSession from mcp.client.sse import sse_client +from mcp.client.stdio import StdioServerParameters from mcp.client.streamable_http import streamablehttp_client -from mcp.server import FastMCP +from mcp.server import FastMCP, Server from mcp.types import TextContent +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.cors import CORSMiddleware -from mcp_proxy.mcp_server import create_starlette_app +from mcp_proxy.mcp_server import MCPServerSettings, create_single_instance_routes, run_mcp_server + + +def create_starlette_app( + mcp_server: Server[t.Any], + allow_origins: list[str] | None = None, + *, + debug: bool = False, + stateless: bool = False, +) -> Starlette: + """Create a Starlette application for the MCP server. + + Args: + mcp_server: The MCP server instance to wrap + allow_origins: List of allowed CORS origins + debug: Enable debug mode + stateless: Whether to use stateless HTTP sessions + + Returns: + Starlette application instance + """ + routes, http_manager = create_single_instance_routes(mcp_server, stateless_instance=stateless) + + middleware: list[Middleware] = [] + if allow_origins: + middleware.append( + Middleware( + CORSMiddleware, + allow_origins=allow_origins, + allow_methods=["*"], + allow_headers=["*"], + ), + ) + + @contextlib.asynccontextmanager + async def lifespan(_app: Starlette) -> t.AsyncIterator[None]: + async with http_manager.run(): + yield + + return Starlette( + debug=debug, + routes=routes, + middleware=middleware, + lifespan=lifespan, + ) class BackgroundServer(uvicorn.Server): @@ -64,7 +114,6 @@ def make_background_server(**kwargs) -> BackgroundServer: # noqa: ANN003 return BackgroundServer(config) -@pytest.mark.asyncio async def test_sse_transport() -> None: """Test basic glue code for the SSE transport and a fake MCP server.""" server = make_background_server(debug=True) @@ -77,7 +126,6 @@ async def test_sse_transport() -> None: assert response.prompts[0].name == "prompt1" -@pytest.mark.asyncio async def test_http_transport() -> None: """Test HTTP transport layer functionality.""" server = make_background_server(debug=True) @@ -118,3 +166,509 @@ async def test_stateless_http_transport() -> None: assert len(tool_result.content) == 1 assert isinstance(tool_result.content[0], TextContent) assert tool_result.content[0].text == f"Echo: test_{i}" + + +# Unit tests for run_mcp_server method + + +@pytest.fixture +def mock_settings() -> MCPServerSettings: + """Create mock MCP server settings for testing.""" + return MCPServerSettings( + bind_host="127.0.0.1", + port=8080, + stateless=False, + allow_origins=["*"], + log_level="INFO", + ) + + +@pytest.fixture +def mock_stdio_params() -> StdioServerParameters: + """Create mock stdio server parameters for testing.""" + return StdioServerParameters( + command="echo", + args=["hello"], + env={"TEST_VAR": "test_value"}, + cwd="/tmp", # noqa: S108 + ) + + +class AsyncContextManagerMock: # noqa: D101 + def __init__(self, mock) -> None: # noqa: ANN001, D107 + self.mock = mock + + async def __aenter__(self): # noqa: ANN204, D105 + return self.mock + + async def __aexit__(self, *args): # noqa: ANN002, ANN204, D105 + pass + + +def setup_async_context_mocks() -> tuple[ + AsyncContextManagerMock, + AsyncContextManagerMock, + AsyncMock, + MagicMock, + list[MagicMock], +]: + """Helper function to set up async context manager mocks.""" + # Setup stdio client mock + mock_streams = (AsyncMock(), AsyncMock()) + + # Setup client session mock + mock_session = AsyncMock() + + # Setup HTTP manager mock + mock_http_manager = MagicMock() + mock_http_manager.run.return_value = AsyncContextManagerMock(None) + mock_routes = [MagicMock()] + + return ( + AsyncContextManagerMock(mock_streams), + AsyncContextManagerMock(mock_session), + mock_session, + mock_http_manager, + mock_routes, + ) + + +async def test_run_mcp_server_no_servers_configured(mock_settings: MCPServerSettings) -> None: + """Test run_mcp_server when no servers are configured.""" + with patch("mcp_proxy.mcp_server.logger") as mock_logger: + await run_mcp_server(mock_settings, None, {}) + mock_logger.error.assert_called_once_with("No servers configured to run.") + + +async def test_run_mcp_server_with_default_server( + mock_settings: MCPServerSettings, + mock_stdio_params: StdioServerParameters, +) -> None: + """Test run_mcp_server with a default server configuration.""" + with ( + patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client, + patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session, + patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy, + patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes, + patch("uvicorn.Server") as mock_uvicorn_server, + patch("mcp_proxy.mcp_server.logger") as mock_logger, + ): + # Setup mocks + mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = ( + setup_async_context_mocks() + ) + mock_stdio_client.return_value = mock_stdio_context + mock_client_session.return_value = mock_session_context + + mock_proxy = AsyncMock() + mock_create_proxy.return_value = mock_proxy + mock_create_routes.return_value = (mock_routes, mock_http_manager) + + mock_server_instance = AsyncMock() + mock_uvicorn_server.return_value = mock_server_instance + + # Run the function + await run_mcp_server(mock_settings, mock_stdio_params, {}) + + # Verify calls + mock_stdio_client.assert_called_once_with(mock_stdio_params) + mock_create_proxy.assert_called_once_with(mock_session) + mock_create_routes.assert_called_once_with( + mock_proxy, + stateless_instance=mock_settings.stateless, + ) + mock_logger.info.assert_any_call( + "Setting up default server: %s %s", + mock_stdio_params.command, + " ".join(mock_stdio_params.args), + ) + mock_server_instance.serve.assert_called_once() + + +async def test_run_mcp_server_with_named_servers( + mock_settings: MCPServerSettings, + mock_stdio_params: StdioServerParameters, +) -> None: + """Test run_mcp_server with named servers configuration.""" + named_servers = { + "server1": mock_stdio_params, + "server2": StdioServerParameters( + command="python", + args=["-m", "mcp_server"], + env={"PYTHON_PATH": "/usr/bin/python"}, + cwd="/home/user", + ), + } + + with ( + patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client, + patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session, + patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy, + patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes, + patch("uvicorn.Server") as mock_uvicorn_server, + patch("mcp_proxy.mcp_server.logger") as mock_logger, + ): + # Setup mocks + mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = ( + setup_async_context_mocks() + ) + mock_stdio_client.return_value = mock_stdio_context + mock_client_session.return_value = mock_session_context + + mock_proxy = AsyncMock() + mock_create_proxy.return_value = mock_proxy + mock_create_routes.return_value = (mock_routes, mock_http_manager) + + mock_server_instance = AsyncMock() + mock_uvicorn_server.return_value = mock_server_instance + + # Run the function + await run_mcp_server(mock_settings, None, named_servers) + + # Verify calls + assert mock_stdio_client.call_count == 2 + assert mock_create_proxy.call_count == 2 + assert mock_create_routes.call_count == 2 + + # Check that named servers were logged + mock_logger.info.assert_any_call( + "Setting up named server '%s': %s %s", + "server1", + mock_stdio_params.command, + " ".join(mock_stdio_params.args), + ) + mock_logger.info.assert_any_call( + "Setting up named server '%s': %s %s", + "server2", + "python", + "-m mcp_server", + ) + + mock_server_instance.serve.assert_called_once() + + +async def test_run_mcp_server_with_cors_middleware( + mock_stdio_params: StdioServerParameters, +) -> None: + """Test run_mcp_server adds CORS middleware when allow_origins is set.""" + settings_with_cors = MCPServerSettings( + bind_host="0.0.0.0", # noqa: S104 + port=9090, + allow_origins=["http://localhost:3000", "https://example.com"], + ) + + with ( + patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client, + patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session, + patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy, + patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes, + patch("mcp_proxy.mcp_server.Starlette") as mock_starlette, + patch("uvicorn.Server") as mock_uvicorn_server, + ): + # Setup mocks + mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = ( + setup_async_context_mocks() + ) + mock_stdio_client.return_value = mock_stdio_context + mock_client_session.return_value = mock_session_context + + mock_proxy = AsyncMock() + mock_create_proxy.return_value = mock_proxy + mock_create_routes.return_value = (mock_routes, mock_http_manager) + + mock_server_instance = AsyncMock() + mock_uvicorn_server.return_value = mock_server_instance + + # Run the function + await run_mcp_server(settings_with_cors, mock_stdio_params, {}) + + # Verify Starlette was called with middleware + mock_starlette.assert_called_once() + call_args = mock_starlette.call_args + middleware = call_args.kwargs["middleware"] + + assert len(middleware) == 1 + assert middleware[0].cls == CORSMiddleware + + +async def test_run_mcp_server_debug_mode( + mock_stdio_params: StdioServerParameters, +) -> None: + """Test run_mcp_server with debug mode enabled.""" + debug_settings = MCPServerSettings( + bind_host="127.0.0.1", + port=8080, + log_level="DEBUG", + ) + + with ( + patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client, + patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session, + patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy, + patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes, + patch("mcp_proxy.mcp_server.Starlette") as mock_starlette, + patch("uvicorn.Server") as mock_uvicorn_server, + ): + # Setup mocks + mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = ( + setup_async_context_mocks() + ) + mock_stdio_client.return_value = mock_stdio_context + mock_client_session.return_value = mock_session_context + + mock_proxy = AsyncMock() + mock_create_proxy.return_value = mock_proxy + mock_create_routes.return_value = (mock_routes, mock_http_manager) + + mock_server_instance = AsyncMock() + mock_uvicorn_server.return_value = mock_server_instance + + # Run the function + await run_mcp_server(debug_settings, mock_stdio_params, {}) + + # Verify Starlette was called with debug=True + mock_starlette.assert_called_once() + call_args = mock_starlette.call_args + assert call_args.kwargs["debug"] is True + + +async def test_run_mcp_server_stateless_mode( + mock_stdio_params: StdioServerParameters, +) -> None: + """Test run_mcp_server with stateless mode enabled.""" + stateless_settings = MCPServerSettings( + bind_host="127.0.0.1", + port=8080, + stateless=True, + ) + + with ( + patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client, + patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session, + patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy, + patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes, + patch("uvicorn.Server") as mock_uvicorn_server, + ): + # Setup mocks + mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = ( + setup_async_context_mocks() + ) + mock_stdio_client.return_value = mock_stdio_context + mock_client_session.return_value = mock_session_context + + mock_proxy = AsyncMock() + mock_create_proxy.return_value = mock_proxy + mock_create_routes.return_value = (mock_routes, mock_http_manager) + + mock_server_instance = AsyncMock() + mock_uvicorn_server.return_value = mock_server_instance + + # Run the function + await run_mcp_server(stateless_settings, mock_stdio_params, {}) + + # Verify create_single_instance_routes was called with stateless_instance=True + mock_create_routes.assert_called_once_with( + mock_proxy, + stateless_instance=True, + ) + + +async def test_run_mcp_server_uvicorn_config( + mock_settings: MCPServerSettings, + mock_stdio_params: StdioServerParameters, +) -> None: + """Test run_mcp_server creates correct uvicorn configuration.""" + with ( + patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client, + patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session, + patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy, + patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes, + patch("uvicorn.Config") as mock_uvicorn_config, + patch("uvicorn.Server") as mock_uvicorn_server, + ): + # Setup mocks + mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = ( + setup_async_context_mocks() + ) + mock_stdio_client.return_value = mock_stdio_context + mock_client_session.return_value = mock_session_context + + mock_proxy = AsyncMock() + mock_create_proxy.return_value = mock_proxy + mock_create_routes.return_value = (mock_routes, mock_http_manager) + + mock_config = MagicMock() + mock_uvicorn_config.return_value = mock_config + + mock_server_instance = AsyncMock() + mock_uvicorn_server.return_value = mock_server_instance + + # Run the function + await run_mcp_server(mock_settings, mock_stdio_params, {}) + + # Verify uvicorn.Config was called with correct parameters + mock_uvicorn_config.assert_called_once() + call_args = mock_uvicorn_config.call_args + + assert call_args.kwargs["host"] == mock_settings.bind_host + assert call_args.kwargs["port"] == mock_settings.port + assert call_args.kwargs["log_level"] == mock_settings.log_level.lower() + + +async def test_run_mcp_server_global_status_updates( + mock_settings: MCPServerSettings, + mock_stdio_params: StdioServerParameters, +) -> None: + """Test run_mcp_server updates global status correctly.""" + from mcp_proxy.mcp_server import _global_status + + # Clear global status before test + _global_status["server_instances"].clear() + + named_servers = {"test_server": mock_stdio_params} + + with ( + patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client, + patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session, + patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy, + patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes, + patch("uvicorn.Server") as mock_uvicorn_server, + ): + # Setup mocks + mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = ( + setup_async_context_mocks() + ) + mock_stdio_client.return_value = mock_stdio_context + mock_client_session.return_value = mock_session_context + + mock_proxy = AsyncMock() + mock_create_proxy.return_value = mock_proxy + mock_create_routes.return_value = (mock_routes, mock_http_manager) + + mock_server_instance = AsyncMock() + mock_uvicorn_server.return_value = mock_server_instance + + # Run the function + await run_mcp_server(mock_settings, mock_stdio_params, named_servers) + + # Verify global status was updated + assert "default" in _global_status["server_instances"] + assert "test_server" in _global_status["server_instances"] + assert _global_status["server_instances"]["default"] == "configured" + assert _global_status["server_instances"]["test_server"] == "configured" + + +async def test_run_mcp_server_sse_url_logging( + mock_settings: MCPServerSettings, + mock_stdio_params: StdioServerParameters, +) -> None: + """Test run_mcp_server logs correct SSE URLs.""" + named_servers = {"test_server": mock_stdio_params} + + with ( + patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client, + patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session, + patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy, + patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes, + patch("uvicorn.Server") as mock_uvicorn_server, + patch("mcp_proxy.mcp_server.logger") as mock_logger, + ): + # Setup mocks + mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = ( + setup_async_context_mocks() + ) + mock_stdio_client.return_value = mock_stdio_context + mock_client_session.return_value = mock_session_context + + mock_proxy = AsyncMock() + mock_create_proxy.return_value = mock_proxy + mock_create_routes.return_value = (mock_routes, mock_http_manager) + + mock_server_instance = AsyncMock() + mock_uvicorn_server.return_value = mock_server_instance + + # Run the function + await run_mcp_server(mock_settings, mock_stdio_params, named_servers) + + # Verify SSE URLs were logged + expected_default_url = f"http://{mock_settings.bind_host}:{mock_settings.port}/sse" + expected_named_url = ( + f"http://{mock_settings.bind_host}:{mock_settings.port}/servers/test_server/sse" + ) + + mock_logger.info.assert_any_call("Serving MCP Servers via SSE:") + mock_logger.info.assert_any_call(" - %s", expected_default_url) + mock_logger.info.assert_any_call(" - %s", expected_named_url) + + +async def test_run_mcp_server_exception_handling( + mock_settings: MCPServerSettings, + mock_stdio_params: StdioServerParameters, +) -> None: + """Test run_mcp_server handles exceptions properly.""" + with ( + patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client, + patch("mcp_proxy.mcp_server.ClientSession"), + ): + # Setup mocks to raise an exception + mock_stdio_client.side_effect = Exception("Connection failed") + + # Should not raise, function should handle exceptions gracefully + try: + await run_mcp_server(mock_settings, mock_stdio_params, {}) + except Exception as e: # noqa: BLE001 + # If an exception is raised, it should be the expected one + assert "Connection failed" in str(e) # noqa: PT017 + + +async def test_run_mcp_server_both_default_and_named_servers( + mock_settings: MCPServerSettings, + mock_stdio_params: StdioServerParameters, +) -> None: + """Test run_mcp_server with both default and named servers.""" + named_servers = {"named_server": mock_stdio_params} + + with ( + patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client, + patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session, + patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy, + patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes, + patch("uvicorn.Server") as mock_uvicorn_server, + patch("mcp_proxy.mcp_server.logger") as mock_logger, + ): + # Setup mocks + mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = ( + setup_async_context_mocks() + ) + mock_stdio_client.return_value = mock_stdio_context + mock_client_session.return_value = mock_session_context + + mock_proxy = AsyncMock() + mock_create_proxy.return_value = mock_proxy + mock_create_routes.return_value = (mock_routes, mock_http_manager) + + mock_server_instance = AsyncMock() + mock_uvicorn_server.return_value = mock_server_instance + + # Run the function with both default and named servers + await run_mcp_server(mock_settings, mock_stdio_params, named_servers) + + # Verify both servers were set up + assert mock_stdio_client.call_count == 2 # One for default, one for named + assert mock_create_proxy.call_count == 2 + assert mock_create_routes.call_count == 2 + + # Verify logging for both servers + mock_logger.info.assert_any_call( + "Setting up default server: %s %s", + mock_stdio_params.command, + " ".join(mock_stdio_params.args), + ) + mock_logger.info.assert_any_call( + "Setting up named server '%s': %s %s", + "named_server", + mock_stdio_params.command, + " ".join(mock_stdio_params.args), + ) + + mock_server_instance.serve.assert_called_once()