This PR adds support for running multiple MCP (STDIO) servers and serving them up via a single mcp-proxy instance, each with a named path in the URL. Example usage: ``` mcp-proxy --port 8080 --named-server fetch 'uvx mcp-server-fetch' --named-server github 'npx -y @modelcontextprotocol/server-github' ``` Would serve: - `http://localhost:8080/servers/fetch` - `http://localhost:8080/servers/github` I've also added the ability to provide a standard mcp client config file with accompanying tests. Please feel free to make any changes as you see fit, or reject the PR if it does not align with your goals. Thank you, --------- Co-authored-by: Magnus Tidemann <magnustidemann@gmail.com> Co-authored-by: Sergey Parfenyuk <sergey.parfenyuk@gmail.com>
675 lines
25 KiB
Python
675 lines
25 KiB
Python
"""Tests for the sse server."""
|
|
# ruff: noqa: PLR2004
|
|
|
|
import asyncio
|
|
import contextlib
|
|
import typing as t
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
import uvicorn
|
|
from mcp.client.session import ClientSession
|
|
from mcp.client.sse import sse_client
|
|
from mcp.client.stdio import StdioServerParameters
|
|
from mcp.client.streamable_http import streamablehttp_client
|
|
from mcp.server import FastMCP, Server
|
|
from mcp.types import TextContent
|
|
from starlette.applications import Starlette
|
|
from starlette.middleware import Middleware
|
|
from starlette.middleware.cors import CORSMiddleware
|
|
|
|
from mcp_proxy.mcp_server import MCPServerSettings, create_single_instance_routes, run_mcp_server
|
|
|
|
|
|
def create_starlette_app(
|
|
mcp_server: Server[t.Any],
|
|
allow_origins: list[str] | None = None,
|
|
*,
|
|
debug: bool = False,
|
|
stateless: bool = False,
|
|
) -> Starlette:
|
|
"""Create a Starlette application for the MCP server.
|
|
|
|
Args:
|
|
mcp_server: The MCP server instance to wrap
|
|
allow_origins: List of allowed CORS origins
|
|
debug: Enable debug mode
|
|
stateless: Whether to use stateless HTTP sessions
|
|
|
|
Returns:
|
|
Starlette application instance
|
|
"""
|
|
routes, http_manager = create_single_instance_routes(mcp_server, stateless_instance=stateless)
|
|
|
|
middleware: list[Middleware] = []
|
|
if allow_origins:
|
|
middleware.append(
|
|
Middleware(
|
|
CORSMiddleware,
|
|
allow_origins=allow_origins,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
),
|
|
)
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def lifespan(_app: Starlette) -> t.AsyncIterator[None]:
|
|
async with http_manager.run():
|
|
yield
|
|
|
|
return Starlette(
|
|
debug=debug,
|
|
routes=routes,
|
|
middleware=middleware,
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
|
|
class BackgroundServer(uvicorn.Server):
|
|
"""A test server that runs in a background thread."""
|
|
|
|
def install_signal_handlers(self) -> None:
|
|
"""Do not install signal handlers."""
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def run_in_background(self) -> t.AsyncIterator[None]:
|
|
"""Run the server in a background thread."""
|
|
task = asyncio.create_task(self.serve())
|
|
try:
|
|
while not self.started: # noqa: ASYNC110
|
|
await asyncio.sleep(1e-3)
|
|
yield
|
|
finally:
|
|
self.should_exit = self.force_exit = True
|
|
await task
|
|
|
|
@property
|
|
def url(self) -> str:
|
|
"""Return the url of the started server."""
|
|
hostport = next(
|
|
iter([socket.getsockname() for server in self.servers for socket in server.sockets]),
|
|
)
|
|
return f"http://{hostport[0]}:{hostport[1]}"
|
|
|
|
|
|
def make_background_server(**kwargs) -> BackgroundServer: # noqa: ANN003
|
|
"""Create a BackgroundServer instance with specified parameters."""
|
|
mcp = FastMCP("TestServer")
|
|
|
|
@mcp.prompt(name="prompt1")
|
|
async def list_prompts() -> str:
|
|
return "hello world"
|
|
|
|
@mcp.tool(name="echo")
|
|
async def call_tool(message: str) -> str:
|
|
return f"Echo: {message}"
|
|
|
|
app = create_starlette_app(
|
|
mcp._mcp_server, # noqa: SLF001
|
|
allow_origins=["*"],
|
|
**kwargs,
|
|
)
|
|
|
|
config = uvicorn.Config(app, port=0, log_level="info")
|
|
return BackgroundServer(config)
|
|
|
|
|
|
async def test_sse_transport() -> None:
|
|
"""Test basic glue code for the SSE transport and a fake MCP server."""
|
|
server = make_background_server(debug=True)
|
|
async with server.run_in_background():
|
|
sse_url = f"{server.url}/sse"
|
|
async with sse_client(url=sse_url) as streams, ClientSession(*streams) as session:
|
|
await session.initialize()
|
|
response = await session.list_prompts()
|
|
assert len(response.prompts) == 1
|
|
assert response.prompts[0].name == "prompt1"
|
|
|
|
|
|
async def test_http_transport() -> None:
|
|
"""Test HTTP transport layer functionality."""
|
|
server = make_background_server(debug=True)
|
|
async with server.run_in_background():
|
|
http_url = f"{server.url}/mcp/"
|
|
async with (
|
|
streamablehttp_client(url=http_url) as (read, write, _),
|
|
ClientSession(read, write) as session,
|
|
):
|
|
await session.initialize()
|
|
response = await session.list_prompts()
|
|
assert len(response.prompts) == 1
|
|
assert response.prompts[0].name == "prompt1"
|
|
|
|
for i in range(3):
|
|
tool_result = await session.call_tool("echo", {"message": f"test_{i}"})
|
|
assert len(tool_result.content) == 1
|
|
assert isinstance(tool_result.content[0], TextContent)
|
|
assert tool_result.content[0].text == f"Echo: test_{i}"
|
|
|
|
|
|
async def test_stateless_http_transport() -> None:
|
|
"""Test stateless HTTP transport functionality."""
|
|
server = make_background_server(debug=True, stateless=True)
|
|
async with server.run_in_background():
|
|
http_url = f"{server.url}/mcp/"
|
|
async with (
|
|
streamablehttp_client(url=http_url) as (read, write, _),
|
|
ClientSession(read, write) as session,
|
|
):
|
|
await session.initialize()
|
|
response = await session.list_prompts()
|
|
assert len(response.prompts) == 1
|
|
assert response.prompts[0].name == "prompt1"
|
|
|
|
for i in range(3):
|
|
tool_result = await session.call_tool("echo", {"message": f"test_{i}"})
|
|
assert len(tool_result.content) == 1
|
|
assert isinstance(tool_result.content[0], TextContent)
|
|
assert tool_result.content[0].text == f"Echo: test_{i}"
|
|
|
|
|
|
# Unit tests for run_mcp_server method
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_settings() -> MCPServerSettings:
|
|
"""Create mock MCP server settings for testing."""
|
|
return MCPServerSettings(
|
|
bind_host="127.0.0.1",
|
|
port=8080,
|
|
stateless=False,
|
|
allow_origins=["*"],
|
|
log_level="INFO",
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_stdio_params() -> StdioServerParameters:
|
|
"""Create mock stdio server parameters for testing."""
|
|
return StdioServerParameters(
|
|
command="echo",
|
|
args=["hello"],
|
|
env={"TEST_VAR": "test_value"},
|
|
cwd="/tmp", # noqa: S108
|
|
)
|
|
|
|
|
|
class AsyncContextManagerMock: # noqa: D101
|
|
def __init__(self, mock) -> None: # noqa: ANN001, D107
|
|
self.mock = mock
|
|
|
|
async def __aenter__(self): # noqa: ANN204, D105
|
|
return self.mock
|
|
|
|
async def __aexit__(self, *args): # noqa: ANN002, ANN204, D105
|
|
pass
|
|
|
|
|
|
def setup_async_context_mocks() -> tuple[
|
|
AsyncContextManagerMock,
|
|
AsyncContextManagerMock,
|
|
AsyncMock,
|
|
MagicMock,
|
|
list[MagicMock],
|
|
]:
|
|
"""Helper function to set up async context manager mocks."""
|
|
# Setup stdio client mock
|
|
mock_streams = (AsyncMock(), AsyncMock())
|
|
|
|
# Setup client session mock
|
|
mock_session = AsyncMock()
|
|
|
|
# Setup HTTP manager mock
|
|
mock_http_manager = MagicMock()
|
|
mock_http_manager.run.return_value = AsyncContextManagerMock(None)
|
|
mock_routes = [MagicMock()]
|
|
|
|
return (
|
|
AsyncContextManagerMock(mock_streams),
|
|
AsyncContextManagerMock(mock_session),
|
|
mock_session,
|
|
mock_http_manager,
|
|
mock_routes,
|
|
)
|
|
|
|
|
|
async def test_run_mcp_server_no_servers_configured(mock_settings: MCPServerSettings) -> None:
|
|
"""Test run_mcp_server when no servers are configured."""
|
|
with patch("mcp_proxy.mcp_server.logger") as mock_logger:
|
|
await run_mcp_server(mock_settings, None, {})
|
|
mock_logger.error.assert_called_once_with("No servers configured to run.")
|
|
|
|
|
|
async def test_run_mcp_server_with_default_server(
|
|
mock_settings: MCPServerSettings,
|
|
mock_stdio_params: StdioServerParameters,
|
|
) -> None:
|
|
"""Test run_mcp_server with a default server configuration."""
|
|
with (
|
|
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
|
|
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
|
|
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
|
|
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
|
|
patch("uvicorn.Server") as mock_uvicorn_server,
|
|
patch("mcp_proxy.mcp_server.logger") as mock_logger,
|
|
):
|
|
# Setup mocks
|
|
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
|
|
setup_async_context_mocks()
|
|
)
|
|
mock_stdio_client.return_value = mock_stdio_context
|
|
mock_client_session.return_value = mock_session_context
|
|
|
|
mock_proxy = AsyncMock()
|
|
mock_create_proxy.return_value = mock_proxy
|
|
mock_create_routes.return_value = (mock_routes, mock_http_manager)
|
|
|
|
mock_server_instance = AsyncMock()
|
|
mock_uvicorn_server.return_value = mock_server_instance
|
|
|
|
# Run the function
|
|
await run_mcp_server(mock_settings, mock_stdio_params, {})
|
|
|
|
# Verify calls
|
|
mock_stdio_client.assert_called_once_with(mock_stdio_params)
|
|
mock_create_proxy.assert_called_once_with(mock_session)
|
|
mock_create_routes.assert_called_once_with(
|
|
mock_proxy,
|
|
stateless_instance=mock_settings.stateless,
|
|
)
|
|
mock_logger.info.assert_any_call(
|
|
"Setting up default server: %s %s",
|
|
mock_stdio_params.command,
|
|
" ".join(mock_stdio_params.args),
|
|
)
|
|
mock_server_instance.serve.assert_called_once()
|
|
|
|
|
|
async def test_run_mcp_server_with_named_servers(
|
|
mock_settings: MCPServerSettings,
|
|
mock_stdio_params: StdioServerParameters,
|
|
) -> None:
|
|
"""Test run_mcp_server with named servers configuration."""
|
|
named_servers = {
|
|
"server1": mock_stdio_params,
|
|
"server2": StdioServerParameters(
|
|
command="python",
|
|
args=["-m", "mcp_server"],
|
|
env={"PYTHON_PATH": "/usr/bin/python"},
|
|
cwd="/home/user",
|
|
),
|
|
}
|
|
|
|
with (
|
|
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
|
|
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
|
|
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
|
|
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
|
|
patch("uvicorn.Server") as mock_uvicorn_server,
|
|
patch("mcp_proxy.mcp_server.logger") as mock_logger,
|
|
):
|
|
# Setup mocks
|
|
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
|
|
setup_async_context_mocks()
|
|
)
|
|
mock_stdio_client.return_value = mock_stdio_context
|
|
mock_client_session.return_value = mock_session_context
|
|
|
|
mock_proxy = AsyncMock()
|
|
mock_create_proxy.return_value = mock_proxy
|
|
mock_create_routes.return_value = (mock_routes, mock_http_manager)
|
|
|
|
mock_server_instance = AsyncMock()
|
|
mock_uvicorn_server.return_value = mock_server_instance
|
|
|
|
# Run the function
|
|
await run_mcp_server(mock_settings, None, named_servers)
|
|
|
|
# Verify calls
|
|
assert mock_stdio_client.call_count == 2
|
|
assert mock_create_proxy.call_count == 2
|
|
assert mock_create_routes.call_count == 2
|
|
|
|
# Check that named servers were logged
|
|
mock_logger.info.assert_any_call(
|
|
"Setting up named server '%s': %s %s",
|
|
"server1",
|
|
mock_stdio_params.command,
|
|
" ".join(mock_stdio_params.args),
|
|
)
|
|
mock_logger.info.assert_any_call(
|
|
"Setting up named server '%s': %s %s",
|
|
"server2",
|
|
"python",
|
|
"-m mcp_server",
|
|
)
|
|
|
|
mock_server_instance.serve.assert_called_once()
|
|
|
|
|
|
async def test_run_mcp_server_with_cors_middleware(
|
|
mock_stdio_params: StdioServerParameters,
|
|
) -> None:
|
|
"""Test run_mcp_server adds CORS middleware when allow_origins is set."""
|
|
settings_with_cors = MCPServerSettings(
|
|
bind_host="0.0.0.0", # noqa: S104
|
|
port=9090,
|
|
allow_origins=["http://localhost:3000", "https://example.com"],
|
|
)
|
|
|
|
with (
|
|
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
|
|
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
|
|
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
|
|
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
|
|
patch("mcp_proxy.mcp_server.Starlette") as mock_starlette,
|
|
patch("uvicorn.Server") as mock_uvicorn_server,
|
|
):
|
|
# Setup mocks
|
|
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
|
|
setup_async_context_mocks()
|
|
)
|
|
mock_stdio_client.return_value = mock_stdio_context
|
|
mock_client_session.return_value = mock_session_context
|
|
|
|
mock_proxy = AsyncMock()
|
|
mock_create_proxy.return_value = mock_proxy
|
|
mock_create_routes.return_value = (mock_routes, mock_http_manager)
|
|
|
|
mock_server_instance = AsyncMock()
|
|
mock_uvicorn_server.return_value = mock_server_instance
|
|
|
|
# Run the function
|
|
await run_mcp_server(settings_with_cors, mock_stdio_params, {})
|
|
|
|
# Verify Starlette was called with middleware
|
|
mock_starlette.assert_called_once()
|
|
call_args = mock_starlette.call_args
|
|
middleware = call_args.kwargs["middleware"]
|
|
|
|
assert len(middleware) == 1
|
|
assert middleware[0].cls == CORSMiddleware
|
|
|
|
|
|
async def test_run_mcp_server_debug_mode(
|
|
mock_stdio_params: StdioServerParameters,
|
|
) -> None:
|
|
"""Test run_mcp_server with debug mode enabled."""
|
|
debug_settings = MCPServerSettings(
|
|
bind_host="127.0.0.1",
|
|
port=8080,
|
|
log_level="DEBUG",
|
|
)
|
|
|
|
with (
|
|
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
|
|
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
|
|
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
|
|
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
|
|
patch("mcp_proxy.mcp_server.Starlette") as mock_starlette,
|
|
patch("uvicorn.Server") as mock_uvicorn_server,
|
|
):
|
|
# Setup mocks
|
|
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
|
|
setup_async_context_mocks()
|
|
)
|
|
mock_stdio_client.return_value = mock_stdio_context
|
|
mock_client_session.return_value = mock_session_context
|
|
|
|
mock_proxy = AsyncMock()
|
|
mock_create_proxy.return_value = mock_proxy
|
|
mock_create_routes.return_value = (mock_routes, mock_http_manager)
|
|
|
|
mock_server_instance = AsyncMock()
|
|
mock_uvicorn_server.return_value = mock_server_instance
|
|
|
|
# Run the function
|
|
await run_mcp_server(debug_settings, mock_stdio_params, {})
|
|
|
|
# Verify Starlette was called with debug=True
|
|
mock_starlette.assert_called_once()
|
|
call_args = mock_starlette.call_args
|
|
assert call_args.kwargs["debug"] is True
|
|
|
|
|
|
async def test_run_mcp_server_stateless_mode(
|
|
mock_stdio_params: StdioServerParameters,
|
|
) -> None:
|
|
"""Test run_mcp_server with stateless mode enabled."""
|
|
stateless_settings = MCPServerSettings(
|
|
bind_host="127.0.0.1",
|
|
port=8080,
|
|
stateless=True,
|
|
)
|
|
|
|
with (
|
|
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
|
|
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
|
|
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
|
|
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
|
|
patch("uvicorn.Server") as mock_uvicorn_server,
|
|
):
|
|
# Setup mocks
|
|
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
|
|
setup_async_context_mocks()
|
|
)
|
|
mock_stdio_client.return_value = mock_stdio_context
|
|
mock_client_session.return_value = mock_session_context
|
|
|
|
mock_proxy = AsyncMock()
|
|
mock_create_proxy.return_value = mock_proxy
|
|
mock_create_routes.return_value = (mock_routes, mock_http_manager)
|
|
|
|
mock_server_instance = AsyncMock()
|
|
mock_uvicorn_server.return_value = mock_server_instance
|
|
|
|
# Run the function
|
|
await run_mcp_server(stateless_settings, mock_stdio_params, {})
|
|
|
|
# Verify create_single_instance_routes was called with stateless_instance=True
|
|
mock_create_routes.assert_called_once_with(
|
|
mock_proxy,
|
|
stateless_instance=True,
|
|
)
|
|
|
|
|
|
async def test_run_mcp_server_uvicorn_config(
|
|
mock_settings: MCPServerSettings,
|
|
mock_stdio_params: StdioServerParameters,
|
|
) -> None:
|
|
"""Test run_mcp_server creates correct uvicorn configuration."""
|
|
with (
|
|
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
|
|
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
|
|
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
|
|
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
|
|
patch("uvicorn.Config") as mock_uvicorn_config,
|
|
patch("uvicorn.Server") as mock_uvicorn_server,
|
|
):
|
|
# Setup mocks
|
|
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
|
|
setup_async_context_mocks()
|
|
)
|
|
mock_stdio_client.return_value = mock_stdio_context
|
|
mock_client_session.return_value = mock_session_context
|
|
|
|
mock_proxy = AsyncMock()
|
|
mock_create_proxy.return_value = mock_proxy
|
|
mock_create_routes.return_value = (mock_routes, mock_http_manager)
|
|
|
|
mock_config = MagicMock()
|
|
mock_uvicorn_config.return_value = mock_config
|
|
|
|
mock_server_instance = AsyncMock()
|
|
mock_uvicorn_server.return_value = mock_server_instance
|
|
|
|
# Run the function
|
|
await run_mcp_server(mock_settings, mock_stdio_params, {})
|
|
|
|
# Verify uvicorn.Config was called with correct parameters
|
|
mock_uvicorn_config.assert_called_once()
|
|
call_args = mock_uvicorn_config.call_args
|
|
|
|
assert call_args.kwargs["host"] == mock_settings.bind_host
|
|
assert call_args.kwargs["port"] == mock_settings.port
|
|
assert call_args.kwargs["log_level"] == mock_settings.log_level.lower()
|
|
|
|
|
|
async def test_run_mcp_server_global_status_updates(
|
|
mock_settings: MCPServerSettings,
|
|
mock_stdio_params: StdioServerParameters,
|
|
) -> None:
|
|
"""Test run_mcp_server updates global status correctly."""
|
|
from mcp_proxy.mcp_server import _global_status
|
|
|
|
# Clear global status before test
|
|
_global_status["server_instances"].clear()
|
|
|
|
named_servers = {"test_server": mock_stdio_params}
|
|
|
|
with (
|
|
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
|
|
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
|
|
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
|
|
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
|
|
patch("uvicorn.Server") as mock_uvicorn_server,
|
|
):
|
|
# Setup mocks
|
|
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
|
|
setup_async_context_mocks()
|
|
)
|
|
mock_stdio_client.return_value = mock_stdio_context
|
|
mock_client_session.return_value = mock_session_context
|
|
|
|
mock_proxy = AsyncMock()
|
|
mock_create_proxy.return_value = mock_proxy
|
|
mock_create_routes.return_value = (mock_routes, mock_http_manager)
|
|
|
|
mock_server_instance = AsyncMock()
|
|
mock_uvicorn_server.return_value = mock_server_instance
|
|
|
|
# Run the function
|
|
await run_mcp_server(mock_settings, mock_stdio_params, named_servers)
|
|
|
|
# Verify global status was updated
|
|
assert "default" in _global_status["server_instances"]
|
|
assert "test_server" in _global_status["server_instances"]
|
|
assert _global_status["server_instances"]["default"] == "configured"
|
|
assert _global_status["server_instances"]["test_server"] == "configured"
|
|
|
|
|
|
async def test_run_mcp_server_sse_url_logging(
|
|
mock_settings: MCPServerSettings,
|
|
mock_stdio_params: StdioServerParameters,
|
|
) -> None:
|
|
"""Test run_mcp_server logs correct SSE URLs."""
|
|
named_servers = {"test_server": mock_stdio_params}
|
|
|
|
with (
|
|
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
|
|
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
|
|
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
|
|
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
|
|
patch("uvicorn.Server") as mock_uvicorn_server,
|
|
patch("mcp_proxy.mcp_server.logger") as mock_logger,
|
|
):
|
|
# Setup mocks
|
|
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
|
|
setup_async_context_mocks()
|
|
)
|
|
mock_stdio_client.return_value = mock_stdio_context
|
|
mock_client_session.return_value = mock_session_context
|
|
|
|
mock_proxy = AsyncMock()
|
|
mock_create_proxy.return_value = mock_proxy
|
|
mock_create_routes.return_value = (mock_routes, mock_http_manager)
|
|
|
|
mock_server_instance = AsyncMock()
|
|
mock_uvicorn_server.return_value = mock_server_instance
|
|
|
|
# Run the function
|
|
await run_mcp_server(mock_settings, mock_stdio_params, named_servers)
|
|
|
|
# Verify SSE URLs were logged
|
|
expected_default_url = f"http://{mock_settings.bind_host}:{mock_settings.port}/sse"
|
|
expected_named_url = (
|
|
f"http://{mock_settings.bind_host}:{mock_settings.port}/servers/test_server/sse"
|
|
)
|
|
|
|
mock_logger.info.assert_any_call("Serving MCP Servers via SSE:")
|
|
mock_logger.info.assert_any_call(" - %s", expected_default_url)
|
|
mock_logger.info.assert_any_call(" - %s", expected_named_url)
|
|
|
|
|
|
async def test_run_mcp_server_exception_handling(
|
|
mock_settings: MCPServerSettings,
|
|
mock_stdio_params: StdioServerParameters,
|
|
) -> None:
|
|
"""Test run_mcp_server handles exceptions properly."""
|
|
with (
|
|
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
|
|
patch("mcp_proxy.mcp_server.ClientSession"),
|
|
):
|
|
# Setup mocks to raise an exception
|
|
mock_stdio_client.side_effect = Exception("Connection failed")
|
|
|
|
# Should not raise, function should handle exceptions gracefully
|
|
try:
|
|
await run_mcp_server(mock_settings, mock_stdio_params, {})
|
|
except Exception as e: # noqa: BLE001
|
|
# If an exception is raised, it should be the expected one
|
|
assert "Connection failed" in str(e) # noqa: PT017
|
|
|
|
|
|
async def test_run_mcp_server_both_default_and_named_servers(
|
|
mock_settings: MCPServerSettings,
|
|
mock_stdio_params: StdioServerParameters,
|
|
) -> None:
|
|
"""Test run_mcp_server with both default and named servers."""
|
|
named_servers = {"named_server": mock_stdio_params}
|
|
|
|
with (
|
|
patch("mcp_proxy.mcp_server.stdio_client") as mock_stdio_client,
|
|
patch("mcp_proxy.mcp_server.ClientSession") as mock_client_session,
|
|
patch("mcp_proxy.mcp_server.create_proxy_server") as mock_create_proxy,
|
|
patch("mcp_proxy.mcp_server.create_single_instance_routes") as mock_create_routes,
|
|
patch("uvicorn.Server") as mock_uvicorn_server,
|
|
patch("mcp_proxy.mcp_server.logger") as mock_logger,
|
|
):
|
|
# Setup mocks
|
|
mock_stdio_context, mock_session_context, mock_session, mock_http_manager, mock_routes = (
|
|
setup_async_context_mocks()
|
|
)
|
|
mock_stdio_client.return_value = mock_stdio_context
|
|
mock_client_session.return_value = mock_session_context
|
|
|
|
mock_proxy = AsyncMock()
|
|
mock_create_proxy.return_value = mock_proxy
|
|
mock_create_routes.return_value = (mock_routes, mock_http_manager)
|
|
|
|
mock_server_instance = AsyncMock()
|
|
mock_uvicorn_server.return_value = mock_server_instance
|
|
|
|
# Run the function with both default and named servers
|
|
await run_mcp_server(mock_settings, mock_stdio_params, named_servers)
|
|
|
|
# Verify both servers were set up
|
|
assert mock_stdio_client.call_count == 2 # One for default, one for named
|
|
assert mock_create_proxy.call_count == 2
|
|
assert mock_create_routes.call_count == 2
|
|
|
|
# Verify logging for both servers
|
|
mock_logger.info.assert_any_call(
|
|
"Setting up default server: %s %s",
|
|
mock_stdio_params.command,
|
|
" ".join(mock_stdio_params.args),
|
|
)
|
|
mock_logger.info.assert_any_call(
|
|
"Setting up named server '%s': %s %s",
|
|
"named_server",
|
|
mock_stdio_params.command,
|
|
" ".join(mock_stdio_params.args),
|
|
)
|
|
|
|
mock_server_instance.serve.assert_called_once()
|