Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ docker build -t mcp-proxy-for-aws .
| `--connect-timeout` | Set desired connect timeout in seconds | 60 |No |
| `--read-timeout` | Set desired read timeout in seconds | 120 |No |
| `--write-timeout` | Set desired write timeout in seconds | 180 |No |
| `--allow-switch-profile` | Enable per-call AWS profile switching by providing an allowlist of profile names. Each tool call can include a `profile` argument to route through a dedicated connection signed with that profile's credentials. | None (disabled) | No |

### Optional Environment Variables

Expand Down Expand Up @@ -163,6 +164,38 @@ Add the following configuration to your MCP client config file (e.g., for Kiro C
> [!NOTE]
> Cline users should not use `--log-level` argument because Cline checks the log messages in stderr for text "error" (case insensitive).

#### Multi-account access with `--allow-switch-profile`

The `--allow-switch-profile` flag lets individual tool calls route through different AWS profiles without restarting the proxy. This is useful when an AI agent needs to query resources across multiple AWS accounts in a single session.

**How it interacts with `--profile`:**
- `--profile` sets the **default** identity used when a tool call does not specify a profile.
- `--allow-switch-profile` defines which additional profiles a tool call may request via a `profile` argument. Each profile gets its own dedicated connection to the backend.
- If a tool call omits `profile`, the default `--profile` connection is used. If it includes `profile`, the request is routed through the matching per-profile connection instead.

```json
{
"mcpServers": {
"<mcp server name>": {
"disabled": false,
"type": "stdio",
"command": "uvx",
"args": [
"mcp-proxy-for-aws@latest",
"<SigV4 MCP endpoint URL>",
"--profile",
"default",
"--allow-switch-profile",
"dev-profile",
"staging-profile"
]
}
}
}
```

In the example above, tool calls without a `profile` argument use the `default` profile. A tool call that includes `"profile": "dev-profile"` is routed through a dedicated connection signed with `dev-profile` credentials.

#### Using Docker

Using the pre-built public ECR image:
Expand Down
9 changes: 9 additions & 0 deletions mcp_proxy_for_aws/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,13 @@ def parse_args():
help='Write timeout (seconds) when connecting to endpoint (default: 180)',
)

parser.add_argument(
'--allow-switch-profile',
nargs='+',
default=None,
metavar='PROFILE',
help='Enable the switch_profile tool and restrict it to the specified AWS CLI profile names '
'(e.g., --allow-switch-profile dev-profile staging-profile)',
)

return parser.parse_args()
194 changes: 194 additions & 0 deletions mcp_proxy_for_aws/middleware/profile_switcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Middleware that enables per-call AWS profile overrides via a ``profile`` argument.

Pass ``profile`` as an extra argument on any tool call to route that single request
through a dedicated transport signed with the specified profile's credentials. The
argument is stripped before forwarding to the backend.

Each profile gets its own lazily-created ``StreamableHttpTransport`` and MCP session,
so parallel subagents querying different accounts don't interfere with each other.
"""

import asyncio
import httpx
import logging
import mcp.types as mt
from collections.abc import Sequence
from fastmcp import Client
from fastmcp.server.middleware import CallNext, Middleware, MiddlewareContext
from fastmcp.tools.tool import Tool, ToolResult
from mcp_proxy_for_aws.utils import create_transport_with_sigv4
from typing import Any, cast
from typing_extensions import override


logger = logging.getLogger(__name__)


class ProfileOverrideMiddleware(Middleware):
"""Middleware that intercepts ``profile`` on any tool call for per-request AWS identity switching.

When a tool call includes a ``profile`` argument, the middleware:

1. Validates the profile against the allowed list
2. Strips ``profile`` from the arguments
3. Forwards the call through a dedicated per-profile MCP client

Each profile gets its own transport and session to the backend so that
requests signed with different AWS identities don't collide.
"""

def __init__(
self,
allowed_profiles: list[str],
service: str,
region: str,
metadata: dict[str, Any],
timeout: httpx.Timeout,
endpoint: str,
) -> None:
"""Initialize the middleware with connection and profile configuration."""
super().__init__()
self._allowed_profiles = set(allowed_profiles)
self._endpoint = endpoint
self._service = service
self._region = region
self._metadata = metadata
self._timeout = timeout
self._profile_clients: dict[str, Client] = {}
self._lock = asyncio.Lock()

# ── tool listing ────────────────────────────────────────────────

@override
async def on_list_tools(
self,
context: MiddlewareContext[mt.ListToolsRequest],
call_next: CallNext[mt.ListToolsRequest, Sequence[Tool]],
) -> Sequence[Tool]:
"""Inject ``profile`` into every tool's schema."""
tools = await call_next(context)

for tool in tools:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious if we could use tool transformations instead of directly manipulating params?

params = tool.parameters
if not isinstance(params, dict):
continue
if 'properties' not in params:
params['properties'] = {}
params['properties']['profile'] = {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the upstream tool list is cached or shared, this mutates shared state. This is ok now because AWSProxyToolManager.get_tools() re-fetches, however this creates dependency on the upstream to always produce fresh dicts. I'd suggest deep-copy the parameters dict before mutating. It is cheap insurance and makes the middleware self-contained.

'type': 'string',
'description': (
'AWS CLI profile to sign this request with. Omit to use the default profile.'
),
'enum': sorted(self._allowed_profiles),
}

return list(tools)

# ── tool invocation ─────────────────────────────────────────────

@override
async def on_call_tool(
self,
context: MiddlewareContext[mt.CallToolRequestParams],
call_next: CallNext[mt.CallToolRequestParams, ToolResult],
) -> ToolResult:
"""Intercept ``profile`` and route through a dedicated per-profile client."""
arguments = context.message.arguments
if isinstance(arguments, dict) and 'profile' in arguments:
profile = arguments['profile']
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

profile can collide with tools that already have a profile parameter. If a backend tool already has a profile parameter in its schema, this middleware will silently overwrite it in on_list_tools and strip it in on_call_tool. This could break legitimate tool parameters.

return await self._call_with_profile(profile, context, call_next)

return await call_next(context)

# ── internals ─────────────────────────────────────────────────

async def _get_profile_client(self, profile: str) -> Client:
"""Get or create a dedicated MCP client for the given profile.

Each profile gets its own ``StreamableHttpTransport`` and MCP session
so that requests signed with different AWS identities don't collide
on the same backend session.
"""
async with self._lock:
if profile not in self._profile_clients:
logger.info('Creating dedicated connection for profile %s', profile)
transport = create_transport_with_sigv4(
self._endpoint,
self._service,
self._region,
self._metadata,
self._timeout,
profile,
)
client = Client(transport=transport)
await client.__aenter__()
self._profile_clients[profile] = client
return self._profile_clients[profile]

async def disconnect_profile_clients(self) -> None:
"""Disconnect all per-profile clients. Call during server shutdown."""
for profile, client in self._profile_clients.items():
try:
await client.__aexit__(None, None, None)
except Exception:
logger.exception('Failed to disconnect profile client %s', profile)
self._profile_clients.clear()

async def _call_with_profile(
self,
profile: str,
context: MiddlewareContext[mt.CallToolRequestParams],
call_next: CallNext[mt.CallToolRequestParams, ToolResult],
) -> ToolResult:
"""Forward a tool call through a dedicated per-profile connection."""
if profile not in self._allowed_profiles:
allowed = ', '.join(sorted(self._allowed_profiles))
return ToolResult(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you return ToolResult(s) on failed validations instead of raising ToolError?

content=f'Error: profile {profile!r} is not in the allowed list. '
f'Allowed profiles: {allowed}'
)

# Strip profile before forwarding to the backend
arguments: dict[str, Any] = dict(cast(dict[str, Any], context.message.arguments))
arguments.pop('profile', None)

logger.info(
'Per-call profile override: routing through dedicated connection for %s', profile
)

try:
client = await self._get_profile_client(profile)
except Exception:
logger.exception('Failed to create connection for profile %s', profile)
return ToolResult(
content=f'Error: failed to create connection for profile {profile!r}. '
'Check that the profile is configured and credentials are valid.'
)

try:
result = await client.call_tool(context.message.name, arguments)
return ToolResult(
content=result.content,
structured_content=result.structured_content,
meta=result.meta,
)
except Exception:
logger.exception('Error calling tool via profile %s', profile)
return ToolResult(
content=f'Error: tool call failed using profile {profile!r}. '
'The request could not be completed with the specified profile.'
)
16 changes: 16 additions & 0 deletions mcp_proxy_for_aws/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from mcp_proxy_for_aws.cli import parse_args
from mcp_proxy_for_aws.logging_config import configure_logging
from mcp_proxy_for_aws.middleware.initialize_middleware import InitializeMiddleware
from mcp_proxy_for_aws.middleware.profile_switcher import ProfileOverrideMiddleware
from mcp_proxy_for_aws.middleware.tool_filter import ToolFilteringMiddleware
from mcp_proxy_for_aws.proxy import AWSMCPProxy, AWSMCPProxyClientFactory
from mcp_proxy_for_aws.utils import (
Expand Down Expand Up @@ -86,6 +87,7 @@ async def run_proxy(args) -> None:
)
client_factory = AWSMCPProxyClientFactory(transport)

profile_middleware: ProfileOverrideMiddleware | None = None
try:
proxy = AWSMCPProxy(
client_factory=client_factory,
Expand All @@ -100,13 +102,27 @@ async def run_proxy(args) -> None:
add_logging_middleware(proxy, args.log_level)
add_tool_filtering_middleware(proxy, args.read_only)

allowed_profiles = getattr(args, 'allow_switch_profile', None)
if isinstance(allowed_profiles, list) and allowed_profiles:
profile_middleware = ProfileOverrideMiddleware(
allowed_profiles=allowed_profiles,
service=service,
region=region,
metadata=metadata,
timeout=timeout,
endpoint=args.endpoint,
)
proxy.add_middleware(profile_middleware)

Comment on lines +105 to +116
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be extracted into a separate helper add_profile_override_middleware, to stay consistent with other middleware declarations

if args.retries:
add_retry_middleware(proxy, args.retries)
await proxy.run_async(transport='stdio', show_banner=False, log_level=args.log_level)
except Exception as e:
logger.error('Cannot start proxy server: %s', e)
raise e
finally:
if profile_middleware:
await profile_middleware.disconnect_profile_clients()
await client_factory.disconnect()


Expand Down
Loading