Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions internal/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ type serverConfig struct {
// shared with the other layers of the application.
middlewares []func(http.Handler) http.Handler
authInfoHandler http.Handler
authzConfig *config.AuthzConfig
authConfig *config.AuthConfig
}

// WithMiddlewares adds middleware to the server
Expand All @@ -46,10 +46,12 @@ func WithAuthInfoHandler(handler http.Handler) ServerOption {
}
}

// WithAuthzConfig sets the authorization configuration for role-based access control
func WithAuthzConfig(authzCfg *config.AuthzConfig) ServerOption {
// WithAuthConfig sets the authentication configuration used by the v1 router.
// Both role-based access control (Authz) and publish-time claim enforcement
// (Mode) are derived from it.
func WithAuthConfig(authCfg *config.AuthConfig) ServerOption {
return func(cfg *serverConfig) {
cfg.authzConfig = authzCfg
cfg.authConfig = authCfg
}
}

Expand Down Expand Up @@ -82,7 +84,7 @@ func NewServer(svc service.RegistryService, opts ...ServerOption) *chi.Mux {

// Mount MCP Registry API v0.1 routes
r.Mount("/registry", v01.Router(svc))
r.Mount("/v1", apiv1.Router(svc, cfg.authzConfig))
r.Mount("/v1", apiv1.Router(svc, cfg.authConfig))

return r
}
Expand Down
5 changes: 5 additions & 0 deletions internal/api/v1/entries.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ func (routes *Routes) publishEntry(w http.ResponseWriter, r *http.Request) {
return
}

if routes.authEnabled && len(req.Claims) == 0 {
common.WriteErrorResponse(w, "claims are required when authentication is enabled", http.StatusBadRequest)
return
}

// Extract JWT claims for authorization (subset validation)
var jwtOpts []service.Option
if jwtClaims := auth.ClaimsFromContext(r.Context()); jwtClaims != nil {
Expand Down
114 changes: 114 additions & 0 deletions internal/api/v1/entries_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@ import (
"net/http/httptest"
"testing"

"github.com/golang-jwt/jwt/v5"
upstreamv0 "github.com/modelcontextprotocol/registry/pkg/api/v0"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"

"github.com/stacklok/toolhive-registry-server/internal/auth"
"github.com/stacklok/toolhive-registry-server/internal/config"
"github.com/stacklok/toolhive-registry-server/internal/service"
"github.com/stacklok/toolhive-registry-server/internal/service/mocks"
)
Expand Down Expand Up @@ -128,6 +131,117 @@ func TestPublishEntry(t *testing.T) {
}
}

func TestPublishEntryClaimsRequired(t *testing.T) {
t.Parallel()

oauthCfg := &config.AuthConfig{Mode: config.AuthModeOAuth}
anonCfg := &config.AuthConfig{Mode: config.AuthModeAnonymous}

tests := []struct {
name string
authCfg *config.AuthConfig
body []byte
jwtClaims jwt.MapClaims
setupMock func(*mocks.MockRegistryService)
wantStatus int
wantError string
}{
{
name: "auth enabled without claims returns 400",
authCfg: oauthCfg,
body: mustMarshal(publishEntryRequest{
Server: &upstreamv0.ServerJSON{Name: "test/server", Version: "1.0.0"},
}),
jwtClaims: jwt.MapClaims{"sub": "user1"},
setupMock: func(_ *mocks.MockRegistryService) {},
wantStatus: http.StatusBadRequest,
wantError: "claims are required",
},
{
name: "auth enabled with empty claims returns 400",
authCfg: oauthCfg,
body: mustMarshal(publishEntryRequest{
Server: &upstreamv0.ServerJSON{Name: "test/server", Version: "1.0.0"},
Claims: map[string]any{},
}),
jwtClaims: jwt.MapClaims{"sub": "user1"},
setupMock: func(_ *mocks.MockRegistryService) {},
wantStatus: http.StatusBadRequest,
wantError: "claims are required",
},
{
name: "auth enabled with claims succeeds",
authCfg: oauthCfg,
body: mustMarshal(publishEntryRequest{
Server: &upstreamv0.ServerJSON{Name: "test/server", Version: "1.0.0"},
Claims: map[string]any{"sub": "user1"},
}),
jwtClaims: jwt.MapClaims{"sub": "user1"},
setupMock: func(m *mocks.MockRegistryService) {
m.EXPECT().PublishServerVersion(gomock.Any(), gomock.Any()).
Return(&upstreamv0.ServerJSON{Name: "test/server", Version: "1.0.0"}, nil)
},
wantStatus: http.StatusCreated,
},
{
name: "anonymous mode without claims succeeds",
authCfg: anonCfg,
body: mustMarshal(publishEntryRequest{
Server: &upstreamv0.ServerJSON{Name: "test/server", Version: "1.0.0"},
}),
setupMock: func(m *mocks.MockRegistryService) {
m.EXPECT().PublishServerVersion(gomock.Any(), gomock.Any()).
Return(&upstreamv0.ServerJSON{Name: "test/server", Version: "1.0.0"}, nil)
},
wantStatus: http.StatusCreated,
},
{
name: "no auth config without claims succeeds",
authCfg: nil,
body: mustMarshal(publishEntryRequest{
Server: &upstreamv0.ServerJSON{Name: "test/server", Version: "1.0.0"},
}),
setupMock: func(m *mocks.MockRegistryService) {
m.EXPECT().PublishServerVersion(gomock.Any(), gomock.Any()).
Return(&upstreamv0.ServerJSON{Name: "test/server", Version: "1.0.0"}, nil)
},
wantStatus: http.StatusCreated,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)

mockSvc := mocks.NewMockRegistryService(ctrl)
tt.setupMock(mockSvc)

router := Router(mockSvc, tt.authCfg)
req, err := http.NewRequest("POST", "/entries", bytes.NewReader(tt.body))
require.NoError(t, err)

if tt.jwtClaims != nil {
ctx := auth.ContextWithClaims(req.Context(), tt.jwtClaims)
req = req.WithContext(ctx)
}

rr := httptest.NewRecorder()
router.ServeHTTP(rr, req)

assert.Equal(t, tt.wantStatus, rr.Code)

if tt.wantError != "" {
var response map[string]string
err = json.Unmarshal(rr.Body.Bytes(), &response)
require.NoError(t, err)
assert.Contains(t, response["error"], tt.wantError)
}
})
}
}

func TestDeletePublishedEntry(t *testing.T) {
t.Parallel()

Expand Down
24 changes: 19 additions & 5 deletions internal/api/v1/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,33 @@ import (

// Routes handles HTTP requests for API v1 endpoints.
type Routes struct {
service service.RegistryService
service service.RegistryService
authEnabled bool
}

// NewRoutes creates a new Routes instance with the given service.
func NewRoutes(svc service.RegistryService) *Routes {
// authEnabled reflects whether the server is configured to require
// authentication (i.e. Auth.Mode != anonymous); handlers use it to gate
// policies that only make sense when publishers are authenticated users.
func NewRoutes(svc service.RegistryService, authEnabled bool) *Routes {
return &Routes{
service: svc,
service: svc,
authEnabled: authEnabled,
}
}

// Router creates and configures the HTTP router for API v1 endpoints.
func Router(svc service.RegistryService, authzCfg *config.AuthzConfig) http.Handler {
routes := NewRoutes(svc)
// authCfg is the full authentication configuration; its Authz subtree drives
// role checks and its Mode determines whether publish-time claim requirements
// are enforced. A nil authCfg means no auth is configured (development).
func Router(svc service.RegistryService, authCfg *config.AuthConfig) http.Handler {
var authzCfg *config.AuthzConfig
authEnabled := false
if authCfg != nil {
authzCfg = authCfg.Authz
authEnabled = authCfg.Mode != config.AuthModeAnonymous
}
routes := NewRoutes(svc, authEnabled)

r := chi.NewRouter()

Expand Down
13 changes: 9 additions & 4 deletions internal/app/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -453,10 +453,15 @@ func buildHTTPServer(
authMw := auth.WrapWithPublicPaths(b.authMiddleware, publicPaths)
b.middlewares = append(b.middlewares, authMw)

// Extract authorization config if available
var authzCfg *config.AuthzConfig
// Extract authentication config if available; the v1 router derives both
// role checks (Authz) and publish-time claim enforcement (Mode) from it.
var authCfg *config.AuthConfig
if b.config != nil && b.config.Auth != nil {
authzCfg = b.config.Auth.Authz
authCfg = b.config.Auth
}
var authzCfg *config.AuthzConfig
if authCfg != nil {
authzCfg = authCfg.Authz
}

// Resolve roles for all authenticated requests so that downstream code
Expand All @@ -473,7 +478,7 @@ func buildHTTPServer(
serverOpts := []api.ServerOption{
api.WithMiddlewares(b.middlewares...),
api.WithAuthInfoHandler(b.authInfoHandler),
api.WithAuthzConfig(authzCfg),
api.WithAuthConfig(authCfg),
}
// Create router with middlewares
router := api.NewServer(svc, serverOpts...)
Expand Down
54 changes: 14 additions & 40 deletions internal/authz/authz_claims_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,24 @@ func TestAuthzIntegration_EmptyClaimsBehavior(t *testing.T) {

waitForSync(t, env, superAdmin)

// Publish three entries as superAdmin (bypasses all claim checks)
t.Run("publish open-entry with empty claims", func(t *testing.T) {
// Publishing with empty claims is rejected when auth is enabled — empty
// claims would create entries invisible to everyone except super-admins.
t.Run("publish with empty claims rejected", func(t *testing.T) {
resp := doRequest(t, "POST", env.baseURL+"/v1/entries", superAdmin, publishReq{
Claims: map[string]any{},
Server: serverJSON("open-entry"),
})
assertStatus(t, resp, 201)
assertStatus(t, resp, 400)
})

t.Run("publish without claims rejected", func(t *testing.T) {
resp := doRequest(t, "POST", env.baseURL+"/v1/entries", superAdmin, publishReq{
Server: serverJSON("open-entry"),
})
assertStatus(t, resp, 400)
})

// Publish two entries with proper claims
t.Run("publish acme-entry with org claim", func(t *testing.T) {
resp := doRequest(t, "POST", env.baseURL+"/v1/entries", superAdmin, publishReq{
Claims: map[string]any{"org": "acme"},
Expand All @@ -158,35 +167,6 @@ func TestAuthzIntegration_EmptyClaimsBehavior(t *testing.T) {
})

// Visibility checks
// Entries with empty claims {} are stored as NULL in the database (the claims
// serializer treats empty maps as nil). The read-path filter returns false for
// NULL record claims (default-deny), so only superAdmin can see them.
t.Run("platform user does NOT see empty-claims entry", func(t *testing.T) {
url := fmt.Sprintf("%s/registry/acme-all/v0.1/servers?search=open-entry", env.baseURL)
resp := doRequest(t, "GET", url, platformWriter, nil)
body := assertStatus(t, resp, 200)
var result struct {
Metadata struct {
Count int `json:"count"`
} `json:"metadata"`
}
require.NoError(t, json.Unmarshal([]byte(body), &result))
assert.Equal(t, 0, result.Metadata.Count, "empty-claims entry should be invisible to regular users")
})

t.Run("data user does NOT see empty-claims entry", func(t *testing.T) {
url := fmt.Sprintf("%s/registry/acme-all/v0.1/servers?search=open-entry", env.baseURL)
resp := doRequest(t, "GET", url, dataWriter, nil)
body := assertStatus(t, resp, 200)
var result struct {
Metadata struct {
Count int `json:"count"`
} `json:"metadata"`
}
require.NoError(t, json.Unmarshal([]byte(body), &result))
assert.Equal(t, 0, result.Metadata.Count, "empty-claims entry should be invisible to regular users")
})

t.Run("platform user sees acme entry", func(t *testing.T) {
url := fmt.Sprintf("%s/registry/acme-all/v0.1/servers?search=acme-entry", env.baseURL)
resp := doRequest(t, "GET", url, platformWriter, nil)
Expand Down Expand Up @@ -221,21 +201,15 @@ func TestAuthzIntegration_EmptyClaimsBehavior(t *testing.T) {
assert.Equal(t, 0, result.Metadata.Count, "expected 0 servers for data user searching platform-entry")
})

t.Run("superAdmin sees all three", func(t *testing.T) {
t.Run("superAdmin sees both entries", func(t *testing.T) {
url := fmt.Sprintf("%s/registry/acme-all/v0.1/servers?search=-entry", env.baseURL)
resp := doRequest(t, "GET", url, superAdmin, nil)
body := assertStatus(t, resp, 200)
assert.Contains(t, body, "open-entry")
assert.Contains(t, body, "acme-entry")
assert.Contains(t, body, "platform-entry")
})

// Cleanup: delete all three entries
t.Run("cleanup open-entry", func(t *testing.T) {
resp := doRequest(t, "DELETE", env.baseURL+"/v1/entries/server/io.test%2Fopen-entry/versions/1.0.0", superAdmin, nil)
assertStatus(t, resp, 204)
})

// Cleanup
t.Run("cleanup acme-entry", func(t *testing.T) {
resp := doRequest(t, "DELETE", env.baseURL+"/v1/entries/server/io.test%2Facme-entry/versions/1.0.0", superAdmin, nil)
assertStatus(t, resp, 204)
Expand Down
4 changes: 2 additions & 2 deletions internal/authz/authz_sources_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func TestAuthzIntegration_SourceCRUD(t *testing.T) {
// Use file-data sources (not managed) since the base config already has a
// managed source and at most one is allowed globally.
fileDataSource := map[string]any{
"file": map[string]any{"data": `{"version":"1.0.0","last_updated":"2025-01-15T10:30:00Z","servers":{}}`},
"file": map[string]any{"data": `{"version":"1.0.0","meta":{"last_updated":"2025-01-15T10:30:00Z"},"data":{"servers":[{"name":"io.test/placeholder","version":"1.0.0","description":"placeholder","packages":[{"registryType":"oci","identifier":"ghcr.io/test/placeholder:latest","transport":{"type":"stdio"}}]}]}}`},
"claims": map[string]any{"org": "acme", "team": "platform"},
}

Expand All @@ -48,7 +48,7 @@ func TestAuthzIntegration_SourceCRUD(t *testing.T) {

t.Run("create source with non-subset claims", func(t *testing.T) {
resp := doRequest(t, "PUT", env.baseURL+"/v1/sources/bad-src", platformAdmin, map[string]any{
"file": map[string]any{"data": `{"version":"1.0.0","last_updated":"2025-01-15T10:30:00Z","servers":{}}`},
"file": map[string]any{"data": `{"version":"1.0.0","meta":{"last_updated":"2025-01-15T10:30:00Z"},"data":{"servers":[{"name":"io.test/placeholder","version":"1.0.0","description":"placeholder","packages":[{"registryType":"oci","identifier":"ghcr.io/test/placeholder:latest","transport":{"type":"stdio"}}]}]}}`},
"claims": map[string]any{"org": "acme", "team": "finance"},
})
assertStatus(t, resp, 403)
Expand Down
2 changes: 1 addition & 1 deletion internal/service/db/admin_claims_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func managedSourceReq(claims map[string]any) *service.SourceCreateRequest {
func fileDataSourceReq(claims map[string]any) *service.SourceCreateRequest {
return &service.SourceCreateRequest{
File: &config.FileConfig{
Data: `{"version":"1.0.0","last_updated":"2025-01-15T10:30:00Z","servers":{}}`,
Data: `{"version":"1.0.0","meta":{"last_updated":"2025-01-15T10:30:00Z"},"data":{"servers":[{"name":"io.test/placeholder","version":"1.0.0","description":"placeholder","packages":[{"registryType":"oci","identifier":"ghcr.io/test/placeholder:latest","transport":{"type":"stdio"}}]}]}}`,
},
Claims: claims,
}
Expand Down
27 changes: 27 additions & 0 deletions internal/service/db/claims_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,33 @@ import (
// Write-path validation (gate checks, publish/delete authorization)
// ---------------------------------------------------------------------------

// checkClaimConsistency verifies that incoming claims match the existing entry's claims.
// Both-nil is OK (no claims on either side). Both-non-nil are compared for equality.
// One-nil-one-non-nil is a mismatch — the publisher must be explicit and consistent.
func checkClaimConsistency(incomingJSON, existingJSON []byte) error {
incomingEmpty := len(incomingJSON) == 0
existingEmpty := len(existingJSON) == 0

if incomingEmpty && existingEmpty {
return nil
}
if incomingEmpty != existingEmpty {
return fmt.Errorf("%w: claims do not match existing entry", service.ErrClaimsMismatch)
}

var incoming, existing map[string]any
if err := json.Unmarshal(incomingJSON, &incoming); err != nil {
return fmt.Errorf("failed to unmarshal incoming claims: %w", err)
}
if err := json.Unmarshal(existingJSON, &existing); err != nil {
return fmt.Errorf("failed to unmarshal existing claims: %w", err)
}
if !claimsEqual(incoming, existing) {
return fmt.Errorf("%w: claims do not match existing entry", service.ErrClaimsMismatch)
}
return nil
}

// validateClaimsSubset checks that callerClaims covers resourceClaims.
// Returns nil if:
// - callerClaims is nil (anonymous mode — no auth enforcement)
Expand Down
Loading
Loading