feat: support passing 'stateless' and 'cwd' arguments (#62)
1. Add support for --stateless parameter configuration 2. Add support for --cwd parameter passing working directory to mcp stdio server 3. Use StreamableHTTPSessionManager from the latest python-mcp-sdk release to manage sessions, simplifying code 4. Optimize test cases
This commit is contained in:
25
README.md
25
README.md
@@ -114,16 +114,18 @@ separator.
|
||||
|
||||
Arguments
|
||||
|
||||
| Name | Required | Description | Example |
|
||||
|---------------------------|----------------------------|------------------------------------------------------------------|-----------------------|
|
||||
| `command_or_url` | Yes | The command to spawn the MCP stdio server | uvx mcp-server-fetch |
|
||||
| `--port` | No, random available | The MCP server port to listen on | 8080 |
|
||||
| `--host` | No, `127.0.0.1` by default | The host IP address that the MCP server will listen on | 0.0.0.0 |
|
||||
| `--env` | No | Additional environment variables to pass to the MCP stdio server | FOO=BAR |
|
||||
| `--pass-environment` | No | Pass through all environment variables when spawning the server | --no-pass-environment |
|
||||
| `--allow-origin` | No | Pass through all environment variables when spawning the server | --allow-cors "\*" |
|
||||
| `--sse-port` (deprecated) | No, random available | The SSE server port to listen on | 8080 |
|
||||
| `--sse-host` (deprecated) | No, `127.0.0.1` by default | The host IP address that the SSE server will listen on | 0.0.0.0 |
|
||||
| Name | Required | Description | Example |
|
||||
|---------------------------|----------------------------|---------------------------------------------------------------------------------------------|-----------------------|
|
||||
| `command_or_url` | Yes | The command to spawn the MCP stdio server | uvx mcp-server-fetch |
|
||||
| `--port` | No, random available | The MCP server port to listen on | 8080 |
|
||||
| `--host` | No, `127.0.0.1` by default | The host IP address that the MCP server will listen on | 0.0.0.0 |
|
||||
| `--env` | No | Additional environment variables to pass to the MCP stdio server | FOO=BAR |
|
||||
| `--cwd` | No | The working directory to pass to the MCP stdio server process. | /tmp |
|
||||
| `--pass-environment` | No | Pass through all environment variables when spawning the server | --no-pass-environment |
|
||||
| `--allow-origin` | No | Allowed origins for the SSE server. Can be used multiple times. Default is no CORS allowed. | --allow-cors "\*" |
|
||||
| `--stateless` | No | Enable stateless mode for streamable http transports. Default is False | --no-stateless |
|
||||
| `--sse-port` (deprecated) | No, random available | The SSE server port to listen on | 8080 |
|
||||
| `--sse-host` (deprecated) | No, `127.0.0.1` by default | The host IP address that the SSE server will listen on | 0.0.0.0 |
|
||||
|
||||
### 2.2 Example usage
|
||||
|
||||
@@ -147,7 +149,8 @@ mcp-proxy --host=0.0.0.0 --port=8080 uvx mcp-server-fetch
|
||||
mcp-proxy --port=8080 -- uvx mcp-server-fetch --user-agent=YourUserAgent
|
||||
```
|
||||
|
||||
This will start an MCP server that can be connected to at `http://127.0.0.1:8080/sse` via SSE, or `http://127.0.0.1:8080/mcp/` via StreamableHttp
|
||||
This will start an MCP server that can be connected to at `http://127.0.0.1:8080/sse` via SSE, or
|
||||
`http://127.0.0.1:8080/mcp/` via StreamableHttp
|
||||
|
||||
## Installation
|
||||
|
||||
|
||||
@@ -77,6 +77,11 @@ def main() -> None:
|
||||
help="Environment variables used when spawning the server. Can be used multiple times.",
|
||||
default=[],
|
||||
)
|
||||
stdio_client_options.add_argument(
|
||||
"--cwd",
|
||||
default=None,
|
||||
help="The working directory to use when spawning the process.",
|
||||
)
|
||||
stdio_client_options.add_argument(
|
||||
"--pass-environment",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
@@ -90,30 +95,36 @@ def main() -> None:
|
||||
default=False,
|
||||
)
|
||||
|
||||
sse_server_group = parser.add_argument_group("SSE server options")
|
||||
sse_server_group.add_argument(
|
||||
mcp_server_group = parser.add_argument_group("SSE server options")
|
||||
mcp_server_group.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Port to expose an SSE server on. Default is a random port",
|
||||
)
|
||||
sse_server_group.add_argument(
|
||||
mcp_server_group.add_argument(
|
||||
"--host",
|
||||
default=None,
|
||||
help="Host to expose an SSE server on. Default is 127.0.0.1",
|
||||
)
|
||||
sse_server_group.add_argument(
|
||||
mcp_server_group.add_argument(
|
||||
"--stateless",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Enable stateless mode for streamable http transports. Default is False",
|
||||
default=False,
|
||||
)
|
||||
mcp_server_group.add_argument(
|
||||
"--sse-port",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Port to expose an SSE server on. Default is a random port",
|
||||
)
|
||||
sse_server_group.add_argument(
|
||||
mcp_server_group.add_argument(
|
||||
"--sse-host",
|
||||
default="127.0.0.1",
|
||||
help="Host to expose an SSE server on. Default is 127.0.0.1",
|
||||
)
|
||||
sse_server_group.add_argument(
|
||||
mcp_server_group.add_argument(
|
||||
"--allow-origin",
|
||||
nargs="+",
|
||||
default=[],
|
||||
@@ -161,11 +172,13 @@ def main() -> None:
|
||||
command=args.command_or_url,
|
||||
args=args.args,
|
||||
env=env,
|
||||
cwd=args.cwd if args.cwd else None,
|
||||
)
|
||||
|
||||
mcp_settings = MCPServerSettings(
|
||||
bind_host=args.host if args.host is not None else args.sse_host,
|
||||
port=args.port if args.port is not None else args.sse_port,
|
||||
stateless=args.stateless,
|
||||
allow_origins=args.allow_origin if len(args.allow_origin) > 0 else None,
|
||||
log_level="DEBUG" if args.debug else "INFO",
|
||||
)
|
||||
|
||||
@@ -5,18 +5,14 @@ import logging
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Literal
|
||||
from uuid import uuid4
|
||||
from typing import Literal
|
||||
|
||||
import anyio
|
||||
import uvicorn
|
||||
from anyio.abc import TaskStatus
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.stdio import StdioServerParameters, stdio_client
|
||||
from mcp.server import Server
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from mcp.server.streamable_http import StreamableHTTPServerTransport
|
||||
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
||||
from starlette.applications import Starlette
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
@@ -28,28 +24,6 @@ from starlette.types import Receive, Scope, Send
|
||||
from .proxy_server import create_proxy_server
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Global task group that will be initialized in the lifespan
|
||||
task_group = None
|
||||
|
||||
MCP_SESSION_ID_HEADER = "mcp-session-id"
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan(_: Starlette) -> AsyncIterator[None]:
|
||||
"""Application lifespan context manager for managing task group."""
|
||||
global task_group # noqa: PLW0603
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
task_group = tg
|
||||
logger.info("Application started, task group initialized!")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
logger.info("Application shutting down, cleaning up resources...")
|
||||
if task_group:
|
||||
tg.cancel_scope.cancel()
|
||||
task_group = None
|
||||
logger.info("Resources cleaned up successfully.")
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -58,13 +32,15 @@ class MCPServerSettings:
|
||||
|
||||
bind_host: str
|
||||
port: int
|
||||
stateless: bool = False
|
||||
allow_origins: list[str] | None = None
|
||||
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
|
||||
|
||||
|
||||
def create_starlette_app( # noqa: C901, Refactor required for complexity
|
||||
def create_starlette_app(
|
||||
mcp_server: Server[object],
|
||||
*,
|
||||
stateless: bool = False,
|
||||
allow_origins: list[str] | None = None,
|
||||
debug: bool = False,
|
||||
) -> Starlette:
|
||||
@@ -97,57 +73,17 @@ def create_starlette_app( # noqa: C901, Refactor required for complexity
|
||||
mcp_server.create_initialization_options(),
|
||||
)
|
||||
|
||||
# Refer: https://github.com/modelcontextprotocol/python-sdk/blob/5d8eaf77be00dbd9b33a7fe1e38cb0da77e49401/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py
|
||||
# We need to store the server instances between requests
|
||||
server_instances: dict[str, Any] = {}
|
||||
# Lock to prevent race conditions when creating new sessions
|
||||
session_creation_lock = anyio.Lock()
|
||||
# Refer: https://github.com/modelcontextprotocol/python-sdk/blob/v1.8.0/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py
|
||||
http = StreamableHTTPSessionManager(
|
||||
app=mcp_server,
|
||||
event_store=None,
|
||||
json_response=True,
|
||||
stateless=stateless,
|
||||
)
|
||||
|
||||
async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
_update_mcp_activity()
|
||||
request = Request(scope, receive)
|
||||
request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER)
|
||||
if request_mcp_session_id is not None and request_mcp_session_id in server_instances:
|
||||
transport = server_instances[request_mcp_session_id]
|
||||
logger.debug("Session already exists, handling request directly")
|
||||
await transport.handle_request(scope, receive, send)
|
||||
elif request_mcp_session_id is None:
|
||||
# try to establish new session
|
||||
logger.debug("Creating new transport")
|
||||
# Use lock to prevent race conditions when creating new sessions
|
||||
async with session_creation_lock:
|
||||
new_session_id = uuid4().hex
|
||||
http_transport = StreamableHTTPServerTransport(
|
||||
mcp_session_id=new_session_id,
|
||||
is_json_response_enabled=True,
|
||||
)
|
||||
server_instances[new_session_id] = http_transport
|
||||
logger.info("Created new transport with session ID: %s", new_session_id)
|
||||
|
||||
async def run_server(task_status: TaskStatus[Any] | None = None) -> None:
|
||||
async with http_transport.connect() as streams:
|
||||
read_stream, write_stream = streams
|
||||
if task_status:
|
||||
task_status.started()
|
||||
await mcp_server.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
mcp_server.create_initialization_options(),
|
||||
)
|
||||
|
||||
if not task_group:
|
||||
raise RuntimeError("Task group is not initialized")
|
||||
|
||||
await task_group.start(run_server)
|
||||
|
||||
# Handle the HTTP request and return the response
|
||||
await http_transport.handle_request(scope, receive, send)
|
||||
else:
|
||||
response = Response(
|
||||
"Bad Request: No valid session ID provided",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
await http.handle_request(scope, receive, send)
|
||||
|
||||
async def handle_status(_: Request) -> Response:
|
||||
"""Health check and service usage monitoring endpoint.
|
||||
@@ -159,6 +95,16 @@ def create_starlette_app( # noqa: C901, Refactor required for complexity
|
||||
"""
|
||||
return JSONResponse(status)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan(_: Starlette) -> AsyncIterator[None]:
|
||||
"""Context manager for session manager."""
|
||||
async with http.run():
|
||||
logger.info("Application started with StreamableHTTP session manager!")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
logger.info("Application shutting down...")
|
||||
|
||||
middleware: list[Middleware] = []
|
||||
if allow_origins is not None:
|
||||
middleware.append(
|
||||
|
||||
@@ -6,11 +6,11 @@ import typing as t
|
||||
|
||||
import pytest
|
||||
import uvicorn
|
||||
from mcp import types
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from mcp.server import Server
|
||||
from mcp.server import FastMCP
|
||||
from mcp.types import TextContent
|
||||
|
||||
from mcp_proxy.mcp_server import create_starlette_app
|
||||
|
||||
@@ -42,19 +42,32 @@ class BackgroundServer(uvicorn.Server):
|
||||
return f"http://{hostport[0]}:{hostport[1]}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_starlette_app() -> None:
|
||||
"""Test basic glue code for the SSE transport and a fake MCP server."""
|
||||
mcp_server: Server[object] = Server("prompt-server")
|
||||
def make_background_server(**kwargs) -> BackgroundServer: # noqa: ANN003
|
||||
"""Create a BackgroundServer instance with specified parameters."""
|
||||
mcp = FastMCP("TestServer")
|
||||
|
||||
@mcp_server.list_prompts() # type: ignore[no-untyped-call,misc]
|
||||
async def list_prompts() -> list[types.Prompt]:
|
||||
return [types.Prompt(name="prompt1")]
|
||||
@mcp.prompt(name="prompt1")
|
||||
async def list_prompts() -> str:
|
||||
return "hello world"
|
||||
|
||||
app = create_starlette_app(mcp_server, allow_origins=["*"])
|
||||
@mcp.tool(name="echo")
|
||||
async def call_tool(message: str) -> str:
|
||||
return f"Echo: {message}"
|
||||
|
||||
app = create_starlette_app(
|
||||
mcp._mcp_server, # noqa: SLF001
|
||||
allow_origins=["*"],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
config = uvicorn.Config(app, port=0, log_level="info")
|
||||
server = BackgroundServer(config)
|
||||
return BackgroundServer(config)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_transport() -> None:
|
||||
"""Test basic glue code for the SSE transport and a fake MCP server."""
|
||||
server = make_background_server(debug=True)
|
||||
async with server.run_in_background():
|
||||
sse_url = f"{server.url}/sse"
|
||||
async with sse_client(url=sse_url) as streams, ClientSession(*streams) as session:
|
||||
@@ -63,6 +76,12 @@ async def test_create_starlette_app() -> None:
|
||||
assert len(response.prompts) == 1
|
||||
assert response.prompts[0].name == "prompt1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_transport() -> None:
|
||||
"""Test HTTP transport layer functionality."""
|
||||
server = make_background_server(debug=True)
|
||||
async with server.run_in_background():
|
||||
http_url = f"{server.url}/mcp/"
|
||||
async with (
|
||||
streamablehttp_client(url=http_url) as (read, write, _),
|
||||
@@ -72,3 +91,30 @@ async def test_create_starlette_app() -> None:
|
||||
response = await session.list_prompts()
|
||||
assert len(response.prompts) == 1
|
||||
assert response.prompts[0].name == "prompt1"
|
||||
|
||||
for i in range(3):
|
||||
tool_result = await session.call_tool("echo", {"message": f"test_{i}"})
|
||||
assert len(tool_result.content) == 1
|
||||
assert isinstance(tool_result.content[0], TextContent)
|
||||
assert tool_result.content[0].text == f"Echo: test_{i}"
|
||||
|
||||
|
||||
async def test_stateless_http_transport() -> None:
|
||||
"""Test stateless HTTP transport functionality."""
|
||||
server = make_background_server(debug=True, stateless=True)
|
||||
async with server.run_in_background():
|
||||
http_url = f"{server.url}/mcp/"
|
||||
async with (
|
||||
streamablehttp_client(url=http_url) as (read, write, _),
|
||||
ClientSession(read, write) as session,
|
||||
):
|
||||
await session.initialize()
|
||||
response = await session.list_prompts()
|
||||
assert len(response.prompts) == 1
|
||||
assert response.prompts[0].name == "prompt1"
|
||||
|
||||
for i in range(3):
|
||||
tool_result = await session.call_tool("echo", {"message": f"test_{i}"})
|
||||
assert len(tool_result.content) == 1
|
||||
assert isinstance(tool_result.content[0], TextContent)
|
||||
assert tool_result.content[0].text == f"Echo: test_{i}"
|
||||
|
||||
Reference in New Issue
Block a user