diff --git a/internal/api/server.go b/internal/api/server.go index d47cf3c9..869a0cfd 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -29,7 +29,7 @@ type serverConfig struct { // shared with the other layers of the application. middlewares []func(http.Handler) http.Handler authInfoHandler http.Handler - authzConfig *config.AuthzConfig + authConfig *config.AuthConfig } // WithMiddlewares adds middleware to the server @@ -46,10 +46,12 @@ func WithAuthInfoHandler(handler http.Handler) ServerOption { } } -// WithAuthzConfig sets the authorization configuration for role-based access control -func WithAuthzConfig(authzCfg *config.AuthzConfig) ServerOption { +// WithAuthConfig sets the authentication configuration used by the v1 router. +// Both role-based access control (Authz) and publish-time claim enforcement +// (Mode) are derived from it. +func WithAuthConfig(authCfg *config.AuthConfig) ServerOption { return func(cfg *serverConfig) { - cfg.authzConfig = authzCfg + cfg.authConfig = authCfg } } @@ -82,7 +84,7 @@ func NewServer(svc service.RegistryService, opts ...ServerOption) *chi.Mux { // Mount MCP Registry API v0.1 routes r.Mount("/registry", v01.Router(svc)) - r.Mount("/v1", apiv1.Router(svc, cfg.authzConfig)) + r.Mount("/v1", apiv1.Router(svc, cfg.authConfig)) return r } diff --git a/internal/api/v1/entries.go b/internal/api/v1/entries.go index b4027280..ef32f0ea 100644 --- a/internal/api/v1/entries.go +++ b/internal/api/v1/entries.go @@ -48,6 +48,11 @@ func (routes *Routes) publishEntry(w http.ResponseWriter, r *http.Request) { return } + if routes.authEnabled && len(req.Claims) == 0 { + common.WriteErrorResponse(w, "claims are required when authentication is enabled", http.StatusBadRequest) + return + } + // Extract JWT claims for authorization (subset validation) var jwtOpts []service.Option if jwtClaims := auth.ClaimsFromContext(r.Context()); jwtClaims != nil { diff --git a/internal/api/v1/entries_test.go b/internal/api/v1/entries_test.go index 8e952dde..683efeb6 100644 --- a/internal/api/v1/entries_test.go +++ b/internal/api/v1/entries_test.go @@ -8,11 +8,14 @@ import ( "net/http/httptest" "testing" + "github.com/golang-jwt/jwt/v5" upstreamv0 "github.com/modelcontextprotocol/registry/pkg/api/v0" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" + "github.com/stacklok/toolhive-registry-server/internal/auth" + "github.com/stacklok/toolhive-registry-server/internal/config" "github.com/stacklok/toolhive-registry-server/internal/service" "github.com/stacklok/toolhive-registry-server/internal/service/mocks" ) @@ -128,6 +131,117 @@ func TestPublishEntry(t *testing.T) { } } +func TestPublishEntryClaimsRequired(t *testing.T) { + t.Parallel() + + oauthCfg := &config.AuthConfig{Mode: config.AuthModeOAuth} + anonCfg := &config.AuthConfig{Mode: config.AuthModeAnonymous} + + tests := []struct { + name string + authCfg *config.AuthConfig + body []byte + jwtClaims jwt.MapClaims + setupMock func(*mocks.MockRegistryService) + wantStatus int + wantError string + }{ + { + name: "auth enabled without claims returns 400", + authCfg: oauthCfg, + body: mustMarshal(publishEntryRequest{ + Server: &upstreamv0.ServerJSON{Name: "test/server", Version: "1.0.0"}, + }), + jwtClaims: jwt.MapClaims{"sub": "user1"}, + setupMock: func(_ *mocks.MockRegistryService) {}, + wantStatus: http.StatusBadRequest, + wantError: "claims are required", + }, + { + name: "auth enabled with empty claims returns 400", + authCfg: oauthCfg, + body: mustMarshal(publishEntryRequest{ + Server: &upstreamv0.ServerJSON{Name: "test/server", Version: "1.0.0"}, + Claims: map[string]any{}, + }), + jwtClaims: jwt.MapClaims{"sub": "user1"}, + setupMock: func(_ *mocks.MockRegistryService) {}, + wantStatus: http.StatusBadRequest, + wantError: "claims are required", + }, + { + name: "auth enabled with claims succeeds", + authCfg: oauthCfg, + body: mustMarshal(publishEntryRequest{ + Server: &upstreamv0.ServerJSON{Name: "test/server", Version: "1.0.0"}, + Claims: map[string]any{"sub": "user1"}, + }), + jwtClaims: jwt.MapClaims{"sub": "user1"}, + setupMock: func(m *mocks.MockRegistryService) { + m.EXPECT().PublishServerVersion(gomock.Any(), gomock.Any()). + Return(&upstreamv0.ServerJSON{Name: "test/server", Version: "1.0.0"}, nil) + }, + wantStatus: http.StatusCreated, + }, + { + name: "anonymous mode without claims succeeds", + authCfg: anonCfg, + body: mustMarshal(publishEntryRequest{ + Server: &upstreamv0.ServerJSON{Name: "test/server", Version: "1.0.0"}, + }), + setupMock: func(m *mocks.MockRegistryService) { + m.EXPECT().PublishServerVersion(gomock.Any(), gomock.Any()). + Return(&upstreamv0.ServerJSON{Name: "test/server", Version: "1.0.0"}, nil) + }, + wantStatus: http.StatusCreated, + }, + { + name: "no auth config without claims succeeds", + authCfg: nil, + body: mustMarshal(publishEntryRequest{ + Server: &upstreamv0.ServerJSON{Name: "test/server", Version: "1.0.0"}, + }), + setupMock: func(m *mocks.MockRegistryService) { + m.EXPECT().PublishServerVersion(gomock.Any(), gomock.Any()). + Return(&upstreamv0.ServerJSON{Name: "test/server", Version: "1.0.0"}, nil) + }, + wantStatus: http.StatusCreated, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockSvc := mocks.NewMockRegistryService(ctrl) + tt.setupMock(mockSvc) + + router := Router(mockSvc, tt.authCfg) + req, err := http.NewRequest("POST", "/entries", bytes.NewReader(tt.body)) + require.NoError(t, err) + + if tt.jwtClaims != nil { + ctx := auth.ContextWithClaims(req.Context(), tt.jwtClaims) + req = req.WithContext(ctx) + } + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + assert.Equal(t, tt.wantStatus, rr.Code) + + if tt.wantError != "" { + var response map[string]string + err = json.Unmarshal(rr.Body.Bytes(), &response) + require.NoError(t, err) + assert.Contains(t, response["error"], tt.wantError) + } + }) + } +} + func TestDeletePublishedEntry(t *testing.T) { t.Parallel() diff --git a/internal/api/v1/routes.go b/internal/api/v1/routes.go index 8f3ab8c1..610b0dd9 100644 --- a/internal/api/v1/routes.go +++ b/internal/api/v1/routes.go @@ -14,19 +14,33 @@ import ( // Routes handles HTTP requests for API v1 endpoints. type Routes struct { - service service.RegistryService + service service.RegistryService + authEnabled bool } // NewRoutes creates a new Routes instance with the given service. -func NewRoutes(svc service.RegistryService) *Routes { +// authEnabled reflects whether the server is configured to require +// authentication (i.e. Auth.Mode != anonymous); handlers use it to gate +// policies that only make sense when publishers are authenticated users. +func NewRoutes(svc service.RegistryService, authEnabled bool) *Routes { return &Routes{ - service: svc, + service: svc, + authEnabled: authEnabled, } } // Router creates and configures the HTTP router for API v1 endpoints. -func Router(svc service.RegistryService, authzCfg *config.AuthzConfig) http.Handler { - routes := NewRoutes(svc) +// authCfg is the full authentication configuration; its Authz subtree drives +// role checks and its Mode determines whether publish-time claim requirements +// are enforced. A nil authCfg means no auth is configured (development). +func Router(svc service.RegistryService, authCfg *config.AuthConfig) http.Handler { + var authzCfg *config.AuthzConfig + authEnabled := false + if authCfg != nil { + authzCfg = authCfg.Authz + authEnabled = authCfg.Mode != config.AuthModeAnonymous + } + routes := NewRoutes(svc, authEnabled) r := chi.NewRouter() diff --git a/internal/app/builder.go b/internal/app/builder.go index a2c65c27..ca68f5cf 100644 --- a/internal/app/builder.go +++ b/internal/app/builder.go @@ -453,10 +453,15 @@ func buildHTTPServer( authMw := auth.WrapWithPublicPaths(b.authMiddleware, publicPaths) b.middlewares = append(b.middlewares, authMw) - // Extract authorization config if available - var authzCfg *config.AuthzConfig + // Extract authentication config if available; the v1 router derives both + // role checks (Authz) and publish-time claim enforcement (Mode) from it. + var authCfg *config.AuthConfig if b.config != nil && b.config.Auth != nil { - authzCfg = b.config.Auth.Authz + authCfg = b.config.Auth + } + var authzCfg *config.AuthzConfig + if authCfg != nil { + authzCfg = authCfg.Authz } // Resolve roles for all authenticated requests so that downstream code @@ -473,7 +478,7 @@ func buildHTTPServer( serverOpts := []api.ServerOption{ api.WithMiddlewares(b.middlewares...), api.WithAuthInfoHandler(b.authInfoHandler), - api.WithAuthzConfig(authzCfg), + api.WithAuthConfig(authCfg), } // Create router with middlewares router := api.NewServer(svc, serverOpts...) diff --git a/internal/authz/authz_claims_test.go b/internal/authz/authz_claims_test.go index 7331dd9d..376e01d1 100644 --- a/internal/authz/authz_claims_test.go +++ b/internal/authz/authz_claims_test.go @@ -132,15 +132,24 @@ func TestAuthzIntegration_EmptyClaimsBehavior(t *testing.T) { waitForSync(t, env, superAdmin) - // Publish three entries as superAdmin (bypasses all claim checks) - t.Run("publish open-entry with empty claims", func(t *testing.T) { + // Publishing with empty claims is rejected when auth is enabled — empty + // claims would create entries invisible to everyone except super-admins. + t.Run("publish with empty claims rejected", func(t *testing.T) { resp := doRequest(t, "POST", env.baseURL+"/v1/entries", superAdmin, publishReq{ Claims: map[string]any{}, Server: serverJSON("open-entry"), }) - assertStatus(t, resp, 201) + assertStatus(t, resp, 400) + }) + + t.Run("publish without claims rejected", func(t *testing.T) { + resp := doRequest(t, "POST", env.baseURL+"/v1/entries", superAdmin, publishReq{ + Server: serverJSON("open-entry"), + }) + assertStatus(t, resp, 400) }) + // Publish two entries with proper claims t.Run("publish acme-entry with org claim", func(t *testing.T) { resp := doRequest(t, "POST", env.baseURL+"/v1/entries", superAdmin, publishReq{ Claims: map[string]any{"org": "acme"}, @@ -158,35 +167,6 @@ func TestAuthzIntegration_EmptyClaimsBehavior(t *testing.T) { }) // Visibility checks - // Entries with empty claims {} are stored as NULL in the database (the claims - // serializer treats empty maps as nil). The read-path filter returns false for - // NULL record claims (default-deny), so only superAdmin can see them. - t.Run("platform user does NOT see empty-claims entry", func(t *testing.T) { - url := fmt.Sprintf("%s/registry/acme-all/v0.1/servers?search=open-entry", env.baseURL) - resp := doRequest(t, "GET", url, platformWriter, nil) - body := assertStatus(t, resp, 200) - var result struct { - Metadata struct { - Count int `json:"count"` - } `json:"metadata"` - } - require.NoError(t, json.Unmarshal([]byte(body), &result)) - assert.Equal(t, 0, result.Metadata.Count, "empty-claims entry should be invisible to regular users") - }) - - t.Run("data user does NOT see empty-claims entry", func(t *testing.T) { - url := fmt.Sprintf("%s/registry/acme-all/v0.1/servers?search=open-entry", env.baseURL) - resp := doRequest(t, "GET", url, dataWriter, nil) - body := assertStatus(t, resp, 200) - var result struct { - Metadata struct { - Count int `json:"count"` - } `json:"metadata"` - } - require.NoError(t, json.Unmarshal([]byte(body), &result)) - assert.Equal(t, 0, result.Metadata.Count, "empty-claims entry should be invisible to regular users") - }) - t.Run("platform user sees acme entry", func(t *testing.T) { url := fmt.Sprintf("%s/registry/acme-all/v0.1/servers?search=acme-entry", env.baseURL) resp := doRequest(t, "GET", url, platformWriter, nil) @@ -221,21 +201,15 @@ func TestAuthzIntegration_EmptyClaimsBehavior(t *testing.T) { assert.Equal(t, 0, result.Metadata.Count, "expected 0 servers for data user searching platform-entry") }) - t.Run("superAdmin sees all three", func(t *testing.T) { + t.Run("superAdmin sees both entries", func(t *testing.T) { url := fmt.Sprintf("%s/registry/acme-all/v0.1/servers?search=-entry", env.baseURL) resp := doRequest(t, "GET", url, superAdmin, nil) body := assertStatus(t, resp, 200) - assert.Contains(t, body, "open-entry") assert.Contains(t, body, "acme-entry") assert.Contains(t, body, "platform-entry") }) - // Cleanup: delete all three entries - t.Run("cleanup open-entry", func(t *testing.T) { - resp := doRequest(t, "DELETE", env.baseURL+"/v1/entries/server/io.test%2Fopen-entry/versions/1.0.0", superAdmin, nil) - assertStatus(t, resp, 204) - }) - + // Cleanup t.Run("cleanup acme-entry", func(t *testing.T) { resp := doRequest(t, "DELETE", env.baseURL+"/v1/entries/server/io.test%2Facme-entry/versions/1.0.0", superAdmin, nil) assertStatus(t, resp, 204) diff --git a/internal/authz/authz_sources_test.go b/internal/authz/authz_sources_test.go index 91ba71e4..f92f95ce 100644 --- a/internal/authz/authz_sources_test.go +++ b/internal/authz/authz_sources_test.go @@ -37,7 +37,7 @@ func TestAuthzIntegration_SourceCRUD(t *testing.T) { // Use file-data sources (not managed) since the base config already has a // managed source and at most one is allowed globally. fileDataSource := map[string]any{ - "file": map[string]any{"data": `{"version":"1.0.0","last_updated":"2025-01-15T10:30:00Z","servers":{}}`}, + "file": map[string]any{"data": `{"version":"1.0.0","meta":{"last_updated":"2025-01-15T10:30:00Z"},"data":{"servers":[{"name":"io.test/placeholder","version":"1.0.0","description":"placeholder","packages":[{"registryType":"oci","identifier":"ghcr.io/test/placeholder:latest","transport":{"type":"stdio"}}]}]}}`}, "claims": map[string]any{"org": "acme", "team": "platform"}, } @@ -48,7 +48,7 @@ func TestAuthzIntegration_SourceCRUD(t *testing.T) { t.Run("create source with non-subset claims", func(t *testing.T) { resp := doRequest(t, "PUT", env.baseURL+"/v1/sources/bad-src", platformAdmin, map[string]any{ - "file": map[string]any{"data": `{"version":"1.0.0","last_updated":"2025-01-15T10:30:00Z","servers":{}}`}, + "file": map[string]any{"data": `{"version":"1.0.0","meta":{"last_updated":"2025-01-15T10:30:00Z"},"data":{"servers":[{"name":"io.test/placeholder","version":"1.0.0","description":"placeholder","packages":[{"registryType":"oci","identifier":"ghcr.io/test/placeholder:latest","transport":{"type":"stdio"}}]}]}}`}, "claims": map[string]any{"org": "acme", "team": "finance"}, }) assertStatus(t, resp, 403) diff --git a/internal/service/db/admin_claims_test.go b/internal/service/db/admin_claims_test.go index be17432b..e1453810 100644 --- a/internal/service/db/admin_claims_test.go +++ b/internal/service/db/admin_claims_test.go @@ -88,7 +88,7 @@ func managedSourceReq(claims map[string]any) *service.SourceCreateRequest { func fileDataSourceReq(claims map[string]any) *service.SourceCreateRequest { return &service.SourceCreateRequest{ File: &config.FileConfig{ - Data: `{"version":"1.0.0","last_updated":"2025-01-15T10:30:00Z","servers":{}}`, + Data: `{"version":"1.0.0","meta":{"last_updated":"2025-01-15T10:30:00Z"},"data":{"servers":[{"name":"io.test/placeholder","version":"1.0.0","description":"placeholder","packages":[{"registryType":"oci","identifier":"ghcr.io/test/placeholder:latest","transport":{"type":"stdio"}}]}]}}`, }, Claims: claims, } diff --git a/internal/service/db/claims_filter.go b/internal/service/db/claims_filter.go index 65286fcb..425ed631 100644 --- a/internal/service/db/claims_filter.go +++ b/internal/service/db/claims_filter.go @@ -14,6 +14,33 @@ import ( // Write-path validation (gate checks, publish/delete authorization) // --------------------------------------------------------------------------- +// checkClaimConsistency verifies that incoming claims match the existing entry's claims. +// Both-nil is OK (no claims on either side). Both-non-nil are compared for equality. +// One-nil-one-non-nil is a mismatch — the publisher must be explicit and consistent. +func checkClaimConsistency(incomingJSON, existingJSON []byte) error { + incomingEmpty := len(incomingJSON) == 0 + existingEmpty := len(existingJSON) == 0 + + if incomingEmpty && existingEmpty { + return nil + } + if incomingEmpty != existingEmpty { + return fmt.Errorf("%w: claims do not match existing entry", service.ErrClaimsMismatch) + } + + var incoming, existing map[string]any + if err := json.Unmarshal(incomingJSON, &incoming); err != nil { + return fmt.Errorf("failed to unmarshal incoming claims: %w", err) + } + if err := json.Unmarshal(existingJSON, &existing); err != nil { + return fmt.Errorf("failed to unmarshal existing claims: %w", err) + } + if !claimsEqual(incoming, existing) { + return fmt.Errorf("%w: claims do not match existing entry", service.ErrClaimsMismatch) + } + return nil +} + // validateClaimsSubset checks that callerClaims covers resourceClaims. // Returns nil if: // - callerClaims is nil (anonymous mode — no auth enforcement) diff --git a/internal/service/db/claims_filter_validate_test.go b/internal/service/db/claims_filter_validate_test.go index 92d8d672..d438f002 100644 --- a/internal/service/db/claims_filter_validate_test.go +++ b/internal/service/db/claims_filter_validate_test.go @@ -200,6 +200,80 @@ func TestClaimsFromCtx(t *testing.T) { } } +func TestCheckClaimConsistency(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + incomingJSON []byte + existingJSON []byte + wantErr error + }{ + { + name: "both nil returns nil", + incomingJSON: nil, + existingJSON: nil, + wantErr: nil, + }, + { + name: "both empty slices returns nil", + incomingJSON: []byte{}, + existingJSON: []byte{}, + wantErr: nil, + }, + { + name: "incoming nil existing has claims returns ErrClaimsMismatch", + incomingJSON: nil, + existingJSON: mustMarshalAuthz(t, map[string]any{"sub": "user1"}), + wantErr: service.ErrClaimsMismatch, + }, + { + name: "incoming has claims existing nil returns ErrClaimsMismatch", + incomingJSON: mustMarshalAuthz(t, map[string]any{"sub": "user1"}), + existingJSON: nil, + wantErr: service.ErrClaimsMismatch, + }, + { + name: "both have identical claims returns nil", + incomingJSON: mustMarshalAuthz(t, map[string]any{"sub": "user1", "org": "acme"}), + existingJSON: mustMarshalAuthz(t, map[string]any{"sub": "user1", "org": "acme"}), + wantErr: nil, + }, + { + name: "both have different claims returns ErrClaimsMismatch", + incomingJSON: mustMarshalAuthz(t, map[string]any{"sub": "user1"}), + existingJSON: mustMarshalAuthz(t, map[string]any{"sub": "user2"}), + wantErr: service.ErrClaimsMismatch, + }, + { + name: "incoming has extra key returns ErrClaimsMismatch", + incomingJSON: mustMarshalAuthz(t, map[string]any{"sub": "user1", "org": "acme"}), + existingJSON: mustMarshalAuthz(t, map[string]any{"sub": "user1"}), + wantErr: service.ErrClaimsMismatch, + }, + { + name: "existing has extra key returns ErrClaimsMismatch", + incomingJSON: mustMarshalAuthz(t, map[string]any{"sub": "user1"}), + existingJSON: mustMarshalAuthz(t, map[string]any{"sub": "user1", "org": "acme"}), + wantErr: service.ErrClaimsMismatch, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := checkClaimConsistency(tt.incomingJSON, tt.existingJSON) + + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + } else { + require.NoError(t, err) + } + }) + } +} + // mustMarshalAuthz is a test helper that marshals v to JSON or fails the test. func mustMarshalAuthz(t *testing.T, v any) []byte { t.Helper() diff --git a/internal/service/db/impl_mcp.go b/internal/service/db/impl_mcp.go index cd188adc..755d5a62 100644 --- a/internal/service/db/impl_mcp.go +++ b/internal/service/db/impl_mcp.go @@ -406,17 +406,8 @@ func insertServerVersionData( }) } else if err == nil { entryID = existing.ID - // Verify claim consistency: subsequent publishes of the same entry must - // carry exactly the same claims as the original. We use strict equality - // (not superset) so that claims cannot drift across publishes — any change - // requires an explicit delete-and-recreate. - if claimsJSON != nil && existing.Claims != nil { - var existingClaims, incoming map[string]any - _ = json.Unmarshal(existing.Claims, &existingClaims) - _ = json.Unmarshal(claimsJSON, &incoming) - if !claimsEqual(incoming, existingClaims) { - return uuid.Nil, fmt.Errorf("%w: claims do not match existing entry", service.ErrClaimsMismatch) - } + if err := checkClaimConsistency(claimsJSON, existing.Claims); err != nil { + return uuid.Nil, err } } if err != nil { diff --git a/internal/service/db/impl_skills.go b/internal/service/db/impl_skills.go index ce79a543..6970a241 100644 --- a/internal/service/db/impl_skills.go +++ b/internal/service/db/impl_skills.go @@ -457,17 +457,8 @@ func (s *dbService) executePublishSkillTransaction( }) } else if err == nil { entryID = existing.ID - // Verify claim consistency: subsequent publishes of the same entry must - // carry exactly the same claims as the original. We use strict equality - // (not superset) so that claims cannot drift across publishes — any change - // requires an explicit delete-and-recreate. - if claimsJSON != nil && existing.Claims != nil { - var existingClaims, incoming map[string]any - _ = json.Unmarshal(existing.Claims, &existingClaims) - _ = json.Unmarshal(claimsJSON, &incoming) - if !claimsEqual(incoming, existingClaims) { - return "", fmt.Errorf("%w: claims do not match existing entry", service.ErrClaimsMismatch) - } + if err := checkClaimConsistency(claimsJSON, existing.Claims); err != nil { + return "", err } } if err != nil {