Skip to content
Open
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## v1.1.8 (2026-04-02)
## Unreleased

### Fixed

- Simplify error middleware and suggest long-lived AWS credentials on auth errors (#216)

### Added

Expand Down
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ docker build -t mcp-proxy-for-aws .
| `--connect-timeout` | Set desired connect timeout in seconds | 60 |No |
| `--read-timeout` | Set desired read timeout in seconds | 120 |No |
| `--write-timeout` | Set desired write timeout in seconds | 180 |No |
| `--tool-error-timeout` | Maximum seconds a tool call may take before being cancelled. When set, returns a graceful error to the agent instead of hanging indefinitely | 300 |No |

### Optional Environment Variables

Expand Down Expand Up @@ -341,12 +342,18 @@ uv sync

## Troubleshooting

### Handling `Authentication error - Invalid credentials`
### Authentication errors
We try to autodetect the service from the url, sometimes this fails, ensure that `--service` is set correctly to the
service you are attempting to connect to.
Otherwise the SigV4 signing will not be able to be verified by the service you connect to, resulting in this error.
Also ensure that you have valid IAM credentials on your machine before retrying.

For long-running sessions, consider using long-lived credentials:
- Use an AWS profile via `--profile`
- Use IAM Identity Center and run `aws sso login` before starting the proxy

### Client hangs on tool calls
If your MCP client hangs waiting for a tool call response (e.g., due to expired credentials or an unresponsive endpoint), use `--tool-error-timeout` to set a maximum duration in seconds for each tool call. When the timeout is exceeded, the proxy returns a graceful error to the agent instead of hanging indefinitely.

## Development & Contributing

Expand Down
9 changes: 9 additions & 0 deletions mcp_proxy_for_aws/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,13 @@ def parse_args():
help='Write timeout (seconds) when connecting to endpoint (default: 180)',
)

parser.add_argument(
'--tool-error-timeout',
type=within_range(0),
default=300.0,
help='Maximum seconds a tool call may take before being cancelled. '
'When set, wraps each tool call with a timeout and returns a graceful error '
'to the agent instead of hanging (default: 300).',
)

return parser.parse_args()
90 changes: 90 additions & 0 deletions mcp_proxy_for_aws/middleware/tool_error_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import anyio
import httpx
import logging
import mcp.types as mt
from fastmcp.server.middleware import CallNext, Middleware, MiddlewareContext
from fastmcp.tools.tool import ToolResult


logger = logging.getLogger(__name__)


class _FailedToolResult(ToolResult):
"""A ToolResult that signals an error via the MCP isError flag."""

def to_mcp_result(self) -> mt.CallToolResult:
return mt.CallToolResult(content=self.content, isError=True)


class ToolErrorMiddleware(Middleware):
"""Middleware that ensures tool calls never hang and always return a response.

Implements two layers of protection:
1. Timeout — bounds how long a tool call can take, breaking any hang.
2. Error propagation — catches any error and returns it as a ToolResult
so the agent always gets a response.

Reconnection is handled automatically by fastmcp on every tool call.
"""

def __init__(
self,
tool_call_timeout: float = 300.0,
) -> None:
"""Initialize the middleware.

Args:
tool_call_timeout: Maximum seconds a tool call may take before being
cancelled.
"""
super().__init__()
self._tool_call_timeout = tool_call_timeout

async def on_call_tool(
self,
context: MiddlewareContext[mt.CallToolRequestParams],
call_next: CallNext[mt.CallToolRequestParams, ToolResult],
) -> ToolResult:
"""Wrap tool calls with timeout and error handling."""
try:
with anyio.fail_after(self._tool_call_timeout):
return await call_next(context)
except Exception as e:
tool_name = context.message.name
logger.error('Tool call %r failed: %s.', tool_name, e)
message = f'Tool call {tool_name!r} failed: {e}. Please retry.'
if self._is_credential_error(e):
message += (
' This may be caused by expired or invalid AWS credentials.'
' Consider using long-lived credentials such as an AWS profile'
' (--profile) or IAM Identity Center (aws sso login).'
)
return self._error_result(message)

@staticmethod
def _is_credential_error(error: Exception) -> bool:
"""Check if the error is likely caused by expired or invalid credentials."""
return isinstance(error, httpx.HTTPStatusError) and error.response.status_code in (
401,
403,
)

@staticmethod
def _error_result(message: str) -> ToolResult:
return _FailedToolResult(
content=[mt.TextContent(type='text', text=message)],
)
13 changes: 13 additions & 0 deletions mcp_proxy_for_aws/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from mcp_proxy_for_aws.cli import parse_args
from mcp_proxy_for_aws.logging_config import configure_logging
from mcp_proxy_for_aws.middleware.initialize_middleware import InitializeMiddleware
from mcp_proxy_for_aws.middleware.tool_error_middleware import ToolErrorMiddleware
from mcp_proxy_for_aws.middleware.tool_filter import ToolFilteringMiddleware
from mcp_proxy_for_aws.proxy import AWSMCPProxy, AWSMCPProxyClientFactory
from mcp_proxy_for_aws.utils import (
Expand Down Expand Up @@ -97,6 +98,7 @@ async def run_proxy(args) -> None:
),
)
proxy.add_middleware(InitializeMiddleware(client_factory))
add_tool_error_middleware(proxy, args.tool_error_timeout)
add_logging_middleware(proxy, args.log_level)
add_tool_filtering_middleware(proxy, args.read_only)

Expand All @@ -110,6 +112,17 @@ async def run_proxy(args) -> None:
await client_factory.disconnect()


def add_tool_error_middleware(mcp: FastMCP, tool_error_timeout: float) -> None:
"""Add tool error middleware.

Args:
mcp: The FastMCP instance to add the middleware to
tool_error_timeout: Maximum seconds a tool call may take.
"""
logger.info('Adding tool error middleware with tool_error_timeout=%s', tool_error_timeout)
mcp.add_middleware(ToolErrorMiddleware(tool_call_timeout=tool_error_timeout))


def add_tool_filtering_middleware(mcp: FastMCP, read_only: bool = False) -> None:
"""Add tool filtering middleware to target MCP server.

Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def test_parse_args_minimal(self):
assert args.connect_timeout == 60.0
assert args.read_timeout == 120.0
assert args.write_timeout == 180.0
assert args.tool_error_timeout == 300.0

@patch(
'sys.argv',
Expand Down
143 changes: 143 additions & 0 deletions tests/unit/test_tool_error_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for ToolErrorMiddleware."""

import anyio
import httpx
import mcp.types as mt
import pytest
from fastmcp.server.middleware import MiddlewareContext
from fastmcp.tools.tool import ToolResult
from mcp import McpError
from mcp.types import ErrorData
from mcp_proxy_for_aws.middleware.tool_error_middleware import ToolErrorMiddleware
from unittest.mock import AsyncMock, Mock


def _make_context(tool_name: str = 'test_tool') -> MiddlewareContext[mt.CallToolRequestParams]:
"""Create a minimal MiddlewareContext for tool calls."""
params = Mock(spec=mt.CallToolRequestParams)
params.name = tool_name
return MiddlewareContext[mt.CallToolRequestParams](
message=params,
type='request',
method='tools/call',
)


def _make_middleware(tool_call_timeout: float = 5.0) -> ToolErrorMiddleware:
"""Create a ToolErrorMiddleware with mocked dependencies."""
middleware = ToolErrorMiddleware(
tool_call_timeout=tool_call_timeout,
)
return middleware


def _is_error(result: ToolResult) -> bool:
"""Check if a ToolResult has the MCP isError flag set."""
mcp_result = result.to_mcp_result()
assert isinstance(mcp_result, mt.CallToolResult)
return bool(mcp_result.isError)


def _get_text(result: ToolResult, index: int = 0) -> str:
"""Extract text from a ToolResult content item."""
content = result.content[index]
assert isinstance(content, mt.TextContent)
return content.text


class TestToolErrorMiddleware:
"""Test cases for ToolErrorMiddleware."""

@pytest.mark.asyncio
async def test_passes_through_on_success(self):
"""Successful tool calls pass through unchanged."""
middleware = _make_middleware()
expected = ToolResult(content=[mt.TextContent(type='text', text='ok')])
call_next = AsyncMock(return_value=expected)
context = _make_context()

result = await middleware.on_call_tool(context, call_next)

assert result is expected
mcp_result = result.to_mcp_result()
assert not isinstance(mcp_result, mt.CallToolResult) or not mcp_result.isError
call_next.assert_awaited_once_with(context)

@pytest.mark.asyncio
async def test_catches_exception_returns_error_result(self):
"""Exceptions are caught and returned as error ToolResults."""
middleware = _make_middleware()
call_next = AsyncMock(
side_effect=McpError(ErrorData(code=-1, message='Connection closed'))
)
context = _make_context()

result = await middleware.on_call_tool(context, call_next)

assert _is_error(result)
assert len(result.content) == 1
text = _get_text(result)
assert 'Connection closed' in text

@pytest.mark.asyncio
async def test_timeout_returns_error_result(self):
"""Tool calls that exceed the timeout return an error ToolResult."""
middleware = _make_middleware(tool_call_timeout=0.1)

async def hang_forever(context: MiddlewareContext[mt.CallToolRequestParams]) -> ToolResult:
await anyio.sleep(999)
return ToolResult(content=[]) # unreachable

context = _make_context(tool_name='slow_tool')

result = await middleware.on_call_tool(context, hang_forever)

assert _is_error(result)
assert len(result.content) == 1
text = _get_text(result)
assert 'slow_tool' in text

@pytest.mark.asyncio
async def test_credential_error_suggests_profile(self):
"""Credential errors suggest using long-lived credentials."""
middleware = _make_middleware()
response = Mock(spec=httpx.Response)
response.status_code = 401
call_next = AsyncMock(
side_effect=httpx.HTTPStatusError('Unauthorized', request=Mock(), response=response)
)
context = _make_context()

result = await middleware.on_call_tool(context, call_next)

assert _is_error(result)
text = _get_text(result)
assert 'expired or invalid AWS credentials' in text
assert '--profile' in text

@pytest.mark.asyncio
async def test_non_credential_error_no_suggestion(self):
"""Non-credential errors do not suggest credential remediation."""
middleware = _make_middleware()
call_next = AsyncMock(side_effect=RuntimeError('transport died'))
context = _make_context()

result = await middleware.on_call_tool(context, call_next)

assert _is_error(result)
text = _get_text(result)
assert '--profile' not in text
Loading