refactor: separate client and proxy server in preparation for more client behaviors (#7)
* Refactor to separate client and proxy server in preparation for more client behaviors * Fix incorrect package names
This commit is contained in:
@@ -1,143 +1 @@
|
||||
"""Create a local server that proxies requests to a remote server over SSE."""
|
||||
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
from mcp import server, types
|
||||
from mcp.client.session import ClientSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def create_proxy_server(remote_app: ClientSession) -> server.Server: # noqa: C901
|
||||
"""Create a server instance from a remote app."""
|
||||
response = await remote_app.initialize()
|
||||
capabilities = response.capabilities
|
||||
|
||||
app = server.Server(response.serverInfo.name)
|
||||
|
||||
if capabilities.prompts:
|
||||
|
||||
async def _list_prompts(_: t.Any) -> types.ServerResult: # noqa: ANN401
|
||||
result = await remote_app.list_prompts()
|
||||
return types.ServerResult(result)
|
||||
|
||||
app.request_handlers[types.ListPromptsRequest] = _list_prompts
|
||||
|
||||
async def _get_prompt(req: types.GetPromptRequest) -> types.ServerResult:
|
||||
result = await remote_app.get_prompt(req.params.name, req.params.arguments)
|
||||
return types.ServerResult(result)
|
||||
|
||||
app.request_handlers[types.GetPromptRequest] = _get_prompt
|
||||
|
||||
if capabilities.resources:
|
||||
|
||||
async def _list_resources(_: t.Any) -> types.ServerResult: # noqa: ANN401
|
||||
result = await remote_app.list_resources()
|
||||
return types.ServerResult(result)
|
||||
|
||||
app.request_handlers[types.ListResourcesRequest] = _list_resources
|
||||
|
||||
# list_resource_templates() is not implemented in the client
|
||||
# async def _list_resource_templates(_: t.Any) -> types.ServerResult:
|
||||
# result = await remote_app.list_resource_templates()
|
||||
# return types.ServerResult(result)
|
||||
|
||||
# app.request_handlers[types.ListResourceTemplatesRequest] = _list_resource_templates
|
||||
|
||||
async def _read_resource(req: types.ReadResourceRequest) -> types.ServerResult:
|
||||
result = await remote_app.read_resource(req.params.uri)
|
||||
return types.ServerResult(result)
|
||||
|
||||
app.request_handlers[types.ReadResourceRequest] = _read_resource
|
||||
|
||||
if capabilities.logging:
|
||||
|
||||
async def _set_logging_level(req: types.SetLevelRequest) -> types.ServerResult:
|
||||
await remote_app.set_logging_level(req.params.level)
|
||||
return types.ServerResult(types.EmptyResult())
|
||||
|
||||
app.request_handlers[types.SetLevelRequest] = _set_logging_level
|
||||
|
||||
if capabilities.resources:
|
||||
|
||||
async def _subscribe_resource(req: types.SubscribeRequest) -> types.ServerResult:
|
||||
await remote_app.subscribe_resource(req.params.uri)
|
||||
return types.ServerResult(types.EmptyResult())
|
||||
|
||||
app.request_handlers[types.SubscribeRequest] = _subscribe_resource
|
||||
|
||||
async def _unsubscribe_resource(req: types.UnsubscribeRequest) -> types.ServerResult:
|
||||
await remote_app.unsubscribe_resource(req.params.uri)
|
||||
return types.ServerResult(types.EmptyResult())
|
||||
|
||||
app.request_handlers[types.UnsubscribeRequest] = _unsubscribe_resource
|
||||
|
||||
if capabilities.tools:
|
||||
|
||||
async def _list_tools(_: t.Any) -> types.ServerResult: # noqa: ANN401
|
||||
tools = await remote_app.list_tools()
|
||||
return types.ServerResult(tools)
|
||||
|
||||
app.request_handlers[types.ListToolsRequest] = _list_tools
|
||||
|
||||
async def _call_tool(req: types.CallToolRequest) -> types.ServerResult:
|
||||
try:
|
||||
result = await remote_app.call_tool(
|
||||
req.params.name,
|
||||
(req.params.arguments or {}),
|
||||
)
|
||||
return types.ServerResult(result)
|
||||
except Exception as e: # noqa: BLE001
|
||||
return types.ServerResult(
|
||||
types.CallToolResult(
|
||||
content=[types.TextContent(type="text", text=str(e))],
|
||||
isError=True,
|
||||
),
|
||||
)
|
||||
|
||||
app.request_handlers[types.CallToolRequest] = _call_tool
|
||||
|
||||
async def _send_progress_notification(req: types.ProgressNotification) -> None:
|
||||
await remote_app.send_progress_notification(
|
||||
req.params.progressToken,
|
||||
req.params.progress,
|
||||
req.params.total,
|
||||
)
|
||||
|
||||
app.notification_handlers[types.ProgressNotification] = _send_progress_notification
|
||||
|
||||
async def _complete(req: types.CompleteRequest) -> types.ServerResult:
|
||||
result = await remote_app.complete(
|
||||
req.params.ref,
|
||||
req.params.argument.model_dump(),
|
||||
)
|
||||
return types.ServerResult(result)
|
||||
|
||||
app.request_handlers[types.CompleteRequest] = _complete
|
||||
|
||||
return app
|
||||
|
||||
|
||||
async def run_sse_client(url: str, api_access_token: str | None = None) -> None:
|
||||
"""Run the SSE client.
|
||||
|
||||
Args:
|
||||
url: The URL to connect to.
|
||||
api_access_token: The API access token to use for authentication.
|
||||
|
||||
"""
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
headers = {}
|
||||
if api_access_token is not None:
|
||||
headers["Authorization"] = f"Bearer {api_access_token}"
|
||||
|
||||
async with sse_client(url=url, headers=headers) as streams, ClientSession(*streams) as session:
|
||||
app = await create_proxy_server(session)
|
||||
async with server.stdio_server() as (read_stream, write_stream):
|
||||
await app.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
app.create_initialization_options(),
|
||||
)
|
||||
"""Library for proxying MCP servers across different transports."""
|
||||
|
||||
@@ -11,7 +11,7 @@ import logging
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from . import run_sse_client
|
||||
from .sse_client import run_sse_client
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
SSE_URL: t.Final[str] = os.getenv("SSE_URL", "")
|
||||
|
||||
119
src/mcp_proxy/proxy_server.py
Normal file
119
src/mcp_proxy/proxy_server.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""Create an MCP server that proxies requests throgh an MCP client.
|
||||
|
||||
This server is created independent of any transport mechanism.
|
||||
"""
|
||||
|
||||
import typing as t
|
||||
|
||||
from mcp import server, types
|
||||
from mcp.client.session import ClientSession
|
||||
|
||||
|
||||
async def create_proxy_server(remote_app: ClientSession) -> server.Server: # noqa: C901
|
||||
"""Create a server instance from a remote app."""
|
||||
response = await remote_app.initialize()
|
||||
capabilities = response.capabilities
|
||||
|
||||
app = server.Server(response.serverInfo.name)
|
||||
|
||||
if capabilities.prompts:
|
||||
|
||||
async def _list_prompts(_: t.Any) -> types.ServerResult: # noqa: ANN401
|
||||
result = await remote_app.list_prompts()
|
||||
return types.ServerResult(result)
|
||||
|
||||
app.request_handlers[types.ListPromptsRequest] = _list_prompts
|
||||
|
||||
async def _get_prompt(req: types.GetPromptRequest) -> types.ServerResult:
|
||||
result = await remote_app.get_prompt(req.params.name, req.params.arguments)
|
||||
return types.ServerResult(result)
|
||||
|
||||
app.request_handlers[types.GetPromptRequest] = _get_prompt
|
||||
|
||||
if capabilities.resources:
|
||||
|
||||
async def _list_resources(_: t.Any) -> types.ServerResult: # noqa: ANN401
|
||||
result = await remote_app.list_resources()
|
||||
return types.ServerResult(result)
|
||||
|
||||
app.request_handlers[types.ListResourcesRequest] = _list_resources
|
||||
|
||||
# list_resource_templates() is not implemented in the client
|
||||
# async def _list_resource_templates(_: t.Any) -> types.ServerResult:
|
||||
# result = await remote_app.list_resource_templates()
|
||||
# return types.ServerResult(result)
|
||||
|
||||
# app.request_handlers[types.ListResourceTemplatesRequest] = _list_resource_templates
|
||||
|
||||
async def _read_resource(req: types.ReadResourceRequest) -> types.ServerResult:
|
||||
result = await remote_app.read_resource(req.params.uri)
|
||||
return types.ServerResult(result)
|
||||
|
||||
app.request_handlers[types.ReadResourceRequest] = _read_resource
|
||||
|
||||
if capabilities.logging:
|
||||
|
||||
async def _set_logging_level(req: types.SetLevelRequest) -> types.ServerResult:
|
||||
await remote_app.set_logging_level(req.params.level)
|
||||
return types.ServerResult(types.EmptyResult())
|
||||
|
||||
app.request_handlers[types.SetLevelRequest] = _set_logging_level
|
||||
|
||||
if capabilities.resources:
|
||||
|
||||
async def _subscribe_resource(req: types.SubscribeRequest) -> types.ServerResult:
|
||||
await remote_app.subscribe_resource(req.params.uri)
|
||||
return types.ServerResult(types.EmptyResult())
|
||||
|
||||
app.request_handlers[types.SubscribeRequest] = _subscribe_resource
|
||||
|
||||
async def _unsubscribe_resource(req: types.UnsubscribeRequest) -> types.ServerResult:
|
||||
await remote_app.unsubscribe_resource(req.params.uri)
|
||||
return types.ServerResult(types.EmptyResult())
|
||||
|
||||
app.request_handlers[types.UnsubscribeRequest] = _unsubscribe_resource
|
||||
|
||||
if capabilities.tools:
|
||||
|
||||
async def _list_tools(_: t.Any) -> types.ServerResult: # noqa: ANN401
|
||||
tools = await remote_app.list_tools()
|
||||
return types.ServerResult(tools)
|
||||
|
||||
app.request_handlers[types.ListToolsRequest] = _list_tools
|
||||
|
||||
async def _call_tool(req: types.CallToolRequest) -> types.ServerResult:
|
||||
try:
|
||||
result = await remote_app.call_tool(
|
||||
req.params.name,
|
||||
(req.params.arguments or {}),
|
||||
)
|
||||
return types.ServerResult(result)
|
||||
except Exception as e: # noqa: BLE001
|
||||
return types.ServerResult(
|
||||
types.CallToolResult(
|
||||
content=[types.TextContent(type="text", text=str(e))],
|
||||
isError=True,
|
||||
),
|
||||
)
|
||||
|
||||
app.request_handlers[types.CallToolRequest] = _call_tool
|
||||
|
||||
async def _send_progress_notification(req: types.ProgressNotification) -> None:
|
||||
await remote_app.send_progress_notification(
|
||||
req.params.progressToken,
|
||||
req.params.progress,
|
||||
req.params.total,
|
||||
)
|
||||
|
||||
app.notification_handlers[types.ProgressNotification] = _send_progress_notification
|
||||
|
||||
async def _complete(req: types.CompleteRequest) -> types.ServerResult:
|
||||
result = await remote_app.complete(
|
||||
req.params.ref,
|
||||
req.params.argument.model_dump(),
|
||||
)
|
||||
return types.ServerResult(result)
|
||||
|
||||
app.request_handlers[types.CompleteRequest] = _complete
|
||||
|
||||
return app
|
||||
29
src/mcp_proxy/sse_client.py
Normal file
29
src/mcp_proxy/sse_client.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""Create a local server that proxies requests to a remote server over SSE."""
|
||||
|
||||
from mcp import server
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
from .proxy_server import create_proxy_server
|
||||
|
||||
|
||||
async def run_sse_client(url: str, api_access_token: str | None = None) -> None:
|
||||
"""Run the SSE client.
|
||||
|
||||
Args:
|
||||
url: The URL to connect to.
|
||||
api_access_token: The API access token to use for authentication.
|
||||
|
||||
"""
|
||||
headers = {}
|
||||
if api_access_token is not None:
|
||||
headers["Authorization"] = f"Bearer {api_access_token}"
|
||||
|
||||
async with sse_client(url=url, headers=headers) as streams, ClientSession(*streams) as session:
|
||||
app = await create_proxy_server(session)
|
||||
async with server.stdio_server() as (read_stream, write_stream):
|
||||
await app.run(
|
||||
read_stream,
|
||||
write_stream,
|
||||
app.create_initialization_options(),
|
||||
)
|
||||
@@ -21,7 +21,7 @@ from mcp.shared.exceptions import McpError
|
||||
from mcp.shared.memory import create_connected_server_and_client_session
|
||||
from pydantic import AnyUrl
|
||||
|
||||
from mcp_proxy import create_proxy_server
|
||||
from mcp_proxy.proxy_server import create_proxy_server
|
||||
|
||||
TOOL_INPUT_SCHEMA = {"type": "object", "properties": {"input1": {"type": "string"}}}
|
||||
|
||||
Reference in New Issue
Block a user