11
.vscode/launch.json
vendored
11
.vscode/launch.json
vendored
@@ -6,6 +6,17 @@
|
||||
"type": "python",
|
||||
"request": "test",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Debug mcp-proxy",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false,
|
||||
"python": "${command:python.interpreterPath}",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"module": "mcp_proxy",
|
||||
"args": ["--sse-port=8080", "--debug", "--", "uvx", "mcp-server-fetch"]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -260,6 +260,7 @@ stdio client options:
|
||||
Environment variables used when spawning the server. Can be used multiple times.
|
||||
--pass-environment, --no-pass-environment
|
||||
Pass through all environment variables when spawning the server.
|
||||
--debug, --no-debug Enable debug mode with detailed logging output.
|
||||
|
||||
SSE server options:
|
||||
--sse-port SSE_PORT Port to expose an SSE server on. Default is a random port
|
||||
|
||||
@@ -18,7 +18,6 @@ from mcp.client.stdio import StdioServerParameters
|
||||
from .sse_client import run_sse_client
|
||||
from .sse_server import SseServerSettings, run_sse_server
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
SSE_URL: t.Final[str | None] = os.getenv(
|
||||
"SSE_URL",
|
||||
None,
|
||||
@@ -84,6 +83,12 @@ def main() -> None:
|
||||
help="Pass through all environment variables when spawning the server.",
|
||||
default=False,
|
||||
)
|
||||
stdio_client_options.add_argument(
|
||||
"--debug",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Enable debug mode with detailed logging output.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
sse_server_group = parser.add_argument_group("SSE server options")
|
||||
sse_server_group.add_argument(
|
||||
@@ -110,13 +115,16 @@ def main() -> None:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if (
|
||||
SSE_URL
|
||||
or args.command_or_url.startswith("http://")
|
||||
or args.command_or_url.startswith("https://")
|
||||
):
|
||||
# Start a client connected to the SSE server, and expose as a stdio server
|
||||
logging.debug("Starting SSE client and stdio server")
|
||||
logger.debug("Starting SSE client and stdio server")
|
||||
headers = dict(args.headers)
|
||||
if api_access_token := os.getenv("API_ACCESS_TOKEN", None):
|
||||
headers["Authorization"] = f"Bearer {api_access_token}"
|
||||
@@ -124,7 +132,7 @@ def main() -> None:
|
||||
return
|
||||
|
||||
# Start a client connected to the given command, and expose as an SSE server
|
||||
logging.debug("Starting stdio client and SSE server")
|
||||
logger.debug("Starting stdio client and SSE server")
|
||||
|
||||
# The environment variables passed to the server process
|
||||
env: dict[str, str] = {}
|
||||
@@ -143,6 +151,7 @@ def main() -> None:
|
||||
bind_host=args.sse_host,
|
||||
port=args.sse_port,
|
||||
allow_origins=args.allow_origin if len(args.allow_origin) > 0 else None,
|
||||
log_level="DEBUG" if args.debug else "INFO",
|
||||
)
|
||||
asyncio.run(run_sse_server(stdio_params, sse_settings))
|
||||
|
||||
|
||||
@@ -3,20 +3,26 @@
|
||||
This server is created independent of any transport mechanism.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
from mcp import server, types
|
||||
from mcp.client.session import ClientSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def create_proxy_server(remote_app: ClientSession) -> server.Server[object]: # noqa: C901
|
||||
|
||||
async def create_proxy_server(remote_app: ClientSession) -> server.Server[object]: # noqa: C901, PLR0915
|
||||
"""Create a server instance from a remote app."""
|
||||
logger.debug("Sending initalization request to remote MCP server...")
|
||||
response = await remote_app.initialize()
|
||||
capabilities = response.capabilities
|
||||
|
||||
logger.debug("Configuring proxied MCP server...")
|
||||
app: server.Server[object] = server.Server(name=response.serverInfo.name)
|
||||
|
||||
if capabilities.prompts:
|
||||
logger.debug("Capabilities: adding Prompts...")
|
||||
|
||||
async def _list_prompts(_: t.Any) -> types.ServerResult: # noqa: ANN401
|
||||
result = await remote_app.list_prompts()
|
||||
@@ -31,6 +37,7 @@ async def create_proxy_server(remote_app: ClientSession) -> server.Server[object
|
||||
app.request_handlers[types.GetPromptRequest] = _get_prompt
|
||||
|
||||
if capabilities.resources:
|
||||
logger.debug("Capabilities: adding Resources...")
|
||||
|
||||
async def _list_resources(_: t.Any) -> types.ServerResult: # noqa: ANN401
|
||||
result = await remote_app.list_resources()
|
||||
@@ -38,12 +45,11 @@ async def create_proxy_server(remote_app: ClientSession) -> server.Server[object
|
||||
|
||||
app.request_handlers[types.ListResourcesRequest] = _list_resources
|
||||
|
||||
# list_resource_templates() is not implemented in the client
|
||||
# async def _list_resource_templates(_: t.Any) -> types.ServerResult:
|
||||
# result = await remote_app.list_resource_templates()
|
||||
# return types.ServerResult(result)
|
||||
async def _list_resource_templates(_: t.Any) -> types.ServerResult: # noqa: ANN401
|
||||
result = await remote_app.list_resource_templates()
|
||||
return types.ServerResult(result)
|
||||
|
||||
# app.request_handlers[types.ListResourceTemplatesRequest] = _list_resource_templates
|
||||
app.request_handlers[types.ListResourceTemplatesRequest] = _list_resource_templates
|
||||
|
||||
async def _read_resource(req: types.ReadResourceRequest) -> types.ServerResult:
|
||||
result = await remote_app.read_resource(req.params.uri)
|
||||
@@ -52,6 +58,7 @@ async def create_proxy_server(remote_app: ClientSession) -> server.Server[object
|
||||
app.request_handlers[types.ReadResourceRequest] = _read_resource
|
||||
|
||||
if capabilities.logging:
|
||||
logger.debug("Capabilities: adding Logging...")
|
||||
|
||||
async def _set_logging_level(req: types.SetLevelRequest) -> types.ServerResult:
|
||||
await remote_app.set_logging_level(req.params.level)
|
||||
@@ -60,6 +67,7 @@ async def create_proxy_server(remote_app: ClientSession) -> server.Server[object
|
||||
app.request_handlers[types.SetLevelRequest] = _set_logging_level
|
||||
|
||||
if capabilities.resources:
|
||||
logger.debug("Capabilities: adding Resources...")
|
||||
|
||||
async def _subscribe_resource(req: types.SubscribeRequest) -> types.ServerResult:
|
||||
await remote_app.subscribe_resource(req.params.uri)
|
||||
@@ -74,6 +82,7 @@ async def create_proxy_server(remote_app: ClientSession) -> server.Server[object
|
||||
app.request_handlers[types.UnsubscribeRequest] = _unsubscribe_resource
|
||||
|
||||
if capabilities.tools:
|
||||
logger.debug("Capabilities: adding Tools...")
|
||||
|
||||
async def _list_tools(_: t.Any) -> types.ServerResult: # noqa: ANN401
|
||||
tools = await remote_app.list_tools()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Create a local SSE server that proxies requests to a stdio MCP server."""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
@@ -16,6 +17,8 @@ from starlette.routing import Mount, Route
|
||||
|
||||
from .proxy_server import create_proxy_server
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SseServerSettings:
|
||||
@@ -81,6 +84,7 @@ async def run_sse_server(
|
||||
|
||||
"""
|
||||
async with stdio_client(stdio_params) as streams, ClientSession(*streams) as session:
|
||||
logger.debug("Starting proxy server...")
|
||||
mcp_server = await create_proxy_server(session)
|
||||
|
||||
# Bind SSE request handling to MCP server
|
||||
@@ -98,4 +102,9 @@ async def run_sse_server(
|
||||
log_level=sse_settings.log_level.lower(),
|
||||
)
|
||||
http_server = uvicorn.Server(config)
|
||||
logger.debug(
|
||||
"Serving incoming requests on %s:%s",
|
||||
sse_settings.bind_host,
|
||||
sse_settings.port,
|
||||
)
|
||||
await http_server.serve()
|
||||
|
||||
@@ -110,6 +110,20 @@ def server_can_list_resources(server: Server[object], resource: types.Resource)
|
||||
return server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server_can_list_resource_templates(
|
||||
server_can_list_resources: Server[object],
|
||||
resource_template: types.ResourceTemplate,
|
||||
) -> Server[object]:
|
||||
"""Return a server instance with resources."""
|
||||
|
||||
@server_can_list_resources.list_resource_templates() # type: ignore[no-untyped-call,misc]
|
||||
async def _() -> list[types.ResourceTemplate]:
|
||||
return [resource_template]
|
||||
|
||||
return server_can_list_resources
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server_can_subscribe_resource(
|
||||
server_can_list_resources: Server[object],
|
||||
@@ -307,6 +321,39 @@ async def test_list_resources(
|
||||
assert list_resources_result.resources == [resource]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"resource",
|
||||
[
|
||||
types.Resource(
|
||||
uri=AnyUrl("scheme://resource-uri"),
|
||||
name="resource-name",
|
||||
description="resource-description",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"resource_template",
|
||||
[
|
||||
types.ResourceTemplate(
|
||||
uriTemplate="scheme://resource-uri/{resource}",
|
||||
name="resource-name",
|
||||
description="resource-description",
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_list_resource_templates(
|
||||
session_generator: SessionContextManager,
|
||||
server_can_list_resource_templates: Server[object],
|
||||
resource_template: types.ResourceTemplate,
|
||||
) -> None:
|
||||
"""Test get_resource."""
|
||||
async with session_generator(server_can_list_resource_templates) as session:
|
||||
await session.initialize()
|
||||
|
||||
list_resources_result = await session.list_resource_templates()
|
||||
assert list_resources_result.resourceTemplates == [resource_template]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prompt_callback", [AsyncMock()])
|
||||
@pytest.mark.parametrize("prompt", [types.Prompt(name="prompt1")])
|
||||
async def test_get_prompt(
|
||||
|
||||
Reference in New Issue
Block a user