Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 57 additions & 11 deletions internal/api/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ type DenyRequest struct {

type AdminServer struct {
Name string `json:"name"`
EndpointSlug string `json:"endpoint_slug"`
EndpointURL string `json:"endpoint_url,omitempty"`
Mode string `json:"mode"`
BaseURL string `json:"base_url,omitempty"`
AuthToken string `json:"auth_token,omitempty"`
Expand Down Expand Up @@ -3074,6 +3076,7 @@ type ServerAdminService struct {
type serverRepo interface {
ListServers(ctx context.Context, filter mcp.ServerFilter) ([]mcp.Upstream, int, error)
GetServerAny(ctx context.Context, name string) (mcp.Upstream, error)
GetServerByEndpointSlugAny(ctx context.Context, slug string) (mcp.Upstream, error)
UpsertServer(ctx context.Context, upstream mcp.Upstream) error
UpdateServerStatus(ctx context.Context, name string, status mcp.ServerStatus) error
DeleteServer(ctx context.Context, name string) error
Expand Down Expand Up @@ -3110,6 +3113,7 @@ func (s *ServerAdminService) Get(ctx context.Context, name string) (AdminServer,
// what we requested and what's actually live.
func (s *ServerAdminService) adminViewWithGrantedScopes(ctx context.Context, upstream mcp.Upstream) AdminServer {
view := toAdminServer(upstream)
view.EndpointURL = s.endpointURL(view.EndpointSlug)
if cred, err := s.oauthRepo.GetCredential(ctx, upstream.Name); err == nil {
view.OAuthGrantedScopes = cred.Scope
}
Expand All @@ -3128,6 +3132,9 @@ func (s *ServerAdminService) Upsert(ctx context.Context, name string, req AdminS
if serverName == "" {
return AdminServer{}, fmt.Errorf("name is required")
}
if mcp.EndpointSlug(serverName) == "" {
return AdminServer{}, fmt.Errorf("name must include at least one letter or number")
}
mode := strings.TrimSpace(req.Mode)
if mode == "" {
return AdminServer{}, fmt.Errorf("mode is required")
Expand All @@ -3140,18 +3147,26 @@ func (s *ServerAdminService) Upsert(ctx context.Context, name string, req AdminS
// chose not to re-send (notably client_secret, which is never echoed
// back to the browser).
existing, _ := s.repo.GetServerAny(ctx, serverName)
endpointSlug := existing.EndpointSlug
if endpointSlug == "" {
endpointSlug = mcp.EndpointSlug(serverName)
if err := s.validateEndpointSlugAvailable(ctx, serverName, endpointSlug); err != nil {
return AdminServer{}, err
}
}

upstream := mcp.Upstream{
Name: serverName,
Mode: mcp.UpstreamMode(mode),
BaseURL: strings.TrimRight(strings.TrimSpace(req.BaseURL), "/"),
AuthToken: req.AuthToken,
AuthHeaders: append([]mcp.AuthHeader(nil), req.AuthHeaders...),
Timeout: time.Duration(defaultIfZero(req.TimeoutSeconds, 30)) * time.Second,
Command: strings.TrimSpace(req.Command),
Args: append([]string(nil), req.Args...),
Env: cloneEnv(req.Env),
Enabled: enabled,
Name: serverName,
EndpointSlug: endpointSlug,
Mode: mcp.UpstreamMode(mode),
BaseURL: strings.TrimRight(strings.TrimSpace(req.BaseURL), "/"),
AuthToken: req.AuthToken,
AuthHeaders: append([]mcp.AuthHeader(nil), req.AuthHeaders...),
Timeout: time.Duration(defaultIfZero(req.TimeoutSeconds, 30)) * time.Second,
Command: strings.TrimSpace(req.Command),
Args: append([]string(nil), req.Args...),
Env: cloneEnv(req.Env),
Enabled: enabled,
}

// OAuth fields from the request body win, except client_secret which
Expand Down Expand Up @@ -3192,6 +3207,33 @@ func (s *ServerAdminService) Upsert(ctx context.Context, name string, req AdminS
return s.adminViewWithGrantedScopes(ctx, upstream), nil
}

func (s *ServerAdminService) validateEndpointSlugAvailable(ctx context.Context, serverName string, slug string) error {
item, err := s.repo.GetServerByEndpointSlugAny(ctx, slug)
if err == sql.ErrNoRows {
return nil
}
if err != nil {
return err
}
if item.Name != serverName {
if strings.EqualFold(item.Name, serverName) {
return fmt.Errorf("server name %q already exists", item.Name)
}
return fmt.Errorf("server name %q would use endpoint slug %q, already used by %q", serverName, slug, item.Name)
}
return nil
}

func (s *ServerAdminService) endpointURL(slug string) string {
if slug == "" {
return ""
}
if s.publicBaseURL == "" {
return "/mcp/" + slug
}
return s.publicBaseURL + "/mcp/" + slug
}

func (s *ServerAdminService) Delete(ctx context.Context, name string, disable bool) error {
if disable {
return s.repo.DisableServer(ctx, name)
Expand Down Expand Up @@ -3544,7 +3586,11 @@ func (h *Handler) externalInvocationDetail(w http.ResponseWriter, r *http.Reques
}

func toAdminServer(upstream mcp.Upstream) AdminServer {
return AdminServer{Name: upstream.Name, Mode: string(upstream.Mode), BaseURL: upstream.BaseURL, AuthToken: upstream.AuthToken, AuthHeaders: append([]mcp.AuthHeader(nil), upstream.AuthHeaders...), TimeoutSeconds: int(upstream.Timeout / time.Second), Command: upstream.Command, Args: append([]string(nil), upstream.Args...), Env: cloneEnv(upstream.Env), Enabled: upstream.Enabled, AuthType: string(upstream.Status.AuthType), ConnectionStatus: string(upstream.Status.ConnectionStatus), AuthStatus: string(upstream.Status.AuthStatus), ReauthNeeded: upstream.Status.ReauthNeeded, LastCheckedAt: upstream.Status.LastCheckedAt, LastCheckOK: upstream.Status.LastCheckOK, LastErrorSummary: upstream.Status.LastErrorSummary, ActionRequired: upstream.Status.ActionRequired, OAuthProviderID: upstream.OAuthProviderID, OAuthProviderLabel: upstream.OAuthProviderLabel, OAuthClientRegistration: string(upstream.OAuthClientRegistration), OAuthClientID: upstream.OAuthClientID, OAuthAuthorizeURL: upstream.OAuthAuthorizeURL, OAuthTokenURL: upstream.OAuthTokenURL, OAuthScopes: upstream.OAuthScopes, HasOAuthClientSecret: strings.TrimSpace(upstream.OAuthClientSecret) != ""}
endpointSlug := upstream.EndpointSlug
if endpointSlug == "" {
endpointSlug = mcp.EndpointSlug(upstream.Name)
}
return AdminServer{Name: upstream.Name, EndpointSlug: endpointSlug, Mode: string(upstream.Mode), BaseURL: upstream.BaseURL, AuthToken: upstream.AuthToken, AuthHeaders: append([]mcp.AuthHeader(nil), upstream.AuthHeaders...), TimeoutSeconds: int(upstream.Timeout / time.Second), Command: upstream.Command, Args: append([]string(nil), upstream.Args...), Env: cloneEnv(upstream.Env), Enabled: upstream.Enabled, AuthType: string(upstream.Status.AuthType), ConnectionStatus: string(upstream.Status.ConnectionStatus), AuthStatus: string(upstream.Status.AuthStatus), ReauthNeeded: upstream.Status.ReauthNeeded, LastCheckedAt: upstream.Status.LastCheckedAt, LastCheckOK: upstream.Status.LastCheckOK, LastErrorSummary: upstream.Status.LastErrorSummary, ActionRequired: upstream.Status.ActionRequired, OAuthProviderID: upstream.OAuthProviderID, OAuthProviderLabel: upstream.OAuthProviderLabel, OAuthClientRegistration: string(upstream.OAuthClientRegistration), OAuthClientID: upstream.OAuthClientID, OAuthAuthorizeURL: upstream.OAuthAuthorizeURL, OAuthTokenURL: upstream.OAuthTokenURL, OAuthScopes: upstream.OAuthScopes, HasOAuthClientSecret: strings.TrimSpace(upstream.OAuthClientSecret) != ""}
}

func validateUpstream(upstream mcp.Upstream) error {
Expand Down
147 changes: 147 additions & 0 deletions internal/api/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

"atryum/internal/auth"
backendclient "atryum/internal/backend"
"atryum/internal/config"
"atryum/internal/invocation"
"atryum/internal/managedagents"
"atryum/internal/mcp"
Expand Down Expand Up @@ -330,6 +331,152 @@ func TestStartConnectUsesUpstreamMCPOAuthCallbackRedirectURI(t *testing.T) {
}
}

func TestServerAdminServiceSurfacesEndpointURLForSluggedServerName(t *testing.T) {
db, err := sql.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("open db: %v", err)
}
defer db.Close()
if err := store.InitDB(db); err != nil {
t.Fatalf("InitDB: %v", err)
}

ctx := context.Background()
serverRepo := store.NewServerRepo(db)
oauthRepo := store.NewOAuthRepo(db)
svc := NewServerAdminService(serverRepo, oauthRepo, nil, 5*time.Second, "https://atryum.example")

server, err := svc.Upsert(ctx, "", AdminServerUpsertRequest{
Name: "Slack Local",
Mode: string(mcp.UpstreamModeHTTP),
BaseURL: "https://mcp.slack.test/mcp",
TimeoutSeconds: 10,
})
if err != nil {
t.Fatalf("Upsert: %v", err)
}
if server.Name != "Slack Local" {
t.Fatalf("name = %q, want Slack Local", server.Name)
}
if server.EndpointSlug != "slack-local" {
t.Fatalf("endpoint_slug = %q, want slack-local", server.EndpointSlug)
}
if server.EndpointURL != "https://atryum.example/mcp/slack-local" {
t.Fatalf("endpoint_url = %q", server.EndpointURL)
}

resolver := mcp.NewResolver(serverRepo, config.Config{})
upstream, err := resolver.ResolveContext(ctx, "slack-local")
if err != nil {
t.Fatalf("ResolveContext: %v", err)
}
if upstream.Name != "Slack Local" {
t.Fatalf("resolved name = %q, want Slack Local", upstream.Name)
}
}

func TestServerAdminServiceRejectsDuplicateEndpointSlug(t *testing.T) {
db, err := sql.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("open db: %v", err)
}
defer db.Close()
if err := store.InitDB(db); err != nil {
t.Fatalf("InitDB: %v", err)
}

ctx := context.Background()
serverRepo := store.NewServerRepo(db)
oauthRepo := store.NewOAuthRepo(db)
svc := NewServerAdminService(serverRepo, oauthRepo, nil, 5*time.Second, "")

_, err = svc.Upsert(ctx, "", AdminServerUpsertRequest{
Name: "Slack Local",
Mode: string(mcp.UpstreamModeHTTP),
BaseURL: "https://mcp.slack.test/mcp",
})
if err != nil {
t.Fatalf("Upsert first server: %v", err)
}
_, err = svc.Upsert(ctx, "", AdminServerUpsertRequest{
Name: "slack-local",
Mode: string(mcp.UpstreamModeHTTP),
BaseURL: "https://other.example/mcp",
})
if err == nil || !strings.Contains(err.Error(), `endpoint slug "slack-local"`) {
t.Fatalf("expected duplicate endpoint slug error, got %v", err)
}
}

func TestServerAdminServiceRejectsCaseOnlyDuplicateNameWithClearError(t *testing.T) {
db, err := sql.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("open db: %v", err)
}
defer db.Close()
if err := store.InitDB(db); err != nil {
t.Fatalf("InitDB: %v", err)
}

ctx := context.Background()
serverRepo := store.NewServerRepo(db)
oauthRepo := store.NewOAuthRepo(db)
svc := NewServerAdminService(serverRepo, oauthRepo, nil, 5*time.Second, "")

_, err = svc.Upsert(ctx, "", AdminServerUpsertRequest{
Name: "slack local",
Mode: string(mcp.UpstreamModeHTTP),
BaseURL: "https://mcp.slack.test/mcp",
})
if err != nil {
t.Fatalf("Upsert first server: %v", err)
}
_, err = svc.Upsert(ctx, "", AdminServerUpsertRequest{
Name: "Slack Local",
Mode: string(mcp.UpstreamModeHTTP),
BaseURL: "https://other.example/mcp",
})
if err == nil || !strings.Contains(err.Error(), `server name "slack local" already exists`) {
t.Fatalf("expected duplicate name error, got %v", err)
}
}

func TestServerAdminServiceUsesStoredEndpointSlug(t *testing.T) {
db, err := sql.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("open db: %v", err)
}
defer db.Close()
if err := store.InitDB(db); err != nil {
t.Fatalf("InitDB: %v", err)
}

ctx := context.Background()
serverRepo := store.NewServerRepo(db)
oauthRepo := store.NewOAuthRepo(db)
if err := serverRepo.UpsertServer(ctx, mcp.Upstream{
Name: "Slack Local",
EndpointSlug: "slack-local-2",
Mode: mcp.UpstreamModeHTTP,
BaseURL: "https://mcp.slack.test/mcp",
Enabled: true,
}); err != nil {
t.Fatalf("UpsertServer: %v", err)
}

svc := NewServerAdminService(serverRepo, oauthRepo, nil, 5*time.Second, "https://atryum.example")
server, err := svc.Get(ctx, "Slack Local")
if err != nil {
t.Fatalf("Get: %v", err)
}
if server.EndpointSlug != "slack-local-2" {
t.Fatalf("endpoint_slug = %q, want slack-local-2", server.EndpointSlug)
}
if server.EndpointURL != "https://atryum.example/mcp/slack-local-2" {
t.Fatalf("endpoint_url = %q", server.EndpointURL)
}
}

func TestAdminServerTestDebugLogsRequestContext(t *testing.T) {
t.Setenv("ATRYUM_MCP_DEBUG", "1")

Expand Down
30 changes: 30 additions & 0 deletions internal/mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"net/url"
"os"
"os/exec"
"regexp"
"sort"
"strconv"
"strings"
Expand Down Expand Up @@ -95,6 +96,7 @@ type AuthHeader struct {

type Upstream struct {
Name string
EndpointSlug string
Mode UpstreamMode
BaseURL string
AuthToken string
Expand Down Expand Up @@ -187,6 +189,7 @@ func DecodeAuthHeaders(env map[string]string) []AuthHeader {

type ServerStore interface {
GetServer(ctx context.Context, name string) (Upstream, error)
GetServerByEndpointSlug(ctx context.Context, slug string) (Upstream, error)
ListServers(ctx context.Context, filter ServerFilter) ([]Upstream, int, error)
CountServers(ctx context.Context) (int, error)
CreateServer(ctx context.Context, upstream Upstream) error
Expand All @@ -198,6 +201,18 @@ type ServerFilter struct {
Enabled *bool
}

var endpointSlugInvalidChars = regexp.MustCompile(`[^a-z0-9._~-]+`)

func EndpointSlug(name string) string {
slug := strings.ToLower(strings.TrimSpace(name))
slug = endpointSlugInvalidChars.ReplaceAllString(slug, "-")
slug = strings.Trim(slug, "-")
if strings.Trim(slug, ".") == "" {
return ""
}
return slug
}

// CredentialStore returns the OAuth access token currently stored for a
// server. It is consulted on every ResolveContext so that proxied MCP
// requests (tools/list, tools/call, forward) get the live access token
Expand Down Expand Up @@ -310,9 +325,24 @@ func (r *Resolver) ResolveContext(ctx context.Context, name string) (Upstream, e
if err != nil && err != sql.ErrNoRows {
return Upstream{}, err
}
if slug := EndpointSlug(name); slug != "" {
upstream, err := r.store.GetServerByEndpointSlug(ctx, slug)
if err == nil {
return r.overlayCredential(ctx, upstream), nil
}
if err != nil && err != sql.ErrNoRows {
return Upstream{}, err
}
}
}
upstream, ok := r.bootstrap[name]
if !ok || !upstream.Enabled {
slug := EndpointSlug(name)
for _, candidate := range r.bootstrap {
if candidate.Enabled && EndpointSlug(candidate.Name) == slug {
return r.overlayCredential(ctx, candidate), nil
}
}
return Upstream{}, fmt.Errorf("upstream %q not configured or disabled", name)
}
return r.overlayCredential(ctx, upstream), nil
Expand Down
Loading
Loading