feat: support proxying multiple MCP stdio servers to SSE (#65)
This PR adds support for running multiple MCP (STDIO) servers and serving them up via a single mcp-proxy instance, each with a named path in the URL. Example usage: ``` mcp-proxy --port 8080 --named-server fetch 'uvx mcp-server-fetch' --named-server github 'npx -y @modelcontextprotocol/server-github' ``` Would serve: - `http://localhost:8080/servers/fetch` - `http://localhost:8080/servers/github` I've also added the ability to provide a standard mcp client config file with accompanying tests. Please feel free to make any changes as you see fit, or reject the PR if it does not align with your goals. Thank you, --------- Co-authored-by: Magnus Tidemann <magnustidemann@gmail.com> Co-authored-by: Sergey Parfenyuk <sergey.parfenyuk@gmail.com>
This commit is contained in:
145
README.md
145
README.md
@@ -14,6 +14,7 @@
|
||||
- [2. SSE to stdio](#2-sse-to-stdio)
|
||||
- [2.1 Configuration](#21-configuration)
|
||||
- [2.2 Example usage](#22-example-usage)
|
||||
- [Named Servers](#named-servers)
|
||||
- [Installation](#installation)
|
||||
- [Installing via Smithery](#installing-via-smithery)
|
||||
- [Installing via PyPI](#installing-via-pypi)
|
||||
@@ -23,6 +24,7 @@
|
||||
- [Extending the container image](#extending-the-container-image)
|
||||
- [Docker Compose Setup](#docker-compose-setup)
|
||||
- [Command line arguments](#command-line-arguments)
|
||||
- [Example config file](#example-config-file)
|
||||
- [Testing](#testing)
|
||||
|
||||
## About
|
||||
@@ -115,18 +117,20 @@ separator.
|
||||
|
||||
Arguments
|
||||
|
||||
| Name | Required | Description | Example |
|
||||
| ------------------------- | -------------------------- | --------------------------------------------------------------------------------------------- | --------------------- |
|
||||
| `command_or_url` | Yes | The command to spawn the MCP stdio server | uvx mcp-server-fetch |
|
||||
| `--port` | No, random available | The MCP server port to listen on | 8080 |
|
||||
| `--host` | No, `127.0.0.1` by default | The host IP address that the MCP server will listen on | 0.0.0.0 |
|
||||
| `--env` | No | Additional environment variables to pass to the MCP stdio server. Can be used multiple times. | FOO BAR |
|
||||
| `--cwd` | No | The working directory to pass to the MCP stdio server process. | /tmp |
|
||||
| `--pass-environment` | No | Pass through all environment variables when spawning the server | --no-pass-environment |
|
||||
| `--allow-origin` | No | Allowed origins for the SSE server. Can be used multiple times. Default is no CORS allowed. | --allow-cors "\*" |
|
||||
| `--stateless` | No | Enable stateless mode for streamable http transports. Default is False | --no-stateless |
|
||||
| `--sse-port` (deprecated) | No, random available | The SSE server port to listen on | 8080 |
|
||||
| `--sse-host` (deprecated) | No, `127.0.0.1` by default | The host IP address that the SSE server will listen on | 0.0.0.0 |
|
||||
| Name | Required | Description | Example |
|
||||
|--------------------------------------|----------------------------|---------------------------------------------------------------------------------------------|---------------------------------------------|
|
||||
| `command_or_url` | Yes | The command to spawn the MCP stdio server | uvx mcp-server-fetch |
|
||||
| `--port` | No, random available | The MCP server port to listen on | 8080 |
|
||||
| `--host` | No, `127.0.0.1` by default | The host IP address that the MCP server will listen on | 0.0.0.0 |
|
||||
| `--env` | No | Additional environment variables to pass to the MCP stdio server. Can be used multiple times. | FOO BAR |
|
||||
| `--cwd` | No | The working directory to pass to the MCP stdio server process. | /tmp |
|
||||
| `--pass-environment` | No | Pass through all environment variables when spawning the server | --no-pass-environment |
|
||||
| `--allow-origin` | No | Allowed origins for the SSE server. Can be used multiple times. Default is no CORS allowed. | --allow-origin "\*" |
|
||||
| `--stateless` | No | Enable stateless mode for streamable http transports. Default is False | --no-stateless |
|
||||
| `--named-server NAME COMMAND_STRING` | No | Defines a named stdio server. | --named-server fetch 'uvx mcp-server-fetch' |
|
||||
| `--named-server-config FILE_PATH` | No | Path to a JSON file defining named stdio servers. | --named-server-config /path/to/servers.json |
|
||||
| `--sse-port` (deprecated) | No, random available | The SSE server port to listen on | 8080 |
|
||||
| `--sse-host` (deprecated) | No, `127.0.0.1` by default | The host IP address that the SSE server will listen on | 0.0.0.0 |
|
||||
|
||||
### 2.2 Example usage
|
||||
|
||||
@@ -148,10 +152,61 @@ mcp-proxy --host=0.0.0.0 --port=8080 uvx mcp-server-fetch
|
||||
# Note that the `--` separator is used to separate the `mcp-proxy` arguments from the `mcp-server-fetch` arguments
|
||||
# (deprecated) mcp-proxy --sse-port=8080 -- uvx mcp-server-fetch --user-agent=YourUserAgent
|
||||
mcp-proxy --port=8080 -- uvx mcp-server-fetch --user-agent=YourUserAgent
|
||||
|
||||
# Start multiple named MCP servers behind the proxy
|
||||
mcp-proxy --port=8080 --named-server fetch 'uvx mcp-server-fetch' --named-server fetch2 'uvx mcp-server-fetch'
|
||||
|
||||
# Start multiple named MCP servers using a configuration file
|
||||
mcp-proxy --port=8080 --named-server-config ./servers.json
|
||||
```
|
||||
|
||||
This will start an MCP server that can be connected to at `http://127.0.0.1:8080/sse` via SSE, or
|
||||
`http://127.0.0.1:8080/mcp/` via StreamableHttp
|
||||
## Named Servers
|
||||
|
||||
- `NAME` is used in the URL path `/servers/NAME/`.
|
||||
- `COMMAND_STRING` is the command to start the server (e.g., 'uvx mcp-server-fetch').
|
||||
- Can be used multiple times.
|
||||
- This argument is ignored if `--named-server-config` is used.
|
||||
- `FILE_PATH` - If provided, this is the exclusive source for named servers, and `--named-server` CLI arguments are ignored.
|
||||
|
||||
If a default server is specified (the `command_or_url` argument without `--named-server` or `--named-server-config`), it will be accessible at the root paths (e.g., `http://127.0.0.1:8080/sse`).
|
||||
|
||||
Named servers (whether defined by `--named-server` or `--named-server-config`) will be accessible under `/servers/<server-name>/` (e.g., `http://127.0.0.1:8080/servers/fetch1/sse`).
|
||||
The `/status` endpoint provides global status.
|
||||
|
||||
**JSON Configuration File Format for `--named-server-config`:**
|
||||
|
||||
The JSON file should follow this structure:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"fetch": {
|
||||
"disabled": false,
|
||||
"timeout": 60,
|
||||
"command": "uvx",
|
||||
"args": [
|
||||
"mcp-server-fetch"
|
||||
],
|
||||
"transportType": "stdio"
|
||||
},
|
||||
"github": {
|
||||
"timeout": 60,
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"@modelcontextprotocol/server-github"
|
||||
],
|
||||
"transportType": "stdio"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
- `mcpServers`: A dictionary where each key is the server name (used in the URL path, e.g., `/servers/fetch/`) and the value is an object defining the server.
|
||||
- `command`: (Required) The command to execute for the stdio server.
|
||||
- `args`: (Optional) A list of arguments for the command. Defaults to an empty list.
|
||||
- `enabled`: (Optional) If `false`, this server definition will be skipped. Defaults to `true`.
|
||||
- `timeout` and `transportType`: These fields are present in standard MCP client configurations but are currently **ignored** by `mcp-proxy` when loading named servers. The transport type is implicitly "stdio".
|
||||
|
||||
## Installation
|
||||
|
||||
@@ -258,14 +313,20 @@ services:
|
||||
|
||||
```bash
|
||||
usage: mcp-proxy [-h] [-H KEY VALUE] [--transport {sse,streamablehttp}] [-e KEY VALUE] [--cwd CWD] [--pass-environment | --no-pass-environment]
|
||||
[--debug | --no-debug] [--port PORT] [--host HOST] [--stateless | --no-stateless] [--sse-port SSE_PORT] [--sse-host SSE_HOST]
|
||||
[--allow-origin ALLOW_ORIGIN [ALLOW_ORIGIN ...]]
|
||||
[--debug | --no-debug] [--named-server NAME COMMAND_STRING] [--named-server-config FILE_PATH] [--port PORT] [--host HOST]
|
||||
[--stateless | --no-stateless] [--sse-port SSE_PORT] [--sse-host SSE_HOST] [--allow-origin ALLOW_ORIGIN [ALLOW_ORIGIN ...]]
|
||||
[command_or_url] [args ...]
|
||||
|
||||
Start the MCP proxy in one of two possible modes: as a client or a server.
|
||||
Start the MCP proxy.
|
||||
It can run as an SSE client (connecting to a remote SSE server and exposing stdio).
|
||||
Or, it can run as an SSE server (connecting to local stdio command(s) and exposing them over SSE).
|
||||
When running as an SSE server, it can proxy a single default stdio command or multiple named stdio commands (defined via CLI or a config file).
|
||||
|
||||
positional arguments:
|
||||
command_or_url Command or URL to connect to. When a URL, will run an SSE/StreamableHTTP client, otherwise will run the given command and connect as a stdio client. See corresponding options for more details.
|
||||
command_or_url Command or URL.
|
||||
If URL (http/https): Runs in SSE/StreamableHTTP client mode.
|
||||
If command string: Runs in SSE server mode, this is the default stdio server.
|
||||
If --named-server or --named-server-config is used, this can be omitted if no default server is desired.
|
||||
|
||||
options:
|
||||
-h, --help show this help message and exit
|
||||
@@ -277,12 +338,16 @@ SSE/StreamableHTTP client options:
|
||||
The transport to use for the client. Default is SSE.
|
||||
|
||||
stdio client options:
|
||||
args Any extra arguments to the command to spawn the server
|
||||
-e, --env KEY VALUE Environment variables used when spawning the server. Can be used multiple times.
|
||||
--cwd CWD The working directory to use when spawning the process.
|
||||
args Any extra arguments to the command to spawn the default server. Ignored if only named servers are defined.
|
||||
-e, --env KEY VALUE Environment variables used when spawning the default server. Can be used multiple times. For named servers, environment is inherited or passed via --pass-environment.
|
||||
--cwd CWD The working directory to use when spawning the default server process. Named servers inherit the proxy's CWD.
|
||||
--pass-environment, --no-pass-environment
|
||||
Pass through all environment variables when spawning the server.
|
||||
Pass through all environment variables when spawning all server processes.
|
||||
--debug, --no-debug Enable debug mode with detailed logging output.
|
||||
--named-server NAME COMMAND_STRING
|
||||
Define a named stdio server. NAME is for the URL path /servers/NAME/. COMMAND_STRING is a single string with the command and its arguments (e.g., 'uvx mcp-server-fetch --timeout 10'). These servers inherit the proxy's CWD and environment from --pass-environment. Can be specified multiple times. Ignored if --named-server-config is used.
|
||||
--named-server-config FILE_PATH
|
||||
Path to a JSON configuration file for named stdio servers. If provided, this will be the exclusive source for named server definitions, and any --named-server CLI arguments will be ignored.
|
||||
|
||||
SSE server options:
|
||||
--port PORT Port to expose an SSE server on. Default is a random port
|
||||
@@ -298,9 +363,39 @@ Examples:
|
||||
mcp-proxy http://localhost:8080/sse
|
||||
mcp-proxy --transport streamablehttp http://localhost:8080/mcp
|
||||
mcp-proxy --headers Authorization 'Bearer YOUR_TOKEN' http://localhost:8080/sse
|
||||
mcp-proxy --port 8080 -- your-command --arg1 value1 --arg2 value2
|
||||
mcp-proxy your-command --port 8080 -e KEY VALUE -e ANOTHER_KEY ANOTHER_VALUE
|
||||
mcp-proxy your-command --port 8080 --allow-origin='*'
|
||||
mcp-proxy --port 8080 -- my-default-command --arg1 value1
|
||||
mcp-proxy --port 8080 --named-server fetch1 'uvx mcp-server-fetch' --named-server tool2 'my-custom-tool --verbose'
|
||||
mcp-proxy --port 8080 --named-server-config /path/to/servers.json
|
||||
mcp-proxy --port 8080 --named-server-config /path/to/servers.json -- my-default-command --arg1
|
||||
mcp-proxy --port 8080 -e KEY VALUE -e ANOTHER_KEY ANOTHER_VALUE -- my-default-command
|
||||
mcp-proxy --port 8080 --allow-origin='*' -- my-default-command
|
||||
```
|
||||
|
||||
### Example config file
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"fetch": {
|
||||
"enabled": true,
|
||||
"timeout": 60,
|
||||
"command": "uvx",
|
||||
"args": [
|
||||
"mcp-server-fetch"
|
||||
],
|
||||
"transportType": "stdio"
|
||||
},
|
||||
"github": {
|
||||
"timeout": 60,
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"@modelcontextprotocol/server-github"
|
||||
],
|
||||
"transportType": "stdio"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
23
config_example.json
Normal file
23
config_example.json
Normal file
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"fetch": {
|
||||
"enabled": true,
|
||||
"timeout": 60,
|
||||
"command": "uvx",
|
||||
"args": [
|
||||
"mcp-server-fetch"
|
||||
],
|
||||
"transportType": "stdio"
|
||||
},
|
||||
"github": {
|
||||
"enabled": false,
|
||||
"timeout": 60,
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"@modelcontextprotocol/server-github"
|
||||
],
|
||||
"transportType": "stdio"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -8,13 +8,16 @@ Two ways to run the application:
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shlex
|
||||
import sys
|
||||
import typing as t
|
||||
|
||||
from mcp.client.stdio import StdioServerParameters
|
||||
|
||||
from .config_loader import load_named_server_configs_from_file
|
||||
from .mcp_server import MCPServerSettings, run_mcp_server
|
||||
from .sse_client import run_sse_client
|
||||
from .streamablehttp_client import run_streamablehttp_client
|
||||
@@ -26,8 +29,8 @@ SSE_URL: t.Final[str | None] = os.getenv(
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Start the client using asyncio."""
|
||||
def _setup_argument_parser() -> argparse.ArgumentParser:
|
||||
"""Set up and return the argument parser for the MCP proxy."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description=("Start the MCP proxy in one of two possible modes: as a client or a server."),
|
||||
epilog=(
|
||||
@@ -36,19 +39,29 @@ def main() -> None:
|
||||
" mcp-proxy --transport streamablehttp http://localhost:8080/mcp\n"
|
||||
" mcp-proxy --headers Authorization 'Bearer YOUR_TOKEN' http://localhost:8080/sse\n"
|
||||
" mcp-proxy --port 8080 -- your-command --arg1 value1 --arg2 value2\n"
|
||||
" mcp-proxy --named-server fetch 'uvx mcp-server-fetch' --port 8080\n"
|
||||
" mcp-proxy your-command --port 8080 -e KEY VALUE -e ANOTHER_KEY ANOTHER_VALUE\n"
|
||||
" mcp-proxy your-command --port 8080 --allow-origin='*'\n"
|
||||
),
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
|
||||
_add_arguments_to_parser(parser)
|
||||
return parser
|
||||
|
||||
|
||||
def _add_arguments_to_parser(parser: argparse.ArgumentParser) -> None:
|
||||
"""Add all arguments to the argument parser."""
|
||||
parser.add_argument(
|
||||
"command_or_url",
|
||||
help=(
|
||||
"Command or URL to connect to. When a URL, will run an SSE/StreamableHTTP client, "
|
||||
"otherwise will run the given command and connect as a stdio client. "
|
||||
"Command or URL to connect to. When a URL, will run an SSE/StreamableHTTP client. "
|
||||
"Otherwise, if --named-server is not used, this will be the command "
|
||||
"for the default stdio client. If --named-server is used, this argument "
|
||||
"is ignored for stdio mode unless no default server is desired. "
|
||||
"See corresponding options for more details."
|
||||
),
|
||||
nargs="?", # Required below to allow for coming form env var
|
||||
nargs="?",
|
||||
default=SSE_URL,
|
||||
)
|
||||
|
||||
@@ -73,7 +86,10 @@ def main() -> None:
|
||||
stdio_client_options.add_argument(
|
||||
"args",
|
||||
nargs="*",
|
||||
help="Any extra arguments to the command to spawn the server",
|
||||
help=(
|
||||
"Any extra arguments to the command to spawn the default server. "
|
||||
"Ignored if only named servers are defined."
|
||||
),
|
||||
)
|
||||
stdio_client_options.add_argument(
|
||||
"-e",
|
||||
@@ -81,18 +97,25 @@ def main() -> None:
|
||||
nargs=2,
|
||||
action="append",
|
||||
metavar=("KEY", "VALUE"),
|
||||
help="Environment variables used when spawning the server. Can be used multiple times.",
|
||||
help=(
|
||||
"Environment variables used when spawning the default server. Can be "
|
||||
"used multiple times. For named servers, environment is inherited or "
|
||||
"passed via --pass-environment."
|
||||
),
|
||||
default=[],
|
||||
)
|
||||
stdio_client_options.add_argument(
|
||||
"--cwd",
|
||||
default=None,
|
||||
help="The working directory to use when spawning the process.",
|
||||
help=(
|
||||
"The working directory to use when spawning the default server process. "
|
||||
"Named servers inherit the proxy's CWD."
|
||||
),
|
||||
)
|
||||
stdio_client_options.add_argument(
|
||||
"--pass-environment",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Pass through all environment variables when spawning the server.",
|
||||
help="Pass through all environment variables when spawning all server processes.",
|
||||
default=False,
|
||||
)
|
||||
stdio_client_options.add_argument(
|
||||
@@ -101,6 +124,31 @@ def main() -> None:
|
||||
help="Enable debug mode with detailed logging output.",
|
||||
default=False,
|
||||
)
|
||||
stdio_client_options.add_argument(
|
||||
"--named-server",
|
||||
action="append",
|
||||
nargs=2,
|
||||
metavar=("NAME", "COMMAND_STRING"),
|
||||
help=(
|
||||
"Define a named stdio server. NAME is for the URL path /servers/NAME/. "
|
||||
"COMMAND_STRING is a single string with the command and its arguments "
|
||||
"(e.g., 'uvx mcp-server-fetch --timeout 10'). "
|
||||
"These servers inherit the proxy's CWD and environment from --pass-environment."
|
||||
),
|
||||
default=[],
|
||||
dest="named_server_definitions",
|
||||
)
|
||||
stdio_client_options.add_argument(
|
||||
"--named-server-config",
|
||||
type=str,
|
||||
default=None,
|
||||
metavar="FILE_PATH",
|
||||
help=(
|
||||
"Path to a JSON configuration file for named stdio servers. "
|
||||
"If provided, this will be the exclusive source for named server definitions, "
|
||||
"and any --named-server CLI arguments will be ignored."
|
||||
),
|
||||
)
|
||||
|
||||
mcp_server_group = parser.add_argument_group("SSE server options")
|
||||
mcp_server_group.add_argument(
|
||||
@@ -135,65 +183,222 @@ def main() -> None:
|
||||
"--allow-origin",
|
||||
nargs="+",
|
||||
default=[],
|
||||
help="Allowed origins for the SSE server. "
|
||||
"Can be used multiple times. Default is no CORS allowed.",
|
||||
help=(
|
||||
"Allowed origins for the SSE server. Can be used multiple times. "
|
||||
"Default is no CORS allowed."
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.command_or_url:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
def _setup_logging(*, debug: bool) -> logging.Logger:
|
||||
"""Set up logging configuration and return the logger."""
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if args.debug else logging.INFO,
|
||||
level=logging.DEBUG if debug else logging.INFO,
|
||||
format="[%(levelname)1.1s %(asctime)s.%(msecs).03d %(name)s] %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
if (
|
||||
SSE_URL
|
||||
or args.command_or_url.startswith("http://")
|
||||
or args.command_or_url.startswith("https://")
|
||||
|
||||
def _handle_sse_client_mode(
|
||||
args_parsed: argparse.Namespace,
|
||||
logger: logging.Logger,
|
||||
) -> None:
|
||||
"""Handle SSE/StreamableHTTP client mode operation."""
|
||||
if args_parsed.named_server_definitions:
|
||||
logger.warning(
|
||||
"--named-server arguments are ignored when command_or_url is an HTTP/HTTPS URL "
|
||||
"(SSE/StreamableHTTP client mode).",
|
||||
)
|
||||
# Start a client connected to the SSE server, and expose as a stdio server
|
||||
logger.debug("Starting SSE/StreamableHTTP client and stdio server")
|
||||
headers = dict(args_parsed.headers)
|
||||
if api_access_token := os.getenv("API_ACCESS_TOKEN", None):
|
||||
headers["Authorization"] = f"Bearer {api_access_token}"
|
||||
|
||||
if args_parsed.transport == "streamablehttp":
|
||||
asyncio.run(run_streamablehttp_client(args_parsed.command_or_url, headers=headers))
|
||||
else:
|
||||
asyncio.run(run_sse_client(args_parsed.command_or_url, headers=headers))
|
||||
|
||||
|
||||
def _configure_default_server(
|
||||
args_parsed: argparse.Namespace,
|
||||
base_env: dict[str, str],
|
||||
logger: logging.Logger,
|
||||
) -> StdioServerParameters | None:
|
||||
"""Configure the default server if applicable."""
|
||||
if not (
|
||||
args_parsed.command_or_url
|
||||
and not args_parsed.command_or_url.startswith(("http://", "https://"))
|
||||
):
|
||||
# Start a client connected to the SSE server, and expose as a stdio server
|
||||
logger.debug("Starting SSE client and stdio server")
|
||||
headers = dict(args.headers)
|
||||
if api_access_token := os.getenv("API_ACCESS_TOKEN", None):
|
||||
headers["Authorization"] = f"Bearer {api_access_token}"
|
||||
if args.transport == "streamablehttp":
|
||||
asyncio.run(run_streamablehttp_client(args.command_or_url, headers=headers))
|
||||
else:
|
||||
asyncio.run(run_sse_client(args.command_or_url, headers=headers))
|
||||
return None
|
||||
|
||||
default_server_env = base_env.copy()
|
||||
default_server_env.update(dict(args_parsed.env)) # Specific env vars for default server
|
||||
|
||||
default_stdio_params = StdioServerParameters(
|
||||
command=args_parsed.command_or_url,
|
||||
args=args_parsed.args,
|
||||
env=default_server_env,
|
||||
cwd=args_parsed.cwd if args_parsed.cwd else None,
|
||||
)
|
||||
logger.info(
|
||||
"Configured default server: %s %s",
|
||||
args_parsed.command_or_url,
|
||||
" ".join(args_parsed.args),
|
||||
)
|
||||
return default_stdio_params
|
||||
|
||||
|
||||
def _load_named_servers_from_config(
|
||||
config_path: str,
|
||||
base_env: dict[str, str],
|
||||
logger: logging.Logger,
|
||||
) -> dict[str, StdioServerParameters]:
|
||||
"""Load named server configurations from a file."""
|
||||
try:
|
||||
return load_named_server_configs_from_file(config_path, base_env)
|
||||
except (FileNotFoundError, json.JSONDecodeError, ValueError):
|
||||
# Specific errors are already logged by the loader function
|
||||
# We log a generic message here before exiting
|
||||
logger.exception(
|
||||
"Failed to load server configurations from %s. Exiting.",
|
||||
config_path,
|
||||
)
|
||||
sys.exit(1)
|
||||
except Exception: # Catch any other unexpected errors from loader
|
||||
logger.exception(
|
||||
"An unexpected error occurred while loading server configurations from %s. Exiting.",
|
||||
config_path,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def _configure_named_servers_from_cli(
|
||||
named_server_definitions: list[tuple[str, str]],
|
||||
base_env: dict[str, str],
|
||||
logger: logging.Logger,
|
||||
) -> dict[str, StdioServerParameters]:
|
||||
"""Configure named servers from CLI arguments."""
|
||||
named_stdio_params: dict[str, StdioServerParameters] = {}
|
||||
|
||||
for name, command_string in named_server_definitions:
|
||||
try:
|
||||
command_parts = shlex.split(command_string)
|
||||
if not command_parts: # Handle empty command_string
|
||||
logger.error("Empty COMMAND_STRING for named server '%s'. Skipping.", name)
|
||||
continue
|
||||
|
||||
command = command_parts[0]
|
||||
command_args = command_parts[1:]
|
||||
# Named servers inherit base_env (which includes passed-through env)
|
||||
# and use the proxy's CWD.
|
||||
named_stdio_params[name] = StdioServerParameters(
|
||||
command=command,
|
||||
args=command_args,
|
||||
env=base_env.copy(), # Each named server gets a copy of the base env
|
||||
cwd=None, # Named servers run in the proxy's CWD
|
||||
)
|
||||
logger.info("Configured named server '%s': %s", name, command_string)
|
||||
except IndexError: # Should be caught by the check for empty command_parts
|
||||
logger.exception(
|
||||
"Invalid COMMAND_STRING for named server '%s': '%s'. Must include a command.",
|
||||
name,
|
||||
command_string,
|
||||
)
|
||||
sys.exit(1)
|
||||
except Exception:
|
||||
logger.exception("Error parsing COMMAND_STRING for named server '%s'", name)
|
||||
sys.exit(1)
|
||||
|
||||
return named_stdio_params
|
||||
|
||||
|
||||
def _create_mcp_settings(args_parsed: argparse.Namespace) -> MCPServerSettings:
|
||||
"""Create MCP server settings from parsed arguments."""
|
||||
return MCPServerSettings(
|
||||
bind_host=args_parsed.host if args_parsed.host is not None else args_parsed.sse_host,
|
||||
port=args_parsed.port if args_parsed.port is not None else args_parsed.sse_port,
|
||||
stateless=args_parsed.stateless,
|
||||
allow_origins=args_parsed.allow_origin if len(args_parsed.allow_origin) > 0 else None,
|
||||
log_level="DEBUG" if args_parsed.debug else "INFO",
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Start the client using asyncio."""
|
||||
parser = _setup_argument_parser()
|
||||
args_parsed = parser.parse_args()
|
||||
logger = _setup_logging(debug=args_parsed.debug)
|
||||
|
||||
# Validate required arguments
|
||||
if (
|
||||
not args_parsed.command_or_url
|
||||
and not args_parsed.named_server_definitions
|
||||
and not args_parsed.named_server_config
|
||||
):
|
||||
parser.print_help()
|
||||
logger.error(
|
||||
"Either a command_or_url for a default server or at least one --named-server "
|
||||
"(or --named-server-config) must be provided for stdio mode.",
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Handle SSE client mode if URL is provided
|
||||
if args_parsed.command_or_url and args_parsed.command_or_url.startswith(
|
||||
("http://", "https://"),
|
||||
):
|
||||
_handle_sse_client_mode(args_parsed, logger)
|
||||
return
|
||||
|
||||
# Start a client connected to the given command, and expose as an SSE server
|
||||
logger.debug("Starting stdio client and SSE server")
|
||||
# Start stdio client(s) and expose as an SSE server
|
||||
logger.debug("Configuring stdio client(s) and SSE server")
|
||||
|
||||
# The environment variables passed to the server process
|
||||
env: dict[str, str] = {}
|
||||
# Pass through current environment variables if configured
|
||||
if args.pass_environment:
|
||||
env.update(os.environ)
|
||||
# Pass in and override any environment variables with those passed on the command line
|
||||
env.update(dict(args.env))
|
||||
# Base environment for all spawned processes
|
||||
base_env: dict[str, str] = {}
|
||||
if args_parsed.pass_environment:
|
||||
base_env.update(os.environ)
|
||||
|
||||
stdio_params = StdioServerParameters(
|
||||
command=args.command_or_url,
|
||||
args=args.args,
|
||||
env=env,
|
||||
cwd=args.cwd if args.cwd else None,
|
||||
# Configure default server
|
||||
default_stdio_params = _configure_default_server(args_parsed, base_env, logger)
|
||||
|
||||
# Configure named servers
|
||||
named_stdio_params: dict[str, StdioServerParameters] = {}
|
||||
if args_parsed.named_server_config:
|
||||
if args_parsed.named_server_definitions:
|
||||
logger.warning(
|
||||
"--named-server CLI arguments are ignored when --named-server-config is provided.",
|
||||
)
|
||||
named_stdio_params = _load_named_servers_from_config(
|
||||
args_parsed.named_server_config,
|
||||
base_env,
|
||||
logger,
|
||||
)
|
||||
elif args_parsed.named_server_definitions:
|
||||
named_stdio_params = _configure_named_servers_from_cli(
|
||||
args_parsed.named_server_definitions,
|
||||
base_env,
|
||||
logger,
|
||||
)
|
||||
|
||||
# Ensure at least one server is configured
|
||||
if not default_stdio_params and not named_stdio_params:
|
||||
parser.print_help()
|
||||
logger.error(
|
||||
"No stdio servers configured. Provide a default command or use --named-server.",
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Create MCP server settings and run the server
|
||||
mcp_settings = _create_mcp_settings(args_parsed)
|
||||
asyncio.run(
|
||||
run_mcp_server(
|
||||
default_server_params=default_stdio_params,
|
||||
named_server_params=named_stdio_params,
|
||||
mcp_settings=mcp_settings,
|
||||
),
|
||||
)
|
||||
|
||||
mcp_settings = MCPServerSettings(
|
||||
bind_host=args.host if args.host is not None else args.sse_host,
|
||||
port=args.port if args.port is not None else args.sse_port,
|
||||
stateless=args.stateless,
|
||||
allow_origins=args.allow_origin if len(args.allow_origin) > 0 else None,
|
||||
log_level="DEBUG" if args.debug else "INFO",
|
||||
)
|
||||
asyncio.run(run_mcp_server(stdio_params, mcp_settings))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
99
src/mcp_proxy/config_loader.py
Normal file
99
src/mcp_proxy/config_loader.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Configuration loader for MCP proxy.
|
||||
|
||||
This module provides functionality to load named server configurations from JSON files.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from mcp.client.stdio import StdioServerParameters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_named_server_configs_from_file(
|
||||
config_file_path: str,
|
||||
base_env: dict[str, str],
|
||||
) -> dict[str, StdioServerParameters]:
|
||||
"""Loads named server configurations from a JSON file.
|
||||
|
||||
Args:
|
||||
config_file_path: Path to the JSON configuration file.
|
||||
base_env: The base environment dictionary to be inherited by servers.
|
||||
|
||||
Returns:
|
||||
A dictionary of named server parameters.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the config file is not found.
|
||||
json.JSONDecodeError: If the config file contains invalid JSON.
|
||||
ValueError: If the config file format is invalid.
|
||||
"""
|
||||
named_stdio_params: dict[str, StdioServerParameters] = {}
|
||||
logger.info("Loading named server configurations from: %s", config_file_path)
|
||||
|
||||
try:
|
||||
with Path(config_file_path).open() as f:
|
||||
config_data = json.load(f)
|
||||
except FileNotFoundError:
|
||||
logger.exception("Configuration file not found: %s", config_file_path)
|
||||
raise
|
||||
except json.JSONDecodeError:
|
||||
logger.exception("Error decoding JSON from configuration file: %s", config_file_path)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Unexpected error opening or reading configuration file %s",
|
||||
config_file_path,
|
||||
)
|
||||
error_message = f"Could not read configuration file: {e}"
|
||||
raise ValueError(error_message) from e
|
||||
|
||||
if not isinstance(config_data, dict) or "mcpServers" not in config_data:
|
||||
msg = f"Invalid config file format in {config_file_path}. Missing 'mcpServers' key."
|
||||
logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
for name, server_config in config_data.get("mcpServers", {}).items():
|
||||
if not isinstance(server_config, dict):
|
||||
logger.warning(
|
||||
"Skipping invalid server config for '%s' in %s. Entry is not a dictionary.",
|
||||
name,
|
||||
config_file_path,
|
||||
)
|
||||
continue
|
||||
if not server_config.get("enabled", True): # Default to True if 'enabled' is not present
|
||||
logger.info("Named server '%s' from config is not enabled. Skipping.", name)
|
||||
continue
|
||||
|
||||
command = server_config.get("command")
|
||||
command_args = server_config.get("args", [])
|
||||
|
||||
if not command:
|
||||
logger.warning(
|
||||
"Named server '%s' from config is missing 'command'. Skipping.",
|
||||
name,
|
||||
)
|
||||
continue
|
||||
if not isinstance(command_args, list):
|
||||
logger.warning(
|
||||
"Named server '%s' from config has invalid 'args' (must be a list). Skipping.",
|
||||
name,
|
||||
)
|
||||
continue
|
||||
|
||||
named_stdio_params[name] = StdioServerParameters(
|
||||
command=command,
|
||||
args=command_args,
|
||||
env=base_env.copy(),
|
||||
cwd=None,
|
||||
)
|
||||
logger.info(
|
||||
"Configured named server '%s' from config: %s %s",
|
||||
name,
|
||||
command,
|
||||
" ".join(command_args),
|
||||
)
|
||||
|
||||
return named_stdio_params
|
||||
@@ -5,12 +5,12 @@ import logging
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
import uvicorn
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.stdio import StdioServerParameters, stdio_client
|
||||
from mcp.server import Server
|
||||
from mcp.server import Server as MCPServerSDK # Renamed to avoid conflict
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
||||
from starlette.applications import Starlette
|
||||
@@ -18,7 +18,7 @@ from starlette.middleware import Middleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
from starlette.routing import Mount, Route
|
||||
from starlette.routing import BaseRoute, Mount, Route
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
from .proxy_server import create_proxy_server
|
||||
@@ -37,123 +37,154 @@ class MCPServerSettings:
|
||||
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
|
||||
|
||||
|
||||
def create_starlette_app(
|
||||
mcp_server: Server[object],
|
||||
# To store last activity for multiple servers if needed, though status endpoint is global for now.
|
||||
_global_status: dict[str, Any] = {
|
||||
"api_last_activity": datetime.now(timezone.utc).isoformat(),
|
||||
"server_instances": {}, # Could be used to store per-instance status later
|
||||
}
|
||||
|
||||
|
||||
def _update_global_activity() -> None:
|
||||
_global_status["api_last_activity"] = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
async def _handle_status(_: Request) -> Response:
|
||||
"""Global health check and service usage monitoring endpoint."""
|
||||
return JSONResponse(_global_status)
|
||||
|
||||
|
||||
def create_single_instance_routes(
|
||||
mcp_server_instance: MCPServerSDK[object],
|
||||
*,
|
||||
stateless: bool = False,
|
||||
allow_origins: list[str] | None = None,
|
||||
debug: bool = False,
|
||||
) -> Starlette:
|
||||
"""Create a Starlette application that can serve the mcp server with SSE or Streamable http."""
|
||||
logger.debug("Creating Starlette app with stateless: %s and debug: %s", stateless, debug)
|
||||
# record the last activity of api
|
||||
status = {
|
||||
"api_last_activity": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
stateless_instance: bool,
|
||||
) -> tuple[list[BaseRoute], StreamableHTTPSessionManager]: # Return the manager itself
|
||||
"""Create Starlette routes and the HTTP session manager for a single MCP server instance."""
|
||||
logger.debug(
|
||||
"Creating routes for a single MCP server instance (stateless: %s)",
|
||||
stateless_instance,
|
||||
)
|
||||
|
||||
def _update_mcp_activity() -> None:
|
||||
status.update(
|
||||
{
|
||||
"api_last_activity": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
sse_transport = SseServerTransport("messages/")
|
||||
http_session_manager = StreamableHTTPSessionManager(
|
||||
app=mcp_server_instance,
|
||||
event_store=None,
|
||||
json_response=True,
|
||||
stateless=stateless_instance,
|
||||
)
|
||||
|
||||
sse = SseServerTransport("/messages/")
|
||||
|
||||
async def handle_sse(request: Request) -> None:
|
||||
async with sse.connect_sse(
|
||||
async def handle_sse_instance(request: Request) -> None:
|
||||
async with sse_transport.connect_sse(
|
||||
request.scope,
|
||||
request.receive,
|
||||
request._send, # noqa: SLF001
|
||||
) as (read_stream, write_stream):
|
||||
_update_mcp_activity()
|
||||
|
||||
await mcp_server.run(
|
||||
_update_global_activity()
|
||||
await mcp_server_instance.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
mcp_server.create_initialization_options(),
|
||||
mcp_server_instance.create_initialization_options(),
|
||||
)
|
||||
|
||||
# Refer: https://github.com/modelcontextprotocol/python-sdk/blob/v1.8.0/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py
|
||||
http = StreamableHTTPSessionManager(
|
||||
app=mcp_server,
|
||||
event_store=None,
|
||||
json_response=True,
|
||||
stateless=stateless,
|
||||
)
|
||||
async def handle_streamable_http_instance(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
_update_global_activity()
|
||||
await http_session_manager.handle_request(scope, receive, send)
|
||||
|
||||
async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
_update_mcp_activity()
|
||||
await http.handle_request(scope, receive, send)
|
||||
|
||||
async def handle_status(_: Request) -> Response:
|
||||
"""Health check and service usage monitoring endpoint.
|
||||
|
||||
Purpose of this handler:
|
||||
- Provides a dedicated API endpoint for external health checks.
|
||||
- Returns last API activity timestamp to monitor service usage patterns and uptime.
|
||||
- Serves as basic infrastructure for potential future service metrics expansion.
|
||||
"""
|
||||
return JSONResponse(status)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan(_: Starlette) -> AsyncIterator[None]:
|
||||
"""Context manager for session manager."""
|
||||
async with http.run():
|
||||
logger.info("Application started with StreamableHTTP session manager!")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
logger.info("Application shutting down...")
|
||||
|
||||
middleware: list[Middleware] = []
|
||||
if allow_origins is not None:
|
||||
middleware.append(
|
||||
Middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=allow_origins,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
),
|
||||
)
|
||||
|
||||
return Starlette(
|
||||
debug=debug,
|
||||
middleware=middleware,
|
||||
routes=[
|
||||
Route("/status", endpoint=handle_status),
|
||||
Mount("/mcp", app=handle_streamable_http),
|
||||
Route("/sse", endpoint=handle_sse),
|
||||
Mount("/messages/", app=sse.handle_post_message),
|
||||
],
|
||||
lifespan=lifespan,
|
||||
)
|
||||
routes = [
|
||||
Mount("/mcp", app=handle_streamable_http_instance),
|
||||
Route("/sse", endpoint=handle_sse_instance),
|
||||
Mount("/messages/", app=sse_transport.handle_post_message),
|
||||
]
|
||||
return routes, http_session_manager
|
||||
|
||||
|
||||
async def run_mcp_server(
|
||||
stdio_params: StdioServerParameters,
|
||||
mcp_settings: MCPServerSettings,
|
||||
default_server_params: StdioServerParameters | None = None,
|
||||
named_server_params: dict[str, StdioServerParameters] | None = None,
|
||||
) -> None:
|
||||
"""Run the stdio client and expose an MCP server.
|
||||
"""Run stdio client(s) and expose an MCP server with multiple possible backends."""
|
||||
if named_server_params is None:
|
||||
named_server_params = {}
|
||||
|
||||
Args:
|
||||
stdio_params: The parameters for the stdio client that spawns a stdio server.
|
||||
mcp_settings: The settings for the MCP server that accepts incoming requests.
|
||||
all_routes: list[BaseRoute] = [
|
||||
Route("/status", endpoint=_handle_status), # Global status endpoint
|
||||
]
|
||||
# Use AsyncExitStack to manage lifecycles of multiple components
|
||||
async with contextlib.AsyncExitStack() as stack:
|
||||
# Manage lifespans of all StreamableHTTPSessionManagers
|
||||
@contextlib.asynccontextmanager
|
||||
async def combined_lifespan(_app: Starlette) -> AsyncIterator[None]:
|
||||
logger.info("Main application lifespan starting...")
|
||||
# All http_session_managers' .run() are already entered into the stack
|
||||
yield
|
||||
logger.info("Main application lifespan shutting down...")
|
||||
|
||||
"""
|
||||
async with stdio_client(stdio_params) as streams, ClientSession(*streams) as session:
|
||||
logger.debug("Starting proxy server...")
|
||||
mcp_server = await create_proxy_server(session)
|
||||
# Setup default server if configured
|
||||
if default_server_params:
|
||||
logger.info(
|
||||
"Setting up default server: %s %s",
|
||||
default_server_params.command,
|
||||
" ".join(default_server_params.args),
|
||||
)
|
||||
stdio_streams = await stack.enter_async_context(stdio_client(default_server_params))
|
||||
session = await stack.enter_async_context(ClientSession(*stdio_streams))
|
||||
proxy = await create_proxy_server(session)
|
||||
|
||||
# Bind request handling to MCP server
|
||||
starlette_app = create_starlette_app(
|
||||
mcp_server,
|
||||
stateless=mcp_settings.stateless,
|
||||
allow_origins=mcp_settings.allow_origins,
|
||||
instance_routes, http_manager = create_single_instance_routes(
|
||||
proxy,
|
||||
stateless_instance=mcp_settings.stateless,
|
||||
)
|
||||
await stack.enter_async_context(http_manager.run()) # Manage lifespan by calling run()
|
||||
all_routes.extend(instance_routes)
|
||||
_global_status["server_instances"]["default"] = "configured"
|
||||
|
||||
# Setup named servers
|
||||
for name, params in named_server_params.items():
|
||||
logger.info(
|
||||
"Setting up named server '%s': %s %s",
|
||||
name,
|
||||
params.command,
|
||||
" ".join(params.args),
|
||||
)
|
||||
stdio_streams_named = await stack.enter_async_context(stdio_client(params))
|
||||
session_named = await stack.enter_async_context(ClientSession(*stdio_streams_named))
|
||||
proxy_named = await create_proxy_server(session_named)
|
||||
|
||||
instance_routes_named, http_manager_named = create_single_instance_routes(
|
||||
proxy_named,
|
||||
stateless_instance=mcp_settings.stateless,
|
||||
)
|
||||
await stack.enter_async_context(
|
||||
http_manager_named.run(),
|
||||
) # Manage lifespan by calling run()
|
||||
|
||||
# Mount these routes under /servers/<name>/
|
||||
server_mount = Mount(f"/servers/{name}", routes=instance_routes_named)
|
||||
all_routes.append(server_mount)
|
||||
_global_status["server_instances"][name] = "configured"
|
||||
|
||||
if not default_server_params and not named_server_params:
|
||||
logger.error("No servers configured to run.")
|
||||
return
|
||||
|
||||
middleware: list[Middleware] = []
|
||||
if mcp_settings.allow_origins:
|
||||
middleware.append(
|
||||
Middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=mcp_settings.allow_origins,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
),
|
||||
)
|
||||
|
||||
starlette_app = Starlette(
|
||||
debug=(mcp_settings.log_level == "DEBUG"),
|
||||
routes=all_routes,
|
||||
middleware=middleware,
|
||||
lifespan=combined_lifespan,
|
||||
)
|
||||
|
||||
# Configure HTTP server
|
||||
config = uvicorn.Config(
|
||||
starlette_app,
|
||||
host=mcp_settings.bind_host,
|
||||
@@ -161,8 +192,27 @@ async def run_mcp_server(
|
||||
log_level=mcp_settings.log_level.lower(),
|
||||
)
|
||||
http_server = uvicorn.Server(config)
|
||||
|
||||
# Print out the SSE URLs for all configured servers
|
||||
base_url = f"http://{mcp_settings.bind_host}:{mcp_settings.port}"
|
||||
sse_urls = []
|
||||
|
||||
# Add default server if configured
|
||||
if default_server_params:
|
||||
sse_urls.append(f"{base_url}/sse")
|
||||
|
||||
# Add named servers
|
||||
sse_urls.extend([f"{base_url}/servers/{name}/sse" for name in named_server_params])
|
||||
|
||||
# Display the SSE URLs prominently
|
||||
if sse_urls:
|
||||
# Using print directly for user visibility, with noqa to ignore linter warnings
|
||||
logger.info("Serving MCP Servers via SSE:")
|
||||
for url in sse_urls:
|
||||
logger.info(" - %s", url)
|
||||
|
||||
logger.debug(
|
||||
"Serving incoming requests on %s:%s",
|
||||
"Serving incoming MCP requests on %s:%s",
|
||||
mcp_settings.bind_host,
|
||||
mcp_settings.port,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Create an MCP server that proxies requests throgh an MCP client.
|
||||
"""Create an MCP server that proxies requests through an MCP client.
|
||||
|
||||
This server is created independent of any transport mechanism.
|
||||
"""
|
||||
|
||||
244
tests/test_config_loader.py
Normal file
244
tests/test_config_loader.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""Tests for the configuration loader module."""
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import tempfile
|
||||
from collections.abc import Callable, Generator
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from mcp.client.stdio import StdioServerParameters
|
||||
|
||||
from mcp_proxy.config_loader import load_named_server_configs_from_file
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_temp_config_file() -> Generator[Callable[[dict], str], None, None]:
|
||||
"""Creates a temporary JSON config file and returns its path."""
|
||||
temp_files: list[str] = []
|
||||
|
||||
def _create_temp_config_file(config_content: dict) -> str:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
delete=False,
|
||||
suffix=".json",
|
||||
) as tmp_config:
|
||||
json.dump(config_content, tmp_config)
|
||||
temp_files.append(tmp_config.name)
|
||||
return tmp_config.name
|
||||
|
||||
yield _create_temp_config_file
|
||||
|
||||
for f_path in temp_files:
|
||||
path = Path(f_path)
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
|
||||
|
||||
def test_load_valid_config(create_temp_config_file: Callable[[dict], str]) -> None:
|
||||
"""Test loading a valid configuration file."""
|
||||
config_content = {
|
||||
"mcpServers": {
|
||||
"server1": {
|
||||
"command": "echo",
|
||||
"args": ["hello"],
|
||||
"enabled": True,
|
||||
},
|
||||
"server2": {
|
||||
"command": "cat",
|
||||
"args": ["file.txt"],
|
||||
},
|
||||
},
|
||||
}
|
||||
tmp_config_path = create_temp_config_file(config_content)
|
||||
base_env = {"PASSED": "env_value"}
|
||||
|
||||
loaded_params = load_named_server_configs_from_file(tmp_config_path, base_env)
|
||||
|
||||
assert "server1" in loaded_params
|
||||
assert loaded_params["server1"].command == "echo"
|
||||
assert loaded_params["server1"].args == ["hello"]
|
||||
assert (
|
||||
loaded_params["server1"].env == base_env
|
||||
) # Env is a copy, check if it contains base_env items
|
||||
|
||||
assert "server2" in loaded_params
|
||||
assert loaded_params["server2"].command == "cat"
|
||||
assert loaded_params["server2"].args == ["file.txt"]
|
||||
assert loaded_params["server2"].env == base_env
|
||||
|
||||
|
||||
def test_load_config_with_not_enabled_server(
|
||||
create_temp_config_file: Callable[[dict], str],
|
||||
) -> None:
|
||||
"""Test loading a configuration with disabled servers."""
|
||||
config_content = {
|
||||
"mcpServers": {
|
||||
"explicitly_enabled_server": {"command": "true_command", "enabled": True},
|
||||
# No 'enabled' flag, defaults to True
|
||||
"implicitly_enabled_server": {"command": "another_true_command"},
|
||||
"not_enabled_server": {"command": "false_command", "enabled": False},
|
||||
},
|
||||
}
|
||||
tmp_config_path = create_temp_config_file(config_content)
|
||||
loaded_params = load_named_server_configs_from_file(tmp_config_path, {})
|
||||
|
||||
assert "explicitly_enabled_server" in loaded_params
|
||||
assert loaded_params["explicitly_enabled_server"].command == "true_command"
|
||||
assert "implicitly_enabled_server" in loaded_params
|
||||
assert loaded_params["implicitly_enabled_server"].command == "another_true_command"
|
||||
assert "not_enabled_server" not in loaded_params
|
||||
|
||||
|
||||
def test_file_not_found() -> None:
|
||||
"""Test handling of non-existent configuration files."""
|
||||
with pytest.raises(FileNotFoundError):
|
||||
load_named_server_configs_from_file("non_existent_file.json", {})
|
||||
|
||||
|
||||
def test_json_decode_error() -> None:
|
||||
"""Test handling of invalid JSON in configuration files."""
|
||||
# Create a file with invalid JSON content
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
delete=False,
|
||||
suffix=".json",
|
||||
) as tmp_config:
|
||||
tmp_config.write("this is not json {")
|
||||
tmp_config_path = tmp_config.name
|
||||
|
||||
# Use try/finally to ensure cleanup
|
||||
try:
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
load_named_server_configs_from_file(tmp_config_path, {})
|
||||
finally:
|
||||
path = Path(tmp_config_path)
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
|
||||
|
||||
def test_load_example_fetch_config_if_uvx_exists() -> None:
|
||||
"""Test loading the example fetch configuration if uvx is available."""
|
||||
if not shutil.which("uvx"):
|
||||
pytest.skip("uvx command not found in PATH, skipping test for example config.")
|
||||
|
||||
# Assuming the test is run from the root of the repository
|
||||
example_config_path = Path(__file__).parent.parent / "config_example.json"
|
||||
|
||||
if not example_config_path.exists():
|
||||
pytest.fail(
|
||||
f"Example config file not found at expected path: {example_config_path}",
|
||||
)
|
||||
|
||||
base_env = {"EXAMPLE_ENV": "true"}
|
||||
loaded_params = load_named_server_configs_from_file(example_config_path, base_env)
|
||||
|
||||
assert "fetch" in loaded_params
|
||||
fetch_param = loaded_params["fetch"]
|
||||
assert isinstance(fetch_param, StdioServerParameters)
|
||||
assert fetch_param.command == "uvx"
|
||||
assert fetch_param.args == ["mcp-server-fetch"]
|
||||
assert fetch_param.env == base_env
|
||||
# The 'timeout' and 'transportType' fields from the config are currently ignored by the loader,
|
||||
# so no need to assert them on StdioServerParameters.
|
||||
|
||||
|
||||
def test_invalid_config_format_missing_mcpservers(
|
||||
create_temp_config_file: Callable[[dict], str],
|
||||
) -> None:
|
||||
"""Test handling of configuration files missing the mcpServers key."""
|
||||
config_content = {"some_other_key": "value"}
|
||||
tmp_config_path = create_temp_config_file(config_content)
|
||||
|
||||
with pytest.raises(ValueError, match="Missing 'mcpServers' key"):
|
||||
load_named_server_configs_from_file(tmp_config_path, {})
|
||||
|
||||
|
||||
@patch("mcp_proxy.config_loader.logger")
|
||||
def test_invalid_server_entry_not_dict(
|
||||
mock_logger: object,
|
||||
create_temp_config_file: Callable[[dict], str],
|
||||
) -> None:
|
||||
"""Test handling of server entries that are not dictionaries."""
|
||||
config_content = {"mcpServers": {"server1": "not_a_dict"}}
|
||||
tmp_config_path = create_temp_config_file(config_content)
|
||||
|
||||
loaded_params = load_named_server_configs_from_file(tmp_config_path, {})
|
||||
assert len(loaded_params) == 0 # No servers should be loaded
|
||||
mock_logger.warning.assert_called_with(
|
||||
"Skipping invalid server config for '%s' in %s. Entry is not a dictionary.",
|
||||
"server1",
|
||||
tmp_config_path,
|
||||
)
|
||||
|
||||
|
||||
@patch("mcp_proxy.config_loader.logger")
|
||||
def test_server_entry_missing_command(
|
||||
mock_logger: object,
|
||||
create_temp_config_file: Callable[[dict], str],
|
||||
) -> None:
|
||||
"""Test handling of server entries missing the command field."""
|
||||
config_content = {"mcpServers": {"server_no_command": {"args": ["arg1"]}}}
|
||||
tmp_config_path = create_temp_config_file(config_content)
|
||||
loaded_params = load_named_server_configs_from_file(tmp_config_path, {})
|
||||
assert "server_no_command" not in loaded_params
|
||||
mock_logger.warning.assert_called_with(
|
||||
"Named server '%s' from config is missing 'command'. Skipping.",
|
||||
"server_no_command",
|
||||
)
|
||||
|
||||
|
||||
@patch("mcp_proxy.config_loader.logger")
|
||||
def test_server_entry_invalid_args_type(
|
||||
mock_logger: object,
|
||||
create_temp_config_file: Callable[[dict], str],
|
||||
) -> None:
|
||||
"""Test handling of server entries with invalid args type."""
|
||||
config_content = {
|
||||
"mcpServers": {
|
||||
"server_invalid_args": {"command": "mycmd", "args": "not_a_list"},
|
||||
},
|
||||
}
|
||||
tmp_config_path = create_temp_config_file(config_content)
|
||||
loaded_params = load_named_server_configs_from_file(tmp_config_path, {})
|
||||
assert "server_invalid_args" not in loaded_params
|
||||
mock_logger.warning.assert_called_with(
|
||||
"Named server '%s' from config has invalid 'args' (must be a list). Skipping.",
|
||||
"server_invalid_args",
|
||||
)
|
||||
|
||||
|
||||
def test_empty_mcpservers_dict(create_temp_config_file: Callable[[dict], str]) -> None:
|
||||
"""Test handling of configuration files with empty mcpServers dictionary."""
|
||||
config_content = {"mcpServers": {}}
|
||||
tmp_config_path = create_temp_config_file(config_content)
|
||||
loaded_params = load_named_server_configs_from_file(tmp_config_path, {})
|
||||
assert len(loaded_params) == 0
|
||||
|
||||
|
||||
def test_config_file_is_empty_json_object(create_temp_config_file: Callable[[dict], str]) -> None:
|
||||
"""Test handling of configuration files with empty JSON objects."""
|
||||
config_content = {} # Empty JSON object
|
||||
tmp_config_path = create_temp_config_file(config_content)
|
||||
with pytest.raises(ValueError, match="Missing 'mcpServers' key"):
|
||||
load_named_server_configs_from_file(tmp_config_path, {})
|
||||
|
||||
|
||||
def test_config_file_is_empty_string() -> None:
|
||||
"""Test handling of configuration files with empty content."""
|
||||
# Create a file with an empty string
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
delete=False,
|
||||
suffix=".json",
|
||||
) as tmp_config:
|
||||
tmp_config.write("") # Empty content
|
||||
tmp_config_path = tmp_config.name
|
||||
try:
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
load_named_server_configs_from_file(tmp_config_path, {})
|
||||
finally:
|
||||
path = Path(tmp_config_path)
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
@@ -1,18 +1,68 @@
|
||||
"""Tests for the sse server."""
|
||||
# ruff: noqa: PLR2004
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import typing as t
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import uvicorn
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import StdioServerParameters
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from mcp.server import FastMCP
|
||||
from mcp.server import FastMCP, Server
|
||||
from mcp.types import TextContent
|
||||
from starlette.applications import Starlette
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
from mcp_proxy.mcp_server import create_starlette_app
|
||||
from mcp_proxy.mcp_server import MCPServerSettings, create_single_instance_routes, run_mcp_server
|
||||
|
||||
|
||||
def create_starlette_app(
|
||||
mcp_server: Server[t.Any],
|
||||
allow_origins: list[str] | None = None,
|
||||
*,
|
||||
debug: bool = False,
|
||||
stateless: bool = False,
|
||||
) -> Starlette:
|
||||
"""Create a Starlette application for the MCP server.
|
||||
|
||||
Args:
|
||||
mcp_server: The MCP server instance to wrap
|
||||
allow_origins: List of allowed CORS origins
|
||||
debug: Enable debug mode
|
||||
stateless: Whether to use stateless HTTP sessions
|
||||
|
||||
Returns:
|
||||
Starlette application instance
|
||||
"""
|
||||
routes, http_manager = create_single_instance_routes(mcp_server, stateless_instance=stateless)
|
||||
|
||||
middleware: list[Middleware] = []
|
||||
if allow_origins:
|
||||
middleware.append(
|
||||
Middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=allow_origins,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
),
|
||||
)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan(_app: Starlette) -> t.AsyncIterator[None]:
|
||||
async with http_manager.run():
|
||||
yield
|
||||
|
||||
return Starlette(
|
||||
debug=debug,
|
||||
routes=routes,
|
||||
middleware=middleware,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
|
||||
class BackgroundServer(uvicorn.Server):
|
||||
@@ -64,7 +114,6 @@ def make_background_server(**kwargs) -> BackgroundServer: # noqa: ANN003
|
||||
return BackgroundServer(config)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_transport() -> None:
|
||||
"""Test basic glue code for the SSE transport and a fake MCP server."""
|
||||
server = make_background_server(debug=True)
|
||||
@@ -77,7 +126,6 @@ async def test_sse_transport() -> None:
|
||||
assert response.prompts[0].name == "prompt1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_transport() -> None:
|
||||
"""Test HTTP transport layer functionality."""
|
||||
server = make_background_server(debug=True)
|
||||
@@ -118,3 +166,509 @@ async def test_stateless_http_transport() -> None:
|
||||
assert len(tool_result.content) == 1
|
||||
assert isinstance(tool_result.content[0], TextContent)
|
||||
assert tool_result.content[0].text == f"Echo: test_{i}"
|
||||
|
||||
|
||||
# Unit tests for run_mcp_server method
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings() -> MCPServerSettings:
|
||||
"""Create mock MCP server settings for testing."""
|
||||
return MCPServerSettings(
|
||||
bind_host="127.0.0.1",
|
||||
port=8080,
|
||||
stateless=False,
|
||||
allow_origins=["*"],
|
||||
log_level="INFO",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stdio_params() -> StdioServerParameters:
|
||||
"""Create mock stdio server parameters for testing."""
|
||||
return StdioServerParameters(
|
||||
command="echo",
|
||||
args=["hello"],
|
||||
env={"TEST_VAR": "test_value"},
|
||||
cwd="/tmp", # noqa: S108
|
||||
)
|
||||
|
||||
|
||||
class AsyncContextManagerMock: # noqa: D101
|
||||
def __init__(self, mock) -> None: # noqa: ANN001, D107
|
||||
self.mock = mock
|
||||
|
||||
async def __aenter__(self): # noqa: ANN204, D105
|
||||
return self.mock
|
||||
|
||||
async def __aexit__(self, *args): # noqa: ANN002, ANN204, D105
|
||||
pass
|
||||
|
||||
|
||||
def setup_async_context_mocks() -> tuple[
|
||||
AsyncContextManagerMock,
|
||||
AsyncContextManagerMock,
|
||||
AsyncMock,
|
||||
MagicMock,
|
||||
list[MagicMock],
|
||||
]:
|
||||
"""Helper function to set up async context manager mocks."""
|
||||
# Setup stdio client mock
|
||||
mock_streams = (AsyncMock(), AsyncMock())
|
||||
|
||||
# Setup client session mock
|
||||
mock_session = AsyncMock()
|
||||
|
||||
# Setup HTTP manager mock
|
||||
mock_http_manager = MagicMock()
|
||||
mock_http_manager.run.return_value = AsyncContextManagerMock(None)
|
||||
mock_routes = [MagicMock()]
|
||||
|
||||
return (
|
||||
AsyncContextManagerMock(mock_streams),
|
||||
AsyncContextManagerMock(mock_session),
|
||||
mock_session,
|
||||
mock_http_manager,
|
||||
mock_routes,
|
||||
)
|
||||
|
||||
|
||||
async def test_run_mcp_server_no_servers_configured(mock_settings: MCPServerSettings) -> None:
|
||||
"""Test run_mcp_server when no servers are configured."""
|
||||
with patch("mcp_proxy.mcp_server.logger") as mock_logger:
|
||||
await run_mcp_server(mock_settings, None, {})
|
||||
mock_logger.error.assert_called_once_with("No servers configured to run.")
|
||||
|
||||
|
||||
async def test_run_mcp_server_with_default_server(
|
||||
mock_settings: MCPServerSettings,
|
||||
mock_stdio_params: StdioServerParameters,
|
||||
) -> None:
|
||||
"""Test run_mcp_server with a default server configuration."""
|
||||
with (
|
||||
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
|
||||
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
|
||||
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
|
||||
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
|
||||
patch("uvicorn.Server") as mock_uvicorn_server,
|
||||
patch("mcp_proxy.mcp_server.logger") as mock_logger,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
|
||||
setup_async_context_mocks()
|
||||
)
|
||||
mock_stdio_client.return_value = mock_stdio_context
|
||||
mock_client_session.return_value = mock_session_context
|
||||
|
||||
mock_proxy = AsyncMock()
|
||||
mock_create_proxy.return_value = mock_proxy
|
||||
mock_create_routes.return_value = (mock_routes, mock_http_manager)
|
||||
|
||||
mock_server_instance = AsyncMock()
|
||||
mock_uvicorn_server.return_value = mock_server_instance
|
||||
|
||||
# Run the function
|
||||
await run_mcp_server(mock_settings, mock_stdio_params, {})
|
||||
|
||||
# Verify calls
|
||||
mock_stdio_client.assert_called_once_with(mock_stdio_params)
|
||||
mock_create_proxy.assert_called_once_with(mock_session)
|
||||
mock_create_routes.assert_called_once_with(
|
||||
mock_proxy,
|
||||
stateless_instance=mock_settings.stateless,
|
||||
)
|
||||
mock_logger.info.assert_any_call(
|
||||
"Setting up default server: %s %s",
|
||||
mock_stdio_params.command,
|
||||
" ".join(mock_stdio_params.args),
|
||||
)
|
||||
mock_server_instance.serve.assert_called_once()
|
||||
|
||||
|
||||
async def test_run_mcp_server_with_named_servers(
|
||||
mock_settings: MCPServerSettings,
|
||||
mock_stdio_params: StdioServerParameters,
|
||||
) -> None:
|
||||
"""Test run_mcp_server with named servers configuration."""
|
||||
named_servers = {
|
||||
"server1": mock_stdio_params,
|
||||
"server2": StdioServerParameters(
|
||||
command="python",
|
||||
args=["-m", "mcp_server"],
|
||||
env={"PYTHON_PATH": "/usr/bin/python"},
|
||||
cwd="/home/user",
|
||||
),
|
||||
}
|
||||
|
||||
with (
|
||||
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
|
||||
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
|
||||
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
|
||||
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
|
||||
patch("uvicorn.Server") as mock_uvicorn_server,
|
||||
patch("mcp_proxy.mcp_server.logger") as mock_logger,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
|
||||
setup_async_context_mocks()
|
||||
)
|
||||
mock_stdio_client.return_value = mock_stdio_context
|
||||
mock_client_session.return_value = mock_session_context
|
||||
|
||||
mock_proxy = AsyncMock()
|
||||
mock_create_proxy.return_value = mock_proxy
|
||||
mock_create_routes.return_value = (mock_routes, mock_http_manager)
|
||||
|
||||
mock_server_instance = AsyncMock()
|
||||
mock_uvicorn_server.return_value = mock_server_instance
|
||||
|
||||
# Run the function
|
||||
await run_mcp_server(mock_settings, None, named_servers)
|
||||
|
||||
# Verify calls
|
||||
assert mock_stdio_client.call_count == 2
|
||||
assert mock_create_proxy.call_count == 2
|
||||
assert mock_create_routes.call_count == 2
|
||||
|
||||
# Check that named servers were logged
|
||||
mock_logger.info.assert_any_call(
|
||||
"Setting up named server '%s': %s %s",
|
||||
"server1",
|
||||
mock_stdio_params.command,
|
||||
" ".join(mock_stdio_params.args),
|
||||
)
|
||||
mock_logger.info.assert_any_call(
|
||||
"Setting up named server '%s': %s %s",
|
||||
"server2",
|
||||
"python",
|
||||
"-m mcp_server",
|
||||
)
|
||||
|
||||
mock_server_instance.serve.assert_called_once()
|
||||
|
||||
|
||||
async def test_run_mcp_server_with_cors_middleware(
|
||||
mock_stdio_params: StdioServerParameters,
|
||||
) -> None:
|
||||
"""Test run_mcp_server adds CORS middleware when allow_origins is set."""
|
||||
settings_with_cors = MCPServerSettings(
|
||||
bind_host="0.0.0.0", # noqa: S104
|
||||
port=9090,
|
||||
allow_origins=["http://localhost:3000", "https://example.com"],
|
||||
)
|
||||
|
||||
with (
|
||||
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
|
||||
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
|
||||
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
|
||||
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
|
||||
patch("mcp_proxy.mcp_server.Starlette") as mock_starlette,
|
||||
patch("uvicorn.Server") as mock_uvicorn_server,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
|
||||
setup_async_context_mocks()
|
||||
)
|
||||
mock_stdio_client.return_value = mock_stdio_context
|
||||
mock_client_session.return_value = mock_session_context
|
||||
|
||||
mock_proxy = AsyncMock()
|
||||
mock_create_proxy.return_value = mock_proxy
|
||||
mock_create_routes.return_value = (mock_routes, mock_http_manager)
|
||||
|
||||
mock_server_instance = AsyncMock()
|
||||
mock_uvicorn_server.return_value = mock_server_instance
|
||||
|
||||
# Run the function
|
||||
await run_mcp_server(settings_with_cors, mock_stdio_params, {})
|
||||
|
||||
# Verify Starlette was called with middleware
|
||||
mock_starlette.assert_called_once()
|
||||
call_args = mock_starlette.call_args
|
||||
middleware = call_args.kwargs["middleware"]
|
||||
|
||||
assert len(middleware) == 1
|
||||
assert middleware[0].cls == CORSMiddleware
|
||||
|
||||
|
||||
async def test_run_mcp_server_debug_mode(
|
||||
mock_stdio_params: StdioServerParameters,
|
||||
) -> None:
|
||||
"""Test run_mcp_server with debug mode enabled."""
|
||||
debug_settings = MCPServerSettings(
|
||||
bind_host="127.0.0.1",
|
||||
port=8080,
|
||||
log_level="DEBUG",
|
||||
)
|
||||
|
||||
with (
|
||||
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
|
||||
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
|
||||
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
|
||||
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
|
||||
patch("mcp_proxy.mcp_server.Starlette") as mock_starlette,
|
||||
patch("uvicorn.Server") as mock_uvicorn_server,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
|
||||
setup_async_context_mocks()
|
||||
)
|
||||
mock_stdio_client.return_value = mock_stdio_context
|
||||
mock_client_session.return_value = mock_session_context
|
||||
|
||||
mock_proxy = AsyncMock()
|
||||
mock_create_proxy.return_value = mock_proxy
|
||||
mock_create_routes.return_value = (mock_routes, mock_http_manager)
|
||||
|
||||
mock_server_instance = AsyncMock()
|
||||
mock_uvicorn_server.return_value = mock_server_instance
|
||||
|
||||
# Run the function
|
||||
await run_mcp_server(debug_settings, mock_stdio_params, {})
|
||||
|
||||
# Verify Starlette was called with debug=True
|
||||
mock_starlette.assert_called_once()
|
||||
call_args = mock_starlette.call_args
|
||||
assert call_args.kwargs["debug"] is True
|
||||
|
||||
|
||||
async def test_run_mcp_server_stateless_mode(
|
||||
mock_stdio_params: StdioServerParameters,
|
||||
) -> None:
|
||||
"""Test run_mcp_server with stateless mode enabled."""
|
||||
stateless_settings = MCPServerSettings(
|
||||
bind_host="127.0.0.1",
|
||||
port=8080,
|
||||
stateless=True,
|
||||
)
|
||||
|
||||
with (
|
||||
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
|
||||
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
|
||||
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
|
||||
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
|
||||
patch("uvicorn.Server") as mock_uvicorn_server,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
|
||||
setup_async_context_mocks()
|
||||
)
|
||||
mock_stdio_client.return_value = mock_stdio_context
|
||||
mock_client_session.return_value = mock_session_context
|
||||
|
||||
mock_proxy = AsyncMock()
|
||||
mock_create_proxy.return_value = mock_proxy
|
||||
mock_create_routes.return_value = (mock_routes, mock_http_manager)
|
||||
|
||||
mock_server_instance = AsyncMock()
|
||||
mock_uvicorn_server.return_value = mock_server_instance
|
||||
|
||||
# Run the function
|
||||
await run_mcp_server(stateless_settings, mock_stdio_params, {})
|
||||
|
||||
# Verify create_single_instance_routes was called with stateless_instance=True
|
||||
mock_create_routes.assert_called_once_with(
|
||||
mock_proxy,
|
||||
stateless_instance=True,
|
||||
)
|
||||
|
||||
|
||||
async def test_run_mcp_server_uvicorn_config(
|
||||
mock_settings: MCPServerSettings,
|
||||
mock_stdio_params: StdioServerParameters,
|
||||
) -> None:
|
||||
"""Test run_mcp_server creates correct uvicorn configuration."""
|
||||
with (
|
||||
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
|
||||
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
|
||||
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
|
||||
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
|
||||
patch("uvicorn.Config") as mock_uvicorn_config,
|
||||
patch("uvicorn.Server") as mock_uvicorn_server,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
|
||||
setup_async_context_mocks()
|
||||
)
|
||||
mock_stdio_client.return_value = mock_stdio_context
|
||||
mock_client_session.return_value = mock_session_context
|
||||
|
||||
mock_proxy = AsyncMock()
|
||||
mock_create_proxy.return_value = mock_proxy
|
||||
mock_create_routes.return_value = (mock_routes, mock_http_manager)
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_uvicorn_config.return_value = mock_config
|
||||
|
||||
mock_server_instance = AsyncMock()
|
||||
mock_uvicorn_server.return_value = mock_server_instance
|
||||
|
||||
# Run the function
|
||||
await run_mcp_server(mock_settings, mock_stdio_params, {})
|
||||
|
||||
# Verify uvicorn.Config was called with correct parameters
|
||||
mock_uvicorn_config.assert_called_once()
|
||||
call_args = mock_uvicorn_config.call_args
|
||||
|
||||
assert call_args.kwargs["host"] == mock_settings.bind_host
|
||||
assert call_args.kwargs["port"] == mock_settings.port
|
||||
assert call_args.kwargs["log_level"] == mock_settings.log_level.lower()
|
||||
|
||||
|
||||
async def test_run_mcp_server_global_status_updates(
|
||||
mock_settings: MCPServerSettings,
|
||||
mock_stdio_params: StdioServerParameters,
|
||||
) -> None:
|
||||
"""Test run_mcp_server updates global status correctly."""
|
||||
from mcp_proxy.mcp_server import _global_status
|
||||
|
||||
# Clear global status before test
|
||||
_global_status["server_instances"].clear()
|
||||
|
||||
named_servers = {"test_server": mock_stdio_params}
|
||||
|
||||
with (
|
||||
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
|
||||
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
|
||||
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
|
||||
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
|
||||
patch("uvicorn.Server") as mock_uvicorn_server,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
|
||||
setup_async_context_mocks()
|
||||
)
|
||||
mock_stdio_client.return_value = mock_stdio_context
|
||||
mock_client_session.return_value = mock_session_context
|
||||
|
||||
mock_proxy = AsyncMock()
|
||||
mock_create_proxy.return_value = mock_proxy
|
||||
mock_create_routes.return_value = (mock_routes, mock_http_manager)
|
||||
|
||||
mock_server_instance = AsyncMock()
|
||||
mock_uvicorn_server.return_value = mock_server_instance
|
||||
|
||||
# Run the function
|
||||
await run_mcp_server(mock_settings, mock_stdio_params, named_servers)
|
||||
|
||||
# Verify global status was updated
|
||||
assert "default" in _global_status["server_instances"]
|
||||
assert "test_server" in _global_status["server_instances"]
|
||||
assert _global_status["server_instances"]["default"] == "configured"
|
||||
assert _global_status["server_instances"]["test_server"] == "configured"
|
||||
|
||||
|
||||
async def test_run_mcp_server_sse_url_logging(
|
||||
mock_settings: MCPServerSettings,
|
||||
mock_stdio_params: StdioServerParameters,
|
||||
) -> None:
|
||||
"""Test run_mcp_server logs correct SSE URLs."""
|
||||
named_servers = {"test_server": mock_stdio_params}
|
||||
|
||||
with (
|
||||
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
|
||||
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
|
||||
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
|
||||
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
|
||||
patch("uvicorn.Server") as mock_uvicorn_server,
|
||||
patch("mcp_proxy.mcp_server.logger") as mock_logger,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
|
||||
setup_async_context_mocks()
|
||||
)
|
||||
mock_stdio_client.return_value = mock_stdio_context
|
||||
mock_client_session.return_value = mock_session_context
|
||||
|
||||
mock_proxy = AsyncMock()
|
||||
mock_create_proxy.return_value = mock_proxy
|
||||
mock_create_routes.return_value = (mock_routes, mock_http_manager)
|
||||
|
||||
mock_server_instance = AsyncMock()
|
||||
mock_uvicorn_server.return_value = mock_server_instance
|
||||
|
||||
# Run the function
|
||||
await run_mcp_server(mock_settings, mock_stdio_params, named_servers)
|
||||
|
||||
# Verify SSE URLs were logged
|
||||
expected_default_url = f"http://{mock_settings.bind_host}:{mock_settings.port}/sse"
|
||||
expected_named_url = (
|
||||
f"http://{mock_settings.bind_host}:{mock_settings.port}/servers/test_server/sse"
|
||||
)
|
||||
|
||||
mock_logger.info.assert_any_call("Serving MCP Servers via SSE:")
|
||||
mock_logger.info.assert_any_call(" - %s", expected_default_url)
|
||||
mock_logger.info.assert_any_call(" - %s", expected_named_url)
|
||||
|
||||
|
||||
async def test_run_mcp_server_exception_handling(
|
||||
mock_settings: MCPServerSettings,
|
||||
mock_stdio_params: StdioServerParameters,
|
||||
) -> None:
|
||||
"""Test run_mcp_server handles exceptions properly."""
|
||||
with (
|
||||
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
|
||||
patch("mcp_proxy.mcp_server.ClientSession"),
|
||||
):
|
||||
# Setup mocks to raise an exception
|
||||
mock_stdio_client.side_effect = Exception("Connection failed")
|
||||
|
||||
# Should not raise, function should handle exceptions gracefully
|
||||
try:
|
||||
await run_mcp_server(mock_settings, mock_stdio_params, {})
|
||||
except Exception as e: # noqa: BLE001
|
||||
# If an exception is raised, it should be the expected one
|
||||
assert "Connection failed" in str(e) # noqa: PT017
|
||||
|
||||
|
||||
async def test_run_mcp_server_both_default_and_named_servers(
|
||||
mock_settings: MCPServerSettings,
|
||||
mock_stdio_params: StdioServerParameters,
|
||||
) -> None:
|
||||
"""Test run_mcp_server with both default and named servers."""
|
||||
named_servers = {"named_server": mock_stdio_params}
|
||||
|
||||
with (
|
||||
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
|
||||
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
|
||||
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
|
||||
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
|
||||
patch("uvicorn.Server") as mock_uvicorn_server,
|
||||
patch("mcp_proxy.mcp_server.logger") as mock_logger,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
|
||||
setup_async_context_mocks()
|
||||
)
|
||||
mock_stdio_client.return_value = mock_stdio_context
|
||||
mock_client_session.return_value = mock_session_context
|
||||
|
||||
mock_proxy = AsyncMock()
|
||||
mock_create_proxy.return_value = mock_proxy
|
||||
mock_create_routes.return_value = (mock_routes, mock_http_manager)
|
||||
|
||||
mock_server_instance = AsyncMock()
|
||||
mock_uvicorn_server.return_value = mock_server_instance
|
||||
|
||||
# Run the function with both default and named servers
|
||||
await run_mcp_server(mock_settings, mock_stdio_params, named_servers)
|
||||
|
||||
# Verify both servers were set up
|
||||
assert mock_stdio_client.call_count == 2 # One for default, one for named
|
||||
assert mock_create_proxy.call_count == 2
|
||||
assert mock_create_routes.call_count == 2
|
||||
|
||||
# Verify logging for both servers
|
||||
mock_logger.info.assert_any_call(
|
||||
"Setting up default server: %s %s",
|
||||
mock_stdio_params.command,
|
||||
" ".join(mock_stdio_params.args),
|
||||
)
|
||||
mock_logger.info.assert_any_call(
|
||||
"Setting up named server '%s': %s %s",
|
||||
"named_server",
|
||||
mock_stdio_params.command,
|
||||
" ".join(mock_stdio_params.args),
|
||||
)
|
||||
|
||||
mock_server_instance.serve.assert_called_once()
|
||||
|
||||
Reference in New Issue
Block a user