From 092b1448c3a92ce0b7330acc7ffab4679285f7a5 Mon Sep 17 00:00:00 2001 From: Anass Taher Date: Fri, 3 Apr 2026 14:37:43 +0200 Subject: [PATCH 1/8] feat: add --tool-timeout flag and ToolTimeoutMiddleware - Add optional --tool-timeout CLI flag to cap tool call duration - Rename error_handling middleware to ToolTimeoutMiddleware - Return graceful isError=True response instead of hanging - Suggest long-lived credentials on 401/403 errors - Document --tool-timeout in README troubleshooting --- CHANGELOG.md | 6 +- README.md | 9 +- mcp_proxy_for_aws/cli.py | 9 ++ .../middleware/tool_timeout_middleware.py | 90 +++++++++++ mcp_proxy_for_aws/server.py | 15 ++ tests/unit/test_cli.py | 1 + tests/unit/test_tool_timeout_middleware.py | 151 ++++++++++++++++++ 7 files changed, 279 insertions(+), 2 deletions(-) create mode 100644 mcp_proxy_for_aws/middleware/tool_timeout_middleware.py create mode 100644 tests/unit/test_tool_timeout_middleware.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 416c01c0..4643d8e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,11 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## v1.1.8 (2026-04-02) +## Unreleased + +### Fixed + +- Simplify error middleware and suggest long-lived AWS credentials on auth errors (#216) ### Added diff --git a/README.md b/README.md index ec6ccd65..555f5883 100644 --- a/README.md +++ b/README.md @@ -107,6 +107,7 @@ docker build -t mcp-proxy-for-aws . | `--connect-timeout` | Set desired connect timeout in seconds | 60 |No | | `--read-timeout` | Set desired read timeout in seconds | 120 |No | | `--write-timeout` | Set desired write timeout in seconds | 180 |No | +| `--tool-timeout` | Maximum seconds a tool call may take before being cancelled. When set, returns a graceful error to the agent instead of hanging indefinitely | Not set |No | ### Optional Environment Variables @@ -325,12 +326,18 @@ uv sync ## Troubleshooting -### Handling `Authentication error - Invalid credentials` +### Authentication errors We try to autodetect the service from the url, sometimes this fails, ensure that `--service` is set correctly to the service you are attempting to connect to. Otherwise the SigV4 signing will not be able to be verified by the service you connect to, resulting in this error. Also ensure that you have valid IAM credentials on your machine before retrying. +For long-running sessions, consider using long-lived credentials: +- Use an AWS profile via `--profile` +- Use IAM Identity Center and run `aws sso login` before starting the proxy + +### Client hangs on tool calls +If your MCP client hangs waiting for a tool call response (e.g., due to expired credentials or an unresponsive endpoint), use `--tool-timeout` to set a maximum duration in seconds for each tool call. When the timeout is exceeded, the proxy returns a graceful error to the agent instead of hanging indefinitely. ## Development & Contributing diff --git a/mcp_proxy_for_aws/cli.py b/mcp_proxy_for_aws/cli.py index dc666659..47aedaf2 100644 --- a/mcp_proxy_for_aws/cli.py +++ b/mcp_proxy_for_aws/cli.py @@ -159,4 +159,13 @@ def parse_args(): help='Write timeout (seconds) when connecting to endpoint (default: 180)', ) + parser.add_argument( + '--tool-timeout', + type=within_range(0), + default=None, + help='Maximum seconds a tool call may take before being cancelled. ' + 'When set, wraps each tool call with a timeout and returns a graceful error ' + 'to the agent instead of hanging. Not set by default.', + ) + return parser.parse_args() diff --git a/mcp_proxy_for_aws/middleware/tool_timeout_middleware.py b/mcp_proxy_for_aws/middleware/tool_timeout_middleware.py new file mode 100644 index 00000000..d6c7543d --- /dev/null +++ b/mcp_proxy_for_aws/middleware/tool_timeout_middleware.py @@ -0,0 +1,90 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import anyio +import httpx +import logging +import mcp.types as mt +from fastmcp.server.middleware import CallNext, Middleware, MiddlewareContext +from fastmcp.tools.tool import ToolResult + + +logger = logging.getLogger(__name__) + + +class _FailedToolResult(ToolResult): + """A ToolResult that signals an error via the MCP isError flag.""" + + def to_mcp_result(self) -> mt.CallToolResult: + return mt.CallToolResult(content=self.content, isError=True) + + +class ToolTimeoutMiddleware(Middleware): + """Middleware that ensures tool calls never hang and always return a response. + + Implements two layers of protection: + 1. Timeout — bounds how long a tool call can take, breaking any hang. + 2. Error propagation — catches any error and returns it as a ToolResult + so the agent always gets a response. + + Reconnection is handled automatically by fastmcp on every tool call. + """ + + def __init__( + self, + tool_call_timeout: float | None = None, + ) -> None: + """Initialize the middleware. + + Args: + tool_call_timeout: Maximum seconds a tool call may take before being + cancelled. None means no timeout (not recommended). + """ + super().__init__() + self._tool_call_timeout = tool_call_timeout + + async def on_call_tool( + self, + context: MiddlewareContext[mt.CallToolRequestParams], + call_next: CallNext[mt.CallToolRequestParams, ToolResult], + ) -> ToolResult: + """Wrap tool calls with timeout and error handling.""" + try: + with anyio.fail_after(self._tool_call_timeout): + return await call_next(context) + except Exception as e: + tool_name = context.message.name + logger.error('Tool call %r failed: %s.', tool_name, e) + message = f'Tool call {tool_name!r} failed: {e}. Please retry.' + if self._is_credential_error(e): + message += ( + ' This may be caused by expired or invalid AWS credentials.' + ' Consider using long-lived credentials such as an AWS profile' + ' (--profile) or IAM Identity Center (aws sso login).' + ) + return self._error_result(message) + + @staticmethod + def _is_credential_error(error: Exception) -> bool: + """Check if the error is likely caused by expired or invalid credentials.""" + return isinstance(error, httpx.HTTPStatusError) and error.response.status_code in ( + 401, + 403, + ) + + @staticmethod + def _error_result(message: str) -> ToolResult: + return _FailedToolResult( + content=[mt.TextContent(type='text', text=message)], + ) diff --git a/mcp_proxy_for_aws/server.py b/mcp_proxy_for_aws/server.py index d49f0e42..c3cc4c58 100644 --- a/mcp_proxy_for_aws/server.py +++ b/mcp_proxy_for_aws/server.py @@ -32,6 +32,7 @@ from mcp_proxy_for_aws.logging_config import configure_logging from mcp_proxy_for_aws.middleware.initialize_middleware import InitializeMiddleware from mcp_proxy_for_aws.middleware.tool_filter import ToolFilteringMiddleware +from mcp_proxy_for_aws.middleware.tool_timeout_middleware import ToolTimeoutMiddleware from mcp_proxy_for_aws.proxy import AWSMCPProxy, AWSMCPProxyClientFactory from mcp_proxy_for_aws.utils import ( create_transport_with_sigv4, @@ -95,6 +96,7 @@ async def run_proxy(args) -> None: ), ) proxy.add_middleware(InitializeMiddleware(client_factory)) + add_tool_timeout_middleware(proxy, args.tool_timeout) add_logging_middleware(proxy, args.log_level) add_tool_filtering_middleware(proxy, args.read_only) @@ -108,6 +110,19 @@ async def run_proxy(args) -> None: await client_factory.disconnect() +def add_tool_timeout_middleware(mcp: FastMCP, tool_timeout: float | None = None) -> None: + """Add tool timeout middleware if a tool timeout is configured. + + Args: + mcp: The FastMCP instance to add the middleware to + tool_timeout: Maximum seconds a tool call may take. None disables the middleware. + """ + if tool_timeout is None: + return + logger.info('Adding tool timeout middleware with tool_timeout=%s', tool_timeout) + mcp.add_middleware(ToolTimeoutMiddleware(tool_call_timeout=tool_timeout)) + + def add_tool_filtering_middleware(mcp: FastMCP, read_only: bool = False) -> None: """Add tool filtering middleware to target MCP server. diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 890fd05d..8e2b7475 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -38,6 +38,7 @@ def test_parse_args_minimal(self): assert args.connect_timeout == 60.0 assert args.read_timeout == 120.0 assert args.write_timeout == 180.0 + assert args.tool_timeout is None @patch( 'sys.argv', diff --git a/tests/unit/test_tool_timeout_middleware.py b/tests/unit/test_tool_timeout_middleware.py new file mode 100644 index 00000000..dc6ae61f --- /dev/null +++ b/tests/unit/test_tool_timeout_middleware.py @@ -0,0 +1,151 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ToolTimeoutMiddleware.""" + +import anyio +import httpx +import mcp.types as mt +import pytest +from fastmcp.server.middleware import MiddlewareContext +from fastmcp.tools.tool import ToolResult +from mcp import McpError +from mcp.types import ErrorData +from mcp_proxy_for_aws.middleware.tool_timeout_middleware import ( + ToolTimeoutMiddleware, + _FailedToolResult, +) +from typing import Optional +from unittest.mock import AsyncMock, Mock + + +def _make_context(tool_name: str = 'test_tool') -> MiddlewareContext[mt.CallToolRequestParams]: + """Create a minimal MiddlewareContext for tool calls.""" + params = Mock(spec=mt.CallToolRequestParams) + params.name = tool_name + return MiddlewareContext[mt.CallToolRequestParams]( + message=params, + type='request', + method='tools/call', + ) + + +def _make_middleware(tool_call_timeout: Optional[float] = 5.0) -> ToolTimeoutMiddleware: + """Create a ToolTimeoutMiddleware with mocked dependencies.""" + middleware = ToolTimeoutMiddleware( + tool_call_timeout=tool_call_timeout, + ) + return middleware + + +def _get_text(result: ToolResult, index: int = 0) -> str: + """Extract text from a ToolResult content item.""" + content = result.content[index] + assert isinstance(content, mt.TextContent) + return content.text + + +class TestToolTimeoutMiddleware: + """Test cases for ToolTimeoutMiddleware.""" + + @pytest.mark.asyncio + async def test_passes_through_on_success(self): + """Successful tool calls pass through unchanged.""" + middleware = _make_middleware() + expected = ToolResult(content=[mt.TextContent(type='text', text='ok')]) + call_next = AsyncMock(return_value=expected) + context = _make_context() + + result = await middleware.on_call_tool(context, call_next) + + assert result is expected + assert not isinstance(result, _FailedToolResult) + call_next.assert_awaited_once_with(context) + + @pytest.mark.asyncio + async def test_catches_exception_returns_error_result(self): + """Exceptions are caught and returned as error ToolResults.""" + middleware = _make_middleware() + call_next = AsyncMock( + side_effect=McpError(ErrorData(code=-1, message='Connection closed')) + ) + context = _make_context() + + result = await middleware.on_call_tool(context, call_next) + + assert isinstance(result, _FailedToolResult) + assert len(result.content) == 1 + text = _get_text(result) + assert 'Connection closed' in text + + @pytest.mark.asyncio + async def test_timeout_returns_error_result(self): + """Tool calls that exceed the timeout return an error ToolResult.""" + middleware = _make_middleware(tool_call_timeout=0.1) + + async def hang_forever(context: MiddlewareContext[mt.CallToolRequestParams]) -> ToolResult: + await anyio.sleep(999) + return ToolResult(content=[]) # unreachable + + context = _make_context(tool_name='slow_tool') + + result = await middleware.on_call_tool(context, hang_forever) + + assert isinstance(result, _FailedToolResult) + assert len(result.content) == 1 + text = _get_text(result) + assert 'slow_tool' in text + + @pytest.mark.asyncio + async def test_credential_error_suggests_profile(self): + """Credential errors suggest using long-lived credentials.""" + middleware = _make_middleware() + response = Mock(spec=httpx.Response) + response.status_code = 401 + call_next = AsyncMock( + side_effect=httpx.HTTPStatusError('Unauthorized', request=Mock(), response=response) + ) + context = _make_context() + + result = await middleware.on_call_tool(context, call_next) + + assert isinstance(result, _FailedToolResult) + text = _get_text(result) + assert 'expired or invalid AWS credentials' in text + assert '--profile' in text + + @pytest.mark.asyncio + async def test_non_credential_error_no_suggestion(self): + """Non-credential errors do not suggest credential remediation.""" + middleware = _make_middleware() + call_next = AsyncMock(side_effect=RuntimeError('transport died')) + context = _make_context() + + result = await middleware.on_call_tool(context, call_next) + + assert isinstance(result, _FailedToolResult) + text = _get_text(result) + assert '--profile' not in text + + @pytest.mark.asyncio + async def test_no_timeout_when_none(self): + """When tool_call_timeout is None, no timeout is applied.""" + middleware = _make_middleware(tool_call_timeout=None) + expected = ToolResult(content=[mt.TextContent(type='text', text='ok')]) + call_next = AsyncMock(return_value=expected) + context = _make_context() + + result = await middleware.on_call_tool(context, call_next) + + assert result is expected From ec3268f381496efa222f3a5b1c4607393b2ab3c3 Mon Sep 17 00:00:00 2001 From: Anass Taher Date: Tue, 7 Apr 2026 11:45:47 +0200 Subject: [PATCH 2/8] refactor: rename ToolTimeoutMiddleware to ToolErrorMiddleware - Rename class, files, CLI flag (--tool-error-timeout), and all references - Default tool-error-timeout to 300s --- README.md | 4 ++-- mcp_proxy_for_aws/cli.py | 6 +++--- ...ut_middleware.py => tool_error_middleware.py} | 2 +- mcp_proxy_for_aws/server.py | 16 +++++++--------- tests/unit/test_cli.py | 2 +- ...ddleware.py => test_tool_error_middleware.py} | 16 ++++++++-------- 6 files changed, 22 insertions(+), 24 deletions(-) rename mcp_proxy_for_aws/middleware/{tool_timeout_middleware.py => tool_error_middleware.py} (98%) rename tests/unit/{test_tool_timeout_middleware.py => test_tool_error_middleware.py} (93%) diff --git a/README.md b/README.md index cb4a900a..2c22b0e0 100644 --- a/README.md +++ b/README.md @@ -107,7 +107,7 @@ docker build -t mcp-proxy-for-aws . | `--connect-timeout` | Set desired connect timeout in seconds | 60 |No | | `--read-timeout` | Set desired read timeout in seconds | 120 |No | | `--write-timeout` | Set desired write timeout in seconds | 180 |No | -| `--tool-timeout` | Maximum seconds a tool call may take before being cancelled. When set, returns a graceful error to the agent instead of hanging indefinitely | Not set |No | +| `--tool-error-timeout` | Maximum seconds a tool call may take before being cancelled. When set, returns a graceful error to the agent instead of hanging indefinitely | 300 |No | ### Optional Environment Variables @@ -353,7 +353,7 @@ For long-running sessions, consider using long-lived credentials: - Use IAM Identity Center and run `aws sso login` before starting the proxy ### Client hangs on tool calls -If your MCP client hangs waiting for a tool call response (e.g., due to expired credentials or an unresponsive endpoint), use `--tool-timeout` to set a maximum duration in seconds for each tool call. When the timeout is exceeded, the proxy returns a graceful error to the agent instead of hanging indefinitely. +If your MCP client hangs waiting for a tool call response (e.g., due to expired credentials or an unresponsive endpoint), use `--tool-error-timeout` to set a maximum duration in seconds for each tool call. When the timeout is exceeded, the proxy returns a graceful error to the agent instead of hanging indefinitely. ## Development & Contributing diff --git a/mcp_proxy_for_aws/cli.py b/mcp_proxy_for_aws/cli.py index 47aedaf2..b636af10 100644 --- a/mcp_proxy_for_aws/cli.py +++ b/mcp_proxy_for_aws/cli.py @@ -160,12 +160,12 @@ def parse_args(): ) parser.add_argument( - '--tool-timeout', + '--tool-error-timeout', type=within_range(0), - default=None, + default=300.0, help='Maximum seconds a tool call may take before being cancelled. ' 'When set, wraps each tool call with a timeout and returns a graceful error ' - 'to the agent instead of hanging. Not set by default.', + 'to the agent instead of hanging (default: 300).', ) return parser.parse_args() diff --git a/mcp_proxy_for_aws/middleware/tool_timeout_middleware.py b/mcp_proxy_for_aws/middleware/tool_error_middleware.py similarity index 98% rename from mcp_proxy_for_aws/middleware/tool_timeout_middleware.py rename to mcp_proxy_for_aws/middleware/tool_error_middleware.py index d6c7543d..405fbb1e 100644 --- a/mcp_proxy_for_aws/middleware/tool_timeout_middleware.py +++ b/mcp_proxy_for_aws/middleware/tool_error_middleware.py @@ -30,7 +30,7 @@ def to_mcp_result(self) -> mt.CallToolResult: return mt.CallToolResult(content=self.content, isError=True) -class ToolTimeoutMiddleware(Middleware): +class ToolErrorMiddleware(Middleware): """Middleware that ensures tool calls never hang and always return a response. Implements two layers of protection: diff --git a/mcp_proxy_for_aws/server.py b/mcp_proxy_for_aws/server.py index c6c07a52..ddfcb65a 100644 --- a/mcp_proxy_for_aws/server.py +++ b/mcp_proxy_for_aws/server.py @@ -33,7 +33,7 @@ from mcp_proxy_for_aws.logging_config import configure_logging from mcp_proxy_for_aws.middleware.initialize_middleware import InitializeMiddleware from mcp_proxy_for_aws.middleware.tool_filter import ToolFilteringMiddleware -from mcp_proxy_for_aws.middleware.tool_timeout_middleware import ToolTimeoutMiddleware +from mcp_proxy_for_aws.middleware.tool_error_middleware import ToolErrorMiddleware from mcp_proxy_for_aws.proxy import AWSMCPProxy, AWSMCPProxyClientFactory from mcp_proxy_for_aws.utils import ( create_transport_with_sigv4, @@ -98,7 +98,7 @@ async def run_proxy(args) -> None: ), ) proxy.add_middleware(InitializeMiddleware(client_factory)) - add_tool_timeout_middleware(proxy, args.tool_timeout) + add_tool_error_middleware(proxy, args.tool_error_timeout) add_logging_middleware(proxy, args.log_level) add_tool_filtering_middleware(proxy, args.read_only) @@ -112,17 +112,15 @@ async def run_proxy(args) -> None: await client_factory.disconnect() -def add_tool_timeout_middleware(mcp: FastMCP, tool_timeout: float | None = None) -> None: - """Add tool timeout middleware if a tool timeout is configured. +def add_tool_error_middleware(mcp: FastMCP, tool_error_timeout: float) -> None: + """Add tool error middleware. Args: mcp: The FastMCP instance to add the middleware to - tool_timeout: Maximum seconds a tool call may take. None disables the middleware. + tool_error_timeout: Maximum seconds a tool call may take. """ - if tool_timeout is None: - return - logger.info('Adding tool timeout middleware with tool_timeout=%s', tool_timeout) - mcp.add_middleware(ToolTimeoutMiddleware(tool_call_timeout=tool_timeout)) + logger.info('Adding tool error middleware with tool_error_timeout=%s', tool_error_timeout) + mcp.add_middleware(ToolErrorMiddleware(tool_call_timeout=tool_error_timeout)) def add_tool_filtering_middleware(mcp: FastMCP, read_only: bool = False) -> None: diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 8e2b7475..12bc6409 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -38,7 +38,7 @@ def test_parse_args_minimal(self): assert args.connect_timeout == 60.0 assert args.read_timeout == 120.0 assert args.write_timeout == 180.0 - assert args.tool_timeout is None + assert args.tool_error_timeout == 300.0 @patch( 'sys.argv', diff --git a/tests/unit/test_tool_timeout_middleware.py b/tests/unit/test_tool_error_middleware.py similarity index 93% rename from tests/unit/test_tool_timeout_middleware.py rename to tests/unit/test_tool_error_middleware.py index dc6ae61f..ba8100d7 100644 --- a/tests/unit/test_tool_timeout_middleware.py +++ b/tests/unit/test_tool_error_middleware.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests for ToolTimeoutMiddleware.""" +"""Unit tests for ToolErrorMiddleware.""" import anyio import httpx @@ -22,8 +22,8 @@ from fastmcp.tools.tool import ToolResult from mcp import McpError from mcp.types import ErrorData -from mcp_proxy_for_aws.middleware.tool_timeout_middleware import ( - ToolTimeoutMiddleware, +from mcp_proxy_for_aws.middleware.tool_error_middleware import ( + ToolErrorMiddleware, _FailedToolResult, ) from typing import Optional @@ -41,9 +41,9 @@ def _make_context(tool_name: str = 'test_tool') -> MiddlewareContext[mt.CallTool ) -def _make_middleware(tool_call_timeout: Optional[float] = 5.0) -> ToolTimeoutMiddleware: - """Create a ToolTimeoutMiddleware with mocked dependencies.""" - middleware = ToolTimeoutMiddleware( +def _make_middleware(tool_call_timeout: Optional[float] = 5.0) -> ToolErrorMiddleware: + """Create a ToolErrorMiddleware with mocked dependencies.""" + middleware = ToolErrorMiddleware( tool_call_timeout=tool_call_timeout, ) return middleware @@ -56,8 +56,8 @@ def _get_text(result: ToolResult, index: int = 0) -> str: return content.text -class TestToolTimeoutMiddleware: - """Test cases for ToolTimeoutMiddleware.""" +class TestToolErrorMiddleware: + """Test cases for ToolErrorMiddleware.""" @pytest.mark.asyncio async def test_passes_through_on_success(self): From 56dc2e79e6e692327a253455819b833de2a64031 Mon Sep 17 00:00:00 2001 From: Anass Taher Date: Tue, 7 Apr 2026 12:03:21 +0200 Subject: [PATCH 3/8] fix: sort imports in server.py --- mcp_proxy_for_aws/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mcp_proxy_for_aws/server.py b/mcp_proxy_for_aws/server.py index ddfcb65a..031164ea 100644 --- a/mcp_proxy_for_aws/server.py +++ b/mcp_proxy_for_aws/server.py @@ -32,8 +32,8 @@ from mcp_proxy_for_aws.cli import parse_args from mcp_proxy_for_aws.logging_config import configure_logging from mcp_proxy_for_aws.middleware.initialize_middleware import InitializeMiddleware -from mcp_proxy_for_aws.middleware.tool_filter import ToolFilteringMiddleware from mcp_proxy_for_aws.middleware.tool_error_middleware import ToolErrorMiddleware +from mcp_proxy_for_aws.middleware.tool_filter import ToolFilteringMiddleware from mcp_proxy_for_aws.proxy import AWSMCPProxy, AWSMCPProxyClientFactory from mcp_proxy_for_aws.utils import ( create_transport_with_sigv4, From 8c543fe28470b26848d009972b198d4d79a22eb1 Mon Sep 17 00:00:00 2001 From: Anass Taher Date: Tue, 7 Apr 2026 13:39:11 +0200 Subject: [PATCH 4/8] fix: remove Optional timeout, decouple tests from private class --- .../middleware/tool_error_middleware.py | 4 +-- tests/unit/test_tool_error_middleware.py | 36 ++++++++----------- 2 files changed, 16 insertions(+), 24 deletions(-) diff --git a/mcp_proxy_for_aws/middleware/tool_error_middleware.py b/mcp_proxy_for_aws/middleware/tool_error_middleware.py index 405fbb1e..bf058f95 100644 --- a/mcp_proxy_for_aws/middleware/tool_error_middleware.py +++ b/mcp_proxy_for_aws/middleware/tool_error_middleware.py @@ -43,13 +43,13 @@ class ToolErrorMiddleware(Middleware): def __init__( self, - tool_call_timeout: float | None = None, + tool_call_timeout: float = 300.0, ) -> None: """Initialize the middleware. Args: tool_call_timeout: Maximum seconds a tool call may take before being - cancelled. None means no timeout (not recommended). + cancelled. """ super().__init__() self._tool_call_timeout = tool_call_timeout diff --git a/tests/unit/test_tool_error_middleware.py b/tests/unit/test_tool_error_middleware.py index ba8100d7..03284d3f 100644 --- a/tests/unit/test_tool_error_middleware.py +++ b/tests/unit/test_tool_error_middleware.py @@ -22,11 +22,7 @@ from fastmcp.tools.tool import ToolResult from mcp import McpError from mcp.types import ErrorData -from mcp_proxy_for_aws.middleware.tool_error_middleware import ( - ToolErrorMiddleware, - _FailedToolResult, -) -from typing import Optional +from mcp_proxy_for_aws.middleware.tool_error_middleware import ToolErrorMiddleware from unittest.mock import AsyncMock, Mock @@ -41,7 +37,7 @@ def _make_context(tool_name: str = 'test_tool') -> MiddlewareContext[mt.CallTool ) -def _make_middleware(tool_call_timeout: Optional[float] = 5.0) -> ToolErrorMiddleware: +def _make_middleware(tool_call_timeout: float = 5.0) -> ToolErrorMiddleware: """Create a ToolErrorMiddleware with mocked dependencies.""" middleware = ToolErrorMiddleware( tool_call_timeout=tool_call_timeout, @@ -49,6 +45,13 @@ def _make_middleware(tool_call_timeout: Optional[float] = 5.0) -> ToolErrorMiddl return middleware +def _is_error(result: ToolResult) -> bool: + """Check if a ToolResult has the MCP isError flag set.""" + mcp_result = result.to_mcp_result() + assert isinstance(mcp_result, mt.CallToolResult) + return bool(mcp_result.isError) + + def _get_text(result: ToolResult, index: int = 0) -> str: """Extract text from a ToolResult content item.""" content = result.content[index] @@ -70,7 +73,7 @@ async def test_passes_through_on_success(self): result = await middleware.on_call_tool(context, call_next) assert result is expected - assert not isinstance(result, _FailedToolResult) + assert not _is_error(result) call_next.assert_awaited_once_with(context) @pytest.mark.asyncio @@ -84,7 +87,7 @@ async def test_catches_exception_returns_error_result(self): result = await middleware.on_call_tool(context, call_next) - assert isinstance(result, _FailedToolResult) + assert _is_error(result) assert len(result.content) == 1 text = _get_text(result) assert 'Connection closed' in text @@ -102,7 +105,7 @@ async def hang_forever(context: MiddlewareContext[mt.CallToolRequestParams]) -> result = await middleware.on_call_tool(context, hang_forever) - assert isinstance(result, _FailedToolResult) + assert _is_error(result) assert len(result.content) == 1 text = _get_text(result) assert 'slow_tool' in text @@ -120,7 +123,7 @@ async def test_credential_error_suggests_profile(self): result = await middleware.on_call_tool(context, call_next) - assert isinstance(result, _FailedToolResult) + assert _is_error(result) text = _get_text(result) assert 'expired or invalid AWS credentials' in text assert '--profile' in text @@ -134,18 +137,7 @@ async def test_non_credential_error_no_suggestion(self): result = await middleware.on_call_tool(context, call_next) - assert isinstance(result, _FailedToolResult) + assert _is_error(result) text = _get_text(result) assert '--profile' not in text - @pytest.mark.asyncio - async def test_no_timeout_when_none(self): - """When tool_call_timeout is None, no timeout is applied.""" - middleware = _make_middleware(tool_call_timeout=None) - expected = ToolResult(content=[mt.TextContent(type='text', text='ok')]) - call_next = AsyncMock(return_value=expected) - context = _make_context() - - result = await middleware.on_call_tool(context, call_next) - - assert result is expected From 0ad995302c19b2f8e73ba706878a64f9a4a2c5ec Mon Sep 17 00:00:00 2001 From: Anass Taher Date: Tue, 7 Apr 2026 13:39:56 +0200 Subject: [PATCH 5/8] fix: handle non-CallToolResult in success test --- tests/unit/test_tool_error_middleware.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_tool_error_middleware.py b/tests/unit/test_tool_error_middleware.py index 03284d3f..40710eba 100644 --- a/tests/unit/test_tool_error_middleware.py +++ b/tests/unit/test_tool_error_middleware.py @@ -73,7 +73,8 @@ async def test_passes_through_on_success(self): result = await middleware.on_call_tool(context, call_next) assert result is expected - assert not _is_error(result) + mcp_result = result.to_mcp_result() + assert not isinstance(mcp_result, mt.CallToolResult) or not mcp_result.isError call_next.assert_awaited_once_with(context) @pytest.mark.asyncio From e43fa1a9c4b2f1e357fc490e23508ba4892bbe8e Mon Sep 17 00:00:00 2001 From: Anass Taher Date: Tue, 7 Apr 2026 13:44:02 +0200 Subject: [PATCH 6/8] fix: remove trailing blank line in test file --- tests/unit/test_tool_error_middleware.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/test_tool_error_middleware.py b/tests/unit/test_tool_error_middleware.py index 40710eba..2150018c 100644 --- a/tests/unit/test_tool_error_middleware.py +++ b/tests/unit/test_tool_error_middleware.py @@ -141,4 +141,3 @@ async def test_non_credential_error_no_suggestion(self): assert _is_error(result) text = _get_text(result) assert '--profile' not in text - From f4d3f8b679e73559db692936beb68ec1d6226b9d Mon Sep 17 00:00:00 2001 From: Anass Taher Date: Tue, 7 Apr 2026 15:27:03 +0200 Subject: [PATCH 7/8] refactor: rename --tool-error-timeout to --tool-timeout --- README.md | 4 ++-- mcp_proxy_for_aws/cli.py | 2 +- mcp_proxy_for_aws/server.py | 10 +++++----- tests/unit/test_cli.py | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 2c22b0e0..1ff03d1f 100644 --- a/README.md +++ b/README.md @@ -107,7 +107,7 @@ docker build -t mcp-proxy-for-aws . | `--connect-timeout` | Set desired connect timeout in seconds | 60 |No | | `--read-timeout` | Set desired read timeout in seconds | 120 |No | | `--write-timeout` | Set desired write timeout in seconds | 180 |No | -| `--tool-error-timeout` | Maximum seconds a tool call may take before being cancelled. When set, returns a graceful error to the agent instead of hanging indefinitely | 300 |No | +| `--tool-timeout` | Maximum seconds a tool call may take before being cancelled. When set, returns a graceful error to the agent instead of hanging indefinitely | 300 |No | ### Optional Environment Variables @@ -353,7 +353,7 @@ For long-running sessions, consider using long-lived credentials: - Use IAM Identity Center and run `aws sso login` before starting the proxy ### Client hangs on tool calls -If your MCP client hangs waiting for a tool call response (e.g., due to expired credentials or an unresponsive endpoint), use `--tool-error-timeout` to set a maximum duration in seconds for each tool call. When the timeout is exceeded, the proxy returns a graceful error to the agent instead of hanging indefinitely. +If your MCP client hangs waiting for a tool call response (e.g., due to expired credentials or an unresponsive endpoint), use `--tool-timeout` to set a maximum duration in seconds for each tool call. When the timeout is exceeded, the proxy returns a graceful error to the agent instead of hanging indefinitely. ## Development & Contributing diff --git a/mcp_proxy_for_aws/cli.py b/mcp_proxy_for_aws/cli.py index b636af10..7340c5b6 100644 --- a/mcp_proxy_for_aws/cli.py +++ b/mcp_proxy_for_aws/cli.py @@ -160,7 +160,7 @@ def parse_args(): ) parser.add_argument( - '--tool-error-timeout', + '--tool-timeout', type=within_range(0), default=300.0, help='Maximum seconds a tool call may take before being cancelled. ' diff --git a/mcp_proxy_for_aws/server.py b/mcp_proxy_for_aws/server.py index 031164ea..de5edb32 100644 --- a/mcp_proxy_for_aws/server.py +++ b/mcp_proxy_for_aws/server.py @@ -98,7 +98,7 @@ async def run_proxy(args) -> None: ), ) proxy.add_middleware(InitializeMiddleware(client_factory)) - add_tool_error_middleware(proxy, args.tool_error_timeout) + add_tool_error_middleware(proxy, args.tool_timeout) add_logging_middleware(proxy, args.log_level) add_tool_filtering_middleware(proxy, args.read_only) @@ -112,15 +112,15 @@ async def run_proxy(args) -> None: await client_factory.disconnect() -def add_tool_error_middleware(mcp: FastMCP, tool_error_timeout: float) -> None: +def add_tool_error_middleware(mcp: FastMCP, tool_timeout: float) -> None: """Add tool error middleware. Args: mcp: The FastMCP instance to add the middleware to - tool_error_timeout: Maximum seconds a tool call may take. + tool_timeout: Maximum seconds a tool call may take. """ - logger.info('Adding tool error middleware with tool_error_timeout=%s', tool_error_timeout) - mcp.add_middleware(ToolErrorMiddleware(tool_call_timeout=tool_error_timeout)) + logger.info('Adding tool error middleware with tool_timeout=%s', tool_timeout) + mcp.add_middleware(ToolErrorMiddleware(tool_call_timeout=tool_timeout)) def add_tool_filtering_middleware(mcp: FastMCP, read_only: bool = False) -> None: diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 12bc6409..8aba705b 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -38,7 +38,7 @@ def test_parse_args_minimal(self): assert args.connect_timeout == 60.0 assert args.read_timeout == 120.0 assert args.write_timeout == 180.0 - assert args.tool_error_timeout == 300.0 + assert args.tool_timeout == 300.0 @patch( 'sys.argv', From 734c25e8bd866cb23d09f6783365a7e8d66ad689 Mon Sep 17 00:00:00 2001 From: Anass Taher Date: Wed, 8 Apr 2026 11:24:02 +0200 Subject: [PATCH 8/8] refactor: raise ToolError instead of custom _FailedToolResult subclass --- .../middleware/tool_error_middleware.py | 22 ++----- tests/unit/test_tool_error_middleware.py | 65 +++++-------------- 2 files changed, 22 insertions(+), 65 deletions(-) diff --git a/mcp_proxy_for_aws/middleware/tool_error_middleware.py b/mcp_proxy_for_aws/middleware/tool_error_middleware.py index bf058f95..621b2b8f 100644 --- a/mcp_proxy_for_aws/middleware/tool_error_middleware.py +++ b/mcp_proxy_for_aws/middleware/tool_error_middleware.py @@ -16,27 +16,21 @@ import httpx import logging import mcp.types as mt +from fastmcp.exceptions import ToolError from fastmcp.server.middleware import CallNext, Middleware, MiddlewareContext -from fastmcp.tools.tool import ToolResult +from fastmcp.tools import ToolResult logger = logging.getLogger(__name__) -class _FailedToolResult(ToolResult): - """A ToolResult that signals an error via the MCP isError flag.""" - - def to_mcp_result(self) -> mt.CallToolResult: - return mt.CallToolResult(content=self.content, isError=True) - - class ToolErrorMiddleware(Middleware): """Middleware that ensures tool calls never hang and always return a response. Implements two layers of protection: 1. Timeout — bounds how long a tool call can take, breaking any hang. - 2. Error propagation — catches any error and returns it as a ToolResult - so the agent always gets a response. + 2. Error propagation — catches any error and returns an error message + to the agent so it always gets a response. Reconnection is handled automatically by fastmcp on every tool call. """ @@ -73,7 +67,7 @@ async def on_call_tool( ' Consider using long-lived credentials such as an AWS profile' ' (--profile) or IAM Identity Center (aws sso login).' ) - return self._error_result(message) + raise ToolError(message) from e @staticmethod def _is_credential_error(error: Exception) -> bool: @@ -82,9 +76,3 @@ def _is_credential_error(error: Exception) -> bool: 401, 403, ) - - @staticmethod - def _error_result(message: str) -> ToolResult: - return _FailedToolResult( - content=[mt.TextContent(type='text', text=message)], - ) diff --git a/tests/unit/test_tool_error_middleware.py b/tests/unit/test_tool_error_middleware.py index 2150018c..cc65cf93 100644 --- a/tests/unit/test_tool_error_middleware.py +++ b/tests/unit/test_tool_error_middleware.py @@ -18,8 +18,9 @@ import httpx import mcp.types as mt import pytest +from fastmcp.exceptions import ToolError from fastmcp.server.middleware import MiddlewareContext -from fastmcp.tools.tool import ToolResult +from fastmcp.tools import ToolResult from mcp import McpError from mcp.types import ErrorData from mcp_proxy_for_aws.middleware.tool_error_middleware import ToolErrorMiddleware @@ -39,24 +40,7 @@ def _make_context(tool_name: str = 'test_tool') -> MiddlewareContext[mt.CallTool def _make_middleware(tool_call_timeout: float = 5.0) -> ToolErrorMiddleware: """Create a ToolErrorMiddleware with mocked dependencies.""" - middleware = ToolErrorMiddleware( - tool_call_timeout=tool_call_timeout, - ) - return middleware - - -def _is_error(result: ToolResult) -> bool: - """Check if a ToolResult has the MCP isError flag set.""" - mcp_result = result.to_mcp_result() - assert isinstance(mcp_result, mt.CallToolResult) - return bool(mcp_result.isError) - - -def _get_text(result: ToolResult, index: int = 0) -> str: - """Extract text from a ToolResult content item.""" - content = result.content[index] - assert isinstance(content, mt.TextContent) - return content.text + return ToolErrorMiddleware(tool_call_timeout=tool_call_timeout) class TestToolErrorMiddleware: @@ -73,29 +57,23 @@ async def test_passes_through_on_success(self): result = await middleware.on_call_tool(context, call_next) assert result is expected - mcp_result = result.to_mcp_result() - assert not isinstance(mcp_result, mt.CallToolResult) or not mcp_result.isError call_next.assert_awaited_once_with(context) @pytest.mark.asyncio - async def test_catches_exception_returns_error_result(self): - """Exceptions are caught and returned as error ToolResults.""" + async def test_catches_exception_raises_tool_error(self): + """Exceptions are caught and raised as ToolError.""" middleware = _make_middleware() call_next = AsyncMock( side_effect=McpError(ErrorData(code=-1, message='Connection closed')) ) context = _make_context() - result = await middleware.on_call_tool(context, call_next) - - assert _is_error(result) - assert len(result.content) == 1 - text = _get_text(result) - assert 'Connection closed' in text + with pytest.raises(ToolError, match='Connection closed'): + await middleware.on_call_tool(context, call_next) @pytest.mark.asyncio - async def test_timeout_returns_error_result(self): - """Tool calls that exceed the timeout return an error ToolResult.""" + async def test_timeout_raises_tool_error(self): + """Tool calls that exceed the timeout raise a ToolError.""" middleware = _make_middleware(tool_call_timeout=0.1) async def hang_forever(context: MiddlewareContext[mt.CallToolRequestParams]) -> ToolResult: @@ -104,12 +82,8 @@ async def hang_forever(context: MiddlewareContext[mt.CallToolRequestParams]) -> context = _make_context(tool_name='slow_tool') - result = await middleware.on_call_tool(context, hang_forever) - - assert _is_error(result) - assert len(result.content) == 1 - text = _get_text(result) - assert 'slow_tool' in text + with pytest.raises(ToolError, match='slow_tool'): + await middleware.on_call_tool(context, hang_forever) @pytest.mark.asyncio async def test_credential_error_suggests_profile(self): @@ -122,12 +96,9 @@ async def test_credential_error_suggests_profile(self): ) context = _make_context() - result = await middleware.on_call_tool(context, call_next) - - assert _is_error(result) - text = _get_text(result) - assert 'expired or invalid AWS credentials' in text - assert '--profile' in text + with pytest.raises(ToolError, match='expired or invalid AWS credentials') as exc_info: + await middleware.on_call_tool(context, call_next) + assert '--profile' in str(exc_info.value) @pytest.mark.asyncio async def test_non_credential_error_no_suggestion(self): @@ -136,8 +107,6 @@ async def test_non_credential_error_no_suggestion(self): call_next = AsyncMock(side_effect=RuntimeError('transport died')) context = _make_context() - result = await middleware.on_call_tool(context, call_next) - - assert _is_error(result) - text = _get_text(result) - assert '--profile' not in text + with pytest.raises(ToolError) as exc_info: + await middleware.on_call_tool(context, call_next) + assert '--profile' not in str(exc_info.value)