Update tests to exercise client behavior
This commit is contained in:
@@ -1,81 +1,109 @@
|
||||
"""Tests for the mcp-proxy module."""
|
||||
"""Tests for the mcp-proxy module.
|
||||
|
||||
Tests are running in two modes:
|
||||
- One where the server is exercised directly though an in memory client, just to
|
||||
set a baseline for the expected behavior.
|
||||
- Another where the server is exercised through a proxy server, which forwards
|
||||
the requests to the original server.
|
||||
|
||||
The same test code is run on both to ensure parity.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from contextlib import asynccontextmanager, AbstractAsyncContextManager
|
||||
|
||||
import pytest
|
||||
import anyio
|
||||
|
||||
from mcp import types
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.server import Server
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.shared.memory import create_connected_server_and_client_session
|
||||
from mcp.server.stdio import stdio_server
|
||||
|
||||
from mcp_proxy import configure_app, create_server
|
||||
from mcp_proxy import create_proxy_server
|
||||
|
||||
async def run_server() -> None:
|
||||
"""Run a stdio server."""
|
||||
TOOL_INPUT_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input1": {"type": "string"}
|
||||
}
|
||||
}
|
||||
|
||||
server = Server("test")
|
||||
SessionContextManager = Callable[[Server], AbstractAsyncContextManager[ClientSession]]
|
||||
|
||||
@server.list_prompts()
|
||||
async def handle_list_prompts() -> list[types.Prompt]:
|
||||
return [
|
||||
types.Prompt(name="prompt1"),
|
||||
]
|
||||
|
||||
async with stdio_server() as (read_stream, write_stream):
|
||||
await server.run(read_stream, write_stream, server.create_initialization_options())
|
||||
# Direct server connection
|
||||
in_memory: SessionContextManager = create_connected_server_and_client_session
|
||||
|
||||
@asynccontextmanager
|
||||
async def proxy(server: Server) -> AsyncGenerator[ClientSession, None]:
|
||||
"""Create a connection to the server through the proxy server."""
|
||||
async with in_memory(server) as session:
|
||||
wrapped_server = await create_proxy_server(session)
|
||||
async with in_memory(wrapped_server) as wrapped_session:
|
||||
yield wrapped_session
|
||||
|
||||
|
||||
async def test_list_prompts():
|
||||
@pytest.fixture(params=["server", "proxy"], scope="function")
|
||||
def session_generator(request: Any) -> SessionContextManager:
|
||||
"""Fixture that returns a client creation strategy either direct or using the proxy."""
|
||||
if request.param == "server":
|
||||
return in_memory
|
||||
return proxy
|
||||
|
||||
|
||||
async def test_list_prompts(session_generator: SessionContextManager):
|
||||
"""Test list_prompts."""
|
||||
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
|
||||
types.JSONRPCMessage
|
||||
](1)
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
|
||||
types.JSONRPCMessage
|
||||
](1)
|
||||
|
||||
server = Server("test")
|
||||
server = Server("prompt-server")
|
||||
|
||||
@server.list_prompts()
|
||||
async def list_prompts() -> list[types.Prompt]:
|
||||
return [
|
||||
types.Prompt(name="prompt1"),
|
||||
]
|
||||
|
||||
async def run_server() -> None:
|
||||
print("running server")
|
||||
await server.run(client_to_server_receive, server_to_client_send, server.create_initialization_options())
|
||||
return [types.Prompt(name="prompt1")]
|
||||
|
||||
async def listen_session():
|
||||
print("listening session")
|
||||
print(session)
|
||||
async for message in session.incoming_messages:
|
||||
if isinstance(message, Exception):
|
||||
raise message
|
||||
print("message")
|
||||
print(message)
|
||||
async with session_generator(server) as session:
|
||||
result = await session.initialize()
|
||||
assert result.serverInfo.name == "prompt-server"
|
||||
assert result.capabilities
|
||||
assert result.capabilities.prompts
|
||||
assert not result.capabilities.tools
|
||||
assert not result.capabilities.resources
|
||||
assert not result.capabilities.logging
|
||||
|
||||
# Create an in memory connection to the fake server
|
||||
async with create_connected_server_and_client_session(server) as session:
|
||||
|
||||
# Baseline behavior for client
|
||||
result = await session.list_prompts()
|
||||
assert result.prompts == [types.Prompt(name="prompt1")]
|
||||
|
||||
with pytest.raises(McpError, match="Method not found"):
|
||||
await session.list_tools()
|
||||
|
||||
# Create a proxy instance to the in memory server
|
||||
wrapped_server = await create_server("name", session)
|
||||
|
||||
# Create a client to the proxy server
|
||||
async with create_connected_server_and_client_session(server) as wrapped_session:
|
||||
await wrapped_session.initialize()
|
||||
async def test_list_tools(session_generator: SessionContextManager):
|
||||
"""Test list_tools."""
|
||||
|
||||
result = await wrapped_session.list_prompts()
|
||||
assert result.prompts == [types.Prompt(name="prompt1")]
|
||||
server = Server("tools-server")
|
||||
|
||||
with pytest.raises(McpError, match="Method not found"):
|
||||
await wrapped_session.list_tools()
|
||||
@server.list_tools()
|
||||
async def list_tools() -> list[types.Tool]:
|
||||
return [types.Tool(
|
||||
name="tool-name",
|
||||
description="tool-description",
|
||||
inputSchema=TOOL_INPUT_SCHEMA
|
||||
)]
|
||||
|
||||
async with session_generator(server) as session:
|
||||
result = await session.initialize()
|
||||
assert result.serverInfo.name == "tools-server"
|
||||
assert result.capabilities
|
||||
assert result.capabilities.tools
|
||||
assert not result.capabilities.prompts
|
||||
assert not result.capabilities.resources
|
||||
assert not result.capabilities.logging
|
||||
|
||||
result = await session.list_tools()
|
||||
assert len(result.tools) == 1
|
||||
assert result.tools[0].name == "tool-name"
|
||||
assert result.tools[0].description == "tool-description"
|
||||
assert result.tools[0].inputSchema == TOOL_INPUT_SCHEMA
|
||||
|
||||
with pytest.raises(McpError, match="Method not found"):
|
||||
await session.list_prompts()
|
||||
|
||||
Reference in New Issue
Block a user