From e730450bf4effd255d0d0cab1f87373643446248 Mon Sep 17 00:00:00 2001 From: Zhengfeng Date: Sun, 19 Oct 2025 02:01:28 +0800 Subject: [PATCH] fix: align /mcp streamable HTTP handling with python-sdk (#119) - mirror the python-sdk fix so `/mcp` requests are scope-normalised to `/mcp/` before hitting the StreamableHTTP session manager, eliminating the 307 redirect/404 regression introduced in #89 - extend the HTTP transport test to cover both `/mcp/` and `/mcp`, ensuring the proxy works with SDK clients out of the box Co-authored-by: Zhengfeng --- src/mcp_proxy/mcp_server.py | 44 +++++++++++++++++++++++++++++++++++-- tests/test_mcp_server.py | 9 +++++--- 2 files changed, 48 insertions(+), 5 deletions(-) diff --git a/src/mcp_proxy/mcp_server.py b/src/mcp_proxy/mcp_server.py index a39abb5..bfe1d91 100644 --- a/src/mcp_proxy/mcp_server.py +++ b/src/mcp_proxy/mcp_server.py @@ -2,7 +2,7 @@ import contextlib import logging -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Awaitable, Callable from dataclasses import dataclass from datetime import datetime, timezone from typing import Any, Literal @@ -48,6 +48,19 @@ def _update_global_activity() -> None: _global_status["api_last_activity"] = datetime.now(timezone.utc).isoformat() +class _ASGIEndpointAdapter: + """Wrap a coroutine function into an ASGI application.""" + + def __init__(self, endpoint: Callable[[Scope, Receive, Send], Awaitable[None]]) -> None: + self._endpoint = endpoint + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + await self._endpoint(scope, receive, send) + + +HTTP_METHODS = ["DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT", "TRACE"] + + async def _handle_status(_: Request) -> Response: """Global health check and service usage monitoring endpoint.""" return JSONResponse(_global_status) @@ -88,9 +101,36 @@ def create_single_instance_routes( async def handle_streamable_http_instance(scope: Scope, receive: Receive, send: Send) -> None: _update_global_activity() - await http_session_manager.handle_request(scope, receive, send) + updated_scope = scope + if scope.get("type") == "http": + path = scope.get("path", "") + if path and path.rstrip("/") == "/mcp" and not path.endswith("/"): + updated_scope = dict(scope) + normalized_path = path + "/" + logger.debug( + "Normalized request path from '%s' to '%s' without redirect", + path, + normalized_path, + ) + updated_scope["path"] = normalized_path + + raw_path = scope.get("raw_path") + if raw_path: + if b"?" in raw_path: + path_part, query_part = raw_path.split(b"?", 1) + updated_scope["raw_path"] = path_part.rstrip(b"/") + b"/?" + query_part + else: + updated_scope["raw_path"] = raw_path.rstrip(b"/") + b"/" + + await http_session_manager.handle_request(updated_scope, receive, send) routes = [ + Route( + "/mcp", + endpoint=_ASGIEndpointAdapter(handle_streamable_http_instance), + methods=HTTP_METHODS, + include_in_schema=False, + ), Mount("/mcp", app=handle_streamable_http_instance), Route("/sse", endpoint=handle_sse_instance), Mount("/messages/", app=sse_transport.handle_post_message), diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py index daeb8df..dfc540a 100644 --- a/tests/test_mcp_server.py +++ b/tests/test_mcp_server.py @@ -58,12 +58,14 @@ def create_starlette_app( async with http_manager.run(): yield - return Starlette( + app = Starlette( debug=debug, routes=routes, middleware=middleware, lifespan=lifespan, ) + app.router.redirect_slashes = False + return app class BackgroundServer(uvicorn.Server): @@ -149,11 +151,12 @@ async def test_sse_transport() -> None: assert response.prompts[0].name == "prompt1" -async def test_http_transport() -> None: +@pytest.mark.parametrize("path_suffix", ["/mcp/", "/mcp"]) +async def test_http_transport(path_suffix: str) -> None: """Test HTTP transport layer functionality.""" server = make_background_server(debug=True) async with server.run_in_background(): - http_url = f"{server.url}/mcp/" + http_url = f"{server.url}{path_suffix}" async with ( streamablehttp_client(url=http_url) as (read, write, _), ClientSession(read, write) as session,