Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mcp_proxy_for_aws/middleware/tool_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import logging
from collections.abc import Awaitable, Callable
from fastmcp.server.middleware import Middleware, MiddlewareContext
from fastmcp.tools.tool import Tool
from fastmcp.tools import Tool
from typing import Sequence


Expand Down
68 changes: 2 additions & 66 deletions mcp_proxy_for_aws/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,84 +16,20 @@
import logging
from fastmcp import Client
from fastmcp.client.transports import ClientTransport
from fastmcp.exceptions import NotFoundError
from fastmcp.server.proxy import ClientFactoryT
from fastmcp.server.proxy import FastMCPProxy as _FastMCPProxy
from fastmcp.server.proxy import ProxyClient as _ProxyClient
from fastmcp.server.proxy import ProxyToolManager as _ProxyToolManager
from fastmcp.tools import Tool
from fastmcp.server.providers.proxy import ProxyClient as _ProxyClient
from mcp import McpError
from mcp.types import InitializeRequest, JSONRPCError, JSONRPCMessage
from typing import Any
from typing_extensions import override


logger = logging.getLogger(__name__)


class AWSProxyToolManager(_ProxyToolManager):
"""Customized proxy tool manager that better suites our needs."""

def __init__(self, client_factory: ClientFactoryT, **kwargs: Any):
"""Initialize a proxy tool manager.

Cached tools are set to None.
"""
super().__init__(client_factory, **kwargs)
self._cached_tools: dict[str, Tool] | None = None

@override
async def get_tool(self, key: str) -> Tool:
"""Return the tool from cached tools.

This method is invoked when the client tries to call a tool.

tool = self.get_tool(key)
tool.invoke(...)

The parent class implementation always make a mcp call to list the tools.
Since the client already knows the name of the tools, list_tool is not necessary.
We are wasting a network call just to get the tools which were already listed.

In case the server supports notifications/tools/listChanged, the `get_tools` method
will be called explicity , hence, we are not missing the change to the tool list.
"""
if self._cached_tools is None:
logger.debug('cached_tools not found, calling get_tools')
self._cached_tools = await self.get_tools()
if key in self._cached_tools:
return self._cached_tools[key]
raise NotFoundError(f'Tool {key!r} not found')

@override
async def get_tools(self) -> dict[str, Tool]:
"""Return list tools."""
self._cached_tools = await super(AWSProxyToolManager, self).get_tools()
return self._cached_tools


class AWSMCPProxy(_FastMCPProxy):
"""Customized MCP Proxy to better suite our needs."""

def __init__(
self,
*,
client_factory: ClientFactoryT,
**kwargs,
):
"""Initialize a client."""
super().__init__(client_factory=client_factory, **kwargs)
self._tool_manager = AWSProxyToolManager(
client_factory=self.client_factory,
transformations=self._tool_manager.transformations,
)


class AWSMCPProxyClient(_ProxyClient):
"""Proxy client that handles HTTP errors when connection fails."""

def __init__(self, transport: ClientTransport, max_connect_retry=3, **kwargs):
"""Constructor of AutoRefreshProxyCilent."""
"""Constructor of AWSMCPProxyClient."""
super().__init__(transport, **kwargs)
self._max_connect_retry = max_connect_retry

Expand Down
5 changes: 3 additions & 2 deletions mcp_proxy_for_aws/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@
import logging
from fastmcp.server.middleware.error_handling import RetryMiddleware
from fastmcp.server.middleware.logging import LoggingMiddleware
from fastmcp.server.providers.proxy import FastMCPProxy
from fastmcp.server.server import FastMCP
from mcp_proxy_for_aws import __version__
from mcp_proxy_for_aws.cli import parse_args
from mcp_proxy_for_aws.logging_config import configure_logging
from mcp_proxy_for_aws.middleware.initialize_middleware import InitializeMiddleware
from mcp_proxy_for_aws.middleware.tool_filter import ToolFilteringMiddleware
from mcp_proxy_for_aws.proxy import AWSMCPProxy, AWSMCPProxyClientFactory
from mcp_proxy_for_aws.proxy import AWSMCPProxyClientFactory
from mcp_proxy_for_aws.utils import (
create_transport_with_sigv4,
determine_aws_region,
Expand Down Expand Up @@ -87,7 +88,7 @@ async def run_proxy(args) -> None:
client_factory = AWSMCPProxyClientFactory(transport)

try:
proxy = AWSMCPProxy(
proxy = FastMCPProxy(
client_factory=client_factory,
name='MCP Proxy for AWS',
version=__version__,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ description = "MCP Proxy for AWS"
readme = "README.md"
requires-python = ">=3.10,<3.15"
dependencies = [
"fastmcp~=2.14.1",
"fastmcp (>=3.2.0,<4)",
"boto3>=1.41.0",
"botocore[crt]>=1.41.0",
]
Expand Down
6 changes: 3 additions & 3 deletions tests/integ/mcp/simple_mcp_server/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,16 @@ def greet(name: str):


@mcp.tool
def add_tool_multiply(ctx: Context):
async def add_tool_multiply(ctx: Context):
"""MCP Tool used for testing dynamic tool behavior through the proxy."""
if not ctx.get_state('multiply_registered'):
if not await ctx.get_state('multiply_registered'):

@mcp.tool
def multiply(x: int, y: int):
"""Multiply two numbers."""
return x * y

ctx.set_state('multiply_registered', True)
await ctx.set_state('multiply_registered', True)
return 'Tool "multiply" added successfully'
return 'Tool "multiply" already exists'

Expand Down
60 changes: 0 additions & 60 deletions tests/unit/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,75 +17,15 @@
import httpx
import pytest
from fastmcp.client.transports import ClientTransport
from fastmcp.exceptions import NotFoundError
from fastmcp.tools import Tool
from mcp import McpError
from mcp.types import ErrorData, InitializeRequest, JSONRPCError
from mcp_proxy_for_aws.proxy import (
AWSMCPProxy,
AWSMCPProxyClient,
AWSMCPProxyClientFactory,
AWSProxyToolManager,
)
from unittest.mock import AsyncMock, Mock, patch


@pytest.mark.asyncio
async def test_tool_manager_get_tool_with_cache():
"""Test get_tool returns from cache when available."""
mock_factory = Mock()
manager = AWSProxyToolManager(mock_factory)
mock_tool = Mock(spec=Tool)
manager._cached_tools = {'test_tool': mock_tool}

result = await manager.get_tool('test_tool')
assert result == mock_tool


@pytest.mark.asyncio
async def test_tool_manager_get_tool_without_cache():
"""Test get_tool fetches tools when cache is empty."""
mock_factory = Mock()
manager = AWSProxyToolManager(mock_factory)
mock_tool = Mock(spec=Tool)

with patch.object(manager, 'get_tools', return_value={'test_tool': mock_tool}):
result = await manager.get_tool('test_tool')
assert result == mock_tool
assert manager._cached_tools == {'test_tool': mock_tool}


@pytest.mark.asyncio
async def test_tool_manager_get_tool_not_found():
"""Test get_tool raises NotFoundError when tool doesn't exist."""
mock_factory = Mock()
manager = AWSProxyToolManager(mock_factory)
manager._cached_tools = {}

with pytest.raises(NotFoundError, match="Tool 'missing_tool' not found"):
await manager.get_tool('missing_tool')


@pytest.mark.asyncio
async def test_tool_manager_get_tools_updates_cache():
"""Test get_tools updates the cache."""
mock_factory = Mock()
manager = AWSProxyToolManager(mock_factory)
mock_tools = {'tool1': Mock(spec=Tool), 'tool2': Mock(spec=Tool)}

with patch('mcp_proxy_for_aws.proxy._ProxyToolManager.get_tools', return_value=mock_tools):
result = await manager.get_tools()
assert result == mock_tools
assert manager._cached_tools == mock_tools


def test_proxy_initialization():
"""Test AWSMCPProxy initializes with custom tool manager."""
mock_factory = Mock()
proxy = AWSMCPProxy(client_factory=mock_factory, name='test')
assert isinstance(proxy._tool_manager, AWSProxyToolManager)


@pytest.mark.asyncio
async def test_proxy_client_connect_success():
"""Test successful connection."""
Expand Down
28 changes: 14 additions & 14 deletions tests/unit/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class TestServer:

@patch('mcp_proxy_for_aws.server.AWSMCPProxyClientFactory')
@patch('mcp_proxy_for_aws.server.create_transport_with_sigv4')
@patch('mcp_proxy_for_aws.server.AWSMCPProxy')
@patch('mcp_proxy_for_aws.server.FastMCPProxy')
@patch('mcp_proxy_for_aws.server.determine_aws_region')
@patch('mcp_proxy_for_aws.server.determine_service_name')
@patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware')
Expand All @@ -43,7 +43,7 @@ async def test_setup_mcp_mode(
mock_add_filtering,
mock_determine_service,
mock_determine_region,
mock_aws_proxy,
mock_fastmcp_proxy,
mock_create_transport,
mock_client_factory_class,
):
Expand Down Expand Up @@ -79,7 +79,7 @@ async def test_setup_mcp_mode(
mock_proxy = Mock()
mock_proxy.run_async = AsyncMock()
mock_proxy.add_middleware = Mock()
mock_aws_proxy.return_value = mock_proxy
mock_fastmcp_proxy.return_value = mock_proxy

# Act
await run_proxy(mock_args)
Expand All @@ -97,7 +97,7 @@ async def test_setup_mcp_mode(
# call_args[0][4] is the Timeout object
assert call_args[0][5] is None # profile
mock_client_factory_class.assert_called_once_with(mock_transport)
mock_aws_proxy.assert_called_once()
mock_fastmcp_proxy.assert_called_once()
mock_add_filtering.assert_called_once_with(mock_proxy, True)
mock_add_retry.assert_called_once_with(mock_proxy, 1)
mock_proxy.run_async.assert_called_once_with(
Expand All @@ -106,7 +106,7 @@ async def test_setup_mcp_mode(

@patch('mcp_proxy_for_aws.server.AWSMCPProxyClientFactory')
@patch('mcp_proxy_for_aws.server.create_transport_with_sigv4')
@patch('mcp_proxy_for_aws.server.AWSMCPProxy')
@patch('mcp_proxy_for_aws.server.FastMCPProxy')
@patch('mcp_proxy_for_aws.server.determine_aws_region')
@patch('mcp_proxy_for_aws.server.determine_service_name')
@patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware')
Expand All @@ -115,7 +115,7 @@ async def test_setup_mcp_mode_no_retries(
mock_add_filtering,
mock_determine_service,
mock_determine_region,
mock_aws_proxy,
mock_fastmcp_proxy,
mock_create_transport,
mock_client_factory_class,
):
Expand Down Expand Up @@ -151,7 +151,7 @@ async def test_setup_mcp_mode_no_retries(
mock_proxy = Mock()
mock_proxy.run_async = AsyncMock()
mock_proxy.add_middleware = Mock()
mock_aws_proxy.return_value = mock_proxy
mock_fastmcp_proxy.return_value = mock_proxy

# Act
await run_proxy(mock_args)
Expand All @@ -172,15 +172,15 @@ async def test_setup_mcp_mode_no_retries(
# call_args[0][4] is the Timeout object
assert call_args[0][5] == 'test-profile' # profile
mock_client_factory_class.assert_called_once_with(mock_transport)
mock_aws_proxy.assert_called_once()
mock_fastmcp_proxy.assert_called_once()
mock_add_filtering.assert_called_once_with(mock_proxy, False)
mock_proxy.run_async.assert_called_once_with(
transport='stdio', show_banner=False, log_level='INFO'
)

@patch('mcp_proxy_for_aws.server.AWSMCPProxyClientFactory')
@patch('mcp_proxy_for_aws.server.create_transport_with_sigv4')
@patch('mcp_proxy_for_aws.server.AWSMCPProxy')
@patch('mcp_proxy_for_aws.server.FastMCPProxy')
@patch('mcp_proxy_for_aws.server.determine_aws_region')
@patch('mcp_proxy_for_aws.server.determine_service_name')
@patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware')
Expand All @@ -189,7 +189,7 @@ async def test_setup_mcp_mode_no_metadata_injects_aws_region(
mock_add_filtering,
mock_determine_service,
mock_determine_region,
mock_aws_proxy,
mock_fastmcp_proxy,
mock_create_transport,
mock_client_factory_class,
):
Expand Down Expand Up @@ -222,7 +222,7 @@ async def test_setup_mcp_mode_no_metadata_injects_aws_region(
mock_proxy = Mock()
mock_proxy.run_async = AsyncMock()
mock_proxy.add_middleware = Mock()
mock_aws_proxy.return_value = mock_proxy
mock_fastmcp_proxy.return_value = mock_proxy

# Act
await run_proxy(mock_args)
Expand All @@ -235,7 +235,7 @@ async def test_setup_mcp_mode_no_metadata_injects_aws_region(

@patch('mcp_proxy_for_aws.server.AWSMCPProxyClientFactory')
@patch('mcp_proxy_for_aws.server.create_transport_with_sigv4')
@patch('mcp_proxy_for_aws.server.AWSMCPProxy')
@patch('mcp_proxy_for_aws.server.FastMCPProxy')
@patch('mcp_proxy_for_aws.server.determine_aws_region')
@patch('mcp_proxy_for_aws.server.determine_service_name')
@patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware')
Expand All @@ -244,7 +244,7 @@ async def test_setup_mcp_mode_metadata_without_aws_region_injects_it(
mock_add_filtering,
mock_determine_service,
mock_determine_region,
mock_aws_proxy,
mock_fastmcp_proxy,
mock_create_transport,
mock_client_factory_class,
):
Expand Down Expand Up @@ -277,7 +277,7 @@ async def test_setup_mcp_mode_metadata_without_aws_region_injects_it(
mock_proxy = Mock()
mock_proxy.run_async = AsyncMock()
mock_proxy.add_middleware = Mock()
mock_aws_proxy.return_value = mock_proxy
mock_fastmcp_proxy.return_value = mock_proxy

# Act
await run_proxy(mock_args)
Expand Down
Loading
Loading