diff --git a/README.md b/README.md index f4a0b37..453fba6 100644 --- a/README.md +++ b/README.md @@ -114,16 +114,18 @@ separator. Arguments -| Name | Required | Description | Example | -|---------------------------|----------------------------|------------------------------------------------------------------|-----------------------| -| `command_or_url` | Yes | The command to spawn the MCP stdio server | uvx mcp-server-fetch | -| `--port` | No, random available | The MCP server port to listen on | 8080 | -| `--host` | No, `127.0.0.1` by default | The host IP address that the MCP server will listen on | 0.0.0.0 | -| `--env` | No | Additional environment variables to pass to the MCP stdio server | FOO=BAR | -| `--pass-environment` | No | Pass through all environment variables when spawning the server | --no-pass-environment | -| `--allow-origin` | No | Pass through all environment variables when spawning the server | --allow-cors "\*" | -| `--sse-port` (deprecated) | No, random available | The SSE server port to listen on | 8080 | -| `--sse-host` (deprecated) | No, `127.0.0.1` by default | The host IP address that the SSE server will listen on | 0.0.0.0 | +| Name | Required | Description | Example | +|---------------------------|----------------------------|---------------------------------------------------------------------------------------------|-----------------------| +| `command_or_url` | Yes | The command to spawn the MCP stdio server | uvx mcp-server-fetch | +| `--port` | No, random available | The MCP server port to listen on | 8080 | +| `--host` | No, `127.0.0.1` by default | The host IP address that the MCP server will listen on | 0.0.0.0 | +| `--env` | No | Additional environment variables to pass to the MCP stdio server | FOO=BAR | +| `--cwd` | No | The working directory to pass to the MCP stdio server process. | /tmp | +| `--pass-environment` | No | Pass through all environment variables when spawning the server | --no-pass-environment | +| `--allow-origin` | No | Allowed origins for the SSE server. Can be used multiple times. Default is no CORS allowed. | --allow-cors "\*" | +| `--stateless` | No | Enable stateless mode for streamable http transports. Default is False | --no-stateless | +| `--sse-port` (deprecated) | No, random available | The SSE server port to listen on | 8080 | +| `--sse-host` (deprecated) | No, `127.0.0.1` by default | The host IP address that the SSE server will listen on | 0.0.0.0 | ### 2.2 Example usage @@ -147,7 +149,8 @@ mcp-proxy --host=0.0.0.0 --port=8080 uvx mcp-server-fetch mcp-proxy --port=8080 -- uvx mcp-server-fetch --user-agent=YourUserAgent ``` -This will start an MCP server that can be connected to at `http://127.0.0.1:8080/sse` via SSE, or `http://127.0.0.1:8080/mcp/` via StreamableHttp +This will start an MCP server that can be connected to at `http://127.0.0.1:8080/sse` via SSE, or +`http://127.0.0.1:8080/mcp/` via StreamableHttp ## Installation diff --git a/src/mcp_proxy/__main__.py b/src/mcp_proxy/__main__.py index dc21fc1..fdf945b 100644 --- a/src/mcp_proxy/__main__.py +++ b/src/mcp_proxy/__main__.py @@ -77,6 +77,11 @@ def main() -> None: help="Environment variables used when spawning the server. Can be used multiple times.", default=[], ) + stdio_client_options.add_argument( + "--cwd", + default=None, + help="The working directory to use when spawning the process.", + ) stdio_client_options.add_argument( "--pass-environment", action=argparse.BooleanOptionalAction, @@ -90,30 +95,36 @@ def main() -> None: default=False, ) - sse_server_group = parser.add_argument_group("SSE server options") - sse_server_group.add_argument( + mcp_server_group = parser.add_argument_group("SSE server options") + mcp_server_group.add_argument( "--port", type=int, default=None, help="Port to expose an SSE server on. Default is a random port", ) - sse_server_group.add_argument( + mcp_server_group.add_argument( "--host", default=None, help="Host to expose an SSE server on. Default is 127.0.0.1", ) - sse_server_group.add_argument( + mcp_server_group.add_argument( + "--stateless", + action=argparse.BooleanOptionalAction, + help="Enable stateless mode for streamable http transports. Default is False", + default=False, + ) + mcp_server_group.add_argument( "--sse-port", type=int, default=0, help="Port to expose an SSE server on. Default is a random port", ) - sse_server_group.add_argument( + mcp_server_group.add_argument( "--sse-host", default="127.0.0.1", help="Host to expose an SSE server on. Default is 127.0.0.1", ) - sse_server_group.add_argument( + mcp_server_group.add_argument( "--allow-origin", nargs="+", default=[], @@ -161,11 +172,13 @@ def main() -> None: command=args.command_or_url, args=args.args, env=env, + cwd=args.cwd if args.cwd else None, ) mcp_settings = MCPServerSettings( bind_host=args.host if args.host is not None else args.sse_host, port=args.port if args.port is not None else args.sse_port, + stateless=args.stateless, allow_origins=args.allow_origin if len(args.allow_origin) > 0 else None, log_level="DEBUG" if args.debug else "INFO", ) diff --git a/src/mcp_proxy/mcp_server.py b/src/mcp_proxy/mcp_server.py index a470834..bb7b0fb 100644 --- a/src/mcp_proxy/mcp_server.py +++ b/src/mcp_proxy/mcp_server.py @@ -5,18 +5,14 @@ import logging from collections.abc import AsyncIterator from dataclasses import dataclass from datetime import datetime, timezone -from http import HTTPStatus -from typing import Any, Literal -from uuid import uuid4 +from typing import Literal -import anyio import uvicorn -from anyio.abc import TaskStatus from mcp.client.session import ClientSession from mcp.client.stdio import StdioServerParameters, stdio_client from mcp.server import Server from mcp.server.sse import SseServerTransport -from mcp.server.streamable_http import StreamableHTTPServerTransport +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.cors import CORSMiddleware @@ -28,28 +24,6 @@ from starlette.types import Receive, Scope, Send from .proxy_server import create_proxy_server logger = logging.getLogger(__name__) -# Global task group that will be initialized in the lifespan -task_group = None - -MCP_SESSION_ID_HEADER = "mcp-session-id" - - -@contextlib.asynccontextmanager -async def lifespan(_: Starlette) -> AsyncIterator[None]: - """Application lifespan context manager for managing task group.""" - global task_group # noqa: PLW0603 - - async with anyio.create_task_group() as tg: - task_group = tg - logger.info("Application started, task group initialized!") - try: - yield - finally: - logger.info("Application shutting down, cleaning up resources...") - if task_group: - tg.cancel_scope.cancel() - task_group = None - logger.info("Resources cleaned up successfully.") @dataclass @@ -58,13 +32,15 @@ class MCPServerSettings: bind_host: str port: int + stateless: bool = False allow_origins: list[str] | None = None log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO" -def create_starlette_app( # noqa: C901, Refactor required for complexity +def create_starlette_app( mcp_server: Server[object], *, + stateless: bool = False, allow_origins: list[str] | None = None, debug: bool = False, ) -> Starlette: @@ -97,57 +73,17 @@ def create_starlette_app( # noqa: C901, Refactor required for complexity mcp_server.create_initialization_options(), ) - # Refer: https://github.com/modelcontextprotocol/python-sdk/blob/5d8eaf77be00dbd9b33a7fe1e38cb0da77e49401/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py - # We need to store the server instances between requests - server_instances: dict[str, Any] = {} - # Lock to prevent race conditions when creating new sessions - session_creation_lock = anyio.Lock() + # Refer: https://github.com/modelcontextprotocol/python-sdk/blob/v1.8.0/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py + http = StreamableHTTPSessionManager( + app=mcp_server, + event_store=None, + json_response=True, + stateless=stateless, + ) async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None: _update_mcp_activity() - request = Request(scope, receive) - request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) - if request_mcp_session_id is not None and request_mcp_session_id in server_instances: - transport = server_instances[request_mcp_session_id] - logger.debug("Session already exists, handling request directly") - await transport.handle_request(scope, receive, send) - elif request_mcp_session_id is None: - # try to establish new session - logger.debug("Creating new transport") - # Use lock to prevent race conditions when creating new sessions - async with session_creation_lock: - new_session_id = uuid4().hex - http_transport = StreamableHTTPServerTransport( - mcp_session_id=new_session_id, - is_json_response_enabled=True, - ) - server_instances[new_session_id] = http_transport - logger.info("Created new transport with session ID: %s", new_session_id) - - async def run_server(task_status: TaskStatus[Any] | None = None) -> None: - async with http_transport.connect() as streams: - read_stream, write_stream = streams - if task_status: - task_status.started() - await mcp_server.run( - read_stream, - write_stream, - mcp_server.create_initialization_options(), - ) - - if not task_group: - raise RuntimeError("Task group is not initialized") - - await task_group.start(run_server) - - # Handle the HTTP request and return the response - await http_transport.handle_request(scope, receive, send) - else: - response = Response( - "Bad Request: No valid session ID provided", - status_code=HTTPStatus.BAD_REQUEST, - ) - await response(scope, receive, send) + await http.handle_request(scope, receive, send) async def handle_status(_: Request) -> Response: """Health check and service usage monitoring endpoint. @@ -159,6 +95,16 @@ def create_starlette_app( # noqa: C901, Refactor required for complexity """ return JSONResponse(status) + @contextlib.asynccontextmanager + async def lifespan(_: Starlette) -> AsyncIterator[None]: + """Context manager for session manager.""" + async with http.run(): + logger.info("Application started with StreamableHTTP session manager!") + try: + yield + finally: + logger.info("Application shutting down...") + middleware: list[Middleware] = [] if allow_origins is not None: middleware.append( diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py index d1b8767..cadf80d 100644 --- a/tests/test_mcp_server.py +++ b/tests/test_mcp_server.py @@ -6,11 +6,11 @@ import typing as t import pytest import uvicorn -from mcp import types from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.streamable_http import streamablehttp_client -from mcp.server import Server +from mcp.server import FastMCP +from mcp.types import TextContent from mcp_proxy.mcp_server import create_starlette_app @@ -42,19 +42,32 @@ class BackgroundServer(uvicorn.Server): return f"http://{hostport[0]}:{hostport[1]}" -@pytest.mark.asyncio -async def test_create_starlette_app() -> None: - """Test basic glue code for the SSE transport and a fake MCP server.""" - mcp_server: Server[object] = Server("prompt-server") +def make_background_server(**kwargs) -> BackgroundServer: # noqa: ANN003 + """Create a BackgroundServer instance with specified parameters.""" + mcp = FastMCP("TestServer") - @mcp_server.list_prompts() # type: ignore[no-untyped-call,misc] - async def list_prompts() -> list[types.Prompt]: - return [types.Prompt(name="prompt1")] + @mcp.prompt(name="prompt1") + async def list_prompts() -> str: + return "hello world" - app = create_starlette_app(mcp_server, allow_origins=["*"]) + @mcp.tool(name="echo") + async def call_tool(message: str) -> str: + return f"Echo: {message}" + + app = create_starlette_app( + mcp._mcp_server, # noqa: SLF001 + allow_origins=["*"], + **kwargs, + ) config = uvicorn.Config(app, port=0, log_level="info") - server = BackgroundServer(config) + return BackgroundServer(config) + + +@pytest.mark.asyncio +async def test_sse_transport() -> None: + """Test basic glue code for the SSE transport and a fake MCP server.""" + server = make_background_server(debug=True) async with server.run_in_background(): sse_url = f"{server.url}/sse" async with sse_client(url=sse_url) as streams, ClientSession(*streams) as session: @@ -63,6 +76,12 @@ async def test_create_starlette_app() -> None: assert len(response.prompts) == 1 assert response.prompts[0].name == "prompt1" + +@pytest.mark.asyncio +async def test_http_transport() -> None: + """Test HTTP transport layer functionality.""" + server = make_background_server(debug=True) + async with server.run_in_background(): http_url = f"{server.url}/mcp/" async with ( streamablehttp_client(url=http_url) as (read, write, _), @@ -72,3 +91,30 @@ async def test_create_starlette_app() -> None: response = await session.list_prompts() assert len(response.prompts) == 1 assert response.prompts[0].name == "prompt1" + + for i in range(3): + tool_result = await session.call_tool("echo", {"message": f"test_{i}"}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert tool_result.content[0].text == f"Echo: test_{i}" + + +async def test_stateless_http_transport() -> None: + """Test stateless HTTP transport functionality.""" + server = make_background_server(debug=True, stateless=True) + async with server.run_in_background(): + http_url = f"{server.url}/mcp/" + async with ( + streamablehttp_client(url=http_url) as (read, write, _), + ClientSession(read, write) as session, + ): + await session.initialize() + response = await session.list_prompts() + assert len(response.prompts) == 1 + assert response.prompts[0].name == "prompt1" + + for i in range(3): + tool_result = await session.call_tool("echo", {"message": f"test_{i}"}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert tool_result.content[0].text == f"Echo: test_{i}"