Skip to content

Commit 3d7b916

Browse files
authored
Wire per-user identity into rate limit middleware (#4718)
The rate limit middleware was passing an empty string for userID, meaning per-user buckets (added in #4704) were never checked. This wires identity.Subject from the auth context into the limiter so per-user rate limits are enforced at runtime. Extract auth.IdentityFromContext() in rateLimitHandler and pass identity.Subject as the userID to limiter.Allow() Hoist the parameterized mock OIDC server from virtualmcp/helpers.go into the shared testutil package for reuse across test suites Add E2E acceptance test: deploy MCPServer with perUser rate limit + inline OIDC auth, verify user-a is rejected after limit, user-b succeeds independently Closes #4550
1 parent 25b5c78 commit 3d7b916

6 files changed

Lines changed: 651 additions & 304 deletions

File tree

pkg/ratelimit/middleware.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"github.com/redis/go-redis/v9"
1717

1818
v1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1"
19+
"github.com/stacklok/toolhive/pkg/auth"
1920
"github.com/stacklok/toolhive/pkg/mcp"
2021
"github.com/stacklok/toolhive/pkg/transport/types"
2122
)
@@ -119,7 +120,14 @@ func rateLimitHandler(limiter Limiter) types.MiddlewareFunction {
119120
return
120121
}
121122

122-
decision, err := limiter.Allow(r.Context(), parsed.ResourceID, "")
123+
// When no identity is present (unauthenticated), userID stays empty
124+
// and per-user buckets are skipped — only shared limits apply. CEL
125+
// validation ensures perUser rate limits require auth to be enabled.
126+
var userID string
127+
if identity, ok := auth.IdentityFromContext(r.Context()); ok {
128+
userID = identity.Subject
129+
}
130+
decision, err := limiter.Allow(r.Context(), parsed.ResourceID, userID)
123131
if err != nil {
124132
slog.Warn("rate limit check failed, allowing request", "error", err)
125133
next.ServeHTTP(w, r)

pkg/ratelimit/middleware_test.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"github.com/stretchr/testify/assert"
1717
"github.com/stretchr/testify/require"
1818

19+
"github.com/stacklok/toolhive/pkg/auth"
1920
"github.com/stacklok/toolhive/pkg/mcp"
2021
)
2122

@@ -29,6 +30,25 @@ func (d *dummyLimiter) Allow(context.Context, string, string) (*Decision, error)
2930
return d.decision, d.err
3031
}
3132

33+
// recordingLimiter captures the arguments passed to Allow.
34+
type recordingLimiter struct {
35+
toolName string
36+
userID string
37+
}
38+
39+
func (r *recordingLimiter) Allow(_ context.Context, toolName, userID string) (*Decision, error) {
40+
r.toolName = toolName
41+
r.userID = userID
42+
return &Decision{Allowed: true}, nil
43+
}
44+
45+
// withIdentity adds an auth.Identity with the given subject to the request context.
46+
func withIdentity(r *http.Request, subject string) *http.Request {
47+
identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: subject}}
48+
ctx := auth.WithIdentity(r.Context(), identity)
49+
return r.WithContext(ctx)
50+
}
51+
3252
// withParsedMCPRequest adds a ParsedMCPRequest to the request context.
3353
func withParsedMCPRequest(r *http.Request, method, resourceID string, id any) *http.Request {
3454
parsed := &mcp.ParsedMCPRequest{
@@ -148,3 +168,43 @@ func TestRateLimitHandler_NonToolCallPassesThrough(t *testing.T) {
148168
assert.True(t, nextCalled, "non-tools/call should pass through regardless of limiter")
149169
assert.Equal(t, http.StatusOK, w.Code)
150170
}
171+
172+
func TestRateLimitHandler_PassesUserID(t *testing.T) {
173+
t.Parallel()
174+
175+
recorder := &recordingLimiter{}
176+
handler := rateLimitHandler(recorder)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
177+
w.WriteHeader(http.StatusOK)
178+
}))
179+
180+
req := httptest.NewRequest(http.MethodPost, "/mcp", nil)
181+
req = withParsedMCPRequest(req, "tools/call", "echo", 1)
182+
req = withIdentity(req, "alice@example.com")
183+
w := httptest.NewRecorder()
184+
185+
handler.ServeHTTP(w, req)
186+
187+
assert.Equal(t, http.StatusOK, w.Code)
188+
assert.Equal(t, "echo", recorder.toolName)
189+
assert.Equal(t, "alice@example.com", recorder.userID)
190+
}
191+
192+
func TestRateLimitHandler_NoIdentityPassesEmptyUserID(t *testing.T) {
193+
t.Parallel()
194+
195+
recorder := &recordingLimiter{}
196+
handler := rateLimitHandler(recorder)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
197+
w.WriteHeader(http.StatusOK)
198+
}))
199+
200+
req := httptest.NewRequest(http.MethodPost, "/mcp", nil)
201+
req = withParsedMCPRequest(req, "tools/call", "echo", 1)
202+
// No identity in context — unauthenticated request.
203+
w := httptest.NewRecorder()
204+
205+
handler.ServeHTTP(w, req)
206+
207+
assert.Equal(t, http.StatusOK, w.Code)
208+
assert.Equal(t, "echo", recorder.toolName)
209+
assert.Empty(t, recorder.userID, "unauthenticated requests should pass empty userID")
210+
}

test/e2e/thv-operator/acceptance_tests/helpers.go

Lines changed: 113 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"github.com/onsi/gomega"
1717
appsv1 "k8s.io/api/apps/v1"
1818
corev1 "k8s.io/api/core/v1"
19+
apierrors "k8s.io/apimachinery/pkg/api/errors"
1920
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2021
"k8s.io/apimachinery/pkg/util/intstr"
2122
"k8s.io/utils/ptr"
@@ -24,9 +25,9 @@ import (
2425
"github.com/stacklok/toolhive/test/e2e/images"
2526
)
2627

27-
// DeployRedis creates a Redis Deployment and Service in the given namespace.
28-
// No password is configured — matches the default empty THV_SESSION_REDIS_PASSWORD.
29-
func DeployRedis(ctx context.Context, c client.Client, namespace string, timeout, pollingInterval time.Duration) {
28+
// EnsureRedis creates a Redis Deployment and Service if they don't already exist,
29+
// then waits for Redis to be ready. Safe to call concurrently from multiple test blocks.
30+
func EnsureRedis(ctx context.Context, c client.Client, namespace string, timeout, pollingInterval time.Duration) {
3031
labels := map[string]string{"app": "redis"}
3132

3233
deployment := &appsv1.Deployment{
@@ -51,7 +52,9 @@ func DeployRedis(ctx context.Context, c client.Client, namespace string, timeout
5152
},
5253
},
5354
}
54-
gomega.Expect(c.Create(ctx, deployment)).To(gomega.Succeed())
55+
if err := c.Create(ctx, deployment); err != nil && !apierrors.IsAlreadyExists(err) {
56+
gomega.Expect(err).ToNot(gomega.HaveOccurred())
57+
}
5558

5659
service := &corev1.Service{
5760
ObjectMeta: metav1.ObjectMeta{
@@ -65,7 +68,9 @@ func DeployRedis(ctx context.Context, c client.Client, namespace string, timeout
6568
},
6669
},
6770
}
68-
gomega.Expect(c.Create(ctx, service)).To(gomega.Succeed())
71+
if err := c.Create(ctx, service); err != nil && !apierrors.IsAlreadyExists(err) {
72+
gomega.Expect(err).ToNot(gomega.HaveOccurred())
73+
}
6974

7075
ginkgo.By("Waiting for Redis to be ready")
7176
gomega.Eventually(func() error {
@@ -99,7 +104,7 @@ func CleanupRedis(ctx context.Context, c client.Client, namespace string) {
99104
}
100105

101106
// SendToolCall sends a JSON-RPC tools/call request and returns the HTTP status code and body.
102-
func SendToolCall(httpClient *http.Client, port int32, toolName string, requestID int) (int, []byte) {
107+
func SendToolCall(ctx context.Context, httpClient *http.Client, port int32, toolName string, requestID int) (int, []byte) {
103108
reqBody := map[string]any{
104109
"jsonrpc": "2.0",
105110
"id": requestID,
@@ -113,7 +118,7 @@ func SendToolCall(httpClient *http.Client, port int32, toolName string, requestI
113118
gomega.Expect(err).ToNot(gomega.HaveOccurred())
114119

115120
url := fmt.Sprintf("http://localhost:%d/mcp", port)
116-
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, url, bytes.NewReader(bodyBytes))
121+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyBytes))
117122
gomega.Expect(err).ToNot(gomega.HaveOccurred())
118123
req.Header.Set("Content-Type", "application/json")
119124
req.Header.Set("Accept", "application/json, text/event-stream")
@@ -127,3 +132,104 @@ func SendToolCall(httpClient *http.Client, port int32, toolName string, requestI
127132

128133
return resp.StatusCode, respBody
129134
}
135+
136+
// SendInitialize sends a JSON-RPC initialize request and returns the session ID
137+
// from the Mcp-Session header. This must be called before tools/call when auth is enabled.
138+
func SendInitialize(
139+
ctx context.Context, httpClient *http.Client, port int32, bearerToken string,
140+
) (sessionID string) {
141+
reqBody := map[string]any{
142+
"jsonrpc": "2.0",
143+
"id": 0,
144+
"method": "initialize",
145+
"params": map[string]any{
146+
"protocolVersion": "2025-03-26",
147+
"capabilities": map[string]any{},
148+
"clientInfo": map[string]any{
149+
"name": "e2e-test",
150+
"version": "1.0.0",
151+
},
152+
},
153+
}
154+
bodyBytes, err := json.Marshal(reqBody)
155+
gomega.Expect(err).ToNot(gomega.HaveOccurred())
156+
157+
url := fmt.Sprintf("http://localhost:%d/mcp", port)
158+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyBytes))
159+
gomega.Expect(err).ToNot(gomega.HaveOccurred())
160+
req.Header.Set("Content-Type", "application/json")
161+
req.Header.Set("Accept", "application/json, text/event-stream")
162+
if bearerToken != "" {
163+
req.Header.Set("Authorization", "Bearer "+bearerToken)
164+
}
165+
166+
resp, err := httpClient.Do(req)
167+
gomega.Expect(err).ToNot(gomega.HaveOccurred())
168+
defer func() { _ = resp.Body.Close() }()
169+
170+
gomega.Expect(resp.StatusCode).To(gomega.Equal(http.StatusOK),
171+
"initialize should succeed")
172+
173+
sessionID = resp.Header.Get("Mcp-Session-Id")
174+
gomega.Expect(sessionID).ToNot(gomega.BeEmpty(),
175+
"initialize response should include Mcp-Session-Id header")
176+
177+
return sessionID
178+
}
179+
180+
// SendAuthenticatedToolCallWithSession sends a JSON-RPC tools/call with Bearer token and session ID.
181+
func SendAuthenticatedToolCallWithSession(
182+
ctx context.Context, httpClient *http.Client, port int32, toolName string, requestID int, bearerToken, sessionID string,
183+
) (int, []byte, string) {
184+
reqBody := map[string]any{
185+
"jsonrpc": "2.0",
186+
"id": requestID,
187+
"method": "tools/call",
188+
"params": map[string]any{
189+
"name": toolName,
190+
"arguments": map[string]any{"input": "test"},
191+
},
192+
}
193+
bodyBytes, err := json.Marshal(reqBody)
194+
gomega.Expect(err).ToNot(gomega.HaveOccurred())
195+
196+
url := fmt.Sprintf("http://localhost:%d/mcp", port)
197+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyBytes))
198+
gomega.Expect(err).ToNot(gomega.HaveOccurred())
199+
req.Header.Set("Content-Type", "application/json")
200+
req.Header.Set("Accept", "application/json, text/event-stream")
201+
req.Header.Set("Authorization", "Bearer "+bearerToken)
202+
if sessionID != "" {
203+
req.Header.Set("Mcp-Session-Id", sessionID)
204+
}
205+
206+
resp, err := httpClient.Do(req)
207+
gomega.Expect(err).ToNot(gomega.HaveOccurred())
208+
defer func() { _ = resp.Body.Close() }()
209+
210+
retryAfter := resp.Header.Get("Retry-After")
211+
212+
respBody, err := io.ReadAll(resp.Body)
213+
gomega.Expect(err).ToNot(gomega.HaveOccurred())
214+
215+
return resp.StatusCode, respBody, retryAfter
216+
}
217+
218+
// GetOIDCToken fetches a JWT from the mock OIDC server for the given subject.
219+
func GetOIDCToken(ctx context.Context, httpClient *http.Client, oidcNodePort int32, subject string) string {
220+
url := fmt.Sprintf("http://localhost:%d/token?subject=%s", oidcNodePort, subject)
221+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
222+
gomega.Expect(err).ToNot(gomega.HaveOccurred())
223+
224+
resp, err := httpClient.Do(req)
225+
gomega.Expect(err).ToNot(gomega.HaveOccurred())
226+
defer func() { _ = resp.Body.Close() }()
227+
228+
var tokenResp struct {
229+
AccessToken string `json:"access_token"`
230+
}
231+
gomega.Expect(json.NewDecoder(resp.Body).Decode(&tokenResp)).To(gomega.Succeed())
232+
gomega.Expect(tokenResp.AccessToken).ToNot(gomega.BeEmpty(), "OIDC server should return a token")
233+
234+
return tokenResp.AccessToken
235+
}

0 commit comments

Comments
 (0)