Skip to content

Commit 1bea695

Browse files
committed
fix: add kwargs to http client factory
1 parent ab72798 commit 1bea695

File tree

3 files changed

+40
-0
lines changed

3 files changed

+40
-0
lines changed

mcp_proxy_for_aws/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def client_factory(
5757
headers: Optional[Dict[str, str]] = None,
5858
timeout: Optional[httpx.Timeout] = None,
5959
auth: Optional[httpx.Auth] = None,
60+
**kw,
6061
) -> httpx.AsyncClient:
6162
return create_sigv4_client(
6263
service=service,
@@ -66,6 +67,7 @@ def client_factory(
6667
timeout=custom_timeout,
6768
metadata=metadata,
6869
auth=auth,
70+
**kw,
6971
)
7072

7173
return StreamableHttpTransport(

tests/integ/mcp/simple_mcp_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ def _build_args(endpoint: str, region_name: str, metadata: Optional[Dict[str, st
8989
'DEBUG',
9090
'--region',
9191
region_name,
92+
'--profile',
93+
'github-integ',
9294
]
9395

9496
# Add metadata arguments if provided

tests/unit/test_utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,42 @@ def test_create_transport_with_sigv4_no_profile(
117117
# If we can't access the factory directly, just verify the transport was created
118118
assert result is not None
119119

120+
@patch('mcp_proxy_for_aws.utils.create_aws_session')
121+
@patch('mcp_proxy_for_aws.utils.create_sigv4_client')
122+
def test_create_transport_with_sigv4_kwargs_passthrough(
123+
self, mock_create_sigv4_client, mock_create_session
124+
):
125+
"""Test that kwargs are passed through to create_sigv4_client."""
126+
from httpx import Timeout
127+
128+
mock_session = MagicMock()
129+
mock_create_session.return_value = mock_session
130+
131+
url = 'https://test-service.us-west-2.api.aws/mcp'
132+
service = 'test-service'
133+
region = 'test-region'
134+
metadata = {'AWS_REGION': 'test-region'}
135+
custom_timeout = Timeout(60.0)
136+
137+
result = create_transport_with_sigv4(url, service, region, metadata, custom_timeout)
138+
139+
if hasattr(result, 'httpx_client_factory') and result.httpx_client_factory:
140+
factory = result.httpx_client_factory
141+
factory(headers=None, timeout=None, auth=None, follow_redirects=True)
142+
143+
mock_create_sigv4_client.assert_called_once_with(
144+
service=service,
145+
session=mock_session,
146+
region=region,
147+
headers=None,
148+
timeout=custom_timeout,
149+
auth=None,
150+
metadata=metadata,
151+
follow_redirects=True,
152+
)
153+
else:
154+
assert result is not None
155+
120156

121157
class TestValidateRequiredArgs:
122158
"""Test cases for validate_service_name function."""

0 commit comments

Comments
 (0)