Skip to content

Commit 092b144

Browse files
committed
feat: add --tool-timeout flag and ToolTimeoutMiddleware
- Add optional --tool-timeout CLI flag to cap tool call duration - Rename error_handling middleware to ToolTimeoutMiddleware - Return graceful isError=True response instead of hanging - Suggest long-lived credentials on 401/403 errors - Document --tool-timeout in README troubleshooting
1 parent be8506a commit 092b144

File tree

7 files changed

+279
-2
lines changed

7 files changed

+279
-2
lines changed

CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8-
## v1.1.8 (2026-04-02)
8+
## Unreleased
9+
10+
### Fixed
11+
12+
- Simplify error middleware and suggest long-lived AWS credentials on auth errors (#216)
913

1014
### Added
1115

README.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ docker build -t mcp-proxy-for-aws .
107107
| `--connect-timeout` | Set desired connect timeout in seconds | 60 |No |
108108
| `--read-timeout` | Set desired read timeout in seconds | 120 |No |
109109
| `--write-timeout` | Set desired write timeout in seconds | 180 |No |
110+
| `--tool-timeout` | Maximum seconds a tool call may take before being cancelled. When set, returns a graceful error to the agent instead of hanging indefinitely | Not set |No |
110111

111112
### Optional Environment Variables
112113

@@ -325,12 +326,18 @@ uv sync
325326

326327
## Troubleshooting
327328

328-
### Handling `Authentication error - Invalid credentials`
329+
### Authentication errors
329330
We try to autodetect the service from the url, sometimes this fails, ensure that `--service` is set correctly to the
330331
service you are attempting to connect to.
331332
Otherwise the SigV4 signing will not be able to be verified by the service you connect to, resulting in this error.
332333
Also ensure that you have valid IAM credentials on your machine before retrying.
333334

335+
For long-running sessions, consider using long-lived credentials:
336+
- Use an AWS profile via `--profile`
337+
- Use IAM Identity Center and run `aws sso login` before starting the proxy
338+
339+
### Client hangs on tool calls
340+
If your MCP client hangs waiting for a tool call response (e.g., due to expired credentials or an unresponsive endpoint), use `--tool-timeout` to set a maximum duration in seconds for each tool call. When the timeout is exceeded, the proxy returns a graceful error to the agent instead of hanging indefinitely.
334341

335342
## Development & Contributing
336343

mcp_proxy_for_aws/cli.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,4 +159,13 @@ def parse_args():
159159
help='Write timeout (seconds) when connecting to endpoint (default: 180)',
160160
)
161161

162+
parser.add_argument(
163+
'--tool-timeout',
164+
type=within_range(0),
165+
default=None,
166+
help='Maximum seconds a tool call may take before being cancelled. '
167+
'When set, wraps each tool call with a timeout and returns a graceful error '
168+
'to the agent instead of hanging. Not set by default.',
169+
)
170+
162171
return parser.parse_args()
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import anyio
16+
import httpx
17+
import logging
18+
import mcp.types as mt
19+
from fastmcp.server.middleware import CallNext, Middleware, MiddlewareContext
20+
from fastmcp.tools.tool import ToolResult
21+
22+
23+
logger = logging.getLogger(__name__)
24+
25+
26+
class _FailedToolResult(ToolResult):
27+
"""A ToolResult that signals an error via the MCP isError flag."""
28+
29+
def to_mcp_result(self) -> mt.CallToolResult:
30+
return mt.CallToolResult(content=self.content, isError=True)
31+
32+
33+
class ToolTimeoutMiddleware(Middleware):
34+
"""Middleware that ensures tool calls never hang and always return a response.
35+
36+
Implements two layers of protection:
37+
1. Timeout — bounds how long a tool call can take, breaking any hang.
38+
2. Error propagation — catches any error and returns it as a ToolResult
39+
so the agent always gets a response.
40+
41+
Reconnection is handled automatically by fastmcp on every tool call.
42+
"""
43+
44+
def __init__(
45+
self,
46+
tool_call_timeout: float | None = None,
47+
) -> None:
48+
"""Initialize the middleware.
49+
50+
Args:
51+
tool_call_timeout: Maximum seconds a tool call may take before being
52+
cancelled. None means no timeout (not recommended).
53+
"""
54+
super().__init__()
55+
self._tool_call_timeout = tool_call_timeout
56+
57+
async def on_call_tool(
58+
self,
59+
context: MiddlewareContext[mt.CallToolRequestParams],
60+
call_next: CallNext[mt.CallToolRequestParams, ToolResult],
61+
) -> ToolResult:
62+
"""Wrap tool calls with timeout and error handling."""
63+
try:
64+
with anyio.fail_after(self._tool_call_timeout):
65+
return await call_next(context)
66+
except Exception as e:
67+
tool_name = context.message.name
68+
logger.error('Tool call %r failed: %s.', tool_name, e)
69+
message = f'Tool call {tool_name!r} failed: {e}. Please retry.'
70+
if self._is_credential_error(e):
71+
message += (
72+
' This may be caused by expired or invalid AWS credentials.'
73+
' Consider using long-lived credentials such as an AWS profile'
74+
' (--profile) or IAM Identity Center (aws sso login).'
75+
)
76+
return self._error_result(message)
77+
78+
@staticmethod
79+
def _is_credential_error(error: Exception) -> bool:
80+
"""Check if the error is likely caused by expired or invalid credentials."""
81+
return isinstance(error, httpx.HTTPStatusError) and error.response.status_code in (
82+
401,
83+
403,
84+
)
85+
86+
@staticmethod
87+
def _error_result(message: str) -> ToolResult:
88+
return _FailedToolResult(
89+
content=[mt.TextContent(type='text', text=message)],
90+
)

mcp_proxy_for_aws/server.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from mcp_proxy_for_aws.logging_config import configure_logging
3333
from mcp_proxy_for_aws.middleware.initialize_middleware import InitializeMiddleware
3434
from mcp_proxy_for_aws.middleware.tool_filter import ToolFilteringMiddleware
35+
from mcp_proxy_for_aws.middleware.tool_timeout_middleware import ToolTimeoutMiddleware
3536
from mcp_proxy_for_aws.proxy import AWSMCPProxy, AWSMCPProxyClientFactory
3637
from mcp_proxy_for_aws.utils import (
3738
create_transport_with_sigv4,
@@ -95,6 +96,7 @@ async def run_proxy(args) -> None:
9596
),
9697
)
9798
proxy.add_middleware(InitializeMiddleware(client_factory))
99+
add_tool_timeout_middleware(proxy, args.tool_timeout)
98100
add_logging_middleware(proxy, args.log_level)
99101
add_tool_filtering_middleware(proxy, args.read_only)
100102

@@ -108,6 +110,19 @@ async def run_proxy(args) -> None:
108110
await client_factory.disconnect()
109111

110112

113+
def add_tool_timeout_middleware(mcp: FastMCP, tool_timeout: float | None = None) -> None:
114+
"""Add tool timeout middleware if a tool timeout is configured.
115+
116+
Args:
117+
mcp: The FastMCP instance to add the middleware to
118+
tool_timeout: Maximum seconds a tool call may take. None disables the middleware.
119+
"""
120+
if tool_timeout is None:
121+
return
122+
logger.info('Adding tool timeout middleware with tool_timeout=%s', tool_timeout)
123+
mcp.add_middleware(ToolTimeoutMiddleware(tool_call_timeout=tool_timeout))
124+
125+
111126
def add_tool_filtering_middleware(mcp: FastMCP, read_only: bool = False) -> None:
112127
"""Add tool filtering middleware to target MCP server.
113128

tests/unit/test_cli.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def test_parse_args_minimal(self):
3838
assert args.connect_timeout == 60.0
3939
assert args.read_timeout == 120.0
4040
assert args.write_timeout == 180.0
41+
assert args.tool_timeout is None
4142

4243
@patch(
4344
'sys.argv',
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for ToolTimeoutMiddleware."""
16+
17+
import anyio
18+
import httpx
19+
import mcp.types as mt
20+
import pytest
21+
from fastmcp.server.middleware import MiddlewareContext
22+
from fastmcp.tools.tool import ToolResult
23+
from mcp import McpError
24+
from mcp.types import ErrorData
25+
from mcp_proxy_for_aws.middleware.tool_timeout_middleware import (
26+
ToolTimeoutMiddleware,
27+
_FailedToolResult,
28+
)
29+
from typing import Optional
30+
from unittest.mock import AsyncMock, Mock
31+
32+
33+
def _make_context(tool_name: str = 'test_tool') -> MiddlewareContext[mt.CallToolRequestParams]:
34+
"""Create a minimal MiddlewareContext for tool calls."""
35+
params = Mock(spec=mt.CallToolRequestParams)
36+
params.name = tool_name
37+
return MiddlewareContext[mt.CallToolRequestParams](
38+
message=params,
39+
type='request',
40+
method='tools/call',
41+
)
42+
43+
44+
def _make_middleware(tool_call_timeout: Optional[float] = 5.0) -> ToolTimeoutMiddleware:
45+
"""Create a ToolTimeoutMiddleware with mocked dependencies."""
46+
middleware = ToolTimeoutMiddleware(
47+
tool_call_timeout=tool_call_timeout,
48+
)
49+
return middleware
50+
51+
52+
def _get_text(result: ToolResult, index: int = 0) -> str:
53+
"""Extract text from a ToolResult content item."""
54+
content = result.content[index]
55+
assert isinstance(content, mt.TextContent)
56+
return content.text
57+
58+
59+
class TestToolTimeoutMiddleware:
60+
"""Test cases for ToolTimeoutMiddleware."""
61+
62+
@pytest.mark.asyncio
63+
async def test_passes_through_on_success(self):
64+
"""Successful tool calls pass through unchanged."""
65+
middleware = _make_middleware()
66+
expected = ToolResult(content=[mt.TextContent(type='text', text='ok')])
67+
call_next = AsyncMock(return_value=expected)
68+
context = _make_context()
69+
70+
result = await middleware.on_call_tool(context, call_next)
71+
72+
assert result is expected
73+
assert not isinstance(result, _FailedToolResult)
74+
call_next.assert_awaited_once_with(context)
75+
76+
@pytest.mark.asyncio
77+
async def test_catches_exception_returns_error_result(self):
78+
"""Exceptions are caught and returned as error ToolResults."""
79+
middleware = _make_middleware()
80+
call_next = AsyncMock(
81+
side_effect=McpError(ErrorData(code=-1, message='Connection closed'))
82+
)
83+
context = _make_context()
84+
85+
result = await middleware.on_call_tool(context, call_next)
86+
87+
assert isinstance(result, _FailedToolResult)
88+
assert len(result.content) == 1
89+
text = _get_text(result)
90+
assert 'Connection closed' in text
91+
92+
@pytest.mark.asyncio
93+
async def test_timeout_returns_error_result(self):
94+
"""Tool calls that exceed the timeout return an error ToolResult."""
95+
middleware = _make_middleware(tool_call_timeout=0.1)
96+
97+
async def hang_forever(context: MiddlewareContext[mt.CallToolRequestParams]) -> ToolResult:
98+
await anyio.sleep(999)
99+
return ToolResult(content=[]) # unreachable
100+
101+
context = _make_context(tool_name='slow_tool')
102+
103+
result = await middleware.on_call_tool(context, hang_forever)
104+
105+
assert isinstance(result, _FailedToolResult)
106+
assert len(result.content) == 1
107+
text = _get_text(result)
108+
assert 'slow_tool' in text
109+
110+
@pytest.mark.asyncio
111+
async def test_credential_error_suggests_profile(self):
112+
"""Credential errors suggest using long-lived credentials."""
113+
middleware = _make_middleware()
114+
response = Mock(spec=httpx.Response)
115+
response.status_code = 401
116+
call_next = AsyncMock(
117+
side_effect=httpx.HTTPStatusError('Unauthorized', request=Mock(), response=response)
118+
)
119+
context = _make_context()
120+
121+
result = await middleware.on_call_tool(context, call_next)
122+
123+
assert isinstance(result, _FailedToolResult)
124+
text = _get_text(result)
125+
assert 'expired or invalid AWS credentials' in text
126+
assert '--profile' in text
127+
128+
@pytest.mark.asyncio
129+
async def test_non_credential_error_no_suggestion(self):
130+
"""Non-credential errors do not suggest credential remediation."""
131+
middleware = _make_middleware()
132+
call_next = AsyncMock(side_effect=RuntimeError('transport died'))
133+
context = _make_context()
134+
135+
result = await middleware.on_call_tool(context, call_next)
136+
137+
assert isinstance(result, _FailedToolResult)
138+
text = _get_text(result)
139+
assert '--profile' not in text
140+
141+
@pytest.mark.asyncio
142+
async def test_no_timeout_when_none(self):
143+
"""When tool_call_timeout is None, no timeout is applied."""
144+
middleware = _make_middleware(tool_call_timeout=None)
145+
expected = ToolResult(content=[mt.TextContent(type='text', text='ok')])
146+
call_next = AsyncMock(return_value=expected)
147+
context = _make_context()
148+
149+
result = await middleware.on_call_tool(context, call_next)
150+
151+
assert result is expected

0 commit comments

Comments
 (0)