Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mcp_proxy_for_aws/middleware/tool_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import logging
from collections.abc import Awaitable, Callable
from fastmcp.server.middleware import Middleware, MiddlewareContext
from fastmcp.tools.tool import Tool
from fastmcp.tools import Tool
from typing import Sequence


Expand Down
68 changes: 2 additions & 66 deletions mcp_proxy_for_aws/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,84 +16,20 @@
import logging
from fastmcp import Client
from fastmcp.client.transports import ClientTransport
from fastmcp.exceptions import NotFoundError
from fastmcp.server.proxy import ClientFactoryT
from fastmcp.server.proxy import FastMCPProxy as _FastMCPProxy
from fastmcp.server.proxy import ProxyClient as _ProxyClient
from fastmcp.server.proxy import ProxyToolManager as _ProxyToolManager
from fastmcp.tools import Tool
from fastmcp.server.providers.proxy import ProxyClient as _ProxyClient
from mcp import McpError
from mcp.types import InitializeRequest, JSONRPCError, JSONRPCMessage
from typing import Any
from typing_extensions import override


logger = logging.getLogger(__name__)


class AWSProxyToolManager(_ProxyToolManager):
"""Customized proxy tool manager that better suites our needs."""

def __init__(self, client_factory: ClientFactoryT, **kwargs: Any):
"""Initialize a proxy tool manager.

Cached tools are set to None.
"""
super().__init__(client_factory, **kwargs)
self._cached_tools: dict[str, Tool] | None = None

@override
async def get_tool(self, key: str) -> Tool:
"""Return the tool from cached tools.

This method is invoked when the client tries to call a tool.

tool = self.get_tool(key)
tool.invoke(...)

The parent class implementation always make a mcp call to list the tools.
Since the client already knows the name of the tools, list_tool is not necessary.
We are wasting a network call just to get the tools which were already listed.

In case the server supports notifications/tools/listChanged, the `get_tools` method
will be called explicity , hence, we are not missing the change to the tool list.
"""
if self._cached_tools is None:
logger.debug('cached_tools not found, calling get_tools')
self._cached_tools = await self.get_tools()
if key in self._cached_tools:
return self._cached_tools[key]
raise NotFoundError(f'Tool {key!r} not found')

@override
async def get_tools(self) -> dict[str, Tool]:
"""Return list tools."""
self._cached_tools = await super(AWSProxyToolManager, self).get_tools()
return self._cached_tools


class AWSMCPProxy(_FastMCPProxy):
"""Customized MCP Proxy to better suite our needs."""

def __init__(
self,
*,
client_factory: ClientFactoryT,
**kwargs,
):
"""Initialize a client."""
super().__init__(client_factory=client_factory, **kwargs)
self._tool_manager = AWSProxyToolManager(
client_factory=self.client_factory,
transformations=self._tool_manager.transformations,
)


class AWSMCPProxyClient(_ProxyClient):
"""Proxy client that handles HTTP errors when connection fails."""

def __init__(self, transport: ClientTransport, max_connect_retry=3, **kwargs):
"""Constructor of AutoRefreshProxyCilent."""
"""Constructor of AWSMCPProxyClient."""
super().__init__(transport, **kwargs)
self._max_connect_retry = max_connect_retry

Expand Down
5 changes: 3 additions & 2 deletions mcp_proxy_for_aws/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@
import logging
from fastmcp.server.middleware.error_handling import RetryMiddleware
from fastmcp.server.middleware.logging import LoggingMiddleware
from fastmcp.server.providers.proxy import FastMCPProxy
from fastmcp.server.server import FastMCP
from mcp_proxy_for_aws import __version__
from mcp_proxy_for_aws.cli import parse_args
from mcp_proxy_for_aws.logging_config import configure_logging
from mcp_proxy_for_aws.middleware.initialize_middleware import InitializeMiddleware
from mcp_proxy_for_aws.middleware.tool_filter import ToolFilteringMiddleware
from mcp_proxy_for_aws.proxy import AWSMCPProxy, AWSMCPProxyClientFactory
from mcp_proxy_for_aws.proxy import AWSMCPProxyClientFactory
from mcp_proxy_for_aws.utils import (
create_transport_with_sigv4,
determine_aws_region,
Expand Down Expand Up @@ -87,7 +88,7 @@ async def run_proxy(args) -> None:
client_factory = AWSMCPProxyClientFactory(transport)

try:
proxy = AWSMCPProxy(
proxy = FastMCPProxy(
client_factory=client_factory,
name='MCP Proxy for AWS',
version=__version__,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ description = "MCP Proxy for AWS"
readme = "README.md"
requires-python = ">=3.10,<3.15"
dependencies = [
"fastmcp~=2.14.1",
"fastmcp~=3.2.0",
"boto3>=1.41.0",
"botocore[crt]>=1.41.0",
]
Expand Down
6 changes: 3 additions & 3 deletions tests/integ/mcp/simple_mcp_server/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,16 @@ def greet(name: str):


@mcp.tool
def add_tool_multiply(ctx: Context):
async def add_tool_multiply(ctx: Context):
"""MCP Tool used for testing dynamic tool behavior through the proxy."""
if not ctx.get_state('multiply_registered'):
if not await ctx.get_state('multiply_registered'):

@mcp.tool
def multiply(x: int, y: int):
"""Multiply two numbers."""
return x * y

ctx.set_state('multiply_registered', True)
await ctx.set_state('multiply_registered', True)
return 'Tool "multiply" added successfully'
return 'Tool "multiply" already exists'

Expand Down
60 changes: 0 additions & 60 deletions tests/unit/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,75 +17,15 @@
import httpx
import pytest
from fastmcp.client.transports import ClientTransport
from fastmcp.exceptions import NotFoundError
from fastmcp.tools import Tool
from mcp import McpError
from mcp.types import ErrorData, InitializeRequest, JSONRPCError
from mcp_proxy_for_aws.proxy import (
AWSMCPProxy,
AWSMCPProxyClient,
AWSMCPProxyClientFactory,
AWSProxyToolManager,
)
from unittest.mock import AsyncMock, Mock, patch


@pytest.mark.asyncio
async def test_tool_manager_get_tool_with_cache():
"""Test get_tool returns from cache when available."""
mock_factory = Mock()
manager = AWSProxyToolManager(mock_factory)
mock_tool = Mock(spec=Tool)
manager._cached_tools = {'test_tool': mock_tool}

result = await manager.get_tool('test_tool')
assert result == mock_tool


@pytest.mark.asyncio
async def test_tool_manager_get_tool_without_cache():
"""Test get_tool fetches tools when cache is empty."""
mock_factory = Mock()
manager = AWSProxyToolManager(mock_factory)
mock_tool = Mock(spec=Tool)

with patch.object(manager, 'get_tools', return_value={'test_tool': mock_tool}):
result = await manager.get_tool('test_tool')
assert result == mock_tool
assert manager._cached_tools == {'test_tool': mock_tool}


@pytest.mark.asyncio
async def test_tool_manager_get_tool_not_found():
"""Test get_tool raises NotFoundError when tool doesn't exist."""
mock_factory = Mock()
manager = AWSProxyToolManager(mock_factory)
manager._cached_tools = {}

with pytest.raises(NotFoundError, match="Tool 'missing_tool' not found"):
await manager.get_tool('missing_tool')


@pytest.mark.asyncio
async def test_tool_manager_get_tools_updates_cache():
"""Test get_tools updates the cache."""
mock_factory = Mock()
manager = AWSProxyToolManager(mock_factory)
mock_tools = {'tool1': Mock(spec=Tool), 'tool2': Mock(spec=Tool)}

with patch('mcp_proxy_for_aws.proxy._ProxyToolManager.get_tools', return_value=mock_tools):
result = await manager.get_tools()
assert result == mock_tools
assert manager._cached_tools == mock_tools


def test_proxy_initialization():
"""Test AWSMCPProxy initializes with custom tool manager."""
mock_factory = Mock()
proxy = AWSMCPProxy(client_factory=mock_factory, name='test')
assert isinstance(proxy._tool_manager, AWSProxyToolManager)


@pytest.mark.asyncio
async def test_proxy_client_connect_success():
"""Test successful connection."""
Expand Down
28 changes: 14 additions & 14 deletions tests/unit/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class TestServer:

@patch('mcp_proxy_for_aws.server.AWSMCPProxyClientFactory')
@patch('mcp_proxy_for_aws.server.create_transport_with_sigv4')
@patch('mcp_proxy_for_aws.server.AWSMCPProxy')
@patch('mcp_proxy_for_aws.server.FastMCPProxy')
@patch('mcp_proxy_for_aws.server.determine_aws_region')
@patch('mcp_proxy_for_aws.server.determine_service_name')
@patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware')
Expand All @@ -43,7 +43,7 @@ async def test_setup_mcp_mode(
mock_add_filtering,
mock_determine_service,
mock_determine_region,
mock_aws_proxy,
mock_fastmcp_proxy,
mock_create_transport,
mock_client_factory_class,
):
Expand Down Expand Up @@ -79,7 +79,7 @@ async def test_setup_mcp_mode(
mock_proxy = Mock()
mock_proxy.run_async = AsyncMock()
mock_proxy.add_middleware = Mock()
mock_aws_proxy.return_value = mock_proxy
mock_fastmcp_proxy.return_value = mock_proxy

# Act
await run_proxy(mock_args)
Expand All @@ -97,7 +97,7 @@ async def test_setup_mcp_mode(
# call_args[0][4] is the Timeout object
assert call_args[0][5] is None # profile
mock_client_factory_class.assert_called_once_with(mock_transport)
mock_aws_proxy.assert_called_once()
mock_fastmcp_proxy.assert_called_once()
mock_add_filtering.assert_called_once_with(mock_proxy, True)
mock_add_retry.assert_called_once_with(mock_proxy, 1)
mock_proxy.run_async.assert_called_once_with(
Expand All @@ -106,7 +106,7 @@ async def test_setup_mcp_mode(

@patch('mcp_proxy_for_aws.server.AWSMCPProxyClientFactory')
@patch('mcp_proxy_for_aws.server.create_transport_with_sigv4')
@patch('mcp_proxy_for_aws.server.AWSMCPProxy')
@patch('mcp_proxy_for_aws.server.FastMCPProxy')
@patch('mcp_proxy_for_aws.server.determine_aws_region')
@patch('mcp_proxy_for_aws.server.determine_service_name')
@patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware')
Expand All @@ -115,7 +115,7 @@ async def test_setup_mcp_mode_no_retries(
mock_add_filtering,
mock_determine_service,
mock_determine_region,
mock_aws_proxy,
mock_fastmcp_proxy,
mock_create_transport,
mock_client_factory_class,
):
Expand Down Expand Up @@ -151,7 +151,7 @@ async def test_setup_mcp_mode_no_retries(
mock_proxy = Mock()
mock_proxy.run_async = AsyncMock()
mock_proxy.add_middleware = Mock()
mock_aws_proxy.return_value = mock_proxy
mock_fastmcp_proxy.return_value = mock_proxy

# Act
await run_proxy(mock_args)
Expand All @@ -172,15 +172,15 @@ async def test_setup_mcp_mode_no_retries(
# call_args[0][4] is the Timeout object
assert call_args[0][5] == 'test-profile' # profile
mock_client_factory_class.assert_called_once_with(mock_transport)
mock_aws_proxy.assert_called_once()
mock_fastmcp_proxy.assert_called_once()
mock_add_filtering.assert_called_once_with(mock_proxy, False)
mock_proxy.run_async.assert_called_once_with(
transport='stdio', show_banner=False, log_level='INFO'
)

@patch('mcp_proxy_for_aws.server.AWSMCPProxyClientFactory')
@patch('mcp_proxy_for_aws.server.create_transport_with_sigv4')
@patch('mcp_proxy_for_aws.server.AWSMCPProxy')
@patch('mcp_proxy_for_aws.server.FastMCPProxy')
@patch('mcp_proxy_for_aws.server.determine_aws_region')
@patch('mcp_proxy_for_aws.server.determine_service_name')
@patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware')
Expand All @@ -189,7 +189,7 @@ async def test_setup_mcp_mode_no_metadata_injects_aws_region(
mock_add_filtering,
mock_determine_service,
mock_determine_region,
mock_aws_proxy,
mock_fastmcp_proxy,
mock_create_transport,
mock_client_factory_class,
):
Expand Down Expand Up @@ -222,7 +222,7 @@ async def test_setup_mcp_mode_no_metadata_injects_aws_region(
mock_proxy = Mock()
mock_proxy.run_async = AsyncMock()
mock_proxy.add_middleware = Mock()
mock_aws_proxy.return_value = mock_proxy
mock_fastmcp_proxy.return_value = mock_proxy

# Act
await run_proxy(mock_args)
Expand All @@ -235,7 +235,7 @@ async def test_setup_mcp_mode_no_metadata_injects_aws_region(

@patch('mcp_proxy_for_aws.server.AWSMCPProxyClientFactory')
@patch('mcp_proxy_for_aws.server.create_transport_with_sigv4')
@patch('mcp_proxy_for_aws.server.AWSMCPProxy')
@patch('mcp_proxy_for_aws.server.FastMCPProxy')
@patch('mcp_proxy_for_aws.server.determine_aws_region')
@patch('mcp_proxy_for_aws.server.determine_service_name')
@patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware')
Expand All @@ -244,7 +244,7 @@ async def test_setup_mcp_mode_metadata_without_aws_region_injects_it(
mock_add_filtering,
mock_determine_service,
mock_determine_region,
mock_aws_proxy,
mock_fastmcp_proxy,
mock_create_transport,
mock_client_factory_class,
):
Expand Down Expand Up @@ -277,7 +277,7 @@ async def test_setup_mcp_mode_metadata_without_aws_region_injects_it(
mock_proxy = Mock()
mock_proxy.run_async = AsyncMock()
mock_proxy.add_middleware = Mock()
mock_aws_proxy.return_value = mock_proxy
mock_fastmcp_proxy.return_value = mock_proxy

# Act
await run_proxy(mock_args)
Expand Down
Loading
Loading