Update tests to exercise client behavior

This commit is contained in:
Allen Porter
2024-12-28 11:39:20 -08:00
parent b59f34e855
commit 728404641b
2 changed files with 165 additions and 133 deletions

View File

@@ -1,81 +1,109 @@
"""Tests for the mcp-proxy module."""
"""Tests for the mcp-proxy module.
Tests are running in two modes:
- One where the server is exercised directly though an in memory client, just to
set a baseline for the expected behavior.
- Another where the server is exercised through a proxy server, which forwards
the requests to the original server.
The same test code is run on both to ensure parity.
"""
from typing import Any
from collections.abc import AsyncGenerator, Callable
from contextlib import asynccontextmanager, AbstractAsyncContextManager
import pytest
import anyio
from mcp import types
from mcp.client.session import ClientSession
from mcp.server import Server
from mcp.shared.exceptions import McpError
from mcp.shared.memory import create_connected_server_and_client_session
from mcp.server.stdio import stdio_server
from mcp_proxy import configure_app, create_server
from mcp_proxy import create_proxy_server
async def run_server() -> None:
"""Run a stdio server."""
TOOL_INPUT_SCHEMA = {
"type": "object",
"properties": {
"input1": {"type": "string"}
}
}
server = Server("test")
SessionContextManager = Callable[[Server], AbstractAsyncContextManager[ClientSession]]
@server.list_prompts()
async def handle_list_prompts() -> list[types.Prompt]:
return [
types.Prompt(name="prompt1"),
]
async with stdio_server() as (read_stream, write_stream):
await server.run(read_stream, write_stream, server.create_initialization_options())
# Direct server connection
in_memory: SessionContextManager = create_connected_server_and_client_session
@asynccontextmanager
async def proxy(server: Server) -> AsyncGenerator[ClientSession, None]:
"""Create a connection to the server through the proxy server."""
async with in_memory(server) as session:
wrapped_server = await create_proxy_server(session)
async with in_memory(wrapped_server) as wrapped_session:
yield wrapped_session
async def test_list_prompts():
@pytest.fixture(params=["server", "proxy"], scope="function")
def session_generator(request: Any) -> SessionContextManager:
"""Fixture that returns a client creation strategy either direct or using the proxy."""
if request.param == "server":
return in_memory
return proxy
async def test_list_prompts(session_generator: SessionContextManager):
"""Test list_prompts."""
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
types.JSONRPCMessage
](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
types.JSONRPCMessage
](1)
server = Server("test")
server = Server("prompt-server")
@server.list_prompts()
async def list_prompts() -> list[types.Prompt]:
return [
types.Prompt(name="prompt1"),
]
async def run_server() -> None:
print("running server")
await server.run(client_to_server_receive, server_to_client_send, server.create_initialization_options())
return [types.Prompt(name="prompt1")]
async def listen_session():
print("listening session")
print(session)
async for message in session.incoming_messages:
if isinstance(message, Exception):
raise message
print("message")
print(message)
async with session_generator(server) as session:
result = await session.initialize()
assert result.serverInfo.name == "prompt-server"
assert result.capabilities
assert result.capabilities.prompts
assert not result.capabilities.tools
assert not result.capabilities.resources
assert not result.capabilities.logging
# Create an in memory connection to the fake server
async with create_connected_server_and_client_session(server) as session:
# Baseline behavior for client
result = await session.list_prompts()
assert result.prompts == [types.Prompt(name="prompt1")]
with pytest.raises(McpError, match="Method not found"):
await session.list_tools()
# Create a proxy instance to the in memory server
wrapped_server = await create_server("name", session)
# Create a client to the proxy server
async with create_connected_server_and_client_session(server) as wrapped_session:
await wrapped_session.initialize()
async def test_list_tools(session_generator: SessionContextManager):
"""Test list_tools."""
result = await wrapped_session.list_prompts()
assert result.prompts == [types.Prompt(name="prompt1")]
server = Server("tools-server")
with pytest.raises(McpError, match="Method not found"):
await wrapped_session.list_tools()
@server.list_tools()
async def list_tools() -> list[types.Tool]:
return [types.Tool(
name="tool-name",
description="tool-description",
inputSchema=TOOL_INPUT_SCHEMA
)]
async with session_generator(server) as session:
result = await session.initialize()
assert result.serverInfo.name == "tools-server"
assert result.capabilities
assert result.capabilities.tools
assert not result.capabilities.prompts
assert not result.capabilities.resources
assert not result.capabilities.logging
result = await session.list_tools()
assert len(result.tools) == 1
assert result.tools[0].name == "tool-name"
assert result.tools[0].description == "tool-description"
assert result.tools[0].inputSchema == TOOL_INPUT_SCHEMA
with pytest.raises(McpError, match="Method not found"):
await session.list_prompts()