From 6f516969cbf8465c2f48387dae81d6d2f8fce2bc Mon Sep 17 00:00:00 2001 From: Anass Taher Date: Tue, 7 Apr 2026 17:25:15 +0200 Subject: [PATCH] fix: use new streamable http client and fix elicitation forwarding - Replace deprecated streamablehttp_client with streamable_http_client - Use StatefulProxyClient to fix elicitation, sampling, and other server-initiated requests failing through the proxy --- CHANGELOG.md | 3 +- mcp_proxy_for_aws/client.py | 19 +++++------ mcp_proxy_for_aws/proxy.py | 4 +-- tests/unit/test_client.py | 64 ++++++++++++++++++++++++------------- tests/unit/test_proxy.py | 13 ++++---- uv.lock | 2 +- 6 files changed, 64 insertions(+), 41 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 416c01c..bee750a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,7 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## v1.1.8 (2026-04-02) +## Unreleased ### Added @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Use new streamable http client (#228) - Add URL scheme validation to prevent credential interception (#169) - Prevent credential exposure in logs (#167) - Replace failing integ test (#178) diff --git a/mcp_proxy_for_aws/client.py b/mcp_proxy_for_aws/client.py index 5ecc689..ff16c44 100644 --- a/mcp_proxy_for_aws/client.py +++ b/mcp_proxy_for_aws/client.py @@ -13,12 +13,13 @@ # limitations under the License. import boto3 +import httpx import logging from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from botocore.credentials import Credentials from contextlib import _AsyncGeneratorContextManager from datetime import timedelta -from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client +from mcp.client.streamable_http import GetSessionIdCallback, streamable_http_client from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client from mcp.shared.message import SessionMessage from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth @@ -118,13 +119,13 @@ def aws_iam_streamablehttp_client( # Create a SigV4 authentication handler with AWS credentials auth = SigV4HTTPXAuth(creds, aws_service, region) + # Create the HTTP client with authentication and configuration + httpx_timeout = httpx.Timeout( + timeout.total_seconds() if isinstance(timeout, timedelta) else timeout + ) + http_client = httpx_client_factory(headers=headers, timeout=httpx_timeout, auth=auth) + # Return the streamable HTTP client context manager with AWS IAM authentication - return streamablehttp_client( - url=endpoint, - headers=headers, - timeout=timeout, - sse_read_timeout=sse_read_timeout, - terminate_on_close=terminate_on_close, - httpx_client_factory=httpx_client_factory, - auth=auth, + return streamable_http_client( + url=endpoint, http_client=http_client, terminate_on_close=terminate_on_close ) diff --git a/mcp_proxy_for_aws/proxy.py b/mcp_proxy_for_aws/proxy.py index b4c3db5..e632ba3 100644 --- a/mcp_proxy_for_aws/proxy.py +++ b/mcp_proxy_for_aws/proxy.py @@ -16,7 +16,7 @@ import logging from fastmcp import Client from fastmcp.client.transports import ClientTransport -from fastmcp.server.providers.proxy import ProxyClient as _ProxyClient +from fastmcp.server.providers.proxy import StatefulProxyClient from mcp import McpError from mcp.types import InitializeRequest, JSONRPCError, JSONRPCMessage from typing_extensions import override @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) -class AWSMCPProxyClient(_ProxyClient): +class AWSMCPProxyClient(StatefulProxyClient): """Proxy client that handles HTTP errors when connection fails.""" def __init__(self, transport: ClientTransport, max_connect_retry=3, **kwargs): diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 1e2ddab..323d1eb 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -60,7 +60,7 @@ async def test_boto3_session_parameters( mock_read, mock_write, mock_get_session = mock_streams with patch('boto3.Session', return_value=mock_session) as mock_boto: - with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client: + with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client: mock_stream_client.return_value.__aenter__ = AsyncMock( return_value=(mock_read, mock_write, mock_get_session) ) @@ -94,9 +94,11 @@ async def test_sigv4_auth_is_created_and_used(mock_session, mock_streams, servic with patch('boto3.Session', return_value=mock_session): with patch('mcp_proxy_for_aws.client.SigV4HTTPXAuth') as mock_auth_cls: - with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client: + with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client: mock_auth = Mock() mock_auth_cls.return_value = mock_auth + mock_http_client = Mock() + mock_factory = Mock(return_value=mock_http_client) mock_stream_client.return_value.__aenter__ = AsyncMock( return_value=(mock_read, mock_write, mock_get_session) ) @@ -106,17 +108,20 @@ async def test_sigv4_auth_is_created_and_used(mock_session, mock_streams, servic endpoint='https://test.example.com/mcp', aws_service=service_name, aws_region=region, + httpx_client_factory=mock_factory, ): pass mock_auth_cls.assert_called_once_with( - # Auth should be constructed with the resolved credentials, service, and region, - # and passed into the streamable client. + # Auth should be constructed with the resolved credentials, service, and region mock_session.get_credentials.return_value, service_name, region, ) - assert mock_stream_client.call_args[1]['auth'] is mock_auth + # Auth should be passed to the httpx client factory + assert mock_factory.call_args[1]['auth'] is mock_auth + # The created http client should be passed to streamable_http_client + assert mock_stream_client.call_args[1]['http_client'] is mock_http_client @pytest.mark.asyncio @@ -132,12 +137,14 @@ async def test_streamable_client_parameters( mock_session, mock_streams, headers, timeout_value, sse_value, terminate_value ): """Test the correctness of streamablehttp_client parameters.""" - # Verify that connection settings are forwarded as-is to the streamable HTTP client. - # timedelta values are allowed and compared directly here. + # Verify that connection settings are forwarded correctly to the httpx client factory + # and streamable HTTP client. mock_read, mock_write, mock_get_session = mock_streams with patch('boto3.Session', return_value=mock_session): - with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client: + with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client: + mock_http_client = Mock() + mock_factory = Mock(return_value=mock_http_client) mock_stream_client.return_value.__aenter__ = AsyncMock( return_value=(mock_read, mock_write, mock_get_session) ) @@ -150,27 +157,34 @@ async def test_streamable_client_parameters( timeout=timeout_value, sse_read_timeout=sse_value, terminate_on_close=terminate_value, + httpx_client_factory=mock_factory, ): pass - call_kwargs = mock_stream_client.call_args[1] - # Confirm each parameter is forwarded unchanged. - assert call_kwargs['url'] == 'https://test.example.com/mcp' - assert call_kwargs['headers'] == headers - assert call_kwargs['timeout'] == timeout_value - assert call_kwargs['sse_read_timeout'] == sse_value - assert call_kwargs['terminate_on_close'] == terminate_value + # Verify headers and auth are passed to the factory + factory_call_kwargs = mock_factory.call_args[1] + assert factory_call_kwargs['headers'] == headers + # Timeout is passed to the factory (converted to httpx.Timeout) + assert factory_call_kwargs['timeout'] is not None + + # Verify the created http client and other params are passed to streamable_http_client + stream_call_kwargs = mock_stream_client.call_args[1] + assert stream_call_kwargs['url'] == 'https://test.example.com/mcp' + assert stream_call_kwargs['http_client'] is mock_http_client + assert stream_call_kwargs['terminate_on_close'] == terminate_value @pytest.mark.asyncio async def test_custom_httpx_client_factory_is_passed(mock_session, mock_streams): """Test the passing of a custom HTTPX client factory.""" - # The factory should be handed through to the underlying streamable client untouched. + # The factory should be used to create the http client. mock_read, mock_write, mock_get_session = mock_streams custom_factory = Mock() + mock_http_client = Mock() + custom_factory.return_value = mock_http_client with patch('boto3.Session', return_value=mock_session): - with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client: + with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client: mock_stream_client.return_value.__aenter__ = AsyncMock( return_value=(mock_read, mock_write, mock_get_session) ) @@ -183,7 +197,10 @@ async def test_custom_httpx_client_factory_is_passed(mock_session, mock_streams) ): pass - assert mock_stream_client.call_args[1]['httpx_client_factory'] is custom_factory + # Verify the custom factory was called + custom_factory.assert_called_once() + # Verify the http client from the factory was passed to streamable_http_client + assert mock_stream_client.call_args[1]['http_client'] is mock_http_client @pytest.mark.asyncio @@ -198,7 +215,7 @@ async def mock_aexit(*_): cleanup_called = True with patch('boto3.Session', return_value=mock_session): - with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client: + with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client: mock_stream_client.return_value.__aenter__ = AsyncMock( return_value=(mock_read, mock_write, mock_get_session) ) @@ -220,7 +237,7 @@ async def test_credentials_parameter_with_region(mock_streams): creds = Credentials('test_key', 'test_secret', 'test_token') with patch('mcp_proxy_for_aws.client.SigV4HTTPXAuth') as mock_auth_cls: - with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client: + with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client: mock_auth = Mock() mock_auth_cls.return_value = mock_auth mock_stream_client.return_value.__aenter__ = AsyncMock( @@ -264,7 +281,7 @@ async def test_credentials_parameter_bypasses_boto3_session(mock_streams): with patch('boto3.Session') as mock_boto: with patch('mcp_proxy_for_aws.client.SigV4HTTPXAuth'): - with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client: + with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client: mock_stream_client.return_value.__aenter__ = AsyncMock( return_value=(mock_read, mock_write, mock_get_session) ) @@ -303,7 +320,9 @@ async def test_http_localhost_endpoint_allowed(mock_session, mock_streams): mock_read, mock_write, mock_get_session = mock_streams with patch('boto3.Session', return_value=mock_session): - with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client: + with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client: + mock_http_client = Mock() + mock_factory = Mock(return_value=mock_http_client) mock_stream_client.return_value.__aenter__ = AsyncMock( return_value=(mock_read, mock_write, mock_get_session) ) @@ -313,5 +332,6 @@ async def test_http_localhost_endpoint_allowed(mock_session, mock_streams): async with aws_iam_streamablehttp_client( endpoint='http://localhost:8080/mcp', aws_service='bedrock-agentcore', + httpx_client_factory=mock_factory, ): pass diff --git a/tests/unit/test_proxy.py b/tests/unit/test_proxy.py index 326734b..0458882 100644 --- a/tests/unit/test_proxy.py +++ b/tests/unit/test_proxy.py @@ -32,7 +32,7 @@ async def test_proxy_client_connect_success(): mock_transport = Mock(spec=ClientTransport) client = AWSMCPProxyClient(mock_transport) - with patch('mcp_proxy_for_aws.proxy._ProxyClient._connect', return_value='connected'): + with patch('mcp_proxy_for_aws.proxy.StatefulProxyClient._connect', return_value='connected'): result = await client._connect() assert result == 'connected' @@ -51,7 +51,7 @@ async def test_proxy_client_connect_http_error_with_mcp_error(): http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response) - with patch('mcp_proxy_for_aws.proxy._ProxyClient._connect', side_effect=http_error): + with patch('mcp_proxy_for_aws.proxy.StatefulProxyClient._connect', side_effect=http_error): with pytest.raises(McpError) as exc_info: await client._connect() assert exc_info.value.error.code == -32600 @@ -69,7 +69,7 @@ async def test_proxy_client_connect_http_error_non_mcp(): http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response) - with patch('mcp_proxy_for_aws.proxy._ProxyClient._connect', side_effect=http_error): + with patch('mcp_proxy_for_aws.proxy.StatefulProxyClient._connect', side_effect=http_error): with pytest.raises(httpx.HTTPStatusError): await client._connect() @@ -180,7 +180,7 @@ async def test_proxy_client_connect_runtime_error_with_mcp_error(): runtime_error = RuntimeError('Connection failed') runtime_error.__cause__ = mcp_error - with patch('mcp_proxy_for_aws.proxy._ProxyClient._connect', side_effect=runtime_error): + with patch('mcp_proxy_for_aws.proxy.StatefulProxyClient._connect', side_effect=runtime_error): with pytest.raises(McpError) as exc_info: await client._connect() assert exc_info.value.error.code == -32600 @@ -194,7 +194,7 @@ async def test_proxy_client_connect_runtime_error_max_retries(): runtime_error = RuntimeError('Connection failed') - with patch('mcp_proxy_for_aws.proxy._ProxyClient._connect', side_effect=runtime_error): + with patch('mcp_proxy_for_aws.proxy.StatefulProxyClient._connect', side_effect=runtime_error): with patch.object(client, '_disconnect', new_callable=AsyncMock) as mock_disconnect: with pytest.raises(RuntimeError): await client._connect() @@ -218,7 +218,8 @@ async def mock_connect_side_effect(*args, **kwargs): return 'connected' with patch( - 'mcp_proxy_for_aws.proxy._ProxyClient._connect', side_effect=mock_connect_side_effect + 'mcp_proxy_for_aws.proxy.StatefulProxyClient._connect', + side_effect=mock_connect_side_effect, ): with patch.object( client, diff --git a/uv.lock b/uv.lock index 2d6f89b..41bc1a4 100644 --- a/uv.lock +++ b/uv.lock @@ -2942,7 +2942,7 @@ dev = [ requires-dist = [ { name = "boto3", specifier = ">=1.41.0" }, { name = "botocore", extras = ["crt"], specifier = ">=1.41.0" }, - { name = "fastmcp", specifier = ">=3.2.0,<4" }, + { name = "fastmcp", specifier = "~=3.2.0" }, ] [package.metadata.requires-dev]