diff --git a/CHANGELOG.md b/CHANGELOG.md index bee750a..589bd50 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Simplify error middleware and suggest long-lived AWS credentials on auth errors (#216) - Use new streamable http client (#228) - Add URL scheme validation to prevent credential interception (#169) - Prevent credential exposure in logs (#167) diff --git a/README.md b/README.md index ea5a731..1ff03d1 100644 --- a/README.md +++ b/README.md @@ -107,6 +107,7 @@ docker build -t mcp-proxy-for-aws . | `--connect-timeout` | Set desired connect timeout in seconds | 60 |No | | `--read-timeout` | Set desired read timeout in seconds | 120 |No | | `--write-timeout` | Set desired write timeout in seconds | 180 |No | +| `--tool-timeout` | Maximum seconds a tool call may take before being cancelled. When set, returns a graceful error to the agent instead of hanging indefinitely | 300 |No | ### Optional Environment Variables @@ -341,12 +342,18 @@ uv sync ## Troubleshooting -### Handling `Authentication error - Invalid credentials` +### Authentication errors We try to autodetect the service from the url, sometimes this fails, ensure that `--service` is set correctly to the service you are attempting to connect to. Otherwise the SigV4 signing will not be able to be verified by the service you connect to, resulting in this error. Also ensure that you have valid IAM credentials on your machine before retrying. +For long-running sessions, consider using long-lived credentials: +- Use an AWS profile via `--profile` +- Use IAM Identity Center and run `aws sso login` before starting the proxy + +### Client hangs on tool calls +If your MCP client hangs waiting for a tool call response (e.g., due to expired credentials or an unresponsive endpoint), use `--tool-timeout` to set a maximum duration in seconds for each tool call. When the timeout is exceeded, the proxy returns a graceful error to the agent instead of hanging indefinitely. ## Development & Contributing diff --git a/mcp_proxy_for_aws/cli.py b/mcp_proxy_for_aws/cli.py index dc66665..7340c5b 100644 --- a/mcp_proxy_for_aws/cli.py +++ b/mcp_proxy_for_aws/cli.py @@ -159,4 +159,13 @@ def parse_args(): help='Write timeout (seconds) when connecting to endpoint (default: 180)', ) + parser.add_argument( + '--tool-timeout', + type=within_range(0), + default=300.0, + help='Maximum seconds a tool call may take before being cancelled. ' + 'When set, wraps each tool call with a timeout and returns a graceful error ' + 'to the agent instead of hanging (default: 300).', + ) + return parser.parse_args() diff --git a/mcp_proxy_for_aws/middleware/tool_error_middleware.py b/mcp_proxy_for_aws/middleware/tool_error_middleware.py new file mode 100644 index 0000000..621b2b8 --- /dev/null +++ b/mcp_proxy_for_aws/middleware/tool_error_middleware.py @@ -0,0 +1,78 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import anyio +import httpx +import logging +import mcp.types as mt +from fastmcp.exceptions import ToolError +from fastmcp.server.middleware import CallNext, Middleware, MiddlewareContext +from fastmcp.tools import ToolResult + + +logger = logging.getLogger(__name__) + + +class ToolErrorMiddleware(Middleware): + """Middleware that ensures tool calls never hang and always return a response. + + Implements two layers of protection: + 1. Timeout — bounds how long a tool call can take, breaking any hang. + 2. Error propagation — catches any error and returns an error message + to the agent so it always gets a response. + + Reconnection is handled automatically by fastmcp on every tool call. + """ + + def __init__( + self, + tool_call_timeout: float = 300.0, + ) -> None: + """Initialize the middleware. + + Args: + tool_call_timeout: Maximum seconds a tool call may take before being + cancelled. + """ + super().__init__() + self._tool_call_timeout = tool_call_timeout + + async def on_call_tool( + self, + context: MiddlewareContext[mt.CallToolRequestParams], + call_next: CallNext[mt.CallToolRequestParams, ToolResult], + ) -> ToolResult: + """Wrap tool calls with timeout and error handling.""" + try: + with anyio.fail_after(self._tool_call_timeout): + return await call_next(context) + except Exception as e: + tool_name = context.message.name + logger.error('Tool call %r failed: %s.', tool_name, e) + message = f'Tool call {tool_name!r} failed: {e}. Please retry.' + if self._is_credential_error(e): + message += ( + ' This may be caused by expired or invalid AWS credentials.' + ' Consider using long-lived credentials such as an AWS profile' + ' (--profile) or IAM Identity Center (aws sso login).' + ) + raise ToolError(message) from e + + @staticmethod + def _is_credential_error(error: Exception) -> bool: + """Check if the error is likely caused by expired or invalid credentials.""" + return isinstance(error, httpx.HTTPStatusError) and error.response.status_code in ( + 401, + 403, + ) diff --git a/mcp_proxy_for_aws/server.py b/mcp_proxy_for_aws/server.py index 28acb30..decfb24 100644 --- a/mcp_proxy_for_aws/server.py +++ b/mcp_proxy_for_aws/server.py @@ -33,6 +33,7 @@ from mcp_proxy_for_aws.cli import parse_args from mcp_proxy_for_aws.logging_config import configure_logging from mcp_proxy_for_aws.middleware.initialize_middleware import InitializeMiddleware +from mcp_proxy_for_aws.middleware.tool_error_middleware import ToolErrorMiddleware from mcp_proxy_for_aws.middleware.tool_filter import ToolFilteringMiddleware from mcp_proxy_for_aws.proxy import AWSMCPProxyClientFactory from mcp_proxy_for_aws.utils import ( @@ -98,6 +99,7 @@ async def run_proxy(args) -> None: ), ) proxy.add_middleware(InitializeMiddleware(client_factory)) + add_tool_error_middleware(proxy, args.tool_timeout) add_logging_middleware(proxy, args.log_level) add_tool_filtering_middleware(proxy, args.read_only) @@ -111,6 +113,17 @@ async def run_proxy(args) -> None: await client_factory.disconnect() +def add_tool_error_middleware(mcp: FastMCP, tool_timeout: float) -> None: + """Add tool error middleware. + + Args: + mcp: The FastMCP instance to add the middleware to + tool_timeout: Maximum seconds a tool call may take. + """ + logger.info('Adding tool error middleware with tool_timeout=%s', tool_timeout) + mcp.add_middleware(ToolErrorMiddleware(tool_call_timeout=tool_timeout)) + + def add_tool_filtering_middleware(mcp: FastMCP, read_only: bool = False) -> None: """Add tool filtering middleware to target MCP server. diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 890fd05..8aba705 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -38,6 +38,7 @@ def test_parse_args_minimal(self): assert args.connect_timeout == 60.0 assert args.read_timeout == 120.0 assert args.write_timeout == 180.0 + assert args.tool_timeout == 300.0 @patch( 'sys.argv', diff --git a/tests/unit/test_tool_error_middleware.py b/tests/unit/test_tool_error_middleware.py new file mode 100644 index 0000000..cc65cf9 --- /dev/null +++ b/tests/unit/test_tool_error_middleware.py @@ -0,0 +1,112 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ToolErrorMiddleware.""" + +import anyio +import httpx +import mcp.types as mt +import pytest +from fastmcp.exceptions import ToolError +from fastmcp.server.middleware import MiddlewareContext +from fastmcp.tools import ToolResult +from mcp import McpError +from mcp.types import ErrorData +from mcp_proxy_for_aws.middleware.tool_error_middleware import ToolErrorMiddleware +from unittest.mock import AsyncMock, Mock + + +def _make_context(tool_name: str = 'test_tool') -> MiddlewareContext[mt.CallToolRequestParams]: + """Create a minimal MiddlewareContext for tool calls.""" + params = Mock(spec=mt.CallToolRequestParams) + params.name = tool_name + return MiddlewareContext[mt.CallToolRequestParams]( + message=params, + type='request', + method='tools/call', + ) + + +def _make_middleware(tool_call_timeout: float = 5.0) -> ToolErrorMiddleware: + """Create a ToolErrorMiddleware with mocked dependencies.""" + return ToolErrorMiddleware(tool_call_timeout=tool_call_timeout) + + +class TestToolErrorMiddleware: + """Test cases for ToolErrorMiddleware.""" + + @pytest.mark.asyncio + async def test_passes_through_on_success(self): + """Successful tool calls pass through unchanged.""" + middleware = _make_middleware() + expected = ToolResult(content=[mt.TextContent(type='text', text='ok')]) + call_next = AsyncMock(return_value=expected) + context = _make_context() + + result = await middleware.on_call_tool(context, call_next) + + assert result is expected + call_next.assert_awaited_once_with(context) + + @pytest.mark.asyncio + async def test_catches_exception_raises_tool_error(self): + """Exceptions are caught and raised as ToolError.""" + middleware = _make_middleware() + call_next = AsyncMock( + side_effect=McpError(ErrorData(code=-1, message='Connection closed')) + ) + context = _make_context() + + with pytest.raises(ToolError, match='Connection closed'): + await middleware.on_call_tool(context, call_next) + + @pytest.mark.asyncio + async def test_timeout_raises_tool_error(self): + """Tool calls that exceed the timeout raise a ToolError.""" + middleware = _make_middleware(tool_call_timeout=0.1) + + async def hang_forever(context: MiddlewareContext[mt.CallToolRequestParams]) -> ToolResult: + await anyio.sleep(999) + return ToolResult(content=[]) # unreachable + + context = _make_context(tool_name='slow_tool') + + with pytest.raises(ToolError, match='slow_tool'): + await middleware.on_call_tool(context, hang_forever) + + @pytest.mark.asyncio + async def test_credential_error_suggests_profile(self): + """Credential errors suggest using long-lived credentials.""" + middleware = _make_middleware() + response = Mock(spec=httpx.Response) + response.status_code = 401 + call_next = AsyncMock( + side_effect=httpx.HTTPStatusError('Unauthorized', request=Mock(), response=response) + ) + context = _make_context() + + with pytest.raises(ToolError, match='expired or invalid AWS credentials') as exc_info: + await middleware.on_call_tool(context, call_next) + assert '--profile' in str(exc_info.value) + + @pytest.mark.asyncio + async def test_non_credential_error_no_suggestion(self): + """Non-credential errors do not suggest credential remediation.""" + middleware = _make_middleware() + call_next = AsyncMock(side_effect=RuntimeError('transport died')) + context = _make_context() + + with pytest.raises(ToolError) as exc_info: + await middleware.on_call_tool(context, call_next) + assert '--profile' not in str(exc_info.value)