Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## v1.1.8 (2026-04-02)
## Unreleased

### Added

- Build and publish container image (#126)

### Fixed

- Use new streamable http client (#228)
- Add URL scheme validation to prevent credential interception (#169)
- Prevent credential exposure in logs (#167)
- Replace failing integ test (#178)
Expand Down
19 changes: 10 additions & 9 deletions mcp_proxy_for_aws/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# limitations under the License.

import boto3
import httpx
import logging
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from botocore.credentials import Credentials
from contextlib import _AsyncGeneratorContextManager
from datetime import timedelta
from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
from mcp.client.streamable_http import GetSessionIdCallback, streamable_http_client
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
from mcp.shared.message import SessionMessage
from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth
Expand Down Expand Up @@ -118,13 +119,13 @@ def aws_iam_streamablehttp_client(
# Create a SigV4 authentication handler with AWS credentials
auth = SigV4HTTPXAuth(creds, aws_service, region)

# Create the HTTP client with authentication and configuration
httpx_timeout = httpx.Timeout(
timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
)
http_client = httpx_client_factory(headers=headers, timeout=httpx_timeout, auth=auth)

# Return the streamable HTTP client context manager with AWS IAM authentication
return streamablehttp_client(
url=endpoint,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
terminate_on_close=terminate_on_close,
httpx_client_factory=httpx_client_factory,
auth=auth,
return streamable_http_client(
url=endpoint, http_client=http_client, terminate_on_close=terminate_on_close
)
4 changes: 2 additions & 2 deletions mcp_proxy_for_aws/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
from fastmcp import Client
from fastmcp.client.transports import ClientTransport
from fastmcp.server.providers.proxy import ProxyClient as _ProxyClient
from fastmcp.server.providers.proxy import StatefulProxyClient
from mcp import McpError
from mcp.types import InitializeRequest, JSONRPCError, JSONRPCMessage
from typing_extensions import override
Expand All @@ -25,7 +25,7 @@
logger = logging.getLogger(__name__)


class AWSMCPProxyClient(_ProxyClient):
class AWSMCPProxyClient(StatefulProxyClient):
"""Proxy client that handles HTTP errors when connection fails."""

def __init__(self, transport: ClientTransport, max_connect_retry=3, **kwargs):
Expand Down
64 changes: 42 additions & 22 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def test_boto3_session_parameters(
mock_read, mock_write, mock_get_session = mock_streams

with patch('boto3.Session', return_value=mock_session) as mock_boto:
with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client:
with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client:
mock_stream_client.return_value.__aenter__ = AsyncMock(
return_value=(mock_read, mock_write, mock_get_session)
)
Expand Down Expand Up @@ -94,9 +94,11 @@ async def test_sigv4_auth_is_created_and_used(mock_session, mock_streams, servic

with patch('boto3.Session', return_value=mock_session):
with patch('mcp_proxy_for_aws.client.SigV4HTTPXAuth') as mock_auth_cls:
with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client:
with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client:
mock_auth = Mock()
mock_auth_cls.return_value = mock_auth
mock_http_client = Mock()
mock_factory = Mock(return_value=mock_http_client)
mock_stream_client.return_value.__aenter__ = AsyncMock(
return_value=(mock_read, mock_write, mock_get_session)
)
Expand All @@ -106,17 +108,20 @@ async def test_sigv4_auth_is_created_and_used(mock_session, mock_streams, servic
endpoint='https://test.example.com/mcp',
aws_service=service_name,
aws_region=region,
httpx_client_factory=mock_factory,
):
pass

mock_auth_cls.assert_called_once_with(
# Auth should be constructed with the resolved credentials, service, and region,
# and passed into the streamable client.
# Auth should be constructed with the resolved credentials, service, and region
mock_session.get_credentials.return_value,
service_name,
region,
)
assert mock_stream_client.call_args[1]['auth'] is mock_auth
# Auth should be passed to the httpx client factory
assert mock_factory.call_args[1]['auth'] is mock_auth
# The created http client should be passed to streamable_http_client
assert mock_stream_client.call_args[1]['http_client'] is mock_http_client


@pytest.mark.asyncio
Expand All @@ -132,12 +137,14 @@ async def test_streamable_client_parameters(
mock_session, mock_streams, headers, timeout_value, sse_value, terminate_value
):
"""Test the correctness of streamablehttp_client parameters."""
# Verify that connection settings are forwarded as-is to the streamable HTTP client.
# timedelta values are allowed and compared directly here.
# Verify that connection settings are forwarded correctly to the httpx client factory
# and streamable HTTP client.
mock_read, mock_write, mock_get_session = mock_streams

with patch('boto3.Session', return_value=mock_session):
with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client:
with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client:
mock_http_client = Mock()
mock_factory = Mock(return_value=mock_http_client)
mock_stream_client.return_value.__aenter__ = AsyncMock(
return_value=(mock_read, mock_write, mock_get_session)
)
Expand All @@ -150,27 +157,34 @@ async def test_streamable_client_parameters(
timeout=timeout_value,
sse_read_timeout=sse_value,
terminate_on_close=terminate_value,
httpx_client_factory=mock_factory,
):
pass

call_kwargs = mock_stream_client.call_args[1]
# Confirm each parameter is forwarded unchanged.
assert call_kwargs['url'] == 'https://test.example.com/mcp'
assert call_kwargs['headers'] == headers
assert call_kwargs['timeout'] == timeout_value
assert call_kwargs['sse_read_timeout'] == sse_value
assert call_kwargs['terminate_on_close'] == terminate_value
# Verify headers and auth are passed to the factory
factory_call_kwargs = mock_factory.call_args[1]
assert factory_call_kwargs['headers'] == headers
# Timeout is passed to the factory (converted to httpx.Timeout)
assert factory_call_kwargs['timeout'] is not None

# Verify the created http client and other params are passed to streamable_http_client
stream_call_kwargs = mock_stream_client.call_args[1]
assert stream_call_kwargs['url'] == 'https://test.example.com/mcp'
assert stream_call_kwargs['http_client'] is mock_http_client
assert stream_call_kwargs['terminate_on_close'] == terminate_value


@pytest.mark.asyncio
async def test_custom_httpx_client_factory_is_passed(mock_session, mock_streams):
"""Test the passing of a custom HTTPX client factory."""
# The factory should be handed through to the underlying streamable client untouched.
# The factory should be used to create the http client.
mock_read, mock_write, mock_get_session = mock_streams
custom_factory = Mock()
mock_http_client = Mock()
custom_factory.return_value = mock_http_client

with patch('boto3.Session', return_value=mock_session):
with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client:
with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client:
mock_stream_client.return_value.__aenter__ = AsyncMock(
return_value=(mock_read, mock_write, mock_get_session)
)
Expand All @@ -183,7 +197,10 @@ async def test_custom_httpx_client_factory_is_passed(mock_session, mock_streams)
):
pass

assert mock_stream_client.call_args[1]['httpx_client_factory'] is custom_factory
# Verify the custom factory was called
custom_factory.assert_called_once()
# Verify the http client from the factory was passed to streamable_http_client
assert mock_stream_client.call_args[1]['http_client'] is mock_http_client


@pytest.mark.asyncio
Expand All @@ -198,7 +215,7 @@ async def mock_aexit(*_):
cleanup_called = True

with patch('boto3.Session', return_value=mock_session):
with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client:
with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client:
mock_stream_client.return_value.__aenter__ = AsyncMock(
return_value=(mock_read, mock_write, mock_get_session)
)
Expand All @@ -220,7 +237,7 @@ async def test_credentials_parameter_with_region(mock_streams):
creds = Credentials('test_key', 'test_secret', 'test_token')

with patch('mcp_proxy_for_aws.client.SigV4HTTPXAuth') as mock_auth_cls:
with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client:
with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client:
mock_auth = Mock()
mock_auth_cls.return_value = mock_auth
mock_stream_client.return_value.__aenter__ = AsyncMock(
Expand Down Expand Up @@ -264,7 +281,7 @@ async def test_credentials_parameter_bypasses_boto3_session(mock_streams):

with patch('boto3.Session') as mock_boto:
with patch('mcp_proxy_for_aws.client.SigV4HTTPXAuth'):
with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client:
with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client:
mock_stream_client.return_value.__aenter__ = AsyncMock(
return_value=(mock_read, mock_write, mock_get_session)
)
Expand Down Expand Up @@ -303,7 +320,9 @@ async def test_http_localhost_endpoint_allowed(mock_session, mock_streams):
mock_read, mock_write, mock_get_session = mock_streams

with patch('boto3.Session', return_value=mock_session):
with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client:
with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client:
mock_http_client = Mock()
mock_factory = Mock(return_value=mock_http_client)
mock_stream_client.return_value.__aenter__ = AsyncMock(
return_value=(mock_read, mock_write, mock_get_session)
)
Expand All @@ -313,5 +332,6 @@ async def test_http_localhost_endpoint_allowed(mock_session, mock_streams):
async with aws_iam_streamablehttp_client(
endpoint='http://localhost:8080/mcp',
aws_service='bedrock-agentcore',
httpx_client_factory=mock_factory,
):
pass
13 changes: 7 additions & 6 deletions tests/unit/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ async def test_proxy_client_connect_success():
mock_transport = Mock(spec=ClientTransport)
client = AWSMCPProxyClient(mock_transport)

with patch('mcp_proxy_for_aws.proxy._ProxyClient._connect', return_value='connected'):
with patch('mcp_proxy_for_aws.proxy.StatefulProxyClient._connect', return_value='connected'):
result = await client._connect()
assert result == 'connected'

Expand All @@ -51,7 +51,7 @@ async def test_proxy_client_connect_http_error_with_mcp_error():

http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response)

with patch('mcp_proxy_for_aws.proxy._ProxyClient._connect', side_effect=http_error):
with patch('mcp_proxy_for_aws.proxy.StatefulProxyClient._connect', side_effect=http_error):
with pytest.raises(McpError) as exc_info:
await client._connect()
assert exc_info.value.error.code == -32600
Expand All @@ -69,7 +69,7 @@ async def test_proxy_client_connect_http_error_non_mcp():

http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response)

with patch('mcp_proxy_for_aws.proxy._ProxyClient._connect', side_effect=http_error):
with patch('mcp_proxy_for_aws.proxy.StatefulProxyClient._connect', side_effect=http_error):
with pytest.raises(httpx.HTTPStatusError):
await client._connect()

Expand Down Expand Up @@ -180,7 +180,7 @@ async def test_proxy_client_connect_runtime_error_with_mcp_error():
runtime_error = RuntimeError('Connection failed')
runtime_error.__cause__ = mcp_error

with patch('mcp_proxy_for_aws.proxy._ProxyClient._connect', side_effect=runtime_error):
with patch('mcp_proxy_for_aws.proxy.StatefulProxyClient._connect', side_effect=runtime_error):
with pytest.raises(McpError) as exc_info:
await client._connect()
assert exc_info.value.error.code == -32600
Expand All @@ -194,7 +194,7 @@ async def test_proxy_client_connect_runtime_error_max_retries():

runtime_error = RuntimeError('Connection failed')

with patch('mcp_proxy_for_aws.proxy._ProxyClient._connect', side_effect=runtime_error):
with patch('mcp_proxy_for_aws.proxy.StatefulProxyClient._connect', side_effect=runtime_error):
with patch.object(client, '_disconnect', new_callable=AsyncMock) as mock_disconnect:
with pytest.raises(RuntimeError):
await client._connect()
Expand All @@ -218,7 +218,8 @@ async def mock_connect_side_effect(*args, **kwargs):
return 'connected'

with patch(
'mcp_proxy_for_aws.proxy._ProxyClient._connect', side_effect=mock_connect_side_effect
'mcp_proxy_for_aws.proxy.StatefulProxyClient._connect',
side_effect=mock_connect_side_effect,
):
with patch.object(
client,
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading