-
Notifications
You must be signed in to change notification settings - Fork 38
fix: stale credentials issue #216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
anasstahr
wants to merge
10
commits into
main
Choose a base branch
from
fix/stale-credentials-issue
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 7 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
092b144
feat: add --tool-timeout flag and ToolTimeoutMiddleware
anasstahr aec7e88
Merge branch 'main' into fix/stale-credentials-issue
anasstahr ec3268f
refactor: rename ToolTimeoutMiddleware to ToolErrorMiddleware
anasstahr 56dc2e7
fix: sort imports in server.py
anasstahr 8c543fe
fix: remove Optional timeout, decouple tests from private class
anasstahr 0ad9953
fix: handle non-CallToolResult in success test
anasstahr e43fa1a
fix: remove trailing blank line in test file
anasstahr f4d3f8b
refactor: rename --tool-error-timeout to --tool-timeout
anasstahr a224341
merge: resolve CHANGELOG conflict with main
anasstahr 734c25e
refactor: raise ToolError instead of custom _FailedToolResult subclass
anasstahr File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import anyio | ||
| import httpx | ||
| import logging | ||
| import mcp.types as mt | ||
| from fastmcp.server.middleware import CallNext, Middleware, MiddlewareContext | ||
| from fastmcp.tools.tool import ToolResult | ||
|
|
||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class _FailedToolResult(ToolResult): | ||
anasstahr marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """A ToolResult that signals an error via the MCP isError flag.""" | ||
|
|
||
| def to_mcp_result(self) -> mt.CallToolResult: | ||
| return mt.CallToolResult(content=self.content, isError=True) | ||
|
|
||
|
|
||
| class ToolErrorMiddleware(Middleware): | ||
| """Middleware that ensures tool calls never hang and always return a response. | ||
|
|
||
| Implements two layers of protection: | ||
| 1. Timeout — bounds how long a tool call can take, breaking any hang. | ||
| 2. Error propagation — catches any error and returns it as a ToolResult | ||
| so the agent always gets a response. | ||
|
|
||
| Reconnection is handled automatically by fastmcp on every tool call. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| tool_call_timeout: float = 300.0, | ||
| ) -> None: | ||
| """Initialize the middleware. | ||
|
|
||
| Args: | ||
| tool_call_timeout: Maximum seconds a tool call may take before being | ||
| cancelled. | ||
| """ | ||
| super().__init__() | ||
| self._tool_call_timeout = tool_call_timeout | ||
|
|
||
| async def on_call_tool( | ||
| self, | ||
| context: MiddlewareContext[mt.CallToolRequestParams], | ||
| call_next: CallNext[mt.CallToolRequestParams, ToolResult], | ||
| ) -> ToolResult: | ||
| """Wrap tool calls with timeout and error handling.""" | ||
| try: | ||
| with anyio.fail_after(self._tool_call_timeout): | ||
| return await call_next(context) | ||
| except Exception as e: | ||
| tool_name = context.message.name | ||
| logger.error('Tool call %r failed: %s.', tool_name, e) | ||
| message = f'Tool call {tool_name!r} failed: {e}. Please retry.' | ||
| if self._is_credential_error(e): | ||
| message += ( | ||
| ' This may be caused by expired or invalid AWS credentials.' | ||
| ' Consider using long-lived credentials such as an AWS profile' | ||
| ' (--profile) or IAM Identity Center (aws sso login).' | ||
| ) | ||
| return self._error_result(message) | ||
|
|
||
| @staticmethod | ||
| def _is_credential_error(error: Exception) -> bool: | ||
| """Check if the error is likely caused by expired or invalid credentials.""" | ||
| return isinstance(error, httpx.HTTPStatusError) and error.response.status_code in ( | ||
| 401, | ||
| 403, | ||
arnewouters marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) | ||
|
|
||
| @staticmethod | ||
| def _error_result(message: str) -> ToolResult: | ||
| return _FailedToolResult( | ||
| content=[mt.TextContent(type='text', text=message)], | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,143 @@ | ||
| # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Unit tests for ToolErrorMiddleware.""" | ||
|
|
||
| import anyio | ||
| import httpx | ||
| import mcp.types as mt | ||
| import pytest | ||
| from fastmcp.server.middleware import MiddlewareContext | ||
| from fastmcp.tools.tool import ToolResult | ||
| from mcp import McpError | ||
| from mcp.types import ErrorData | ||
| from mcp_proxy_for_aws.middleware.tool_error_middleware import ToolErrorMiddleware | ||
| from unittest.mock import AsyncMock, Mock | ||
|
|
||
|
|
||
| def _make_context(tool_name: str = 'test_tool') -> MiddlewareContext[mt.CallToolRequestParams]: | ||
| """Create a minimal MiddlewareContext for tool calls.""" | ||
| params = Mock(spec=mt.CallToolRequestParams) | ||
| params.name = tool_name | ||
| return MiddlewareContext[mt.CallToolRequestParams]( | ||
| message=params, | ||
| type='request', | ||
| method='tools/call', | ||
| ) | ||
|
|
||
|
|
||
| def _make_middleware(tool_call_timeout: float = 5.0) -> ToolErrorMiddleware: | ||
| """Create a ToolErrorMiddleware with mocked dependencies.""" | ||
| middleware = ToolErrorMiddleware( | ||
| tool_call_timeout=tool_call_timeout, | ||
| ) | ||
| return middleware | ||
|
|
||
|
|
||
| def _is_error(result: ToolResult) -> bool: | ||
| """Check if a ToolResult has the MCP isError flag set.""" | ||
| mcp_result = result.to_mcp_result() | ||
| assert isinstance(mcp_result, mt.CallToolResult) | ||
| return bool(mcp_result.isError) | ||
|
|
||
|
|
||
| def _get_text(result: ToolResult, index: int = 0) -> str: | ||
| """Extract text from a ToolResult content item.""" | ||
| content = result.content[index] | ||
| assert isinstance(content, mt.TextContent) | ||
| return content.text | ||
|
|
||
|
|
||
| class TestToolErrorMiddleware: | ||
| """Test cases for ToolErrorMiddleware.""" | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_passes_through_on_success(self): | ||
| """Successful tool calls pass through unchanged.""" | ||
| middleware = _make_middleware() | ||
| expected = ToolResult(content=[mt.TextContent(type='text', text='ok')]) | ||
| call_next = AsyncMock(return_value=expected) | ||
| context = _make_context() | ||
|
|
||
| result = await middleware.on_call_tool(context, call_next) | ||
|
|
||
| assert result is expected | ||
| mcp_result = result.to_mcp_result() | ||
| assert not isinstance(mcp_result, mt.CallToolResult) or not mcp_result.isError | ||
| call_next.assert_awaited_once_with(context) | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_catches_exception_returns_error_result(self): | ||
| """Exceptions are caught and returned as error ToolResults.""" | ||
| middleware = _make_middleware() | ||
| call_next = AsyncMock( | ||
| side_effect=McpError(ErrorData(code=-1, message='Connection closed')) | ||
| ) | ||
| context = _make_context() | ||
|
|
||
| result = await middleware.on_call_tool(context, call_next) | ||
|
|
||
| assert _is_error(result) | ||
| assert len(result.content) == 1 | ||
| text = _get_text(result) | ||
| assert 'Connection closed' in text | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_timeout_returns_error_result(self): | ||
| """Tool calls that exceed the timeout return an error ToolResult.""" | ||
| middleware = _make_middleware(tool_call_timeout=0.1) | ||
|
|
||
| async def hang_forever(context: MiddlewareContext[mt.CallToolRequestParams]) -> ToolResult: | ||
| await anyio.sleep(999) | ||
| return ToolResult(content=[]) # unreachable | ||
|
|
||
| context = _make_context(tool_name='slow_tool') | ||
|
|
||
| result = await middleware.on_call_tool(context, hang_forever) | ||
|
|
||
| assert _is_error(result) | ||
| assert len(result.content) == 1 | ||
| text = _get_text(result) | ||
| assert 'slow_tool' in text | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_credential_error_suggests_profile(self): | ||
| """Credential errors suggest using long-lived credentials.""" | ||
| middleware = _make_middleware() | ||
| response = Mock(spec=httpx.Response) | ||
| response.status_code = 401 | ||
| call_next = AsyncMock( | ||
| side_effect=httpx.HTTPStatusError('Unauthorized', request=Mock(), response=response) | ||
| ) | ||
| context = _make_context() | ||
|
|
||
| result = await middleware.on_call_tool(context, call_next) | ||
|
|
||
| assert _is_error(result) | ||
| text = _get_text(result) | ||
| assert 'expired or invalid AWS credentials' in text | ||
| assert '--profile' in text | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_non_credential_error_no_suggestion(self): | ||
| """Non-credential errors do not suggest credential remediation.""" | ||
| middleware = _make_middleware() | ||
| call_next = AsyncMock(side_effect=RuntimeError('transport died')) | ||
| context = _make_context() | ||
|
|
||
| result = await middleware.on_call_tool(context, call_next) | ||
|
|
||
| assert _is_error(result) | ||
| text = _get_text(result) | ||
| assert '--profile' not in text |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.