feat: support passing 'stateless' and 'cwd' arguments (#62)

1. Add support for --stateless parameter configuration
2. Add support for --cwd parameter passing working directory to mcp
stdio server
3. Use StreamableHTTPSessionManager from the latest python-mcp-sdk
release to manage sessions, simplifying code
4. Optimize test cases
This commit is contained in:
caydenwei
2025-05-11 19:26:37 +08:00
committed by GitHub
parent 8fee3d9833
commit 2980a50ad2
4 changed files with 113 additions and 105 deletions

View File

@@ -114,16 +114,18 @@ separator.
Arguments
| Name | Required | Description | Example |
|---------------------------|----------------------------|------------------------------------------------------------------|-----------------------|
| `command_or_url` | Yes | The command to spawn the MCP stdio server | uvx mcp-server-fetch |
| `--port` | No, random available | The MCP server port to listen on | 8080 |
| `--host` | No, `127.0.0.1` by default | The host IP address that the MCP server will listen on | 0.0.0.0 |
| `--env` | No | Additional environment variables to pass to the MCP stdio server | FOO=BAR |
| `--pass-environment` | No | Pass through all environment variables when spawning the server | --no-pass-environment |
| `--allow-origin` | No | Pass through all environment variables when spawning the server | --allow-cors "\*" |
| `--sse-port` (deprecated) | No, random available | The SSE server port to listen on | 8080 |
| `--sse-host` (deprecated) | No, `127.0.0.1` by default | The host IP address that the SSE server will listen on | 0.0.0.0 |
| Name | Required | Description | Example |
|---------------------------|----------------------------|---------------------------------------------------------------------------------------------|-----------------------|
| `command_or_url` | Yes | The command to spawn the MCP stdio server | uvx mcp-server-fetch |
| `--port` | No, random available | The MCP server port to listen on | 8080 |
| `--host` | No, `127.0.0.1` by default | The host IP address that the MCP server will listen on | 0.0.0.0 |
| `--env` | No | Additional environment variables to pass to the MCP stdio server | FOO=BAR |
| `--cwd` | No | The working directory to pass to the MCP stdio server process. | /tmp |
| `--pass-environment` | No | Pass through all environment variables when spawning the server | --no-pass-environment |
| `--allow-origin` | No | Allowed origins for the SSE server. Can be used multiple times. Default is no CORS allowed. | --allow-cors "\*" |
| `--stateless` | No | Enable stateless mode for streamable http transports. Default is False | --no-stateless |
| `--sse-port` (deprecated) | No, random available | The SSE server port to listen on | 8080 |
| `--sse-host` (deprecated) | No, `127.0.0.1` by default | The host IP address that the SSE server will listen on | 0.0.0.0 |
### 2.2 Example usage
@@ -147,7 +149,8 @@ mcp-proxy --host=0.0.0.0 --port=8080 uvx mcp-server-fetch
mcp-proxy --port=8080 -- uvx mcp-server-fetch --user-agent=YourUserAgent
```
This will start an MCP server that can be connected to at `http://127.0.0.1:8080/sse` via SSE, or `http://127.0.0.1:8080/mcp/` via StreamableHttp
This will start an MCP server that can be connected to at `http://127.0.0.1:8080/sse` via SSE, or
`http://127.0.0.1:8080/mcp/` via StreamableHttp
## Installation

View File

@@ -77,6 +77,11 @@ def main() -> None:
help="Environment variables used when spawning the server. Can be used multiple times.",
default=[],
)
stdio_client_options.add_argument(
"--cwd",
default=None,
help="The working directory to use when spawning the process.",
)
stdio_client_options.add_argument(
"--pass-environment",
action=argparse.BooleanOptionalAction,
@@ -90,30 +95,36 @@ def main() -> None:
default=False,
)
sse_server_group = parser.add_argument_group("SSE server options")
sse_server_group.add_argument(
mcp_server_group = parser.add_argument_group("SSE server options")
mcp_server_group.add_argument(
"--port",
type=int,
default=None,
help="Port to expose an SSE server on. Default is a random port",
)
sse_server_group.add_argument(
mcp_server_group.add_argument(
"--host",
default=None,
help="Host to expose an SSE server on. Default is 127.0.0.1",
)
sse_server_group.add_argument(
mcp_server_group.add_argument(
"--stateless",
action=argparse.BooleanOptionalAction,
help="Enable stateless mode for streamable http transports. Default is False",
default=False,
)
mcp_server_group.add_argument(
"--sse-port",
type=int,
default=0,
help="Port to expose an SSE server on. Default is a random port",
)
sse_server_group.add_argument(
mcp_server_group.add_argument(
"--sse-host",
default="127.0.0.1",
help="Host to expose an SSE server on. Default is 127.0.0.1",
)
sse_server_group.add_argument(
mcp_server_group.add_argument(
"--allow-origin",
nargs="+",
default=[],
@@ -161,11 +172,13 @@ def main() -> None:
command=args.command_or_url,
args=args.args,
env=env,
cwd=args.cwd if args.cwd else None,
)
mcp_settings = MCPServerSettings(
bind_host=args.host if args.host is not None else args.sse_host,
port=args.port if args.port is not None else args.sse_port,
stateless=args.stateless,
allow_origins=args.allow_origin if len(args.allow_origin) > 0 else None,
log_level="DEBUG" if args.debug else "INFO",
)

View File

@@ -5,18 +5,14 @@ import logging
from collections.abc import AsyncIterator
from dataclasses import dataclass
from datetime import datetime, timezone
from http import HTTPStatus
from typing import Any, Literal
from uuid import uuid4
from typing import Literal
import anyio
import uvicorn
from anyio.abc import TaskStatus
from mcp.client.session import ClientSession
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.server import Server
from mcp.server.sse import SseServerTransport
from mcp.server.streamable_http import StreamableHTTPServerTransport
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
@@ -28,28 +24,6 @@ from starlette.types import Receive, Scope, Send
from .proxy_server import create_proxy_server
logger = logging.getLogger(__name__)
# Global task group that will be initialized in the lifespan
task_group = None
MCP_SESSION_ID_HEADER = "mcp-session-id"
@contextlib.asynccontextmanager
async def lifespan(_: Starlette) -> AsyncIterator[None]:
"""Application lifespan context manager for managing task group."""
global task_group # noqa: PLW0603
async with anyio.create_task_group() as tg:
task_group = tg
logger.info("Application started, task group initialized!")
try:
yield
finally:
logger.info("Application shutting down, cleaning up resources...")
if task_group:
tg.cancel_scope.cancel()
task_group = None
logger.info("Resources cleaned up successfully.")
@dataclass
@@ -58,13 +32,15 @@ class MCPServerSettings:
bind_host: str
port: int
stateless: bool = False
allow_origins: list[str] | None = None
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
def create_starlette_app( # noqa: C901, Refactor required for complexity
def create_starlette_app(
mcp_server: Server[object],
*,
stateless: bool = False,
allow_origins: list[str] | None = None,
debug: bool = False,
) -> Starlette:
@@ -97,57 +73,17 @@ def create_starlette_app( # noqa: C901, Refactor required for complexity
mcp_server.create_initialization_options(),
)
# Refer: https://github.com/modelcontextprotocol/python-sdk/blob/5d8eaf77be00dbd9b33a7fe1e38cb0da77e49401/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py
# We need to store the server instances between requests
server_instances: dict[str, Any] = {}
# Lock to prevent race conditions when creating new sessions
session_creation_lock = anyio.Lock()
# Refer: https://github.com/modelcontextprotocol/python-sdk/blob/v1.8.0/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py
http = StreamableHTTPSessionManager(
app=mcp_server,
event_store=None,
json_response=True,
stateless=stateless,
)
async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None:
_update_mcp_activity()
request = Request(scope, receive)
request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER)
if request_mcp_session_id is not None and request_mcp_session_id in server_instances:
transport = server_instances[request_mcp_session_id]
logger.debug("Session already exists, handling request directly")
await transport.handle_request(scope, receive, send)
elif request_mcp_session_id is None:
# try to establish new session
logger.debug("Creating new transport")
# Use lock to prevent race conditions when creating new sessions
async with session_creation_lock:
new_session_id = uuid4().hex
http_transport = StreamableHTTPServerTransport(
mcp_session_id=new_session_id,
is_json_response_enabled=True,
)
server_instances[new_session_id] = http_transport
logger.info("Created new transport with session ID: %s", new_session_id)
async def run_server(task_status: TaskStatus[Any] | None = None) -> None:
async with http_transport.connect() as streams:
read_stream, write_stream = streams
if task_status:
task_status.started()
await mcp_server.run(
read_stream,
write_stream,
mcp_server.create_initialization_options(),
)
if not task_group:
raise RuntimeError("Task group is not initialized")
await task_group.start(run_server)
# Handle the HTTP request and return the response
await http_transport.handle_request(scope, receive, send)
else:
response = Response(
"Bad Request: No valid session ID provided",
status_code=HTTPStatus.BAD_REQUEST,
)
await response(scope, receive, send)
await http.handle_request(scope, receive, send)
async def handle_status(_: Request) -> Response:
"""Health check and service usage monitoring endpoint.
@@ -159,6 +95,16 @@ def create_starlette_app( # noqa: C901, Refactor required for complexity
"""
return JSONResponse(status)
@contextlib.asynccontextmanager
async def lifespan(_: Starlette) -> AsyncIterator[None]:
"""Context manager for session manager."""
async with http.run():
logger.info("Application started with StreamableHTTP session manager!")
try:
yield
finally:
logger.info("Application shutting down...")
middleware: list[Middleware] = []
if allow_origins is not None:
middleware.append(

View File

@@ -6,11 +6,11 @@ import typing as t
import pytest
import uvicorn
from mcp import types
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
from mcp.server import Server
from mcp.server import FastMCP
from mcp.types import TextContent
from mcp_proxy.mcp_server import create_starlette_app
@@ -42,19 +42,32 @@ class BackgroundServer(uvicorn.Server):
return f"http://{hostport[0]}:{hostport[1]}"
@pytest.mark.asyncio
async def test_create_starlette_app() -> None:
"""Test basic glue code for the SSE transport and a fake MCP server."""
mcp_server: Server[object] = Server("prompt-server")
def make_background_server(**kwargs) -> BackgroundServer: # noqa: ANN003
"""Create a BackgroundServer instance with specified parameters."""
mcp = FastMCP("TestServer")
@mcp_server.list_prompts() # type: ignore[no-untyped-call,misc]
async def list_prompts() -> list[types.Prompt]:
return [types.Prompt(name="prompt1")]
@mcp.prompt(name="prompt1")
async def list_prompts() -> str:
return "hello world"
app = create_starlette_app(mcp_server, allow_origins=["*"])
@mcp.tool(name="echo")
async def call_tool(message: str) -> str:
return f"Echo: {message}"
app = create_starlette_app(
mcp._mcp_server, # noqa: SLF001
allow_origins=["*"],
**kwargs,
)
config = uvicorn.Config(app, port=0, log_level="info")
server = BackgroundServer(config)
return BackgroundServer(config)
@pytest.mark.asyncio
async def test_sse_transport() -> None:
"""Test basic glue code for the SSE transport and a fake MCP server."""
server = make_background_server(debug=True)
async with server.run_in_background():
sse_url = f"{server.url}/sse"
async with sse_client(url=sse_url) as streams, ClientSession(*streams) as session:
@@ -63,6 +76,12 @@ async def test_create_starlette_app() -> None:
assert len(response.prompts) == 1
assert response.prompts[0].name == "prompt1"
@pytest.mark.asyncio
async def test_http_transport() -> None:
"""Test HTTP transport layer functionality."""
server = make_background_server(debug=True)
async with server.run_in_background():
http_url = f"{server.url}/mcp/"
async with (
streamablehttp_client(url=http_url) as (read, write, _),
@@ -72,3 +91,30 @@ async def test_create_starlette_app() -> None:
response = await session.list_prompts()
assert len(response.prompts) == 1
assert response.prompts[0].name == "prompt1"
for i in range(3):
tool_result = await session.call_tool("echo", {"message": f"test_{i}"})
assert len(tool_result.content) == 1
assert isinstance(tool_result.content[0], TextContent)
assert tool_result.content[0].text == f"Echo: test_{i}"
async def test_stateless_http_transport() -> None:
"""Test stateless HTTP transport functionality."""
server = make_background_server(debug=True, stateless=True)
async with server.run_in_background():
http_url = f"{server.url}/mcp/"
async with (
streamablehttp_client(url=http_url) as (read, write, _),
ClientSession(read, write) as session,
):
await session.initialize()
response = await session.list_prompts()
assert len(response.prompts) == 1
assert response.prompts[0].name == "prompt1"
for i in range(3):
tool_result = await session.call_tool("echo", {"message": f"test_{i}"})
assert len(tool_result.content) == 1
assert isinstance(tool_result.content[0], TextContent)
assert tool_result.content[0].text == f"Echo: test_{i}"