refactor: separate client and proxy server in preparation for more client behaviors (#7)

* Refactor to separate client and proxy server in preparation for more client behaviors

* Fix incorrect package names
This commit is contained in:
Allen Porter
2024-12-31 01:10:33 -08:00
committed by GitHub
parent c132722d66
commit 8423905ca2
5 changed files with 151 additions and 145 deletions

View File

@@ -1,143 +1 @@
"""Create a local server that proxies requests to a remote server over SSE."""
import logging
import typing as t
from mcp import server, types
from mcp.client.session import ClientSession
logger = logging.getLogger(__name__)
async def create_proxy_server(remote_app: ClientSession) -> server.Server: # noqa: C901
"""Create a server instance from a remote app."""
response = await remote_app.initialize()
capabilities = response.capabilities
app = server.Server(response.serverInfo.name)
if capabilities.prompts:
async def _list_prompts(_: t.Any) -> types.ServerResult: # noqa: ANN401
result = await remote_app.list_prompts()
return types.ServerResult(result)
app.request_handlers[types.ListPromptsRequest] = _list_prompts
async def _get_prompt(req: types.GetPromptRequest) -> types.ServerResult:
result = await remote_app.get_prompt(req.params.name, req.params.arguments)
return types.ServerResult(result)
app.request_handlers[types.GetPromptRequest] = _get_prompt
if capabilities.resources:
async def _list_resources(_: t.Any) -> types.ServerResult: # noqa: ANN401
result = await remote_app.list_resources()
return types.ServerResult(result)
app.request_handlers[types.ListResourcesRequest] = _list_resources
# list_resource_templates() is not implemented in the client
# async def _list_resource_templates(_: t.Any) -> types.ServerResult:
# result = await remote_app.list_resource_templates()
# return types.ServerResult(result)
# app.request_handlers[types.ListResourceTemplatesRequest] = _list_resource_templates
async def _read_resource(req: types.ReadResourceRequest) -> types.ServerResult:
result = await remote_app.read_resource(req.params.uri)
return types.ServerResult(result)
app.request_handlers[types.ReadResourceRequest] = _read_resource
if capabilities.logging:
async def _set_logging_level(req: types.SetLevelRequest) -> types.ServerResult:
await remote_app.set_logging_level(req.params.level)
return types.ServerResult(types.EmptyResult())
app.request_handlers[types.SetLevelRequest] = _set_logging_level
if capabilities.resources:
async def _subscribe_resource(req: types.SubscribeRequest) -> types.ServerResult:
await remote_app.subscribe_resource(req.params.uri)
return types.ServerResult(types.EmptyResult())
app.request_handlers[types.SubscribeRequest] = _subscribe_resource
async def _unsubscribe_resource(req: types.UnsubscribeRequest) -> types.ServerResult:
await remote_app.unsubscribe_resource(req.params.uri)
return types.ServerResult(types.EmptyResult())
app.request_handlers[types.UnsubscribeRequest] = _unsubscribe_resource
if capabilities.tools:
async def _list_tools(_: t.Any) -> types.ServerResult: # noqa: ANN401
tools = await remote_app.list_tools()
return types.ServerResult(tools)
app.request_handlers[types.ListToolsRequest] = _list_tools
async def _call_tool(req: types.CallToolRequest) -> types.ServerResult:
try:
result = await remote_app.call_tool(
req.params.name,
(req.params.arguments or {}),
)
return types.ServerResult(result)
except Exception as e: # noqa: BLE001
return types.ServerResult(
types.CallToolResult(
content=[types.TextContent(type="text", text=str(e))],
isError=True,
),
)
app.request_handlers[types.CallToolRequest] = _call_tool
async def _send_progress_notification(req: types.ProgressNotification) -> None:
await remote_app.send_progress_notification(
req.params.progressToken,
req.params.progress,
req.params.total,
)
app.notification_handlers[types.ProgressNotification] = _send_progress_notification
async def _complete(req: types.CompleteRequest) -> types.ServerResult:
result = await remote_app.complete(
req.params.ref,
req.params.argument.model_dump(),
)
return types.ServerResult(result)
app.request_handlers[types.CompleteRequest] = _complete
return app
async def run_sse_client(url: str, api_access_token: str | None = None) -> None:
"""Run the SSE client.
Args:
url: The URL to connect to.
api_access_token: The API access token to use for authentication.
"""
from mcp.client.sse import sse_client
headers = {}
if api_access_token is not None:
headers["Authorization"] = f"Bearer {api_access_token}"
async with sse_client(url=url, headers=headers) as streams, ClientSession(*streams) as session:
app = await create_proxy_server(session)
async with server.stdio_server() as (read_stream, write_stream):
await app.run(
read_stream,
write_stream,
app.create_initialization_options(),
)
"""Library for proxying MCP servers across different transports."""

View File

@@ -11,7 +11,7 @@ import logging
import os
import typing as t
from . import run_sse_client
from .sse_client import run_sse_client
logging.basicConfig(level=logging.DEBUG)
SSE_URL: t.Final[str] = os.getenv("SSE_URL", "")

View File

@@ -0,0 +1,119 @@
"""Create an MCP server that proxies requests throgh an MCP client.
This server is created independent of any transport mechanism.
"""
import typing as t
from mcp import server, types
from mcp.client.session import ClientSession
async def create_proxy_server(remote_app: ClientSession) -> server.Server: # noqa: C901
"""Create a server instance from a remote app."""
response = await remote_app.initialize()
capabilities = response.capabilities
app = server.Server(response.serverInfo.name)
if capabilities.prompts:
async def _list_prompts(_: t.Any) -> types.ServerResult: # noqa: ANN401
result = await remote_app.list_prompts()
return types.ServerResult(result)
app.request_handlers[types.ListPromptsRequest] = _list_prompts
async def _get_prompt(req: types.GetPromptRequest) -> types.ServerResult:
result = await remote_app.get_prompt(req.params.name, req.params.arguments)
return types.ServerResult(result)
app.request_handlers[types.GetPromptRequest] = _get_prompt
if capabilities.resources:
async def _list_resources(_: t.Any) -> types.ServerResult: # noqa: ANN401
result = await remote_app.list_resources()
return types.ServerResult(result)
app.request_handlers[types.ListResourcesRequest] = _list_resources
# list_resource_templates() is not implemented in the client
# async def _list_resource_templates(_: t.Any) -> types.ServerResult:
# result = await remote_app.list_resource_templates()
# return types.ServerResult(result)
# app.request_handlers[types.ListResourceTemplatesRequest] = _list_resource_templates
async def _read_resource(req: types.ReadResourceRequest) -> types.ServerResult:
result = await remote_app.read_resource(req.params.uri)
return types.ServerResult(result)
app.request_handlers[types.ReadResourceRequest] = _read_resource
if capabilities.logging:
async def _set_logging_level(req: types.SetLevelRequest) -> types.ServerResult:
await remote_app.set_logging_level(req.params.level)
return types.ServerResult(types.EmptyResult())
app.request_handlers[types.SetLevelRequest] = _set_logging_level
if capabilities.resources:
async def _subscribe_resource(req: types.SubscribeRequest) -> types.ServerResult:
await remote_app.subscribe_resource(req.params.uri)
return types.ServerResult(types.EmptyResult())
app.request_handlers[types.SubscribeRequest] = _subscribe_resource
async def _unsubscribe_resource(req: types.UnsubscribeRequest) -> types.ServerResult:
await remote_app.unsubscribe_resource(req.params.uri)
return types.ServerResult(types.EmptyResult())
app.request_handlers[types.UnsubscribeRequest] = _unsubscribe_resource
if capabilities.tools:
async def _list_tools(_: t.Any) -> types.ServerResult: # noqa: ANN401
tools = await remote_app.list_tools()
return types.ServerResult(tools)
app.request_handlers[types.ListToolsRequest] = _list_tools
async def _call_tool(req: types.CallToolRequest) -> types.ServerResult:
try:
result = await remote_app.call_tool(
req.params.name,
(req.params.arguments or {}),
)
return types.ServerResult(result)
except Exception as e: # noqa: BLE001
return types.ServerResult(
types.CallToolResult(
content=[types.TextContent(type="text", text=str(e))],
isError=True,
),
)
app.request_handlers[types.CallToolRequest] = _call_tool
async def _send_progress_notification(req: types.ProgressNotification) -> None:
await remote_app.send_progress_notification(
req.params.progressToken,
req.params.progress,
req.params.total,
)
app.notification_handlers[types.ProgressNotification] = _send_progress_notification
async def _complete(req: types.CompleteRequest) -> types.ServerResult:
result = await remote_app.complete(
req.params.ref,
req.params.argument.model_dump(),
)
return types.ServerResult(result)
app.request_handlers[types.CompleteRequest] = _complete
return app

View File

@@ -0,0 +1,29 @@
"""Create a local server that proxies requests to a remote server over SSE."""
from mcp import server
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from .proxy_server import create_proxy_server
async def run_sse_client(url: str, api_access_token: str | None = None) -> None:
"""Run the SSE client.
Args:
url: The URL to connect to.
api_access_token: The API access token to use for authentication.
"""
headers = {}
if api_access_token is not None:
headers["Authorization"] = f"Bearer {api_access_token}"
async with sse_client(url=url, headers=headers) as streams, ClientSession(*streams) as session:
app = await create_proxy_server(session)
async with server.stdio_server() as (read_stream, write_stream):
await app.run(
read_stream,
write_stream,
app.create_initialization_options(),
)

View File

@@ -21,7 +21,7 @@ from mcp.shared.exceptions import McpError
from mcp.shared.memory import create_connected_server_and_client_session
from pydantic import AnyUrl
from mcp_proxy import create_proxy_server
from mcp_proxy.proxy_server import create_proxy_server
TOOL_INPUT_SCHEMA = {"type": "object", "properties": {"input1": {"type": "string"}}}