feat: support proxying multiple MCP stdio servers to SSE (#65)

This PR adds support for running multiple MCP (STDIO) servers and
serving them up via a single mcp-proxy instance, each with a named path
in the URL.

Example usage:

```
mcp-proxy --port 8080 --named-server fetch 'uvx mcp-server-fetch' --named-server github 'npx -y @modelcontextprotocol/server-github'
```

Would serve:

- `http://localhost:8080/servers/fetch`
- `http://localhost:8080/servers/github`

I've also added the ability to provide a standard mcp client config file
with accompanying tests.

Please feel free to make any changes as you see fit, or reject the PR if
it does not align with your goals.

Thank you,

---------

Co-authored-by: Magnus Tidemann <magnustidemann@gmail.com>
Co-authored-by: Sergey Parfenyuk <sergey.parfenyuk@gmail.com>
This commit is contained in:
Sam
2025-05-27 19:48:25 +10:00
committed by GitHub
parent f31cd3e73c
commit b25056fadd
8 changed files with 1452 additions and 182 deletions

145
README.md
View File

@@ -14,6 +14,7 @@
- [2. SSE to stdio](#2-sse-to-stdio)
- [2.1 Configuration](#21-configuration)
- [2.2 Example usage](#22-example-usage)
- [Named Servers](#named-servers)
- [Installation](#installation)
- [Installing via Smithery](#installing-via-smithery)
- [Installing via PyPI](#installing-via-pypi)
@@ -23,6 +24,7 @@
- [Extending the container image](#extending-the-container-image)
- [Docker Compose Setup](#docker-compose-setup)
- [Command line arguments](#command-line-arguments)
- [Example config file](#example-config-file)
- [Testing](#testing)
## About
@@ -115,18 +117,20 @@ separator.
Arguments
| Name | Required | Description | Example |
| ------------------------- | -------------------------- | --------------------------------------------------------------------------------------------- | --------------------- |
| `command_or_url` | Yes | The command to spawn the MCP stdio server | uvx mcp-server-fetch |
| `--port` | No, random available | The MCP server port to listen on | 8080 |
| `--host` | No, `127.0.0.1` by default | The host IP address that the MCP server will listen on | 0.0.0.0 |
| `--env` | No | Additional environment variables to pass to the MCP stdio server. Can be used multiple times. | FOO BAR |
| `--cwd` | No | The working directory to pass to the MCP stdio server process. | /tmp |
| `--pass-environment` | No | Pass through all environment variables when spawning the server | --no-pass-environment |
| `--allow-origin` | No | Allowed origins for the SSE server. Can be used multiple times. Default is no CORS allowed. | --allow-cors "\*" |
| `--stateless` | No | Enable stateless mode for streamable http transports. Default is False | --no-stateless |
| `--sse-port` (deprecated) | No, random available | The SSE server port to listen on | 8080 |
| `--sse-host` (deprecated) | No, `127.0.0.1` by default | The host IP address that the SSE server will listen on | 0.0.0.0 |
| Name | Required | Description | Example |
|--------------------------------------|----------------------------|---------------------------------------------------------------------------------------------|---------------------------------------------|
| `command_or_url` | Yes | The command to spawn the MCP stdio server | uvx mcp-server-fetch |
| `--port` | No, random available | The MCP server port to listen on | 8080 |
| `--host` | No, `127.0.0.1` by default | The host IP address that the MCP server will listen on | 0.0.0.0 |
| `--env` | No | Additional environment variables to pass to the MCP stdio server. Can be used multiple times. | FOO BAR |
| `--cwd` | No | The working directory to pass to the MCP stdio server process. | /tmp |
| `--pass-environment` | No | Pass through all environment variables when spawning the server | --no-pass-environment |
| `--allow-origin` | No | Allowed origins for the SSE server. Can be used multiple times. Default is no CORS allowed. | --allow-origin "\*" |
| `--stateless` | No | Enable stateless mode for streamable http transports. Default is False | --no-stateless |
| `--named-server NAME COMMAND_STRING` | No | Defines a named stdio server. | --named-server fetch 'uvx mcp-server-fetch' |
| `--named-server-config FILE_PATH` | No | Path to a JSON file defining named stdio servers. | --named-server-config /path/to/servers.json |
| `--sse-port` (deprecated) | No, random available | The SSE server port to listen on | 8080 |
| `--sse-host` (deprecated) | No, `127.0.0.1` by default | The host IP address that the SSE server will listen on | 0.0.0.0 |
### 2.2 Example usage
@@ -148,10 +152,61 @@ mcp-proxy --host=0.0.0.0 --port=8080 uvx mcp-server-fetch
# Note that the `--` separator is used to separate the `mcp-proxy` arguments from the `mcp-server-fetch` arguments
# (deprecated) mcp-proxy --sse-port=8080 -- uvx mcp-server-fetch --user-agent=YourUserAgent
mcp-proxy --port=8080 -- uvx mcp-server-fetch --user-agent=YourUserAgent
# Start multiple named MCP servers behind the proxy
mcp-proxy --port=8080 --named-server fetch 'uvx mcp-server-fetch' --named-server fetch2 'uvx mcp-server-fetch'
# Start multiple named MCP servers using a configuration file
mcp-proxy --port=8080 --named-server-config ./servers.json
```
This will start an MCP server that can be connected to at `http://127.0.0.1:8080/sse` via SSE, or
`http://127.0.0.1:8080/mcp/` via StreamableHttp
## Named Servers
- `NAME` is used in the URL path `/servers/NAME/`.
- `COMMAND_STRING` is the command to start the server (e.g., 'uvx mcp-server-fetch').
- Can be used multiple times.
- This argument is ignored if `--named-server-config` is used.
- `FILE_PATH` - If provided, this is the exclusive source for named servers, and `--named-server` CLI arguments are ignored.
If a default server is specified (the `command_or_url` argument without `--named-server` or `--named-server-config`), it will be accessible at the root paths (e.g., `http://127.0.0.1:8080/sse`).
Named servers (whether defined by `--named-server` or `--named-server-config`) will be accessible under `/servers/<server-name>/` (e.g., `http://127.0.0.1:8080/servers/fetch1/sse`).
The `/status` endpoint provides global status.
**JSON Configuration File Format for `--named-server-config`:**
The JSON file should follow this structure:
```json
{
"mcpServers": {
"fetch": {
"disabled": false,
"timeout": 60,
"command": "uvx",
"args": [
"mcp-server-fetch"
],
"transportType": "stdio"
},
"github": {
"timeout": 60,
"command": "npx",
"args": [
"-y",
"@modelcontextprotocol/server-github"
],
"transportType": "stdio"
}
}
}
```
- `mcpServers`: A dictionary where each key is the server name (used in the URL path, e.g., `/servers/fetch/`) and the value is an object defining the server.
- `command`: (Required) The command to execute for the stdio server.
- `args`: (Optional) A list of arguments for the command. Defaults to an empty list.
- `enabled`: (Optional) If `false`, this server definition will be skipped. Defaults to `true`.
- `timeout` and `transportType`: These fields are present in standard MCP client configurations but are currently **ignored** by `mcp-proxy` when loading named servers. The transport type is implicitly "stdio".
## Installation
@@ -258,14 +313,20 @@ services:
```bash
usage: mcp-proxy [-h] [-H KEY VALUE] [--transport {sse,streamablehttp}] [-e KEY VALUE] [--cwd CWD] [--pass-environment | --no-pass-environment]
[--debug | --no-debug] [--port PORT] [--host HOST] [--stateless | --no-stateless] [--sse-port SSE_PORT] [--sse-host SSE_HOST]
[--allow-origin ALLOW_ORIGIN [ALLOW_ORIGIN ...]]
[--debug | --no-debug] [--named-server NAME COMMAND_STRING] [--named-server-config FILE_PATH] [--port PORT] [--host HOST]
[--stateless | --no-stateless] [--sse-port SSE_PORT] [--sse-host SSE_HOST] [--allow-origin ALLOW_ORIGIN [ALLOW_ORIGIN ...]]
[command_or_url] [args ...]
Start the MCP proxy in one of two possible modes: as a client or a server.
Start the MCP proxy.
It can run as an SSE client (connecting to a remote SSE server and exposing stdio).
Or, it can run as an SSE server (connecting to local stdio command(s) and exposing them over SSE).
When running as an SSE server, it can proxy a single default stdio command or multiple named stdio commands (defined via CLI or a config file).
positional arguments:
command_or_url Command or URL to connect to. When a URL, will run an SSE/StreamableHTTP client, otherwise will run the given command and connect as a stdio client. See corresponding options for more details.
command_or_url Command or URL.
If URL (http/https): Runs in SSE/StreamableHTTP client mode.
If command string: Runs in SSE server mode, this is the default stdio server.
If --named-server or --named-server-config is used, this can be omitted if no default server is desired.
options:
-h, --help show this help message and exit
@@ -277,12 +338,16 @@ SSE/StreamableHTTP client options:
The transport to use for the client. Default is SSE.
stdio client options:
args Any extra arguments to the command to spawn the server
-e, --env KEY VALUE Environment variables used when spawning the server. Can be used multiple times.
--cwd CWD The working directory to use when spawning the process.
args Any extra arguments to the command to spawn the default server. Ignored if only named servers are defined.
-e, --env KEY VALUE Environment variables used when spawning the default server. Can be used multiple times. For named servers, environment is inherited or passed via --pass-environment.
--cwd CWD The working directory to use when spawning the default server process. Named servers inherit the proxy's CWD.
--pass-environment, --no-pass-environment
Pass through all environment variables when spawning the server.
Pass through all environment variables when spawning all server processes.
--debug, --no-debug Enable debug mode with detailed logging output.
--named-server NAME COMMAND_STRING
Define a named stdio server. NAME is for the URL path /servers/NAME/. COMMAND_STRING is a single string with the command and its arguments (e.g., 'uvx mcp-server-fetch --timeout 10'). These servers inherit the proxy's CWD and environment from --pass-environment. Can be specified multiple times. Ignored if --named-server-config is used.
--named-server-config FILE_PATH
Path to a JSON configuration file for named stdio servers. If provided, this will be the exclusive source for named server definitions, and any --named-server CLI arguments will be ignored.
SSE server options:
--port PORT Port to expose an SSE server on. Default is a random port
@@ -298,9 +363,39 @@ Examples:
mcp-proxy http://localhost:8080/sse
mcp-proxy --transport streamablehttp http://localhost:8080/mcp
mcp-proxy --headers Authorization 'Bearer YOUR_TOKEN' http://localhost:8080/sse
mcp-proxy --port 8080 -- your-command --arg1 value1 --arg2 value2
mcp-proxy your-command --port 8080 -e KEY VALUE -e ANOTHER_KEY ANOTHER_VALUE
mcp-proxy your-command --port 8080 --allow-origin='*'
mcp-proxy --port 8080 -- my-default-command --arg1 value1
mcp-proxy --port 8080 --named-server fetch1 'uvx mcp-server-fetch' --named-server tool2 'my-custom-tool --verbose'
mcp-proxy --port 8080 --named-server-config /path/to/servers.json
mcp-proxy --port 8080 --named-server-config /path/to/servers.json -- my-default-command --arg1
mcp-proxy --port 8080 -e KEY VALUE -e ANOTHER_KEY ANOTHER_VALUE -- my-default-command
mcp-proxy --port 8080 --allow-origin='*' -- my-default-command
```
### Example config file
```json
{
"mcpServers": {
"fetch": {
"enabled": true,
"timeout": 60,
"command": "uvx",
"args": [
"mcp-server-fetch"
],
"transportType": "stdio"
},
"github": {
"timeout": 60,
"command": "npx",
"args": [
"-y",
"@modelcontextprotocol/server-github"
],
"transportType": "stdio"
}
}
}
```
## Testing

23
config_example.json Normal file
View File

@@ -0,0 +1,23 @@
{
"mcpServers": {
"fetch": {
"enabled": true,
"timeout": 60,
"command": "uvx",
"args": [
"mcp-server-fetch"
],
"transportType": "stdio"
},
"github": {
"enabled": false,
"timeout": 60,
"command": "npx",
"args": [
"-y",
"@modelcontextprotocol/server-github"
],
"transportType": "stdio"
}
}
}

View File

@@ -8,13 +8,16 @@ Two ways to run the application:
import argparse
import asyncio
import json
import logging
import os
import shlex
import sys
import typing as t
from mcp.client.stdio import StdioServerParameters
from .config_loader import load_named_server_configs_from_file
from .mcp_server import MCPServerSettings, run_mcp_server
from .sse_client import run_sse_client
from .streamablehttp_client import run_streamablehttp_client
@@ -26,8 +29,8 @@ SSE_URL: t.Final[str | None] = os.getenv(
)
def main() -> None:
"""Start the client using asyncio."""
def _setup_argument_parser() -> argparse.ArgumentParser:
"""Set up and return the argument parser for the MCP proxy."""
parser = argparse.ArgumentParser(
description=("Start the MCP proxy in one of two possible modes: as a client or a server."),
epilog=(
@@ -36,19 +39,29 @@ def main() -> None:
" mcp-proxy --transport streamablehttp http://localhost:8080/mcp\n"
" mcp-proxy --headers Authorization 'Bearer YOUR_TOKEN' http://localhost:8080/sse\n"
" mcp-proxy --port 8080 -- your-command --arg1 value1 --arg2 value2\n"
" mcp-proxy --named-server fetch 'uvx mcp-server-fetch' --port 8080\n"
" mcp-proxy your-command --port 8080 -e KEY VALUE -e ANOTHER_KEY ANOTHER_VALUE\n"
" mcp-proxy your-command --port 8080 --allow-origin='*'\n"
),
formatter_class=argparse.RawTextHelpFormatter,
)
_add_arguments_to_parser(parser)
return parser
def _add_arguments_to_parser(parser: argparse.ArgumentParser) -> None:
"""Add all arguments to the argument parser."""
parser.add_argument(
"command_or_url",
help=(
"Command or URL to connect to. When a URL, will run an SSE/StreamableHTTP client, "
"otherwise will run the given command and connect as a stdio client. "
"Command or URL to connect to. When a URL, will run an SSE/StreamableHTTP client. "
"Otherwise, if --named-server is not used, this will be the command "
"for the default stdio client. If --named-server is used, this argument "
"is ignored for stdio mode unless no default server is desired. "
"See corresponding options for more details."
),
nargs="?", # Required below to allow for coming form env var
nargs="?",
default=SSE_URL,
)
@@ -73,7 +86,10 @@ def main() -> None:
stdio_client_options.add_argument(
"args",
nargs="*",
help="Any extra arguments to the command to spawn the server",
help=(
"Any extra arguments to the command to spawn the default server. "
"Ignored if only named servers are defined."
),
)
stdio_client_options.add_argument(
"-e",
@@ -81,18 +97,25 @@ def main() -> None:
nargs=2,
action="append",
metavar=("KEY", "VALUE"),
help="Environment variables used when spawning the server. Can be used multiple times.",
help=(
"Environment variables used when spawning the default server. Can be "
"used multiple times. For named servers, environment is inherited or "
"passed via --pass-environment."
),
default=[],
)
stdio_client_options.add_argument(
"--cwd",
default=None,
help="The working directory to use when spawning the process.",
help=(
"The working directory to use when spawning the default server process. "
"Named servers inherit the proxy's CWD."
),
)
stdio_client_options.add_argument(
"--pass-environment",
action=argparse.BooleanOptionalAction,
help="Pass through all environment variables when spawning the server.",
help="Pass through all environment variables when spawning all server processes.",
default=False,
)
stdio_client_options.add_argument(
@@ -101,6 +124,31 @@ def main() -> None:
help="Enable debug mode with detailed logging output.",
default=False,
)
stdio_client_options.add_argument(
"--named-server",
action="append",
nargs=2,
metavar=("NAME", "COMMAND_STRING"),
help=(
"Define a named stdio server. NAME is for the URL path /servers/NAME/. "
"COMMAND_STRING is a single string with the command and its arguments "
"(e.g., 'uvx mcp-server-fetch --timeout 10'). "
"These servers inherit the proxy's CWD and environment from --pass-environment."
),
default=[],
dest="named_server_definitions",
)
stdio_client_options.add_argument(
"--named-server-config",
type=str,
default=None,
metavar="FILE_PATH",
help=(
"Path to a JSON configuration file for named stdio servers. "
"If provided, this will be the exclusive source for named server definitions, "
"and any --named-server CLI arguments will be ignored."
),
)
mcp_server_group = parser.add_argument_group("SSE server options")
mcp_server_group.add_argument(
@@ -135,65 +183,222 @@ def main() -> None:
"--allow-origin",
nargs="+",
default=[],
help="Allowed origins for the SSE server. "
"Can be used multiple times. Default is no CORS allowed.",
help=(
"Allowed origins for the SSE server. Can be used multiple times. "
"Default is no CORS allowed."
),
)
args = parser.parse_args()
if not args.command_or_url:
parser.print_help()
sys.exit(1)
def _setup_logging(*, debug: bool) -> logging.Logger:
"""Set up logging configuration and return the logger."""
logging.basicConfig(
level=logging.DEBUG if args.debug else logging.INFO,
level=logging.DEBUG if debug else logging.INFO,
format="[%(levelname)1.1s %(asctime)s.%(msecs).03d %(name)s] %(message)s",
)
logger = logging.getLogger(__name__)
return logging.getLogger(__name__)
if (
SSE_URL
or args.command_or_url.startswith("http://")
or args.command_or_url.startswith("https://")
def _handle_sse_client_mode(
args_parsed: argparse.Namespace,
logger: logging.Logger,
) -> None:
"""Handle SSE/StreamableHTTP client mode operation."""
if args_parsed.named_server_definitions:
logger.warning(
"--named-server arguments are ignored when command_or_url is an HTTP/HTTPS URL "
"(SSE/StreamableHTTP client mode).",
)
# Start a client connected to the SSE server, and expose as a stdio server
logger.debug("Starting SSE/StreamableHTTP client and stdio server")
headers = dict(args_parsed.headers)
if api_access_token := os.getenv("API_ACCESS_TOKEN", None):
headers["Authorization"] = f"Bearer {api_access_token}"
if args_parsed.transport == "streamablehttp":
asyncio.run(run_streamablehttp_client(args_parsed.command_or_url, headers=headers))
else:
asyncio.run(run_sse_client(args_parsed.command_or_url, headers=headers))
def _configure_default_server(
args_parsed: argparse.Namespace,
base_env: dict[str, str],
logger: logging.Logger,
) -> StdioServerParameters | None:
"""Configure the default server if applicable."""
if not (
args_parsed.command_or_url
and not args_parsed.command_or_url.startswith(("http://", "https://"))
):
# Start a client connected to the SSE server, and expose as a stdio server
logger.debug("Starting SSE client and stdio server")
headers = dict(args.headers)
if api_access_token := os.getenv("API_ACCESS_TOKEN", None):
headers["Authorization"] = f"Bearer {api_access_token}"
if args.transport == "streamablehttp":
asyncio.run(run_streamablehttp_client(args.command_or_url, headers=headers))
else:
asyncio.run(run_sse_client(args.command_or_url, headers=headers))
return None
default_server_env = base_env.copy()
default_server_env.update(dict(args_parsed.env)) # Specific env vars for default server
default_stdio_params = StdioServerParameters(
command=args_parsed.command_or_url,
args=args_parsed.args,
env=default_server_env,
cwd=args_parsed.cwd if args_parsed.cwd else None,
)
logger.info(
"Configured default server: %s %s",
args_parsed.command_or_url,
" ".join(args_parsed.args),
)
return default_stdio_params
def _load_named_servers_from_config(
config_path: str,
base_env: dict[str, str],
logger: logging.Logger,
) -> dict[str, StdioServerParameters]:
"""Load named server configurations from a file."""
try:
return load_named_server_configs_from_file(config_path, base_env)
except (FileNotFoundError, json.JSONDecodeError, ValueError):
# Specific errors are already logged by the loader function
# We log a generic message here before exiting
logger.exception(
"Failed to load server configurations from %s. Exiting.",
config_path,
)
sys.exit(1)
except Exception: # Catch any other unexpected errors from loader
logger.exception(
"An unexpected error occurred while loading server configurations from %s. Exiting.",
config_path,
)
sys.exit(1)
def _configure_named_servers_from_cli(
named_server_definitions: list[tuple[str, str]],
base_env: dict[str, str],
logger: logging.Logger,
) -> dict[str, StdioServerParameters]:
"""Configure named servers from CLI arguments."""
named_stdio_params: dict[str, StdioServerParameters] = {}
for name, command_string in named_server_definitions:
try:
command_parts = shlex.split(command_string)
if not command_parts: # Handle empty command_string
logger.error("Empty COMMAND_STRING for named server '%s'. Skipping.", name)
continue
command = command_parts[0]
command_args = command_parts[1:]
# Named servers inherit base_env (which includes passed-through env)
# and use the proxy's CWD.
named_stdio_params[name] = StdioServerParameters(
command=command,
args=command_args,
env=base_env.copy(), # Each named server gets a copy of the base env
cwd=None, # Named servers run in the proxy's CWD
)
logger.info("Configured named server '%s': %s", name, command_string)
except IndexError: # Should be caught by the check for empty command_parts
logger.exception(
"Invalid COMMAND_STRING for named server '%s': '%s'. Must include a command.",
name,
command_string,
)
sys.exit(1)
except Exception:
logger.exception("Error parsing COMMAND_STRING for named server '%s'", name)
sys.exit(1)
return named_stdio_params
def _create_mcp_settings(args_parsed: argparse.Namespace) -> MCPServerSettings:
"""Create MCP server settings from parsed arguments."""
return MCPServerSettings(
bind_host=args_parsed.host if args_parsed.host is not None else args_parsed.sse_host,
port=args_parsed.port if args_parsed.port is not None else args_parsed.sse_port,
stateless=args_parsed.stateless,
allow_origins=args_parsed.allow_origin if len(args_parsed.allow_origin) > 0 else None,
log_level="DEBUG" if args_parsed.debug else "INFO",
)
def main() -> None:
"""Start the client using asyncio."""
parser = _setup_argument_parser()
args_parsed = parser.parse_args()
logger = _setup_logging(debug=args_parsed.debug)
# Validate required arguments
if (
not args_parsed.command_or_url
and not args_parsed.named_server_definitions
and not args_parsed.named_server_config
):
parser.print_help()
logger.error(
"Either a command_or_url for a default server or at least one --named-server "
"(or --named-server-config) must be provided for stdio mode.",
)
sys.exit(1)
# Handle SSE client mode if URL is provided
if args_parsed.command_or_url and args_parsed.command_or_url.startswith(
("http://", "https://"),
):
_handle_sse_client_mode(args_parsed, logger)
return
# Start a client connected to the given command, and expose as an SSE server
logger.debug("Starting stdio client and SSE server")
# Start stdio client(s) and expose as an SSE server
logger.debug("Configuring stdio client(s) and SSE server")
# The environment variables passed to the server process
env: dict[str, str] = {}
# Pass through current environment variables if configured
if args.pass_environment:
env.update(os.environ)
# Pass in and override any environment variables with those passed on the command line
env.update(dict(args.env))
# Base environment for all spawned processes
base_env: dict[str, str] = {}
if args_parsed.pass_environment:
base_env.update(os.environ)
stdio_params = StdioServerParameters(
command=args.command_or_url,
args=args.args,
env=env,
cwd=args.cwd if args.cwd else None,
# Configure default server
default_stdio_params = _configure_default_server(args_parsed, base_env, logger)
# Configure named servers
named_stdio_params: dict[str, StdioServerParameters] = {}
if args_parsed.named_server_config:
if args_parsed.named_server_definitions:
logger.warning(
"--named-server CLI arguments are ignored when --named-server-config is provided.",
)
named_stdio_params = _load_named_servers_from_config(
args_parsed.named_server_config,
base_env,
logger,
)
elif args_parsed.named_server_definitions:
named_stdio_params = _configure_named_servers_from_cli(
args_parsed.named_server_definitions,
base_env,
logger,
)
# Ensure at least one server is configured
if not default_stdio_params and not named_stdio_params:
parser.print_help()
logger.error(
"No stdio servers configured. Provide a default command or use --named-server.",
)
sys.exit(1)
# Create MCP server settings and run the server
mcp_settings = _create_mcp_settings(args_parsed)
asyncio.run(
run_mcp_server(
default_server_params=default_stdio_params,
named_server_params=named_stdio_params,
mcp_settings=mcp_settings,
),
)
mcp_settings = MCPServerSettings(
bind_host=args.host if args.host is not None else args.sse_host,
port=args.port if args.port is not None else args.sse_port,
stateless=args.stateless,
allow_origins=args.allow_origin if len(args.allow_origin) > 0 else None,
log_level="DEBUG" if args.debug else "INFO",
)
asyncio.run(run_mcp_server(stdio_params, mcp_settings))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,99 @@
"""Configuration loader for MCP proxy.
This module provides functionality to load named server configurations from JSON files.
"""
import json
import logging
from pathlib import Path
from mcp.client.stdio import StdioServerParameters
logger = logging.getLogger(__name__)
def load_named_server_configs_from_file(
config_file_path: str,
base_env: dict[str, str],
) -> dict[str, StdioServerParameters]:
"""Loads named server configurations from a JSON file.
Args:
config_file_path: Path to the JSON configuration file.
base_env: The base environment dictionary to be inherited by servers.
Returns:
A dictionary of named server parameters.
Raises:
FileNotFoundError: If the config file is not found.
json.JSONDecodeError: If the config file contains invalid JSON.
ValueError: If the config file format is invalid.
"""
named_stdio_params: dict[str, StdioServerParameters] = {}
logger.info("Loading named server configurations from: %s", config_file_path)
try:
with Path(config_file_path).open() as f:
config_data = json.load(f)
except FileNotFoundError:
logger.exception("Configuration file not found: %s", config_file_path)
raise
except json.JSONDecodeError:
logger.exception("Error decoding JSON from configuration file: %s", config_file_path)
raise
except Exception as e:
logger.exception(
"Unexpected error opening or reading configuration file %s",
config_file_path,
)
error_message = f"Could not read configuration file: {e}"
raise ValueError(error_message) from e
if not isinstance(config_data, dict) or "mcpServers" not in config_data:
msg = f"Invalid config file format in {config_file_path}. Missing 'mcpServers' key."
logger.error(msg)
raise ValueError(msg)
for name, server_config in config_data.get("mcpServers", {}).items():
if not isinstance(server_config, dict):
logger.warning(
"Skipping invalid server config for '%s' in %s. Entry is not a dictionary.",
name,
config_file_path,
)
continue
if not server_config.get("enabled", True): # Default to True if 'enabled' is not present
logger.info("Named server '%s' from config is not enabled. Skipping.", name)
continue
command = server_config.get("command")
command_args = server_config.get("args", [])
if not command:
logger.warning(
"Named server '%s' from config is missing 'command'. Skipping.",
name,
)
continue
if not isinstance(command_args, list):
logger.warning(
"Named server '%s' from config has invalid 'args' (must be a list). Skipping.",
name,
)
continue
named_stdio_params[name] = StdioServerParameters(
command=command,
args=command_args,
env=base_env.copy(),
cwd=None,
)
logger.info(
"Configured named server '%s' from config: %s %s",
name,
command,
" ".join(command_args),
)
return named_stdio_params

View File

@@ -5,12 +5,12 @@ import logging
from collections.abc import AsyncIterator
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Literal
from typing import Any, Literal
import uvicorn
from mcp.client.session import ClientSession
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.server import Server
from mcp.server import Server as MCPServerSDK # Renamed to avoid conflict
from mcp.server.sse import SseServerTransport
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from starlette.applications import Starlette
@@ -18,7 +18,7 @@ from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.routing import Mount, Route
from starlette.routing import BaseRoute, Mount, Route
from starlette.types import Receive, Scope, Send
from .proxy_server import create_proxy_server
@@ -37,123 +37,154 @@ class MCPServerSettings:
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
def create_starlette_app(
mcp_server: Server[object],
# To store last activity for multiple servers if needed, though status endpoint is global for now.
_global_status: dict[str, Any] = {
"api_last_activity": datetime.now(timezone.utc).isoformat(),
"server_instances": {}, # Could be used to store per-instance status later
}
def _update_global_activity() -> None:
_global_status["api_last_activity"] = datetime.now(timezone.utc).isoformat()
async def _handle_status(_: Request) -> Response:
"""Global health check and service usage monitoring endpoint."""
return JSONResponse(_global_status)
def create_single_instance_routes(
mcp_server_instance: MCPServerSDK[object],
*,
stateless: bool = False,
allow_origins: list[str] | None = None,
debug: bool = False,
) -> Starlette:
"""Create a Starlette application that can serve the mcp server with SSE or Streamable http."""
logger.debug("Creating Starlette app with stateless: %s and debug: %s", stateless, debug)
# record the last activity of api
status = {
"api_last_activity": datetime.now(timezone.utc).isoformat(),
}
stateless_instance: bool,
) -> tuple[list[BaseRoute], StreamableHTTPSessionManager]: # Return the manager itself
"""Create Starlette routes and the HTTP session manager for a single MCP server instance."""
logger.debug(
"Creating routes for a single MCP server instance (stateless: %s)",
stateless_instance,
)
def _update_mcp_activity() -> None:
status.update(
{
"api_last_activity": datetime.now(timezone.utc).isoformat(),
},
)
sse_transport = SseServerTransport("messages/")
http_session_manager = StreamableHTTPSessionManager(
app=mcp_server_instance,
event_store=None,
json_response=True,
stateless=stateless_instance,
)
sse = SseServerTransport("/messages/")
async def handle_sse(request: Request) -> None:
async with sse.connect_sse(
async def handle_sse_instance(request: Request) -> None:
async with sse_transport.connect_sse(
request.scope,
request.receive,
request._send, # noqa: SLF001
) as (read_stream, write_stream):
_update_mcp_activity()
await mcp_server.run(
_update_global_activity()
await mcp_server_instance.run(
read_stream,
write_stream,
mcp_server.create_initialization_options(),
mcp_server_instance.create_initialization_options(),
)
# Refer: https://github.com/modelcontextprotocol/python-sdk/blob/v1.8.0/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py
http = StreamableHTTPSessionManager(
app=mcp_server,
event_store=None,
json_response=True,
stateless=stateless,
)
async def handle_streamable_http_instance(scope: Scope, receive: Receive, send: Send) -> None:
_update_global_activity()
await http_session_manager.handle_request(scope, receive, send)
async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None:
_update_mcp_activity()
await http.handle_request(scope, receive, send)
async def handle_status(_: Request) -> Response:
"""Health check and service usage monitoring endpoint.
Purpose of this handler:
- Provides a dedicated API endpoint for external health checks.
- Returns last API activity timestamp to monitor service usage patterns and uptime.
- Serves as basic infrastructure for potential future service metrics expansion.
"""
return JSONResponse(status)
@contextlib.asynccontextmanager
async def lifespan(_: Starlette) -> AsyncIterator[None]:
"""Context manager for session manager."""
async with http.run():
logger.info("Application started with StreamableHTTP session manager!")
try:
yield
finally:
logger.info("Application shutting down...")
middleware: list[Middleware] = []
if allow_origins is not None:
middleware.append(
Middleware(
CORSMiddleware,
allow_origins=allow_origins,
allow_methods=["*"],
allow_headers=["*"],
),
)
return Starlette(
debug=debug,
middleware=middleware,
routes=[
Route("/status", endpoint=handle_status),
Mount("/mcp", app=handle_streamable_http),
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
],
lifespan=lifespan,
)
routes = [
Mount("/mcp", app=handle_streamable_http_instance),
Route("/sse", endpoint=handle_sse_instance),
Mount("/messages/", app=sse_transport.handle_post_message),
]
return routes, http_session_manager
async def run_mcp_server(
stdio_params: StdioServerParameters,
mcp_settings: MCPServerSettings,
default_server_params: StdioServerParameters | None = None,
named_server_params: dict[str, StdioServerParameters] | None = None,
) -> None:
"""Run the stdio client and expose an MCP server.
"""Run stdio client(s) and expose an MCP server with multiple possible backends."""
if named_server_params is None:
named_server_params = {}
Args:
stdio_params: The parameters for the stdio client that spawns a stdio server.
mcp_settings: The settings for the MCP server that accepts incoming requests.
all_routes: list[BaseRoute] = [
Route("/status", endpoint=_handle_status), # Global status endpoint
]
# Use AsyncExitStack to manage lifecycles of multiple components
async with contextlib.AsyncExitStack() as stack:
# Manage lifespans of all StreamableHTTPSessionManagers
@contextlib.asynccontextmanager
async def combined_lifespan(_app: Starlette) -> AsyncIterator[None]:
logger.info("Main application lifespan starting...")
# All http_session_managers' .run() are already entered into the stack
yield
logger.info("Main application lifespan shutting down...")
"""
async with stdio_client(stdio_params) as streams, ClientSession(*streams) as session:
logger.debug("Starting proxy server...")
mcp_server = await create_proxy_server(session)
# Setup default server if configured
if default_server_params:
logger.info(
"Setting up default server: %s %s",
default_server_params.command,
" ".join(default_server_params.args),
)
stdio_streams = await stack.enter_async_context(stdio_client(default_server_params))
session = await stack.enter_async_context(ClientSession(*stdio_streams))
proxy = await create_proxy_server(session)
# Bind request handling to MCP server
starlette_app = create_starlette_app(
mcp_server,
stateless=mcp_settings.stateless,
allow_origins=mcp_settings.allow_origins,
instance_routes, http_manager = create_single_instance_routes(
proxy,
stateless_instance=mcp_settings.stateless,
)
await stack.enter_async_context(http_manager.run()) # Manage lifespan by calling run()
all_routes.extend(instance_routes)
_global_status["server_instances"]["default"] = "configured"
# Setup named servers
for name, params in named_server_params.items():
logger.info(
"Setting up named server '%s': %s %s",
name,
params.command,
" ".join(params.args),
)
stdio_streams_named = await stack.enter_async_context(stdio_client(params))
session_named = await stack.enter_async_context(ClientSession(*stdio_streams_named))
proxy_named = await create_proxy_server(session_named)
instance_routes_named, http_manager_named = create_single_instance_routes(
proxy_named,
stateless_instance=mcp_settings.stateless,
)
await stack.enter_async_context(
http_manager_named.run(),
) # Manage lifespan by calling run()
# Mount these routes under /servers/<name>/
server_mount = Mount(f"/servers/{name}", routes=instance_routes_named)
all_routes.append(server_mount)
_global_status["server_instances"][name] = "configured"
if not default_server_params and not named_server_params:
logger.error("No servers configured to run.")
return
middleware: list[Middleware] = []
if mcp_settings.allow_origins:
middleware.append(
Middleware(
CORSMiddleware,
allow_origins=mcp_settings.allow_origins,
allow_methods=["*"],
allow_headers=["*"],
),
)
starlette_app = Starlette(
debug=(mcp_settings.log_level == "DEBUG"),
routes=all_routes,
middleware=middleware,
lifespan=combined_lifespan,
)
# Configure HTTP server
config = uvicorn.Config(
starlette_app,
host=mcp_settings.bind_host,
@@ -161,8 +192,27 @@ async def run_mcp_server(
log_level=mcp_settings.log_level.lower(),
)
http_server = uvicorn.Server(config)
# Print out the SSE URLs for all configured servers
base_url = f"http://{mcp_settings.bind_host}:{mcp_settings.port}"
sse_urls = []
# Add default server if configured
if default_server_params:
sse_urls.append(f"{base_url}/sse")
# Add named servers
sse_urls.extend([f"{base_url}/servers/{name}/sse" for name in named_server_params])
# Display the SSE URLs prominently
if sse_urls:
# Using print directly for user visibility, with noqa to ignore linter warnings
logger.info("Serving MCP Servers via SSE:")
for url in sse_urls:
logger.info(" - %s", url)
logger.debug(
"Serving incoming requests on %s:%s",
"Serving incoming MCP requests on %s:%s",
mcp_settings.bind_host,
mcp_settings.port,
)

View File

@@ -1,4 +1,4 @@
"""Create an MCP server that proxies requests throgh an MCP client.
"""Create an MCP server that proxies requests through an MCP client.
This server is created independent of any transport mechanism.
"""

244
tests/test_config_loader.py Normal file
View File

@@ -0,0 +1,244 @@
"""Tests for the configuration loader module."""
import json
import shutil
import tempfile
from collections.abc import Callable, Generator
from pathlib import Path
from unittest.mock import patch
import pytest
from mcp.client.stdio import StdioServerParameters
from mcp_proxy.config_loader import load_named_server_configs_from_file
@pytest.fixture
def create_temp_config_file() -> Generator[Callable[[dict], str], None, None]:
"""Creates a temporary JSON config file and returns its path."""
temp_files: list[str] = []
def _create_temp_config_file(config_content: dict) -> str:
with tempfile.NamedTemporaryFile(
mode="w",
delete=False,
suffix=".json",
) as tmp_config:
json.dump(config_content, tmp_config)
temp_files.append(tmp_config.name)
return tmp_config.name
yield _create_temp_config_file
for f_path in temp_files:
path = Path(f_path)
if path.exists():
path.unlink()
def test_load_valid_config(create_temp_config_file: Callable[[dict], str]) -> None:
"""Test loading a valid configuration file."""
config_content = {
"mcpServers": {
"server1": {
"command": "echo",
"args": ["hello"],
"enabled": True,
},
"server2": {
"command": "cat",
"args": ["file.txt"],
},
},
}
tmp_config_path = create_temp_config_file(config_content)
base_env = {"PASSED": "env_value"}
loaded_params = load_named_server_configs_from_file(tmp_config_path, base_env)
assert "server1" in loaded_params
assert loaded_params["server1"].command == "echo"
assert loaded_params["server1"].args == ["hello"]
assert (
loaded_params["server1"].env == base_env
) # Env is a copy, check if it contains base_env items
assert "server2" in loaded_params
assert loaded_params["server2"].command == "cat"
assert loaded_params["server2"].args == ["file.txt"]
assert loaded_params["server2"].env == base_env
def test_load_config_with_not_enabled_server(
create_temp_config_file: Callable[[dict], str],
) -> None:
"""Test loading a configuration with disabled servers."""
config_content = {
"mcpServers": {
"explicitly_enabled_server": {"command": "true_command", "enabled": True},
# No 'enabled' flag, defaults to True
"implicitly_enabled_server": {"command": "another_true_command"},
"not_enabled_server": {"command": "false_command", "enabled": False},
},
}
tmp_config_path = create_temp_config_file(config_content)
loaded_params = load_named_server_configs_from_file(tmp_config_path, {})
assert "explicitly_enabled_server" in loaded_params
assert loaded_params["explicitly_enabled_server"].command == "true_command"
assert "implicitly_enabled_server" in loaded_params
assert loaded_params["implicitly_enabled_server"].command == "another_true_command"
assert "not_enabled_server" not in loaded_params
def test_file_not_found() -> None:
"""Test handling of non-existent configuration files."""
with pytest.raises(FileNotFoundError):
load_named_server_configs_from_file("non_existent_file.json", {})
def test_json_decode_error() -> None:
"""Test handling of invalid JSON in configuration files."""
# Create a file with invalid JSON content
with tempfile.NamedTemporaryFile(
mode="w",
delete=False,
suffix=".json",
) as tmp_config:
tmp_config.write("this is not json {")
tmp_config_path = tmp_config.name
# Use try/finally to ensure cleanup
try:
with pytest.raises(json.JSONDecodeError):
load_named_server_configs_from_file(tmp_config_path, {})
finally:
path = Path(tmp_config_path)
if path.exists():
path.unlink()
def test_load_example_fetch_config_if_uvx_exists() -> None:
"""Test loading the example fetch configuration if uvx is available."""
if not shutil.which("uvx"):
pytest.skip("uvx command not found in PATH, skipping test for example config.")
# Assuming the test is run from the root of the repository
example_config_path = Path(__file__).parent.parent / "config_example.json"
if not example_config_path.exists():
pytest.fail(
f"Example config file not found at expected path: {example_config_path}",
)
base_env = {"EXAMPLE_ENV": "true"}
loaded_params = load_named_server_configs_from_file(example_config_path, base_env)
assert "fetch" in loaded_params
fetch_param = loaded_params["fetch"]
assert isinstance(fetch_param, StdioServerParameters)
assert fetch_param.command == "uvx"
assert fetch_param.args == ["mcp-server-fetch"]
assert fetch_param.env == base_env
# The 'timeout' and 'transportType' fields from the config are currently ignored by the loader,
# so no need to assert them on StdioServerParameters.
def test_invalid_config_format_missing_mcpservers(
create_temp_config_file: Callable[[dict], str],
) -> None:
"""Test handling of configuration files missing the mcpServers key."""
config_content = {"some_other_key": "value"}
tmp_config_path = create_temp_config_file(config_content)
with pytest.raises(ValueError, match="Missing 'mcpServers' key"):
load_named_server_configs_from_file(tmp_config_path, {})
@patch("mcp_proxy.config_loader.logger")
def test_invalid_server_entry_not_dict(
mock_logger: object,
create_temp_config_file: Callable[[dict], str],
) -> None:
"""Test handling of server entries that are not dictionaries."""
config_content = {"mcpServers": {"server1": "not_a_dict"}}
tmp_config_path = create_temp_config_file(config_content)
loaded_params = load_named_server_configs_from_file(tmp_config_path, {})
assert len(loaded_params) == 0 # No servers should be loaded
mock_logger.warning.assert_called_with(
"Skipping invalid server config for '%s' in %s. Entry is not a dictionary.",
"server1",
tmp_config_path,
)
@patch("mcp_proxy.config_loader.logger")
def test_server_entry_missing_command(
mock_logger: object,
create_temp_config_file: Callable[[dict], str],
) -> None:
"""Test handling of server entries missing the command field."""
config_content = {"mcpServers": {"server_no_command": {"args": ["arg1"]}}}
tmp_config_path = create_temp_config_file(config_content)
loaded_params = load_named_server_configs_from_file(tmp_config_path, {})
assert "server_no_command" not in loaded_params
mock_logger.warning.assert_called_with(
"Named server '%s' from config is missing 'command'. Skipping.",
"server_no_command",
)
@patch("mcp_proxy.config_loader.logger")
def test_server_entry_invalid_args_type(
mock_logger: object,
create_temp_config_file: Callable[[dict], str],
) -> None:
"""Test handling of server entries with invalid args type."""
config_content = {
"mcpServers": {
"server_invalid_args": {"command": "mycmd", "args": "not_a_list"},
},
}
tmp_config_path = create_temp_config_file(config_content)
loaded_params = load_named_server_configs_from_file(tmp_config_path, {})
assert "server_invalid_args" not in loaded_params
mock_logger.warning.assert_called_with(
"Named server '%s' from config has invalid 'args' (must be a list). Skipping.",
"server_invalid_args",
)
def test_empty_mcpservers_dict(create_temp_config_file: Callable[[dict], str]) -> None:
"""Test handling of configuration files with empty mcpServers dictionary."""
config_content = {"mcpServers": {}}
tmp_config_path = create_temp_config_file(config_content)
loaded_params = load_named_server_configs_from_file(tmp_config_path, {})
assert len(loaded_params) == 0
def test_config_file_is_empty_json_object(create_temp_config_file: Callable[[dict], str]) -> None:
"""Test handling of configuration files with empty JSON objects."""
config_content = {} # Empty JSON object
tmp_config_path = create_temp_config_file(config_content)
with pytest.raises(ValueError, match="Missing 'mcpServers' key"):
load_named_server_configs_from_file(tmp_config_path, {})
def test_config_file_is_empty_string() -> None:
"""Test handling of configuration files with empty content."""
# Create a file with an empty string
with tempfile.NamedTemporaryFile(
mode="w",
delete=False,
suffix=".json",
) as tmp_config:
tmp_config.write("") # Empty content
tmp_config_path = tmp_config.name
try:
with pytest.raises(json.JSONDecodeError):
load_named_server_configs_from_file(tmp_config_path, {})
finally:
path = Path(tmp_config_path)
if path.exists():
path.unlink()

View File

@@ -1,18 +1,68 @@
"""Tests for the sse server."""
# ruff: noqa: PLR2004
import asyncio
import contextlib
import typing as t
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import uvicorn
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters
from mcp.client.streamable_http import streamablehttp_client
from mcp.server import FastMCP
from mcp.server import FastMCP, Server
from mcp.types import TextContent
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
from mcp_proxy.mcp_server import create_starlette_app
from mcp_proxy.mcp_server import MCPServerSettings, create_single_instance_routes, run_mcp_server
def create_starlette_app(
mcp_server: Server[t.Any],
allow_origins: list[str] | None = None,
*,
debug: bool = False,
stateless: bool = False,
) -> Starlette:
"""Create a Starlette application for the MCP server.
Args:
mcp_server: The MCP server instance to wrap
allow_origins: List of allowed CORS origins
debug: Enable debug mode
stateless: Whether to use stateless HTTP sessions
Returns:
Starlette application instance
"""
routes, http_manager = create_single_instance_routes(mcp_server, stateless_instance=stateless)
middleware: list[Middleware] = []
if allow_origins:
middleware.append(
Middleware(
CORSMiddleware,
allow_origins=allow_origins,
allow_methods=["*"],
allow_headers=["*"],
),
)
@contextlib.asynccontextmanager
async def lifespan(_app: Starlette) -> t.AsyncIterator[None]:
async with http_manager.run():
yield
return Starlette(
debug=debug,
routes=routes,
middleware=middleware,
lifespan=lifespan,
)
class BackgroundServer(uvicorn.Server):
@@ -64,7 +114,6 @@ def make_background_server(**kwargs) -> BackgroundServer: # noqa: ANN003
return BackgroundServer(config)
@pytest.mark.asyncio
async def test_sse_transport() -> None:
"""Test basic glue code for the SSE transport and a fake MCP server."""
server = make_background_server(debug=True)
@@ -77,7 +126,6 @@ async def test_sse_transport() -> None:
assert response.prompts[0].name == "prompt1"
@pytest.mark.asyncio
async def test_http_transport() -> None:
"""Test HTTP transport layer functionality."""
server = make_background_server(debug=True)
@@ -118,3 +166,509 @@ async def test_stateless_http_transport() -> None:
assert len(tool_result.content) == 1
assert isinstance(tool_result.content[0], TextContent)
assert tool_result.content[0].text == f"Echo: test_{i}"
# Unit tests for run_mcp_server method
@pytest.fixture
def mock_settings() -> MCPServerSettings:
"""Create mock MCP server settings for testing."""
return MCPServerSettings(
bind_host="127.0.0.1",
port=8080,
stateless=False,
allow_origins=["*"],
log_level="INFO",
)
@pytest.fixture
def mock_stdio_params() -> StdioServerParameters:
"""Create mock stdio server parameters for testing."""
return StdioServerParameters(
command="echo",
args=["hello"],
env={"TEST_VAR": "test_value"},
cwd="/tmp", # noqa: S108
)
class AsyncContextManagerMock: # noqa: D101
def __init__(self, mock) -> None: # noqa: ANN001, D107
self.mock = mock
async def __aenter__(self): # noqa: ANN204, D105
return self.mock
async def __aexit__(self, *args): # noqa: ANN002, ANN204, D105
pass
def setup_async_context_mocks() -> tuple[
AsyncContextManagerMock,
AsyncContextManagerMock,
AsyncMock,
MagicMock,
list[MagicMock],
]:
"""Helper function to set up async context manager mocks."""
# Setup stdio client mock
mock_streams = (AsyncMock(), AsyncMock())
# Setup client session mock
mock_session = AsyncMock()
# Setup HTTP manager mock
mock_http_manager = MagicMock()
mock_http_manager.run.return_value = AsyncContextManagerMock(None)
mock_routes = [MagicMock()]
return (
AsyncContextManagerMock(mock_streams),
AsyncContextManagerMock(mock_session),
mock_session,
mock_http_manager,
mock_routes,
)
async def test_run_mcp_server_no_servers_configured(mock_settings: MCPServerSettings) -> None:
"""Test run_mcp_server when no servers are configured."""
with patch("mcp_proxy.mcp_server.logger") as mock_logger:
await run_mcp_server(mock_settings, None, {})
mock_logger.error.assert_called_once_with("No servers configured to run.")
async def test_run_mcp_server_with_default_server(
mock_settings: MCPServerSettings,
mock_stdio_params: StdioServerParameters,
) -> None:
"""Test run_mcp_server with a default server configuration."""
with (
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
patch("uvicorn.Server") as mock_uvicorn_server,
patch("mcp_proxy.mcp_server.logger") as mock_logger,
):
# Setup mocks
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
setup_async_context_mocks()
)
mock_stdio_client.return_value = mock_stdio_context
mock_client_session.return_value = mock_session_context
mock_proxy = AsyncMock()
mock_create_proxy.return_value = mock_proxy
mock_create_routes.return_value = (mock_routes, mock_http_manager)
mock_server_instance = AsyncMock()
mock_uvicorn_server.return_value = mock_server_instance
# Run the function
await run_mcp_server(mock_settings, mock_stdio_params, {})
# Verify calls
mock_stdio_client.assert_called_once_with(mock_stdio_params)
mock_create_proxy.assert_called_once_with(mock_session)
mock_create_routes.assert_called_once_with(
mock_proxy,
stateless_instance=mock_settings.stateless,
)
mock_logger.info.assert_any_call(
"Setting up default server: %s %s",
mock_stdio_params.command,
" ".join(mock_stdio_params.args),
)
mock_server_instance.serve.assert_called_once()
async def test_run_mcp_server_with_named_servers(
mock_settings: MCPServerSettings,
mock_stdio_params: StdioServerParameters,
) -> None:
"""Test run_mcp_server with named servers configuration."""
named_servers = {
"server1": mock_stdio_params,
"server2": StdioServerParameters(
command="python",
args=["-m", "mcp_server"],
env={"PYTHON_PATH": "/usr/bin/python"},
cwd="/home/user",
),
}
with (
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
patch("uvicorn.Server") as mock_uvicorn_server,
patch("mcp_proxy.mcp_server.logger") as mock_logger,
):
# Setup mocks
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
setup_async_context_mocks()
)
mock_stdio_client.return_value = mock_stdio_context
mock_client_session.return_value = mock_session_context
mock_proxy = AsyncMock()
mock_create_proxy.return_value = mock_proxy
mock_create_routes.return_value = (mock_routes, mock_http_manager)
mock_server_instance = AsyncMock()
mock_uvicorn_server.return_value = mock_server_instance
# Run the function
await run_mcp_server(mock_settings, None, named_servers)
# Verify calls
assert mock_stdio_client.call_count == 2
assert mock_create_proxy.call_count == 2
assert mock_create_routes.call_count == 2
# Check that named servers were logged
mock_logger.info.assert_any_call(
"Setting up named server '%s': %s %s",
"server1",
mock_stdio_params.command,
" ".join(mock_stdio_params.args),
)
mock_logger.info.assert_any_call(
"Setting up named server '%s': %s %s",
"server2",
"python",
"-m mcp_server",
)
mock_server_instance.serve.assert_called_once()
async def test_run_mcp_server_with_cors_middleware(
mock_stdio_params: StdioServerParameters,
) -> None:
"""Test run_mcp_server adds CORS middleware when allow_origins is set."""
settings_with_cors = MCPServerSettings(
bind_host="0.0.0.0", # noqa: S104
port=9090,
allow_origins=["http://localhost:3000", "https://example.com"],
)
with (
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
patch("mcp_proxy.mcp_server.Starlette") as mock_starlette,
patch("uvicorn.Server") as mock_uvicorn_server,
):
# Setup mocks
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
setup_async_context_mocks()
)
mock_stdio_client.return_value = mock_stdio_context
mock_client_session.return_value = mock_session_context
mock_proxy = AsyncMock()
mock_create_proxy.return_value = mock_proxy
mock_create_routes.return_value = (mock_routes, mock_http_manager)
mock_server_instance = AsyncMock()
mock_uvicorn_server.return_value = mock_server_instance
# Run the function
await run_mcp_server(settings_with_cors, mock_stdio_params, {})
# Verify Starlette was called with middleware
mock_starlette.assert_called_once()
call_args = mock_starlette.call_args
middleware = call_args.kwargs["middleware"]
assert len(middleware) == 1
assert middleware[0].cls == CORSMiddleware
async def test_run_mcp_server_debug_mode(
mock_stdio_params: StdioServerParameters,
) -> None:
"""Test run_mcp_server with debug mode enabled."""
debug_settings = MCPServerSettings(
bind_host="127.0.0.1",
port=8080,
log_level="DEBUG",
)
with (
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
patch("mcp_proxy.mcp_server.Starlette") as mock_starlette,
patch("uvicorn.Server") as mock_uvicorn_server,
):
# Setup mocks
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
setup_async_context_mocks()
)
mock_stdio_client.return_value = mock_stdio_context
mock_client_session.return_value = mock_session_context
mock_proxy = AsyncMock()
mock_create_proxy.return_value = mock_proxy
mock_create_routes.return_value = (mock_routes, mock_http_manager)
mock_server_instance = AsyncMock()
mock_uvicorn_server.return_value = mock_server_instance
# Run the function
await run_mcp_server(debug_settings, mock_stdio_params, {})
# Verify Starlette was called with debug=True
mock_starlette.assert_called_once()
call_args = mock_starlette.call_args
assert call_args.kwargs["debug"] is True
async def test_run_mcp_server_stateless_mode(
mock_stdio_params: StdioServerParameters,
) -> None:
"""Test run_mcp_server with stateless mode enabled."""
stateless_settings = MCPServerSettings(
bind_host="127.0.0.1",
port=8080,
stateless=True,
)
with (
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
patch("uvicorn.Server") as mock_uvicorn_server,
):
# Setup mocks
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
setup_async_context_mocks()
)
mock_stdio_client.return_value = mock_stdio_context
mock_client_session.return_value = mock_session_context
mock_proxy = AsyncMock()
mock_create_proxy.return_value = mock_proxy
mock_create_routes.return_value = (mock_routes, mock_http_manager)
mock_server_instance = AsyncMock()
mock_uvicorn_server.return_value = mock_server_instance
# Run the function
await run_mcp_server(stateless_settings, mock_stdio_params, {})
# Verify create_single_instance_routes was called with stateless_instance=True
mock_create_routes.assert_called_once_with(
mock_proxy,
stateless_instance=True,
)
async def test_run_mcp_server_uvicorn_config(
mock_settings: MCPServerSettings,
mock_stdio_params: StdioServerParameters,
) -> None:
"""Test run_mcp_server creates correct uvicorn configuration."""
with (
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
patch("uvicorn.Config") as mock_uvicorn_config,
patch("uvicorn.Server") as mock_uvicorn_server,
):
# Setup mocks
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
setup_async_context_mocks()
)
mock_stdio_client.return_value = mock_stdio_context
mock_client_session.return_value = mock_session_context
mock_proxy = AsyncMock()
mock_create_proxy.return_value = mock_proxy
mock_create_routes.return_value = (mock_routes, mock_http_manager)
mock_config = MagicMock()
mock_uvicorn_config.return_value = mock_config
mock_server_instance = AsyncMock()
mock_uvicorn_server.return_value = mock_server_instance
# Run the function
await run_mcp_server(mock_settings, mock_stdio_params, {})
# Verify uvicorn.Config was called with correct parameters
mock_uvicorn_config.assert_called_once()
call_args = mock_uvicorn_config.call_args
assert call_args.kwargs["host"] == mock_settings.bind_host
assert call_args.kwargs["port"] == mock_settings.port
assert call_args.kwargs["log_level"] == mock_settings.log_level.lower()
async def test_run_mcp_server_global_status_updates(
mock_settings: MCPServerSettings,
mock_stdio_params: StdioServerParameters,
) -> None:
"""Test run_mcp_server updates global status correctly."""
from mcp_proxy.mcp_server import _global_status
# Clear global status before test
_global_status["server_instances"].clear()
named_servers = {"test_server": mock_stdio_params}
with (
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
patch("uvicorn.Server") as mock_uvicorn_server,
):
# Setup mocks
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
setup_async_context_mocks()
)
mock_stdio_client.return_value = mock_stdio_context
mock_client_session.return_value = mock_session_context
mock_proxy = AsyncMock()
mock_create_proxy.return_value = mock_proxy
mock_create_routes.return_value = (mock_routes, mock_http_manager)
mock_server_instance = AsyncMock()
mock_uvicorn_server.return_value = mock_server_instance
# Run the function
await run_mcp_server(mock_settings, mock_stdio_params, named_servers)
# Verify global status was updated
assert "default" in _global_status["server_instances"]
assert "test_server" in _global_status["server_instances"]
assert _global_status["server_instances"]["default"] == "configured"
assert _global_status["server_instances"]["test_server"] == "configured"
async def test_run_mcp_server_sse_url_logging(
mock_settings: MCPServerSettings,
mock_stdio_params: StdioServerParameters,
) -> None:
"""Test run_mcp_server logs correct SSE URLs."""
named_servers = {"test_server": mock_stdio_params}
with (
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
patch("uvicorn.Server") as mock_uvicorn_server,
patch("mcp_proxy.mcp_server.logger") as mock_logger,
):
# Setup mocks
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
setup_async_context_mocks()
)
mock_stdio_client.return_value = mock_stdio_context
mock_client_session.return_value = mock_session_context
mock_proxy = AsyncMock()
mock_create_proxy.return_value = mock_proxy
mock_create_routes.return_value = (mock_routes, mock_http_manager)
mock_server_instance = AsyncMock()
mock_uvicorn_server.return_value = mock_server_instance
# Run the function
await run_mcp_server(mock_settings, mock_stdio_params, named_servers)
# Verify SSE URLs were logged
expected_default_url = f"http://{mock_settings.bind_host}:{mock_settings.port}/sse"
expected_named_url = (
f"http://{mock_settings.bind_host}:{mock_settings.port}/servers/test_server/sse"
)
mock_logger.info.assert_any_call("Serving MCP Servers via SSE:")
mock_logger.info.assert_any_call(" - %s", expected_default_url)
mock_logger.info.assert_any_call(" - %s", expected_named_url)
async def test_run_mcp_server_exception_handling(
mock_settings: MCPServerSettings,
mock_stdio_params: StdioServerParameters,
) -> None:
"""Test run_mcp_server handles exceptions properly."""
with (
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
patch("mcp_proxy.mcp_server.ClientSession"),
):
# Setup mocks to raise an exception
mock_stdio_client.side_effect = Exception("Connection failed")
# Should not raise, function should handle exceptions gracefully
try:
await run_mcp_server(mock_settings, mock_stdio_params, {})
except Exception as e: # noqa: BLE001
# If an exception is raised, it should be the expected one
assert "Connection failed" in str(e) # noqa: PT017
async def test_run_mcp_server_both_default_and_named_servers(
mock_settings: MCPServerSettings,
mock_stdio_params: StdioServerParameters,
) -> None:
"""Test run_mcp_server with both default and named servers."""
named_servers = {"named_server": mock_stdio_params}
with (
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
patch("uvicorn.Server") as mock_uvicorn_server,
patch("mcp_proxy.mcp_server.logger") as mock_logger,
):
# Setup mocks
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
setup_async_context_mocks()
)
mock_stdio_client.return_value = mock_stdio_context
mock_client_session.return_value = mock_session_context
mock_proxy = AsyncMock()
mock_create_proxy.return_value = mock_proxy
mock_create_routes.return_value = (mock_routes, mock_http_manager)
mock_server_instance = AsyncMock()
mock_uvicorn_server.return_value = mock_server_instance
# Run the function with both default and named servers
await run_mcp_server(mock_settings, mock_stdio_params, named_servers)
# Verify both servers were set up
assert mock_stdio_client.call_count == 2 # One for default, one for named
assert mock_create_proxy.call_count == 2
assert mock_create_routes.call_count == 2
# Verify logging for both servers
mock_logger.info.assert_any_call(
"Setting up default server: %s %s",
mock_stdio_params.command,
" ".join(mock_stdio_params.args),
)
mock_logger.info.assert_any_call(
"Setting up named server '%s': %s %s",
"named_server",
mock_stdio_params.command,
" ".join(mock_stdio_params.args),
)
mock_server_instance.serve.assert_called_once()