From b59f34e8555e047eabd2d3544824b6a617468574 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sat, 28 Dec 2024 09:20:17 -0800 Subject: [PATCH] Add basic test for mcp-proxy --- pyproject.toml | 8 ++++ src/mcp_proxy/__init__.py | 9 ++++- tests/__init__.py | 1 + tests/test_init.py | 81 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 97 insertions(+), 2 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/test_init.py diff --git a/pyproject.toml b/pyproject.toml index 65df37a..197e4cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,3 +11,11 @@ build-backend = "setuptools.build_meta" [project.scripts] mcp-proxy = "mcp_proxy.__main__:main" + +[tool.pytest.ini_options] +pythonpath = "src" +addopts = [ + "--import-mode=importlib", +] +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" diff --git a/src/mcp_proxy/__init__.py b/src/mcp_proxy/__init__.py index 0b98f49..0ec6d78 100644 --- a/src/mcp_proxy/__init__.py +++ b/src/mcp_proxy/__init__.py @@ -7,7 +7,7 @@ from mcp.client.session import ClientSession logger = logging.getLogger(__name__) -async def confugure_app(name: str, remote_app: ClientSession): +async def create_server(name: str, remote_app: ClientSession): app = server.Server(name) async def _list_prompts(_: t.Any) -> types.ServerResult: @@ -96,6 +96,11 @@ async def confugure_app(name: str, remote_app: ClientSession): app.request_handlers[types.CompleteRequest] = _complete + return app + + +async def configure_app(name: str, remote_app: ClientSession): + app = await create_server(name, remote_app) async with server.stdio_server() as (read_stream, write_stream): await app.run( read_stream, @@ -111,4 +116,4 @@ async def run_sse_client(url: str): async with ClientSession(read_stream, write_stream) as session: response = await session.initialize() - await confugure_app(response.serverInfo.name, session) + await configure_app(response.serverInfo.name, session) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..56648f1 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for mcp-proxy.""" diff --git a/tests/test_init.py b/tests/test_init.py new file mode 100644 index 0000000..391ee10 --- /dev/null +++ b/tests/test_init.py @@ -0,0 +1,81 @@ +"""Tests for the mcp-proxy module.""" + +import pytest +import anyio +from mcp import types +from mcp.client.session import ClientSession +from mcp.server import Server +from mcp.shared.exceptions import McpError +from mcp.shared.memory import create_connected_server_and_client_session +from mcp.server.stdio import stdio_server + +from mcp_proxy import configure_app, create_server + +async def run_server() -> None: + """Run a stdio server.""" + + server = Server("test") + + @server.list_prompts() + async def handle_list_prompts() -> list[types.Prompt]: + return [ + types.Prompt(name="prompt1"), + ] + + async with stdio_server() as (read_stream, write_stream): + await server.run(read_stream, write_stream, server.create_initialization_options()) + + +async def test_list_prompts(): + """Test list_prompts.""" + + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ + types.JSONRPCMessage + ](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ + types.JSONRPCMessage + ](1) + + server = Server("test") + + @server.list_prompts() + async def list_prompts() -> list[types.Prompt]: + return [ + types.Prompt(name="prompt1"), + ] + + async def run_server() -> None: + print("running server") + await server.run(client_to_server_receive, server_to_client_send, server.create_initialization_options()) + + async def listen_session(): + print("listening session") + print(session) + async for message in session.incoming_messages: + if isinstance(message, Exception): + raise message + print("message") + print(message) + + # Create an in memory connection to the fake server + async with create_connected_server_and_client_session(server) as session: + + # Baseline behavior for client + result = await session.list_prompts() + assert result.prompts == [types.Prompt(name="prompt1")] + + with pytest.raises(McpError, match="Method not found"): + await session.list_tools() + + # Create a proxy instance to the in memory server + wrapped_server = await create_server("name", session) + + # Create a client to the proxy server + async with create_connected_server_and_client_session(server) as wrapped_session: + await wrapped_session.initialize() + + result = await wrapped_session.list_prompts() + assert result.prompts == [types.Prompt(name="prompt1")] + + with pytest.raises(McpError, match="Method not found"): + await wrapped_session.list_tools()