Skip to content

Commit 56f0492

Browse files
fix(auth): add asyncio lock to prevent duplicate profile client creation
Concurrent tool calls (e.g. from parallel subagents) could race in _get_profile_client, each creating a separate Client for the same profile. The loser's client would leak — connected but never tracked or cleaned up. Wrapping in an asyncio.Lock ensures only one client is created per profile. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 48a6eca commit 56f0492

File tree

2 files changed

+58
-14
lines changed

2 files changed

+58
-14
lines changed

mcp_proxy_for_aws/middleware/profile_switcher.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
so parallel subagents querying different accounts don't interfere with each other.
2323
"""
2424

25+
import asyncio
2526
import httpx
2627
import logging
2728
import mcp.types as mt
@@ -68,6 +69,7 @@ def __init__(
6869
self._metadata = metadata
6970
self._timeout = timeout
7071
self._profile_clients: dict[str, Client] = {}
72+
self._lock = asyncio.Lock()
7173

7274
# ── tool listing ────────────────────────────────────────────────
7375

@@ -121,20 +123,21 @@ async def _get_profile_client(self, profile: str) -> Client:
121123
so that requests signed with different AWS identities don't collide
122124
on the same backend session.
123125
"""
124-
if profile not in self._profile_clients:
125-
logger.info('Creating dedicated connection for profile %s', profile)
126-
transport = create_transport_with_sigv4(
127-
self._endpoint,
128-
self._service,
129-
self._region,
130-
self._metadata,
131-
self._timeout,
132-
profile,
133-
)
134-
client = Client(transport=transport)
135-
await client.__aenter__()
136-
self._profile_clients[profile] = client
137-
return self._profile_clients[profile]
126+
async with self._lock:
127+
if profile not in self._profile_clients:
128+
logger.info('Creating dedicated connection for profile %s', profile)
129+
transport = create_transport_with_sigv4(
130+
self._endpoint,
131+
self._service,
132+
self._region,
133+
self._metadata,
134+
self._timeout,
135+
profile,
136+
)
137+
client = Client(transport=transport)
138+
await client.__aenter__()
139+
self._profile_clients[profile] = client
140+
return self._profile_clients[profile]
138141

139142
async def disconnect_profile_clients(self) -> None:
140143
"""Disconnect all per-profile clients. Call during server shutdown."""

tests/unit/test_profile_switcher.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""Tests for the ProfileOverrideMiddleware."""
1616

17+
import asyncio
1718
import httpx
1819
import pytest
1920
from fastmcp.server.middleware import MiddlewareContext
@@ -206,6 +207,46 @@ async def test_profile_override_tool_call_failure(self, middleware, mock_context
206207
assert 'backend error' not in result.content[0].text
207208

208209

210+
class TestGetProfileClient:
211+
"""Tests for the _get_profile_client method."""
212+
213+
@pytest.mark.asyncio
214+
async def test_lock_prevents_duplicate_client_creation(self, middleware):
215+
"""Concurrent calls for the same profile only create one client."""
216+
call_count = 0
217+
mock_client = AsyncMock()
218+
219+
original_aenter = mock_client.__aenter__
220+
221+
async def slow_aenter(*args, **kwargs):
222+
nonlocal call_count
223+
call_count += 1
224+
await asyncio.sleep(0.05)
225+
return await original_aenter(*args, **kwargs)
226+
227+
mock_client.__aenter__ = slow_aenter
228+
229+
mock_transport = Mock()
230+
231+
with patch(
232+
'mcp_proxy_for_aws.middleware.profile_switcher.create_transport_with_sigv4',
233+
return_value=mock_transport,
234+
), patch(
235+
'mcp_proxy_for_aws.middleware.profile_switcher.Client',
236+
return_value=mock_client,
237+
):
238+
results = await asyncio.gather(
239+
middleware._get_profile_client('dev-profile'),
240+
middleware._get_profile_client('dev-profile'),
241+
middleware._get_profile_client('dev-profile'),
242+
)
243+
244+
# All calls return the same client
245+
assert all(r is mock_client for r in results)
246+
# Client was only created once despite 3 concurrent calls
247+
assert call_count == 1
248+
249+
209250
class TestDisconnectProfileClients:
210251
"""Tests for the disconnect_profile_clients method."""
211252

0 commit comments

Comments
 (0)