diff --git a/src/mcp_proxy/__init__.py b/src/mcp_proxy/__init__.py index 0ec6d78..8c95f43 100644 --- a/src/mcp_proxy/__init__.py +++ b/src/mcp_proxy/__init__.py @@ -7,79 +7,89 @@ from mcp.client.session import ClientSession logger = logging.getLogger(__name__) -async def create_server(name: str, remote_app: ClientSession): - app = server.Server(name) +async def create_proxy_server(remote_app: ClientSession): + """Create a server instance from a remote app.""" - async def _list_prompts(_: t.Any) -> types.ServerResult: - result = await remote_app.list_prompts() - return types.ServerResult(result) + response = await remote_app.initialize() + capabilities = response.capabilities - app.request_handlers[types.ListPromptsRequest] = _list_prompts + app = server.Server(response.serverInfo.name) - async def _get_prompt(req: types.GetPromptRequest) -> types.ServerResult: - result = await remote_app.get_prompt(req.params.name, req.params.arguments) - return types.ServerResult(result) - - app.request_handlers[types.GetPromptRequest] = _get_prompt - - async def _list_resources(_: t.Any) -> types.ServerResult: - result = await remote_app.list_resources() - return types.ServerResult(result) - - app.request_handlers[types.ListResourcesRequest] = _list_resources - - # list_resource_templates() is not implemented in the client - # async def _list_resource_templates(_: t.Any) -> types.ServerResult: - # result = await remote_app.list_resource_templates() - # return types.ServerResult(result) - - # app.request_handlers[types.ListResourceTemplatesRequest] = _list_resource_templates - - async def _read_resource(req: types.ReadResourceRequest): - result = await remote_app.read_resource(req.params.uri) - return types.ServerResult(result) - - app.request_handlers[types.ReadResourceRequest] = _read_resource - - async def _set_logging_level(req: types.SetLevelRequest): - await remote_app.set_logging_level(req.params.level) - return types.ServerResult(types.EmptyResult()) - - app.request_handlers[types.SetLevelRequest] = _set_logging_level - - async def _subscribe_resource(req: types.SubscribeRequest): - await remote_app.subscribe_resource(req.params.uri) - return types.ServerResult(types.EmptyResult()) - - app.request_handlers[types.SubscribeRequest] = _subscribe_resource - - async def _unsubscribe_resource(req: types.UnsubscribeRequest): - await remote_app.unsubscribe_resource(req.params.uri) - return types.ServerResult(types.EmptyResult()) - - app.request_handlers[types.UnsubscribeRequest] = _unsubscribe_resource - - async def _list_tools(_: t.Any): - tools = await remote_app.list_tools() - return types.ServerResult(tools) - - app.request_handlers[types.ListToolsRequest] = _list_tools - - async def _call_tool(req: types.CallToolRequest) -> types.ServerResult: - try: - result = await remote_app.call_tool( - req.params.name, (req.params.arguments or {}) - ) + if capabilities.prompts: + async def _list_prompts(_: t.Any) -> types.ServerResult: + result = await remote_app.list_prompts() return types.ServerResult(result) - except Exception as e: - return types.ServerResult( - types.CallToolResult( - content=[types.TextContent(type="text", text=str(e))], - isError=True, - ) - ) - app.request_handlers[types.CallToolRequest] = _call_tool + app.request_handlers[types.ListPromptsRequest] = _list_prompts + + async def _get_prompt(req: types.GetPromptRequest) -> types.ServerResult: + result = await remote_app.get_prompt(req.params.name, req.params.arguments) + return types.ServerResult(result) + + app.request_handlers[types.GetPromptRequest] = _get_prompt + + if capabilities.resources: + async def _list_resources(_: t.Any) -> types.ServerResult: + result = await remote_app.list_resources() + return types.ServerResult(result) + + app.request_handlers[types.ListResourcesRequest] = _list_resources + + # list_resource_templates() is not implemented in the client + # async def _list_resource_templates(_: t.Any) -> types.ServerResult: + # result = await remote_app.list_resource_templates() + # return types.ServerResult(result) + + # app.request_handlers[types.ListResourceTemplatesRequest] = _list_resource_templates + + async def _read_resource(req: types.ReadResourceRequest): + result = await remote_app.read_resource(req.params.uri) + return types.ServerResult(result) + + app.request_handlers[types.ReadResourceRequest] = _read_resource + + if capabilities.logging: + async def _set_logging_level(req: types.SetLevelRequest): + await remote_app.set_logging_level(req.params.level) + return types.ServerResult(types.EmptyResult()) + + app.request_handlers[types.SetLevelRequest] = _set_logging_level + + if capabilities.resources: + async def _subscribe_resource(req: types.SubscribeRequest): + await remote_app.subscribe_resource(req.params.uri) + return types.ServerResult(types.EmptyResult()) + + app.request_handlers[types.SubscribeRequest] = _subscribe_resource + + async def _unsubscribe_resource(req: types.UnsubscribeRequest): + await remote_app.unsubscribe_resource(req.params.uri) + return types.ServerResult(types.EmptyResult()) + + app.request_handlers[types.UnsubscribeRequest] = _unsubscribe_resource + + if capabilities.tools: + async def _list_tools(_: t.Any): + tools = await remote_app.list_tools() + return types.ServerResult(tools) + + app.request_handlers[types.ListToolsRequest] = _list_tools + + async def _call_tool(req: types.CallToolRequest) -> types.ServerResult: + try: + result = await remote_app.call_tool( + req.params.name, (req.params.arguments or {}) + ) + return types.ServerResult(result) + except Exception as e: + return types.ServerResult( + types.CallToolResult( + content=[types.TextContent(type="text", text=str(e))], + isError=True, + ) + ) + + app.request_handlers[types.CallToolRequest] = _call_tool async def _send_progress_notification(req: types.ProgressNotification): await remote_app.send_progress_notification( @@ -99,21 +109,15 @@ async def create_server(name: str, remote_app: ClientSession): return app -async def configure_app(name: str, remote_app: ClientSession): - app = await create_server(name, remote_app) - async with server.stdio_server() as (read_stream, write_stream): - await app.run( - read_stream, - write_stream, - app.create_initialization_options(), - ) - - async def run_sse_client(url: str): from mcp.client.sse import sse_client async with sse_client(url=url) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: - response = await session.initialize() - - await configure_app(response.serverInfo.name, session) + app = await create_proxy_server(session) + async with server.stdio_server() as (read_stream, write_stream): + await app.run( + read_stream, + write_stream, + app.create_initialization_options(), + ) diff --git a/tests/test_init.py b/tests/test_init.py index 391ee10..95d0924 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -1,81 +1,109 @@ -"""Tests for the mcp-proxy module.""" +"""Tests for the mcp-proxy module. + +Tests are running in two modes: +- One where the server is exercised directly though an in memory client, just to + set a baseline for the expected behavior. +- Another where the server is exercised through a proxy server, which forwards + the requests to the original server. + +The same test code is run on both to ensure parity. +""" + +from typing import Any +from collections.abc import AsyncGenerator, Callable +from contextlib import asynccontextmanager, AbstractAsyncContextManager import pytest -import anyio + from mcp import types from mcp.client.session import ClientSession from mcp.server import Server from mcp.shared.exceptions import McpError from mcp.shared.memory import create_connected_server_and_client_session -from mcp.server.stdio import stdio_server -from mcp_proxy import configure_app, create_server +from mcp_proxy import create_proxy_server -async def run_server() -> None: - """Run a stdio server.""" +TOOL_INPUT_SCHEMA = { + "type": "object", + "properties": { + "input1": {"type": "string"} + } +} - server = Server("test") +SessionContextManager = Callable[[Server], AbstractAsyncContextManager[ClientSession]] - @server.list_prompts() - async def handle_list_prompts() -> list[types.Prompt]: - return [ - types.Prompt(name="prompt1"), - ] - - async with stdio_server() as (read_stream, write_stream): - await server.run(read_stream, write_stream, server.create_initialization_options()) +# Direct server connection +in_memory: SessionContextManager = create_connected_server_and_client_session + +@asynccontextmanager +async def proxy(server: Server) -> AsyncGenerator[ClientSession, None]: + """Create a connection to the server through the proxy server.""" + async with in_memory(server) as session: + wrapped_server = await create_proxy_server(session) + async with in_memory(wrapped_server) as wrapped_session: + yield wrapped_session -async def test_list_prompts(): +@pytest.fixture(params=["server", "proxy"], scope="function") +def session_generator(request: Any) -> SessionContextManager: + """Fixture that returns a client creation strategy either direct or using the proxy.""" + if request.param == "server": + return in_memory + return proxy + + +async def test_list_prompts(session_generator: SessionContextManager): """Test list_prompts.""" - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - types.JSONRPCMessage - ](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - types.JSONRPCMessage - ](1) - - server = Server("test") + server = Server("prompt-server") @server.list_prompts() async def list_prompts() -> list[types.Prompt]: - return [ - types.Prompt(name="prompt1"), - ] - - async def run_server() -> None: - print("running server") - await server.run(client_to_server_receive, server_to_client_send, server.create_initialization_options()) + return [types.Prompt(name="prompt1")] - async def listen_session(): - print("listening session") - print(session) - async for message in session.incoming_messages: - if isinstance(message, Exception): - raise message - print("message") - print(message) + async with session_generator(server) as session: + result = await session.initialize() + assert result.serverInfo.name == "prompt-server" + assert result.capabilities + assert result.capabilities.prompts + assert not result.capabilities.tools + assert not result.capabilities.resources + assert not result.capabilities.logging - # Create an in memory connection to the fake server - async with create_connected_server_and_client_session(server) as session: - - # Baseline behavior for client result = await session.list_prompts() assert result.prompts == [types.Prompt(name="prompt1")] with pytest.raises(McpError, match="Method not found"): await session.list_tools() - # Create a proxy instance to the in memory server - wrapped_server = await create_server("name", session) - # Create a client to the proxy server - async with create_connected_server_and_client_session(server) as wrapped_session: - await wrapped_session.initialize() +async def test_list_tools(session_generator: SessionContextManager): + """Test list_tools.""" - result = await wrapped_session.list_prompts() - assert result.prompts == [types.Prompt(name="prompt1")] + server = Server("tools-server") - with pytest.raises(McpError, match="Method not found"): - await wrapped_session.list_tools() + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [types.Tool( + name="tool-name", + description="tool-description", + inputSchema=TOOL_INPUT_SCHEMA + )] + + async with session_generator(server) as session: + result = await session.initialize() + assert result.serverInfo.name == "tools-server" + assert result.capabilities + assert result.capabilities.tools + assert not result.capabilities.prompts + assert not result.capabilities.resources + assert not result.capabilities.logging + + result = await session.list_tools() + assert len(result.tools) == 1 + assert result.tools[0].name == "tool-name" + assert result.tools[0].description == "tool-description" + assert result.tools[0].inputSchema == TOOL_INPUT_SCHEMA + + with pytest.raises(McpError, match="Method not found"): + await session.list_prompts()