diff --git a/README.md b/README.md index ea5a7310..153d4e84 100644 --- a/README.md +++ b/README.md @@ -107,6 +107,7 @@ docker build -t mcp-proxy-for-aws . | `--connect-timeout` | Set desired connect timeout in seconds | 60 |No | | `--read-timeout` | Set desired read timeout in seconds | 120 |No | | `--write-timeout` | Set desired write timeout in seconds | 180 |No | +| `--allow-switch-profile` | Enable per-call AWS profile switching by providing an allowlist of profile names. Each tool call can include a `profile` argument to route through a dedicated connection signed with that profile's credentials. | None (disabled) | No | ### Optional Environment Variables @@ -163,6 +164,38 @@ Add the following configuration to your MCP client config file (e.g., for Kiro C > [!NOTE] > Cline users should not use `--log-level` argument because Cline checks the log messages in stderr for text "error" (case insensitive). +#### Multi-account access with `--allow-switch-profile` + +The `--allow-switch-profile` flag lets individual tool calls route through different AWS profiles without restarting the proxy. This is useful when an AI agent needs to query resources across multiple AWS accounts in a single session. + +**How it interacts with `--profile`:** +- `--profile` sets the **default** identity used when a tool call does not specify a profile. +- `--allow-switch-profile` defines which additional profiles a tool call may request via a `profile` argument. Each profile gets its own dedicated connection to the backend. +- If a tool call omits `profile`, the default `--profile` connection is used. If it includes `profile`, the request is routed through the matching per-profile connection instead. + +```json +{ + "mcpServers": { + "": { + "disabled": false, + "type": "stdio", + "command": "uvx", + "args": [ + "mcp-proxy-for-aws@latest", + "", + "--profile", + "default", + "--allow-switch-profile", + "dev-profile", + "staging-profile" + ] + } + } +} +``` + +In the example above, tool calls without a `profile` argument use the `default` profile. A tool call that includes `"profile": "dev-profile"` is routed through a dedicated connection signed with `dev-profile` credentials. + #### Using Docker Using the pre-built public ECR image: diff --git a/mcp_proxy_for_aws/cli.py b/mcp_proxy_for_aws/cli.py index dc666659..f7bd10f0 100644 --- a/mcp_proxy_for_aws/cli.py +++ b/mcp_proxy_for_aws/cli.py @@ -159,4 +159,13 @@ def parse_args(): help='Write timeout (seconds) when connecting to endpoint (default: 180)', ) + parser.add_argument( + '--allow-switch-profile', + nargs='+', + default=None, + metavar='PROFILE', + help='Enable the switch_profile tool and restrict it to the specified AWS CLI profile names ' + '(e.g., --allow-switch-profile dev-profile staging-profile)', + ) + return parser.parse_args() diff --git a/mcp_proxy_for_aws/middleware/profile_switcher.py b/mcp_proxy_for_aws/middleware/profile_switcher.py new file mode 100644 index 00000000..1266e786 --- /dev/null +++ b/mcp_proxy_for_aws/middleware/profile_switcher.py @@ -0,0 +1,194 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Middleware that enables per-call AWS profile overrides via a ``profile`` argument. + +Pass ``profile`` as an extra argument on any tool call to route that single request +through a dedicated transport signed with the specified profile's credentials. The +argument is stripped before forwarding to the backend. + +Each profile gets its own lazily-created ``StreamableHttpTransport`` and MCP session, +so parallel subagents querying different accounts don't interfere with each other. +""" + +import asyncio +import httpx +import logging +import mcp.types as mt +from collections.abc import Sequence +from fastmcp import Client +from fastmcp.server.middleware import CallNext, Middleware, MiddlewareContext +from fastmcp.tools.tool import Tool, ToolResult +from mcp_proxy_for_aws.utils import create_transport_with_sigv4 +from typing import Any, cast +from typing_extensions import override + + +logger = logging.getLogger(__name__) + + +class ProfileOverrideMiddleware(Middleware): + """Middleware that intercepts ``profile`` on any tool call for per-request AWS identity switching. + + When a tool call includes a ``profile`` argument, the middleware: + + 1. Validates the profile against the allowed list + 2. Strips ``profile`` from the arguments + 3. Forwards the call through a dedicated per-profile MCP client + + Each profile gets its own transport and session to the backend so that + requests signed with different AWS identities don't collide. + """ + + def __init__( + self, + allowed_profiles: list[str], + service: str, + region: str, + metadata: dict[str, Any], + timeout: httpx.Timeout, + endpoint: str, + ) -> None: + """Initialize the middleware with connection and profile configuration.""" + super().__init__() + self._allowed_profiles = set(allowed_profiles) + self._endpoint = endpoint + self._service = service + self._region = region + self._metadata = metadata + self._timeout = timeout + self._profile_clients: dict[str, Client] = {} + self._lock = asyncio.Lock() + + # ── tool listing ──────────────────────────────────────────────── + + @override + async def on_list_tools( + self, + context: MiddlewareContext[mt.ListToolsRequest], + call_next: CallNext[mt.ListToolsRequest, Sequence[Tool]], + ) -> Sequence[Tool]: + """Inject ``profile`` into every tool's schema.""" + tools = await call_next(context) + + for tool in tools: + params = tool.parameters + if not isinstance(params, dict): + continue + if 'properties' not in params: + params['properties'] = {} + params['properties']['profile'] = { + 'type': 'string', + 'description': ( + 'AWS CLI profile to sign this request with. Omit to use the default profile.' + ), + 'enum': sorted(self._allowed_profiles), + } + + return list(tools) + + # ── tool invocation ───────────────────────────────────────────── + + @override + async def on_call_tool( + self, + context: MiddlewareContext[mt.CallToolRequestParams], + call_next: CallNext[mt.CallToolRequestParams, ToolResult], + ) -> ToolResult: + """Intercept ``profile`` and route through a dedicated per-profile client.""" + arguments = context.message.arguments + if isinstance(arguments, dict) and 'profile' in arguments: + profile = arguments['profile'] + return await self._call_with_profile(profile, context, call_next) + + return await call_next(context) + + # ── internals ───────────────────────────────────────────────── + + async def _get_profile_client(self, profile: str) -> Client: + """Get or create a dedicated MCP client for the given profile. + + Each profile gets its own ``StreamableHttpTransport`` and MCP session + so that requests signed with different AWS identities don't collide + on the same backend session. + """ + async with self._lock: + if profile not in self._profile_clients: + logger.info('Creating dedicated connection for profile %s', profile) + transport = create_transport_with_sigv4( + self._endpoint, + self._service, + self._region, + self._metadata, + self._timeout, + profile, + ) + client = Client(transport=transport) + await client.__aenter__() + self._profile_clients[profile] = client + return self._profile_clients[profile] + + async def disconnect_profile_clients(self) -> None: + """Disconnect all per-profile clients. Call during server shutdown.""" + for profile, client in self._profile_clients.items(): + try: + await client.__aexit__(None, None, None) + except Exception: + logger.exception('Failed to disconnect profile client %s', profile) + self._profile_clients.clear() + + async def _call_with_profile( + self, + profile: str, + context: MiddlewareContext[mt.CallToolRequestParams], + call_next: CallNext[mt.CallToolRequestParams, ToolResult], + ) -> ToolResult: + """Forward a tool call through a dedicated per-profile connection.""" + if profile not in self._allowed_profiles: + allowed = ', '.join(sorted(self._allowed_profiles)) + return ToolResult( + content=f'Error: profile {profile!r} is not in the allowed list. ' + f'Allowed profiles: {allowed}' + ) + + # Strip profile before forwarding to the backend + arguments: dict[str, Any] = dict(cast(dict[str, Any], context.message.arguments)) + arguments.pop('profile', None) + + logger.info( + 'Per-call profile override: routing through dedicated connection for %s', profile + ) + + try: + client = await self._get_profile_client(profile) + except Exception: + logger.exception('Failed to create connection for profile %s', profile) + return ToolResult( + content=f'Error: failed to create connection for profile {profile!r}. ' + 'Check that the profile is configured and credentials are valid.' + ) + + try: + result = await client.call_tool(context.message.name, arguments) + return ToolResult( + content=result.content, + structured_content=result.structured_content, + meta=result.meta, + ) + except Exception: + logger.exception('Error calling tool via profile %s', profile) + return ToolResult( + content=f'Error: tool call failed using profile {profile!r}. ' + 'The request could not be completed with the specified profile.' + ) diff --git a/mcp_proxy_for_aws/server.py b/mcp_proxy_for_aws/server.py index 032f2447..0bc18f53 100644 --- a/mcp_proxy_for_aws/server.py +++ b/mcp_proxy_for_aws/server.py @@ -32,6 +32,7 @@ from mcp_proxy_for_aws.cli import parse_args from mcp_proxy_for_aws.logging_config import configure_logging from mcp_proxy_for_aws.middleware.initialize_middleware import InitializeMiddleware +from mcp_proxy_for_aws.middleware.profile_switcher import ProfileOverrideMiddleware from mcp_proxy_for_aws.middleware.tool_filter import ToolFilteringMiddleware from mcp_proxy_for_aws.proxy import AWSMCPProxy, AWSMCPProxyClientFactory from mcp_proxy_for_aws.utils import ( @@ -86,6 +87,7 @@ async def run_proxy(args) -> None: ) client_factory = AWSMCPProxyClientFactory(transport) + profile_middleware: ProfileOverrideMiddleware | None = None try: proxy = AWSMCPProxy( client_factory=client_factory, @@ -100,6 +102,18 @@ async def run_proxy(args) -> None: add_logging_middleware(proxy, args.log_level) add_tool_filtering_middleware(proxy, args.read_only) + allowed_profiles = getattr(args, 'allow_switch_profile', None) + if isinstance(allowed_profiles, list) and allowed_profiles: + profile_middleware = ProfileOverrideMiddleware( + allowed_profiles=allowed_profiles, + service=service, + region=region, + metadata=metadata, + timeout=timeout, + endpoint=args.endpoint, + ) + proxy.add_middleware(profile_middleware) + if args.retries: add_retry_middleware(proxy, args.retries) await proxy.run_async(transport='stdio', show_banner=False, log_level=args.log_level) @@ -107,6 +121,8 @@ async def run_proxy(args) -> None: logger.error('Cannot start proxy server: %s', e) raise e finally: + if profile_middleware: + await profile_middleware.disconnect_profile_clients() await client_factory.disconnect() diff --git a/tests/unit/test_profile_switcher.py b/tests/unit/test_profile_switcher.py new file mode 100644 index 00000000..e4d9a537 --- /dev/null +++ b/tests/unit/test_profile_switcher.py @@ -0,0 +1,278 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the ProfileOverrideMiddleware.""" + +import asyncio +import httpx +import pytest +from fastmcp.server.middleware import MiddlewareContext +from mcp_proxy_for_aws.middleware.profile_switcher import ProfileOverrideMiddleware +from unittest.mock import AsyncMock, MagicMock, Mock, patch + + +ALLOWED_PROFILES = [ + 'dev-profile', + 'staging-profile', +] + + +@pytest.fixture +def middleware(): + """Create a ProfileOverrideMiddleware instance.""" + return ProfileOverrideMiddleware( + allowed_profiles=ALLOWED_PROFILES, + service='lambda', + region='us-east-1', + metadata={'proxy': 'test'}, + timeout=httpx.Timeout(30), + endpoint='https://test.us-east-1.api.aws/mcp', + ) + + +@pytest.fixture +def mock_context(): + """Create a mock MiddlewareContext.""" + return Mock(spec=MiddlewareContext) + + +class TestOnListTools: + """Tests for the on_list_tools method.""" + + @pytest.mark.asyncio + async def test_injects_profile_property_into_tool_schemas(self, middleware, mock_context): + """Every proxied tool gets a profile property in its schema.""" + tool = Mock() + tool.name = 'some_tool' + tool.parameters = {'type': 'object', 'properties': {'arg': {'type': 'string'}}} + call_next = AsyncMock(return_value=[tool]) + + result = await middleware.on_list_tools(mock_context, call_next) + + assert len(result) == 1 + assert result[0].name == 'some_tool' + profile_schema = result[0].parameters['properties']['profile'] + assert profile_schema['type'] == 'string' + assert 'AWS CLI profile' in profile_schema['description'] + assert profile_schema['enum'] == sorted(ALLOWED_PROFILES) + call_next.assert_called_once_with(mock_context) + + @pytest.mark.asyncio + async def test_handles_empty_tool_list(self, middleware, mock_context): + """Empty tool list is returned as-is.""" + call_next = AsyncMock(return_value=[]) + + result = await middleware.on_list_tools(mock_context, call_next) + + assert len(result) == 0 + + @pytest.mark.asyncio + async def test_skips_tool_with_non_dict_parameters(self, middleware, mock_context): + """Tools whose parameters are not a dict are left unchanged.""" + tool = Mock() + tool.name = 'odd_tool' + tool.parameters = None + call_next = AsyncMock(return_value=[tool]) + + result = await middleware.on_list_tools(mock_context, call_next) + + assert len(result) == 1 + assert result[0].parameters is None + + @pytest.mark.asyncio + async def test_adds_properties_key_when_missing(self, middleware, mock_context): + """Profile is injected even when the schema has no properties key.""" + tool = Mock() + tool.name = 'bare_tool' + tool.parameters = {'type': 'object'} + call_next = AsyncMock(return_value=[tool]) + + result = await middleware.on_list_tools(mock_context, call_next) + + assert 'properties' in result[0].parameters + assert 'profile' in result[0].parameters['properties'] + + +class TestOnCallTool: + """Tests for the on_call_tool method.""" + + @pytest.mark.asyncio + async def test_passes_through_calls_without_profile(self, middleware, mock_context): + """Tool calls without profile are forwarded unchanged.""" + mock_context.message = Mock() + mock_context.message.name = 'some_tool' + mock_context.message.arguments = {'arg': 'value'} + expected_result = Mock() + call_next = AsyncMock(return_value=expected_result) + + result = await middleware.on_call_tool(mock_context, call_next) + + assert result == expected_result + call_next.assert_called_once_with(mock_context) + + @pytest.mark.asyncio + async def test_passes_through_calls_with_none_arguments(self, middleware, mock_context): + """Tool calls with None arguments are forwarded unchanged.""" + mock_context.message = Mock() + mock_context.message.name = 'some_tool' + mock_context.message.arguments = None + expected_result = Mock() + call_next = AsyncMock(return_value=expected_result) + + result = await middleware.on_call_tool(mock_context, call_next) + + assert result == expected_result + call_next.assert_called_once_with(mock_context) + + +class TestPerCallProfileOverride: + """Tests for the profile per-call override path.""" + + @pytest.mark.asyncio + async def test_profile_override_disallowed(self, middleware, mock_context): + """Profile with a disallowed profile returns an error.""" + mock_context.message = Mock() + mock_context.message.name = 'some_tool' + mock_context.message.arguments = {'arg': 'value', 'profile': 'evil-profile'} + call_next = AsyncMock() + + result = await middleware.on_call_tool(mock_context, call_next) + + assert 'not in the allowed list' in result.content[0].text + call_next.assert_not_called() + + @pytest.mark.asyncio + async def test_profile_override_strips_profile_arg(self, middleware, mock_context): + """Profile is stripped before forwarding to the backend.""" + mock_client = AsyncMock() + mock_call_result = MagicMock() + mock_call_result.content = 'result' + mock_call_result.structured_content = None + mock_call_result.meta = None + mock_client.call_tool.return_value = mock_call_result + + mock_context.message = Mock() + mock_context.message.name = 'some_tool' + mock_context.message.arguments = {'arg': 'value', 'profile': 'dev-profile'} + call_next = AsyncMock() + + with patch.object(middleware, '_get_profile_client', return_value=mock_client): + await middleware.on_call_tool(mock_context, call_next) + + mock_client.call_tool.assert_called_once_with('some_tool', {'arg': 'value'}) + call_next.assert_not_called() + + @pytest.mark.asyncio + async def test_profile_override_connection_failure(self, middleware, mock_context): + """Connection failure returns a sanitized error.""" + mock_context.message = Mock() + mock_context.message.name = 'some_tool' + mock_context.message.arguments = {'arg': 'value', 'profile': 'dev-profile'} + call_next = AsyncMock() + + with patch.object( + middleware, '_get_profile_client', side_effect=Exception('connection refused') + ): + result = await middleware.on_call_tool(mock_context, call_next) + + assert 'failed to create connection' in result.content[0].text + assert 'connection refused' not in result.content[0].text + + @pytest.mark.asyncio + async def test_profile_override_tool_call_failure(self, middleware, mock_context): + """Tool call failure returns a sanitized error.""" + mock_client = AsyncMock() + mock_client.call_tool.side_effect = Exception('backend error') + + mock_context.message = Mock() + mock_context.message.name = 'some_tool' + mock_context.message.arguments = {'arg': 'value', 'profile': 'dev-profile'} + call_next = AsyncMock() + + with patch.object(middleware, '_get_profile_client', return_value=mock_client): + result = await middleware.on_call_tool(mock_context, call_next) + + assert 'tool call failed' in result.content[0].text + assert 'backend error' not in result.content[0].text + + +class TestGetProfileClient: + """Tests for the _get_profile_client method.""" + + @pytest.mark.asyncio + async def test_lock_prevents_duplicate_client_creation(self, middleware): + """Concurrent calls for the same profile only create one client.""" + call_count = 0 + mock_client = AsyncMock() + + original_aenter = mock_client.__aenter__ + + async def slow_aenter(*args, **kwargs): + nonlocal call_count + call_count += 1 + await asyncio.sleep(0.05) + return await original_aenter(*args, **kwargs) + + mock_client.__aenter__ = slow_aenter + + mock_transport = Mock() + + with patch( + 'mcp_proxy_for_aws.middleware.profile_switcher.create_transport_with_sigv4', + return_value=mock_transport, + ), patch( + 'mcp_proxy_for_aws.middleware.profile_switcher.Client', + return_value=mock_client, + ): + results = await asyncio.gather( + middleware._get_profile_client('dev-profile'), + middleware._get_profile_client('dev-profile'), + middleware._get_profile_client('dev-profile'), + ) + + # All calls return the same client + assert all(r is mock_client for r in results) + # Client was only created once despite 3 concurrent calls + assert call_count == 1 + + +class TestDisconnectProfileClients: + """Tests for the disconnect_profile_clients method.""" + + @pytest.mark.asyncio + async def test_disconnects_all_clients(self, middleware): + """All cached profile clients are closed and the cache is cleared.""" + client_a = AsyncMock() + client_b = AsyncMock() + middleware._profile_clients = {'profile-a': client_a, 'profile-b': client_b} + + await middleware.disconnect_profile_clients() + + client_a.__aexit__.assert_called_once_with(None, None, None) + client_b.__aexit__.assert_called_once_with(None, None, None) + assert middleware._profile_clients == {} + + @pytest.mark.asyncio + async def test_continues_on_client_error(self, middleware): + """A failing client does not prevent other clients from disconnecting.""" + client_good = AsyncMock() + client_bad = AsyncMock() + client_bad.__aexit__.side_effect = Exception('disconnect failed') + middleware._profile_clients = {'bad': client_bad, 'good': client_good} + + await middleware.disconnect_profile_clients() + + client_bad.__aexit__.assert_called_once_with(None, None, None) + client_good.__aexit__.assert_called_once_with(None, None, None) + assert middleware._profile_clients == {}