Files
mcp-proxy/tests/test_mcp_server.py
Sam b25056fadd feat: support proxying multiple MCP stdio servers to SSE (#65)
This PR adds support for running multiple MCP (STDIO) servers and
serving them up via a single mcp-proxy instance, each with a named path
in the URL.

Example usage:

```
mcp-proxy --port 8080 --named-server fetch 'uvx mcp-server-fetch' --named-server github 'npx -y @modelcontextprotocol/server-github'
```

Would serve:

- `http://localhost:8080/servers/fetch`
- `http://localhost:8080/servers/github`

I've also added the ability to provide a standard mcp client config file
with accompanying tests.

Please feel free to make any changes as you see fit, or reject the PR if
it does not align with your goals.

Thank you,

---------

Co-authored-by: Magnus Tidemann <magnustidemann@gmail.com>
Co-authored-by: Sergey Parfenyuk <sergey.parfenyuk@gmail.com>
2025-05-27 11:48:25 +02:00

675 lines
25 KiB
Python

"""Tests for the sse server."""
# ruff: noqa: PLR2004
import asyncio
import contextlib
import typing as t
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import uvicorn
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters
from mcp.client.streamable_http import streamablehttp_client
from mcp.server import FastMCP, Server
from mcp.types import TextContent
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
from mcp_proxy.mcp_server import MCPServerSettings, create_single_instance_routes, run_mcp_server
def create_starlette_app(
mcp_server: Server[t.Any],
allow_origins: list[str] | None = None,
*,
debug: bool = False,
stateless: bool = False,
) -> Starlette:
"""Create a Starlette application for the MCP server.
Args:
mcp_server: The MCP server instance to wrap
allow_origins: List of allowed CORS origins
debug: Enable debug mode
stateless: Whether to use stateless HTTP sessions
Returns:
Starlette application instance
"""
routes, http_manager = create_single_instance_routes(mcp_server, stateless_instance=stateless)
middleware: list[Middleware] = []
if allow_origins:
middleware.append(
Middleware(
CORSMiddleware,
allow_origins=allow_origins,
allow_methods=["*"],
allow_headers=["*"],
),
)
@contextlib.asynccontextmanager
async def lifespan(_app: Starlette) -> t.AsyncIterator[None]:
async with http_manager.run():
yield
return Starlette(
debug=debug,
routes=routes,
middleware=middleware,
lifespan=lifespan,
)
class BackgroundServer(uvicorn.Server):
"""A test server that runs in a background thread."""
def install_signal_handlers(self) -> None:
"""Do not install signal handlers."""
@contextlib.asynccontextmanager
async def run_in_background(self) -> t.AsyncIterator[None]:
"""Run the server in a background thread."""
task = asyncio.create_task(self.serve())
try:
while not self.started: # noqa: ASYNC110
await asyncio.sleep(1e-3)
yield
finally:
self.should_exit = self.force_exit = True
await task
@property
def url(self) -> str:
"""Return the url of the started server."""
hostport = next(
iter([socket.getsockname() for server in self.servers for socket in server.sockets]),
)
return f"http://{hostport[0]}:{hostport[1]}"
def make_background_server(**kwargs) -> BackgroundServer: # noqa: ANN003
"""Create a BackgroundServer instance with specified parameters."""
mcp = FastMCP("TestServer")
@mcp.prompt(name="prompt1")
async def list_prompts() -> str:
return "hello world"
@mcp.tool(name="echo")
async def call_tool(message: str) -> str:
return f"Echo: {message}"
app = create_starlette_app(
mcp._mcp_server, # noqa: SLF001
allow_origins=["*"],
**kwargs,
)
config = uvicorn.Config(app, port=0, log_level="info")
return BackgroundServer(config)
async def test_sse_transport() -> None:
"""Test basic glue code for the SSE transport and a fake MCP server."""
server = make_background_server(debug=True)
async with server.run_in_background():
sse_url = f"{server.url}/sse"
async with sse_client(url=sse_url) as streams, ClientSession(*streams) as session:
await session.initialize()
response = await session.list_prompts()
assert len(response.prompts) == 1
assert response.prompts[0].name == "prompt1"
async def test_http_transport() -> None:
"""Test HTTP transport layer functionality."""
server = make_background_server(debug=True)
async with server.run_in_background():
http_url = f"{server.url}/mcp/"
async with (
streamablehttp_client(url=http_url) as (read, write, _),
ClientSession(read, write) as session,
):
await session.initialize()
response = await session.list_prompts()
assert len(response.prompts) == 1
assert response.prompts[0].name == "prompt1"
for i in range(3):
tool_result = await session.call_tool("echo", {"message": f"test_{i}"})
assert len(tool_result.content) == 1
assert isinstance(tool_result.content[0], TextContent)
assert tool_result.content[0].text == f"Echo: test_{i}"
async def test_stateless_http_transport() -> None:
"""Test stateless HTTP transport functionality."""
server = make_background_server(debug=True, stateless=True)
async with server.run_in_background():
http_url = f"{server.url}/mcp/"
async with (
streamablehttp_client(url=http_url) as (read, write, _),
ClientSession(read, write) as session,
):
await session.initialize()
response = await session.list_prompts()
assert len(response.prompts) == 1
assert response.prompts[0].name == "prompt1"
for i in range(3):
tool_result = await session.call_tool("echo", {"message": f"test_{i}"})
assert len(tool_result.content) == 1
assert isinstance(tool_result.content[0], TextContent)
assert tool_result.content[0].text == f"Echo: test_{i}"
# Unit tests for run_mcp_server method
@pytest.fixture
def mock_settings() -> MCPServerSettings:
"""Create mock MCP server settings for testing."""
return MCPServerSettings(
bind_host="127.0.0.1",
port=8080,
stateless=False,
allow_origins=["*"],
log_level="INFO",
)
@pytest.fixture
def mock_stdio_params() -> StdioServerParameters:
"""Create mock stdio server parameters for testing."""
return StdioServerParameters(
command="echo",
args=["hello"],
env={"TEST_VAR": "test_value"},
cwd="/tmp", # noqa: S108
)
class AsyncContextManagerMock: # noqa: D101
def __init__(self, mock) -> None: # noqa: ANN001, D107
self.mock = mock
async def __aenter__(self): # noqa: ANN204, D105
return self.mock
async def __aexit__(self, *args): # noqa: ANN002, ANN204, D105
pass
def setup_async_context_mocks() -> tuple[
AsyncContextManagerMock,
AsyncContextManagerMock,
AsyncMock,
MagicMock,
list[MagicMock],
]:
"""Helper function to set up async context manager mocks."""
# Setup stdio client mock
mock_streams = (AsyncMock(), AsyncMock())
# Setup client session mock
mock_session = AsyncMock()
# Setup HTTP manager mock
mock_http_manager = MagicMock()
mock_http_manager.run.return_value = AsyncContextManagerMock(None)
mock_routes = [MagicMock()]
return (
AsyncContextManagerMock(mock_streams),
AsyncContextManagerMock(mock_session),
mock_session,
mock_http_manager,
mock_routes,
)
async def test_run_mcp_server_no_servers_configured(mock_settings: MCPServerSettings) -> None:
"""Test run_mcp_server when no servers are configured."""
with patch("mcp_proxy.mcp_server.logger") as mock_logger:
await run_mcp_server(mock_settings, None, {})
mock_logger.error.assert_called_once_with("No servers configured to run.")
async def test_run_mcp_server_with_default_server(
mock_settings: MCPServerSettings,
mock_stdio_params: StdioServerParameters,
) -> None:
"""Test run_mcp_server with a default server configuration."""
with (
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
patch("uvicorn.Server") as mock_uvicorn_server,
patch("mcp_proxy.mcp_server.logger") as mock_logger,
):
# Setup mocks
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
setup_async_context_mocks()
)
mock_stdio_client.return_value = mock_stdio_context
mock_client_session.return_value = mock_session_context
mock_proxy = AsyncMock()
mock_create_proxy.return_value = mock_proxy
mock_create_routes.return_value = (mock_routes, mock_http_manager)
mock_server_instance = AsyncMock()
mock_uvicorn_server.return_value = mock_server_instance
# Run the function
await run_mcp_server(mock_settings, mock_stdio_params, {})
# Verify calls
mock_stdio_client.assert_called_once_with(mock_stdio_params)
mock_create_proxy.assert_called_once_with(mock_session)
mock_create_routes.assert_called_once_with(
mock_proxy,
stateless_instance=mock_settings.stateless,
)
mock_logger.info.assert_any_call(
"Setting up default server: %s %s",
mock_stdio_params.command,
" ".join(mock_stdio_params.args),
)
mock_server_instance.serve.assert_called_once()
async def test_run_mcp_server_with_named_servers(
mock_settings: MCPServerSettings,
mock_stdio_params: StdioServerParameters,
) -> None:
"""Test run_mcp_server with named servers configuration."""
named_servers = {
"server1": mock_stdio_params,
"server2": StdioServerParameters(
command="python",
args=["-m", "mcp_server"],
env={"PYTHON_PATH": "/usr/bin/python"},
cwd="/home/user",
),
}
with (
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
patch("uvicorn.Server") as mock_uvicorn_server,
patch("mcp_proxy.mcp_server.logger") as mock_logger,
):
# Setup mocks
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
setup_async_context_mocks()
)
mock_stdio_client.return_value = mock_stdio_context
mock_client_session.return_value = mock_session_context
mock_proxy = AsyncMock()
mock_create_proxy.return_value = mock_proxy
mock_create_routes.return_value = (mock_routes, mock_http_manager)
mock_server_instance = AsyncMock()
mock_uvicorn_server.return_value = mock_server_instance
# Run the function
await run_mcp_server(mock_settings, None, named_servers)
# Verify calls
assert mock_stdio_client.call_count == 2
assert mock_create_proxy.call_count == 2
assert mock_create_routes.call_count == 2
# Check that named servers were logged
mock_logger.info.assert_any_call(
"Setting up named server '%s': %s %s",
"server1",
mock_stdio_params.command,
" ".join(mock_stdio_params.args),
)
mock_logger.info.assert_any_call(
"Setting up named server '%s': %s %s",
"server2",
"python",
"-m mcp_server",
)
mock_server_instance.serve.assert_called_once()
async def test_run_mcp_server_with_cors_middleware(
mock_stdio_params: StdioServerParameters,
) -> None:
"""Test run_mcp_server adds CORS middleware when allow_origins is set."""
settings_with_cors = MCPServerSettings(
bind_host="0.0.0.0", # noqa: S104
port=9090,
allow_origins=["http://localhost:3000", "https://example.com"],
)
with (
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
patch("mcp_proxy.mcp_server.Starlette") as mock_starlette,
patch("uvicorn.Server") as mock_uvicorn_server,
):
# Setup mocks
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
setup_async_context_mocks()
)
mock_stdio_client.return_value = mock_stdio_context
mock_client_session.return_value = mock_session_context
mock_proxy = AsyncMock()
mock_create_proxy.return_value = mock_proxy
mock_create_routes.return_value = (mock_routes, mock_http_manager)
mock_server_instance = AsyncMock()
mock_uvicorn_server.return_value = mock_server_instance
# Run the function
await run_mcp_server(settings_with_cors, mock_stdio_params, {})
# Verify Starlette was called with middleware
mock_starlette.assert_called_once()
call_args = mock_starlette.call_args
middleware = call_args.kwargs["middleware"]
assert len(middleware) == 1
assert middleware[0].cls == CORSMiddleware
async def test_run_mcp_server_debug_mode(
mock_stdio_params: StdioServerParameters,
) -> None:
"""Test run_mcp_server with debug mode enabled."""
debug_settings = MCPServerSettings(
bind_host="127.0.0.1",
port=8080,
log_level="DEBUG",
)
with (
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
patch("mcp_proxy.mcp_server.Starlette") as mock_starlette,
patch("uvicorn.Server") as mock_uvicorn_server,
):
# Setup mocks
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
setup_async_context_mocks()
)
mock_stdio_client.return_value = mock_stdio_context
mock_client_session.return_value = mock_session_context
mock_proxy = AsyncMock()
mock_create_proxy.return_value = mock_proxy
mock_create_routes.return_value = (mock_routes, mock_http_manager)
mock_server_instance = AsyncMock()
mock_uvicorn_server.return_value = mock_server_instance
# Run the function
await run_mcp_server(debug_settings, mock_stdio_params, {})
# Verify Starlette was called with debug=True
mock_starlette.assert_called_once()
call_args = mock_starlette.call_args
assert call_args.kwargs["debug"] is True
async def test_run_mcp_server_stateless_mode(
mock_stdio_params: StdioServerParameters,
) -> None:
"""Test run_mcp_server with stateless mode enabled."""
stateless_settings = MCPServerSettings(
bind_host="127.0.0.1",
port=8080,
stateless=True,
)
with (
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
patch("uvicorn.Server") as mock_uvicorn_server,
):
# Setup mocks
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
setup_async_context_mocks()
)
mock_stdio_client.return_value = mock_stdio_context
mock_client_session.return_value = mock_session_context
mock_proxy = AsyncMock()
mock_create_proxy.return_value = mock_proxy
mock_create_routes.return_value = (mock_routes, mock_http_manager)
mock_server_instance = AsyncMock()
mock_uvicorn_server.return_value = mock_server_instance
# Run the function
await run_mcp_server(stateless_settings, mock_stdio_params, {})
# Verify create_single_instance_routes was called with stateless_instance=True
mock_create_routes.assert_called_once_with(
mock_proxy,
stateless_instance=True,
)
async def test_run_mcp_server_uvicorn_config(
mock_settings: MCPServerSettings,
mock_stdio_params: StdioServerParameters,
) -> None:
"""Test run_mcp_server creates correct uvicorn configuration."""
with (
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
patch("uvicorn.Config") as mock_uvicorn_config,
patch("uvicorn.Server") as mock_uvicorn_server,
):
# Setup mocks
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
setup_async_context_mocks()
)
mock_stdio_client.return_value = mock_stdio_context
mock_client_session.return_value = mock_session_context
mock_proxy = AsyncMock()
mock_create_proxy.return_value = mock_proxy
mock_create_routes.return_value = (mock_routes, mock_http_manager)
mock_config = MagicMock()
mock_uvicorn_config.return_value = mock_config
mock_server_instance = AsyncMock()
mock_uvicorn_server.return_value = mock_server_instance
# Run the function
await run_mcp_server(mock_settings, mock_stdio_params, {})
# Verify uvicorn.Config was called with correct parameters
mock_uvicorn_config.assert_called_once()
call_args = mock_uvicorn_config.call_args
assert call_args.kwargs["host"] == mock_settings.bind_host
assert call_args.kwargs["port"] == mock_settings.port
assert call_args.kwargs["log_level"] == mock_settings.log_level.lower()
async def test_run_mcp_server_global_status_updates(
mock_settings: MCPServerSettings,
mock_stdio_params: StdioServerParameters,
) -> None:
"""Test run_mcp_server updates global status correctly."""
from mcp_proxy.mcp_server import _global_status
# Clear global status before test
_global_status["server_instances"].clear()
named_servers = {"test_server": mock_stdio_params}
with (
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
patch("uvicorn.Server") as mock_uvicorn_server,
):
# Setup mocks
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
setup_async_context_mocks()
)
mock_stdio_client.return_value = mock_stdio_context
mock_client_session.return_value = mock_session_context
mock_proxy = AsyncMock()
mock_create_proxy.return_value = mock_proxy
mock_create_routes.return_value = (mock_routes, mock_http_manager)
mock_server_instance = AsyncMock()
mock_uvicorn_server.return_value = mock_server_instance
# Run the function
await run_mcp_server(mock_settings, mock_stdio_params, named_servers)
# Verify global status was updated
assert "default" in _global_status["server_instances"]
assert "test_server" in _global_status["server_instances"]
assert _global_status["server_instances"]["default"] == "configured"
assert _global_status["server_instances"]["test_server"] == "configured"
async def test_run_mcp_server_sse_url_logging(
mock_settings: MCPServerSettings,
mock_stdio_params: StdioServerParameters,
) -> None:
"""Test run_mcp_server logs correct SSE URLs."""
named_servers = {"test_server": mock_stdio_params}
with (
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
patch("uvicorn.Server") as mock_uvicorn_server,
patch("mcp_proxy.mcp_server.logger") as mock_logger,
):
# Setup mocks
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
setup_async_context_mocks()
)
mock_stdio_client.return_value = mock_stdio_context
mock_client_session.return_value = mock_session_context
mock_proxy = AsyncMock()
mock_create_proxy.return_value = mock_proxy
mock_create_routes.return_value = (mock_routes, mock_http_manager)
mock_server_instance = AsyncMock()
mock_uvicorn_server.return_value = mock_server_instance
# Run the function
await run_mcp_server(mock_settings, mock_stdio_params, named_servers)
# Verify SSE URLs were logged
expected_default_url = f"http://{mock_settings.bind_host}:{mock_settings.port}/sse"
expected_named_url = (
f"http://{mock_settings.bind_host}:{mock_settings.port}/servers/test_server/sse"
)
mock_logger.info.assert_any_call("Serving MCP Servers via SSE:")
mock_logger.info.assert_any_call(" - %s", expected_default_url)
mock_logger.info.assert_any_call(" - %s", expected_named_url)
async def test_run_mcp_server_exception_handling(
mock_settings: MCPServerSettings,
mock_stdio_params: StdioServerParameters,
) -> None:
"""Test run_mcp_server handles exceptions properly."""
with (
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
patch("mcp_proxy.mcp_server.ClientSession"),
):
# Setup mocks to raise an exception
mock_stdio_client.side_effect = Exception("Connection failed")
# Should not raise, function should handle exceptions gracefully
try:
await run_mcp_server(mock_settings, mock_stdio_params, {})
except Exception as e: # noqa: BLE001
# If an exception is raised, it should be the expected one
assert "Connection failed" in str(e) # noqa: PT017
async def test_run_mcp_server_both_default_and_named_servers(
mock_settings: MCPServerSettings,
mock_stdio_params: StdioServerParameters,
) -> None:
"""Test run_mcp_server with both default and named servers."""
named_servers = {"named_server": mock_stdio_params}
with (
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
patch("uvicorn.Server") as mock_uvicorn_server,
patch("mcp_proxy.mcp_server.logger") as mock_logger,
):
# Setup mocks
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
setup_async_context_mocks()
)
mock_stdio_client.return_value = mock_stdio_context
mock_client_session.return_value = mock_session_context
mock_proxy = AsyncMock()
mock_create_proxy.return_value = mock_proxy
mock_create_routes.return_value = (mock_routes, mock_http_manager)
mock_server_instance = AsyncMock()
mock_uvicorn_server.return_value = mock_server_instance
# Run the function with both default and named servers
await run_mcp_server(mock_settings, mock_stdio_params, named_servers)
# Verify both servers were set up
assert mock_stdio_client.call_count == 2 # One for default, one for named
assert mock_create_proxy.call_count == 2
assert mock_create_routes.call_count == 2
# Verify logging for both servers
mock_logger.info.assert_any_call(
"Setting up default server: %s %s",
mock_stdio_params.command,
" ".join(mock_stdio_params.args),
)
mock_logger.info.assert_any_call(
"Setting up named server '%s': %s %s",
"named_server",
mock_stdio_params.command,
" ".join(mock_stdio_params.args),
)
mock_server_instance.serve.assert_called_once()