From 6a9e6016a5cbd4f7cb662e1c307a78955c30d3e6 Mon Sep 17 00:00:00 2001 From: Stefan Haubold Date: Fri, 8 May 2026 14:02:27 +0200 Subject: [PATCH] auth: review follow-ups (provider routing, URL normalization, expiry preflight, timeouts) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Six fixes called out across the security/architecture/correctness reviews of #cli-auth-consolidation. Each fix lands with focused tests; no behaviour changes outside the auth path. 1. Eliminate duplicate ENTIRE_AUTH_PROVIDER_VERSION read in cmd/entire/cli/api/auth_tokens.go. The provider table in cmd/entire/cli/auth.Provider now owns AuthTokensPath; api.Client takes it via WithAuthTokensPath. The api package no longer reads the env var. 2. Read provider version once at startup. CurrentProvider() resolves via sync.Once and freezes; tests inject via SetProviderForTest. resolveProvider is a pure function so the routing table is exercisable without env-var gymnastics. 3. Normalize URLs in tokenmanager same-host / aud / cache-key compares. normalizeOriginURL handles trailing slash, scheme/host case, and default ports (RFC 3986 §6.2.2.1 / §6.2.3); non-URL audiences pass through unchanged for byte-exact compare. 4. Preflight core-token expiry and clear the exchange cache on SaveCoreToken. Long-expired tokens surface as ErrNotLoggedIn (so "run login" UX kicks in) instead of confusing STS / 401 errors; a re-login can't return the previous user's exchanged tokens. 5. Tighten tokenstore malformed-JSON detection. Well-formed JSON without an access_token now surfaces as ErrMalformed. The shim's bare-string fallback rejects JSON-shaped content via looksLikeBareToken so "Authorization: Bearer {}" can't ship. 6. Add per-request timeouts (DefaultRequestTimeout = 30s) to deviceflow.Client and sts.Client via context.WithTimeout. The wrap lives at the method level so the deadline covers the body read, not just the dial. Tests pin both the firing path and the default/override resolution. Entire-Checkpoint: c79c0ff7d6c1 --- auth/deviceflow/deviceflow.go | 46 ++++- auth/deviceflow/deviceflow_test.go | 80 +++++++++ auth/sts/sts.go | 33 ++++ auth/sts/sts_test.go | 59 +++++++ auth/tokenmanager/tokenmanager.go | 102 ++++++++++- auth/tokenmanager/tokenmanager_test.go | 228 ++++++++++++++++++++++++- auth/tokenstore/keyring.go | 11 ++ auth/tokenstore/keyring_test.go | 31 ++++ cmd/entire/cli/api/auth_tokens.go | 43 +++-- cmd/entire/cli/api/auth_tokens_test.go | 100 +++++------ cmd/entire/cli/api/client.go | 20 +++ cmd/entire/cli/auth.go | 8 +- cmd/entire/cli/auth/client.go | 10 +- cmd/entire/cli/auth/exchange.go | 8 +- cmd/entire/cli/auth/provider.go | 111 +++++++++--- cmd/entire/cli/auth/provider_test.go | 105 +++++++----- cmd/entire/cli/auth/store.go | 29 ++++ cmd/entire/cli/auth/store_test.go | 42 +++++ cmd/entire/cli/logout.go | 4 +- 19 files changed, 910 insertions(+), 160 deletions(-) diff --git a/auth/deviceflow/deviceflow.go b/auth/deviceflow/deviceflow.go index 9a72f53340..2827e053e6 100644 --- a/auth/deviceflow/deviceflow.go +++ b/auth/deviceflow/deviceflow.go @@ -47,6 +47,15 @@ type DeviceCode struct { Interval int `json:"interval"` } +// DefaultRequestTimeout caps a single device-flow HTTP round-trip +// (StartDeviceAuth or one PollDeviceAuth call). Set conservatively: +// healthy device-flow endpoints respond in sub-seconds, so the cap +// mainly defends against slow-loris responses dripping bytes within +// MaxResponseBytes — see Client.RequestTimeout for the per-Client +// override. The polling-loop interval is the caller's concern; this +// timeout governs only the individual HTTP request. +const DefaultRequestTimeout = 30 * time.Second + // Client polls an RFC 8628 device authorization grant. // // All configuration is explicit; the package has no global state and @@ -60,6 +69,26 @@ type Client struct { UserAgent string DeviceCodePath string TokenPath string + + // RequestTimeout is the per-request deadline applied via + // context.WithTimeout on top of the caller's context. Zero falls + // back to DefaultRequestTimeout. Negative disables the cap (useful + // for tests that want to drive timing via the caller's ctx alone). + RequestTimeout time.Duration +} + +// requestTimeout resolves the effective per-request timeout: the +// configured RequestTimeout if positive, the package default if zero, +// or zero (no cap) if negative. +func (c *Client) requestTimeout() time.Duration { + switch { + case c.RequestTimeout < 0: + return 0 + case c.RequestTimeout == 0: + return DefaultRequestTimeout + default: + return c.RequestTimeout + } } // Sentinel errors returned by PollDeviceAuth when the token endpoint @@ -110,6 +139,12 @@ func errCodeToSentinel(code string) error { // server. The returned DeviceCode is opaque to the client; pass it // back unmodified on every PollDeviceAuth. func (c *Client) StartDeviceAuth(ctx context.Context) (*DeviceCode, error) { + if timeout := c.requestTimeout(); timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + body := url.Values{} body.Set("client_id", c.ClientID) if c.Scope != "" { @@ -141,6 +176,12 @@ func (c *Client) StartDeviceAuth(ctx context.Context) (*DeviceCode, error) { // the matching sentinel error from this package. Other failures // (network, malformed responses) are wrapped with context. func (c *Client) PollDeviceAuth(ctx context.Context, deviceCode string) (*tokens.TokenSet, error) { + if timeout := c.requestTimeout(); timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + body := url.Values{} body.Set("grant_type", deviceCodeGrantType) body.Set("client_id", c.ClientID) @@ -195,7 +236,10 @@ func (c *Client) PollDeviceAuth(ctx context.Context, deviceCode string) (*tokens } // postForm POSTs body as application/x-www-form-urlencoded to a path -// resolved against the client's BaseURL. +// resolved against the client's BaseURL. The caller is responsible +// for applying any per-request timeout via context.WithTimeout — the +// timeout must cover the body-read that happens after postForm +// returns, so cancel-on-return here would interrupt that read. func (c *Client) postForm(ctx context.Context, path string, body url.Values) (*http.Response, error) { endpoint, err := resolveURL(c.BaseURL, path) if err != nil { diff --git a/auth/deviceflow/deviceflow_test.go b/auth/deviceflow/deviceflow_test.go index 05fd091c57..a43dfd4709 100644 --- a/auth/deviceflow/deviceflow_test.go +++ b/auth/deviceflow/deviceflow_test.go @@ -360,3 +360,83 @@ func TestResolveURL(t *testing.T) { }) } } + +// TestPollDeviceAuth_RequestTimeoutFires pins the slow-loris defence: +// a handler that never finishes writing a response must surface as a +// context deadline error rather than blocking the polling loop forever. +func TestPollDeviceAuth_RequestTimeoutFires(t *testing.T) { + t.Parallel() + // Cleanup is LIFO and httptest.Server.Close waits for active + // handler goroutines, so close(hung) is registered AFTER + // newTestClient to fire first and let the handler exit before + // srv.Close runs. + hung := make(chan struct{}) + c := newTestClient(t, func(_ http.ResponseWriter, r *http.Request) { + select { + case <-hung: + case <-r.Context().Done(): + } + }) + t.Cleanup(func() { close(hung) }) + c.RequestTimeout = 50 * time.Millisecond + + _, err := c.PollDeviceAuth(context.Background(), "dev-1") + if err == nil { + t.Fatal("expected timeout error, got nil") + } + if !strings.Contains(err.Error(), "context deadline exceeded") { + t.Fatalf("err = %v, want context deadline exceeded", err) + } +} + +// TestStartDeviceAuth_RequestTimeoutFires mirrors the poll-side test +// for the device-code endpoint. +func TestStartDeviceAuth_RequestTimeoutFires(t *testing.T) { + t.Parallel() + // Cleanup is LIFO and httptest.Server.Close waits for active + // handler goroutines, so close(hung) is registered AFTER + // newTestClient to fire first and let the handler exit before + // srv.Close runs. + hung := make(chan struct{}) + c := newTestClient(t, func(_ http.ResponseWriter, r *http.Request) { + select { + case <-hung: + case <-r.Context().Done(): + } + }) + t.Cleanup(func() { close(hung) }) + c.RequestTimeout = 50 * time.Millisecond + + _, err := c.StartDeviceAuth(context.Background()) + if err == nil { + t.Fatal("expected timeout error, got nil") + } + if !strings.Contains(err.Error(), "context deadline exceeded") { + t.Fatalf("err = %v, want context deadline exceeded", err) + } +} + +// TestRequestTimeout_DefaultAndOverride exercises the timeout policy +// without doing IO — pure resolution of the (zero / negative / +// positive) input contract. +func TestRequestTimeout_DefaultAndOverride(t *testing.T) { + t.Parallel() + cases := []struct { + name string + in time.Duration + want time.Duration + }{ + {"zero -> default", 0, DefaultRequestTimeout}, + {"negative -> disabled", -1, 0}, + {"positive -> verbatim", 5 * time.Second, 5 * time.Second}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + c := &Client{RequestTimeout: tc.in} + if got := c.requestTimeout(); got != tc.want { + t.Fatalf("requestTimeout() = %v, want %v", got, tc.want) + } + }) + } +} diff --git a/auth/sts/sts.go b/auth/sts/sts.go index 72a53cb578..d8233cd285 100644 --- a/auth/sts/sts.go +++ b/auth/sts/sts.go @@ -68,6 +68,13 @@ func (r ExchangeRequest) validate() error { return nil } +// DefaultRequestTimeout caps a single token-exchange round-trip. Set +// conservatively: even with a slow auth host plus TLS handshake, a +// healthy exchange completes in sub-seconds. The cap mainly defends +// against slow-loris responses dripping bytes within MaxResponseBytes +// — see Client.RequestTimeout for the per-Client override. +const DefaultRequestTimeout = 30 * time.Second + // Client exchanges subject tokens for tokens of a different type at an // RFC 8693 token endpoint. // @@ -78,6 +85,12 @@ type Client struct { BaseURL string Path string UserAgent string + + // RequestTimeout is the per-Exchange deadline applied via + // context.WithTimeout on top of the caller's context. Zero falls + // back to DefaultRequestTimeout. Negative disables the cap (useful + // for tests that want to drive timing via the caller's ctx alone). + RequestTimeout time.Duration } // Exchange performs one RFC 8693 token exchange. @@ -98,6 +111,12 @@ func (c *Client) Exchange(ctx context.Context, req ExchangeRequest) (*tokens.Tok return nil, fmt.Errorf("token exchange: resolve URL: %w", err) } + if timeout := c.requestTimeout(); timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode())) if err != nil { return nil, fmt.Errorf("token exchange: create request: %w", err) @@ -176,6 +195,20 @@ func buildForm(req ExchangeRequest) url.Values { return form } +// requestTimeout resolves the effective per-request timeout: the +// configured RequestTimeout if positive, the package default if zero, +// or zero (no cap) if negative. +func (c *Client) requestTimeout() time.Duration { + switch { + case c.RequestTimeout < 0: + return 0 + case c.RequestTimeout == 0: + return DefaultRequestTimeout + default: + return c.RequestTimeout + } +} + func resolveURL(baseURL, path string) (string, error) { base, err := url.Parse(baseURL) if err != nil { diff --git a/auth/sts/sts_test.go b/auth/sts/sts_test.go index 5f8282aa20..5e4f348d5c 100644 --- a/auth/sts/sts_test.go +++ b/auth/sts/sts_test.go @@ -297,3 +297,62 @@ func TestExchange_NoExpiry(t *testing.T) { t.Fatalf("ExpiresAt = %v, want zero", got.ExpiresAt) } } + +// TestExchange_RequestTimeoutFires pins the slow-loris defence: a +// handler that never writes a response body must surface as a context +// deadline error rather than blocking the caller indefinitely. +// +// Cleanup order matters: t.Cleanup is LIFO, and httptest.Server.Close +// waits for in-flight handler goroutines to return. We register +// `close(hung)` AFTER newTestClient so it fires first and lets the +// handler exit before srv.Close runs. +func TestExchange_RequestTimeoutFires(t *testing.T) { + t.Parallel() + hung := make(chan struct{}) + + c := newTestClient(t, func(_ http.ResponseWriter, r *http.Request) { + select { + case <-hung: + case <-r.Context().Done(): + } + }) + t.Cleanup(func() { close(hung) }) + c.RequestTimeout = 50 * time.Millisecond + + _, err := c.Exchange(context.Background(), ExchangeRequest{ + SubjectToken: "sub", + SubjectTokenType: SubjectTokenTypeJWT, + RequestedTokenType: "urn:example:t", + }) + if err == nil { + t.Fatal("expected timeout error, got nil") + } + if !strings.Contains(err.Error(), "context deadline exceeded") { + t.Fatalf("err = %v, want context deadline exceeded", err) + } +} + +// TestRequestTimeout_DefaultAndOverride exercises the timeout policy +// without doing IO — pure resolution of the (zero / negative / +// positive) input contract. +func TestRequestTimeout_DefaultAndOverride(t *testing.T) { + t.Parallel() + cases := []struct { + name string + in time.Duration + want time.Duration + }{ + {"zero -> default", 0, DefaultRequestTimeout}, + {"negative -> disabled", -1, 0}, + {"positive -> verbatim", 5 * time.Second, 5 * time.Second}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + c := &Client{RequestTimeout: tc.in} + if got := c.requestTimeout(); got != tc.want { + t.Fatalf("requestTimeout() = %v, want %v", got, tc.want) + } + }) + } +} diff --git a/auth/tokenmanager/tokenmanager.go b/auth/tokenmanager/tokenmanager.go index 82aa426068..2d1af6c6af 100644 --- a/auth/tokenmanager/tokenmanager.go +++ b/auth/tokenmanager/tokenmanager.go @@ -18,7 +18,6 @@ import ( "fmt" "net/http" "net/url" - "slices" "strings" "sync" "time" @@ -137,8 +136,20 @@ func (m *Manager) Issuer() string { return m.cfg.Issuer } // SaveCoreToken persists the device-flow access token under the // configured Issuer. +// +// On successful save the in-memory exchange cache is cleared so a +// re-login under a different identity can't return the previous user's +// exchanged tokens. The cacheKey already binds entries to CoreToken so +// this is defence-in-depth against a future refactor that drops the +// core token from the cache key — see TestSaveCoreToken_ClearsExchangeCache. func (m *Manager) SaveCoreToken(accessToken string) error { - return m.cfg.Store.SaveTokens(m.cfg.Issuer, tokens.TokenSet{AccessToken: accessToken}) //nolint:wrapcheck // backend error already names the operation + if err := m.cfg.Store.SaveTokens(m.cfg.Issuer, tokens.TokenSet{AccessToken: accessToken}); err != nil { + return fmt.Errorf("save core token: %w", err) + } + m.mu.Lock() + m.cache = map[cacheKey]cachedToken{} + m.mu.Unlock() + return nil } // LookupCoreToken returns the stored core token, or "" if none is @@ -225,16 +236,26 @@ func (m *Manager) Token(ctx context.Context, req TokenRequest) (string, error) { if core == "" { return "", ErrNotLoggedIn } + // Preflight expiry: a long-stored core token would otherwise hit the + // resource (or STS) and surface as a confusing "invalid_grant" / + // "401". Parse-failure is intentionally not treated as expired — + // opaque (non-JWT) access tokens have no client-visible expiry, so + // we let them flow and trust the server to reject if necessary. + if coreTokenExpired(core, m.cfg.Now()) { + return "", ErrNotLoggedIn + } + + normResource := normalizeOriginURL(req.Resource) - if req.Audience == "" && m.cfg.Issuer == req.Resource { + if req.Audience == "" && normalizeOriginURL(m.cfg.Issuer) == normResource { return core, nil } - if req.Audience == "" && coreTokenAudienceIncludes(core, req.Resource) { + if req.Audience == "" && coreTokenAudienceIncludes(core, normResource) { return core, nil } resolved := m.resolve(req) - key := makeCacheKey(core, resolved) + key := makeCacheKey(core, resolved, normResource) if hit, ok := m.cacheLookup(key); ok { return hit, nil } @@ -258,12 +279,73 @@ func (m *Manager) resolve(req TokenRequest) TokenRequest { return req } +// coreTokenExpired reports whether the core token has an `exp` claim +// in the past at now. JWT parse failures (and tokens without an `exp` +// claim) are reported as not-expired so opaque access tokens flow +// through the rest of the resolution rules unchanged. +func coreTokenExpired(coreJWT string, now time.Time) bool { + claims, err := tokens.ParseClaims(coreJWT) + if err != nil { + return false + } + if claims.ExpiresAt.IsZero() { + return false + } + return !now.Before(claims.ExpiresAt) +} + +// coreTokenAudienceIncludes reports whether the core JWT's `aud` claim +// covers target. target is expected to already be in normalised form +// (see normalizeOriginURL); aud entries are normalised here so a +// trailing-slash / case difference between the AS and the caller +// doesn't force a needless STS exchange. func coreTokenAudienceIncludes(coreJWT, target string) bool { claims, err := tokens.ParseClaims(coreJWT) if err != nil { return false } - return slices.Contains(claims.Audience, target) + for _, aud := range claims.Audience { + if normalizeOriginURL(aud) == target { + return true + } + } + return false +} + +// normalizeOriginURL canonicalises an origin URL for equality +// comparisons. RFC 3986 §6.2.2.1 makes scheme and host case-insensitive +// and §6.2.3 makes the empty path equivalent to "/" — we collapse to +// no-trailing-slash. Default ports (80/http, 443/https) are stripped. +// +// On parse failure (or when the input lacks a scheme or host — common +// for non-URL audiences) the input is returned unchanged so callers +// fall back to byte-exact comparison. +func normalizeOriginURL(raw string) string { + u, err := url.Parse(raw) + if err != nil || u.Scheme == "" || u.Host == "" { + return raw + } + u.Scheme = strings.ToLower(u.Scheme) + + hostname := strings.ToLower(u.Hostname()) + port := u.Port() + dropPort := (u.Scheme == "http" && port == "80") || + (u.Scheme == "https" && port == "443") || + port == "" + + switch { + case dropPort && strings.Contains(hostname, ":"): // IPv6 without port + u.Host = "[" + hostname + "]" + case dropPort: + u.Host = hostname + case strings.Contains(hostname, ":"): // IPv6 with non-default port + u.Host = "[" + hostname + "]:" + port + default: + u.Host = hostname + ":" + port + } + + u.Path = strings.TrimRight(u.Path, "/") + return u.String() } // cachedToken is one entry in the per-process exchange cache. @@ -297,11 +379,13 @@ type cacheKey struct { // makeCacheKey builds a cacheKey from the (resolved) request. Includes // every wire-affecting field so different combinations don't shadow -// each other. -func makeCacheKey(coreToken string, req TokenRequest) cacheKey { +// each other. normalizedResource is the caller-supplied Resource after +// passing through normalizeOriginURL, so https://api.example.com and +// https://api.example.com/ share a single cache entry. +func makeCacheKey(coreToken string, req TokenRequest, normalizedResource string) cacheKey { return cacheKey{ CoreToken: coreToken, - Resource: req.Resource, + Resource: normalizedResource, Audience: req.Audience, RequestedTokenType: req.RequestedTokenType, Scope: req.Scope, diff --git a/auth/tokenmanager/tokenmanager_test.go b/auth/tokenmanager/tokenmanager_test.go index 916280c7d1..4529509312 100644 --- a/auth/tokenmanager/tokenmanager_test.go +++ b/auth/tokenmanager/tokenmanager_test.go @@ -41,10 +41,11 @@ func (s *memStore) DeleteTokens(profile string) error { } const ( - testIssuer = "https://auth.example.com" - testResource = "https://api.example.com" - testClientID = "test-cli" - testSTSPath = "/sts/token" + testIssuer = "https://auth.example.com" + testResource = "https://api.example.com" + testClientID = "test-cli" + testSTSPath = "/sts/token" + testExchangedTok = "exchanged" ) func makeJWTWithAudience(t *testing.T, aud []string) string { @@ -59,6 +60,25 @@ func makeJWTWithAudience(t *testing.T, aud []string) string { return header + "." + body + "." + sig } +// makeJWTWithExp builds an unsigned JWT carrying `exp` (and optionally +// `aud`). The signature segment is junk — tokenmanager never verifies +// it, ParseClaims is documented as unverified. +func makeJWTWithExp(t *testing.T, exp time.Time, aud []string) string { + t.Helper() + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`)) + claims := map[string]any{"sub": "test", "exp": exp.Unix()} + if len(aud) > 0 { + claims["aud"] = aud + } + payload, err := json.Marshal(claims) + if err != nil { + t.Fatalf("marshal payload: %v", err) + } + body := base64.RawURLEncoding.EncodeToString(payload) + sig := base64.RawURLEncoding.EncodeToString([]byte("not-real")) + return header + "." + body + "." + sig +} + func newTestManager(t *testing.T, store tokenstore.Store, exchange func(context.Context, sts.ExchangeRequest) (*tokens.TokenSet, error)) *Manager { t.Helper() m, err := New(Config{ @@ -216,7 +236,7 @@ func TestToken_ExplicitAudienceBypassesAudienceShortcut(t *testing.T) { m := newTestManager(t, store, func(_ context.Context, req sts.ExchangeRequest) (*tokens.TokenSet, error) { calls++ got = req - return &tokens.TokenSet{AccessToken: "exchanged"}, nil + return &tokens.TokenSet{AccessToken: testExchangedTok}, nil }) token, err := m.Token(context.Background(), TokenRequest{Resource: testResource, Audience: requestedAudience}) @@ -224,7 +244,7 @@ func TestToken_ExplicitAudienceBypassesAudienceShortcut(t *testing.T) { t.Fatalf("Token: %v", err) } - if token != "exchanged" || calls != 1 { + if token != testExchangedTok || calls != 1 { t.Fatalf("Token returned %q with %d exchange calls, want exchanged token from one exchange", token, calls) } if got.Audience != requestedAudience { @@ -288,7 +308,7 @@ func TestToken_ExchangeIncludesResource(t *testing.T) { var got sts.ExchangeRequest m := newTestManager(t, store, func(_ context.Context, req sts.ExchangeRequest) (*tokens.TokenSet, error) { got = req - return &tokens.TokenSet{AccessToken: "exchanged"}, nil + return &tokens.TokenSet{AccessToken: testExchangedTok}, nil }) if _, err := m.TokenForResource(context.Background(), testResource); err != nil { @@ -383,7 +403,7 @@ func TestToken_CacheExpires(t *testing.T) { Now: func() time.Time { return now }, Exchange: func(_ context.Context, _ sts.ExchangeRequest) (*tokens.TokenSet, error) { calls++ - return &tokens.TokenSet{AccessToken: "exchanged", ExpiresAt: now.Add(time.Minute)}, nil + return &tokens.TokenSet{AccessToken: testExchangedTok, ExpiresAt: now.Add(time.Minute)}, nil }, }) if err != nil { @@ -630,7 +650,7 @@ func TestCoreTokenAudienceShortcut_FallsThroughOnMalformedJWT(t *testing.T) { var exchangeCalls int m := newTestManager(t, store, func(_ context.Context, _ sts.ExchangeRequest) (*tokens.TokenSet, error) { exchangeCalls++ - return &tokens.TokenSet{AccessToken: "exchanged"}, nil + return &tokens.TokenSet{AccessToken: testExchangedTok}, nil }) got, err := m.TokenForResource(context.Background(), testResource) @@ -666,3 +686,193 @@ func TestToken_StoreErrorSurfacesNotAsErrNotLoggedIn(t *testing.T) { t.Fatalf("err = %v, want underlying store error", err) } } + +// TestToken_ExpiredCoreReturnsNotLoggedIn pins the preflight behaviour: +// a core token whose JWT `exp` is in the past surfaces ErrNotLoggedIn +// before the request reaches STS or the resource. Without preflight, +// users see a confusing "invalid_grant" / "401" instead of "run login". +func TestToken_ExpiredCoreReturnsNotLoggedIn(t *testing.T) { + t.Parallel() + now := time.Date(2026, 1, 1, 12, 0, 0, 0, time.UTC) + expired := makeJWTWithExp(t, now.Add(-time.Hour), nil) + + store := newMemStore() + store.data[testIssuer] = tokens.TokenSet{AccessToken: expired} + + m, err := New(Config{ + Issuer: testIssuer, ClientID: testClientID, STSPath: testSTSPath, Store: store, + Now: func() time.Time { return now }, + Exchange: func(_ context.Context, _ sts.ExchangeRequest) (*tokens.TokenSet, error) { + t.Fatal("exchange must not run for an expired core token") + return nil, errors.New("unreachable") + }, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + _, err = m.TokenForResource(context.Background(), testResource) + if !errors.Is(err, ErrNotLoggedIn) { + t.Fatalf("err = %v, want ErrNotLoggedIn", err) + } +} + +// TestToken_OpaqueCorePassesPreflight guards the parse-failure branch: +// non-JWT (opaque) access tokens have no client-visible expiry, so +// they must NOT be classified as expired by the preflight check. +func TestToken_OpaqueCorePassesPreflight(t *testing.T) { + t.Parallel() + store := newMemStore() + store.data[testIssuer] = tokens.TokenSet{AccessToken: "opaque-not-a-jwt"} + + m, err := New(Config{ + Issuer: testIssuer, ClientID: testClientID, Store: store, + Exchange: func(_ context.Context, _ sts.ExchangeRequest) (*tokens.TokenSet, error) { + t.Fatal("same-host shortcut should win for opaque core token == issuer") + return nil, errors.New("unreachable") + }, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + got, err := m.TokenForResource(context.Background(), testIssuer) + if err != nil { + t.Fatalf("TokenForResource: %v", err) + } + if got != "opaque-not-a-jwt" { + t.Fatalf("got %q, want opaque core verbatim", got) + } +} + +// TestSaveCoreToken_ClearsExchangeCache pins the cache-invalidation +// contract on save: a re-login under a different identity must not +// return the previous user's exchanged tokens, even if a future +// refactor accidentally drops CoreToken from the cache key. +func TestSaveCoreToken_ClearsExchangeCache(t *testing.T) { + t.Parallel() + store := newMemStore() + store.data[testIssuer] = tokens.TokenSet{AccessToken: "user-a-core"} + + calls := 0 + m := newTestManager(t, store, func(_ context.Context, _ sts.ExchangeRequest) (*tokens.TokenSet, error) { + calls++ + return &tokens.TokenSet{AccessToken: "user-a-exchanged"}, nil + }) + + if _, err := m.TokenForResource(context.Background(), testResource); err != nil { + t.Fatalf("first call: %v", err) + } + if calls != 1 { + t.Fatalf("calls after first Token = %d, want 1", calls) + } + + if err := m.SaveCoreToken("user-b-core"); err != nil { + t.Fatalf("SaveCoreToken: %v", err) + } + + if _, err := m.TokenForResource(context.Background(), testResource); err != nil { + t.Fatalf("post-save call: %v", err) + } + if calls != 2 { + t.Fatalf("exchange calls after save = %d, want 2 (cache must be cleared on save)", calls) + } +} + +// TestNormalizeOriginURL covers the cases where same-host / aud-shortcut +// equality has historically misfired: trailing slash, scheme/host case, +// default-port presence. Inputs that don't parse as origin URLs must +// pass through unchanged so non-URL audiences keep byte-exact compare. +func TestNormalizeOriginURL(t *testing.T) { + t.Parallel() + cases := []struct { + name string + in string + want string + }{ + {"empty", "", ""}, + {"plain", "https://api.example.com", "https://api.example.com"}, + {"trailing slash", "https://api.example.com/", "https://api.example.com"}, + {"upper scheme", "HTTPS://api.example.com", "https://api.example.com"}, + {"upper host", "https://API.Example.COM", "https://api.example.com"}, + {"default https port", "https://api.example.com:443", "https://api.example.com"}, + {"default http port", "http://api.example.com:80/", "http://api.example.com"}, + {"non-default port preserved", "https://api.example.com:8443", "https://api.example.com:8443"}, + {"path preserved (sans trailing slash)", "https://api.example.com/v2/", "https://api.example.com/v2"}, + {"non-URL audience passes through", "urn:example:cli", "urn:example:cli"}, + {"bare string passes through", "some-audience", "some-audience"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if got := normalizeOriginURL(tc.in); got != tc.want { + t.Errorf("normalizeOriginURL(%q) = %q, want %q", tc.in, got, tc.want) + } + }) + } +} + +// TestToken_SameHostShortcut_NormalisesURLs guards against a regression +// where a trailing-slash or case difference between Issuer and Resource +// forces a needless STS exchange (or fails outright on single-host +// deployments where STSPath is empty). +func TestToken_SameHostShortcut_NormalisesURLs(t *testing.T) { + t.Parallel() + store := newMemStore() + store.data[testIssuer] = tokens.TokenSet{AccessToken: "core-tok"} + + m, err := New(Config{ + Issuer: testIssuer, ClientID: testClientID, // STSPath intentionally empty + Store: store, + Exchange: func(_ context.Context, _ sts.ExchangeRequest) (*tokens.TokenSet, error) { + t.Fatal("exchange must not run when issuer == resource modulo trailing slash / case") + return nil, errors.New("unreachable") + }, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + for _, resource := range []string{ + testIssuer + "/", // trailing slash + strings.ToUpper(testIssuer[:8]) + testIssuer[8:], // uppercase scheme + } { + got, err := m.TokenForResource(context.Background(), resource) + if err != nil { + t.Fatalf("TokenForResource(%q): %v", resource, err) + } + if got != "core-tok" { + t.Fatalf("TokenForResource(%q) = %q, want core token verbatim", resource, got) + } + } +} + +// TestToken_CacheCollapsesURLEquivalents pins the cache key being +// computed off the normalised resource: two equivalent forms must +// share a single entry rather than each driving its own STS round-trip. +func TestToken_CacheCollapsesURLEquivalents(t *testing.T) { + t.Parallel() + store := newMemStore() + store.data[testIssuer] = tokens.TokenSet{AccessToken: "core-tok"} + + var calls int + m := newTestManager(t, store, func(_ context.Context, _ sts.ExchangeRequest) (*tokens.TokenSet, error) { + calls++ + return &tokens.TokenSet{AccessToken: testExchangedTok}, nil + }) + + first, err := m.TokenForResource(context.Background(), testResource) + if err != nil { + t.Fatalf("first call: %v", err) + } + second, err := m.TokenForResource(context.Background(), testResource+"/") + if err != nil { + t.Fatalf("second call: %v", err) + } + if first != testExchangedTok || second != testExchangedTok { + t.Fatalf("tokens = (%q, %q), want both exchanged", first, second) + } + if calls != 1 { + t.Fatalf("exchange calls = %d, want 1 (trailing-slash variant must hit cache)", calls) + } +} diff --git a/auth/tokenstore/keyring.go b/auth/tokenstore/keyring.go index f80a2e9ddf..8ec3ddfceb 100644 --- a/auth/tokenstore/keyring.go +++ b/auth/tokenstore/keyring.go @@ -114,6 +114,17 @@ func decodeTokenSet(raw string) (tokens.TokenSet, error) { return tokens.TokenSet{}, fmt.Errorf("%w: unmarshal TokenSet: %w", ErrMalformed, err) } + // json.Unmarshal happily decodes any well-formed JSON object to a + // zero keyringTokenSet — for example {} or an unrelated CLI's blob + // that happens to be keyed against the same (service, profile). We + // surface that as ErrMalformed so the embedding shim can route to + // its legacy/upgrade path rather than returning a TokenSet with an + // empty AccessToken (which the caller can't distinguish from a + // successful load of a freshly-cleared entry). + if strings.TrimSpace(wire.AccessToken) == "" { + return tokens.TokenSet{}, fmt.Errorf("%w: stored entry has no access_token", ErrMalformed) + } + t := tokens.TokenSet{ AccessToken: wire.AccessToken, RefreshToken: wire.RefreshToken, diff --git a/auth/tokenstore/keyring_test.go b/auth/tokenstore/keyring_test.go index c30ea5469f..f21402e48b 100644 --- a/auth/tokenstore/keyring_test.go +++ b/auth/tokenstore/keyring_test.go @@ -195,3 +195,34 @@ func TestKeyring_LoadTokens_BadExpiresAtReturnsErrMalformed(t *testing.T) { t.Fatalf("err = %v, want ErrMalformed", err) } } + +// TestKeyring_LoadTokens_EmptyAccessTokenReturnsErrMalformed pins the +// guard against well-formed JSON that decodes to a zero TokenSet. An +// unrelated CLI's blob keyed against the same service/profile, or a +// "{}" entry from a buggy save, would otherwise produce a TokenSet +// with empty AccessToken indistinguishable from a successful load — +// and the shim would then ship "Authorization: Bearer " on the wire. +func TestKeyring_LoadTokens_EmptyAccessTokenReturnsErrMalformed(t *testing.T) { + cases := []struct { + name string + body string + }{ + {"empty object", `{}`}, + {"unrelated fields only", `{"foo":"bar","count":3}`}, + {"explicit empty access_token", `{"access_token":""}`}, + {"whitespace access_token", `{"access_token":" "}`}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + service := "test-empty-" + tc.name + const profile = "https://example.com" + if err := keyring.Set(service, profile, tc.body); err != nil { + t.Fatalf("seed keyring: %v", err) + } + _, err := NewKeyring(service).LoadTokens(profile) + if !errors.Is(err, ErrMalformed) { + t.Fatalf("err = %v, want ErrMalformed", err) + } + }) + } +} diff --git a/cmd/entire/cli/api/auth_tokens.go b/cmd/entire/cli/api/auth_tokens.go index e54c69bc13..557c9e6ab3 100644 --- a/cmd/entire/cli/api/auth_tokens.go +++ b/cmd/entire/cli/api/auth_tokens.go @@ -2,10 +2,9 @@ package api import ( "context" + "errors" "fmt" "net/url" - "os" - "strings" ) // Token is a single API token row returned by the auth-tokens endpoint. @@ -25,25 +24,27 @@ type TokensResponse struct { Tokens []Token `json:"tokens"` } -// authTokensProviderVersionEnvVar must match the env var read by -// cmd/entire/cli/auth's currentProvider(). Duplicated here rather than -// imported because api/ is a leaf package and shouldn't take a -// dependency on auth/ for routing. -const authTokensProviderVersionEnvVar = "ENTIRE_AUTH_PROVIDER_VERSION" //nolint:gosec // env var name, not a credential +// errAuthTokensPathUnset surfaces when an auth-tokens method is called +// on a Client that wasn't given a base path. Construct via +// NewClientWithBaseURL(...).WithAuthTokensPath(...) — the active path +// lives in cmd/entire/cli/auth.CurrentProvider().AuthTokensPath, the +// single source of truth for provider-version routing. +var errAuthTokensPathUnset = errors.New("api: auth-tokens path is unset (call (*Client).WithAuthTokensPath before list/revoke)") -// authTokensBasePath returns the auth-tokens endpoint family base path -// for the active provider version. v1 (default) hits /api/v1/auth/tokens; -// v2 hits /api/auth/tokens (no version segment). -func authTokensBasePath() string { - if strings.TrimSpace(os.Getenv(authTokensProviderVersionEnvVar)) == "v2" { - return "/api/auth/tokens" +func (c *Client) authTokensBasePath() (string, error) { + if c.authTokensPath == "" { + return "", errAuthTokensPathUnset } - return "/api/v1/auth/tokens" + return c.authTokensPath, nil } // ListTokens returns the authenticated user's non-expired API tokens. func (c *Client) ListTokens(ctx context.Context) ([]Token, error) { - resp, err := c.Get(ctx, authTokensBasePath()) + base, err := c.authTokensBasePath() + if err != nil { + return nil, fmt.Errorf("list tokens: %w", err) + } + resp, err := c.Get(ctx, base) if err != nil { return nil, fmt.Errorf("list tokens: %w", err) } @@ -62,7 +63,11 @@ func (c *Client) ListTokens(ctx context.Context) ([]Token, error) { // RevokeCurrentToken revokes the bearer token used to authenticate this client. func (c *Client) RevokeCurrentToken(ctx context.Context) error { - resp, err := c.Delete(ctx, authTokensBasePath()+"/current") + base, err := c.authTokensBasePath() + if err != nil { + return fmt.Errorf("revoke current token: %w", err) + } + resp, err := c.Delete(ctx, base+"/current") if err != nil { return fmt.Errorf("revoke current token: %w", err) } @@ -76,7 +81,11 @@ func (c *Client) RevokeCurrentToken(ctx context.Context) error { // RevokeToken revokes the API token with the given id. func (c *Client) RevokeToken(ctx context.Context, id string) error { - resp, err := c.Delete(ctx, authTokensBasePath()+"/"+url.PathEscape(id)) + base, err := c.authTokensBasePath() + if err != nil { + return fmt.Errorf("revoke token %s: %w", id, err) + } + resp, err := c.Delete(ctx, base+"/"+url.PathEscape(id)) if err != nil { return fmt.Errorf("revoke token %s: %w", id, err) } diff --git a/cmd/entire/cli/api/auth_tokens_test.go b/cmd/entire/cli/api/auth_tokens_test.go index 12a9c4628f..07ccd694d2 100644 --- a/cmd/entire/cli/api/auth_tokens_test.go +++ b/cmd/entire/cli/api/auth_tokens_test.go @@ -9,8 +9,23 @@ import ( "testing" ) +const ( + testV1AuthTokensPath = "/api/v1/auth/tokens" + testV2AuthTokensPath = "/api/auth/tokens" +) + +// newAuthTokensTestClient builds a Client pointed at server.URL with +// the given auth-tokens base path. Used by all auth-tokens tests so +// the wiring matches production: callers chain WithAuthTokensPath at +// construction time. +func newAuthTokensTestClient(serverURL, authTokensPath string) *Client { + c := NewClient("tok").WithAuthTokensPath(authTokensPath) + c.baseURL = serverURL + return c +} + func TestClient_RevokeCurrentToken_SendsDeleteWithBearer(t *testing.T) { - t.Setenv(authTokensProviderVersionEnvVar, "") + t.Parallel() var gotMethod, gotPath, gotAuth string @@ -23,8 +38,7 @@ func TestClient_RevokeCurrentToken_SendsDeleteWithBearer(t *testing.T) { })) defer server.Close() - c := NewClient("tok") - c.baseURL = server.URL + c := newAuthTokensTestClient(server.URL, testV1AuthTokensPath) if err := c.RevokeCurrentToken(context.Background()); err != nil { t.Fatalf("RevokeCurrentToken() error = %v", err) @@ -51,8 +65,7 @@ func TestClient_RevokeCurrentToken_ReturnsHTTPErrorOn401(t *testing.T) { })) defer server.Close() - c := NewClient("tok") - c.baseURL = server.URL + c := newAuthTokensTestClient(server.URL, testV1AuthTokensPath) err := c.RevokeCurrentToken(context.Background()) if err == nil { @@ -71,7 +84,7 @@ func TestClient_RevokeCurrentToken_ReturnsHTTPErrorOn401(t *testing.T) { } func TestClient_ListTokens_DecodesResponse(t *testing.T) { - t.Setenv(authTokensProviderVersionEnvVar, "") + t.Parallel() var gotMethod, gotPath, gotAuth string @@ -87,8 +100,7 @@ func TestClient_ListTokens_DecodesResponse(t *testing.T) { })) defer server.Close() - c := NewClient("tok") - c.baseURL = server.URL + c := newAuthTokensTestClient(server.URL, testV1AuthTokensPath) tokens, err := c.ListTokens(context.Background()) if err != nil { @@ -129,8 +141,7 @@ func TestClient_ListTokens_ReturnsHTTPErrorOn401(t *testing.T) { })) defer server.Close() - c := NewClient("tok") - c.baseURL = server.URL + c := newAuthTokensTestClient(server.URL, testV1AuthTokensPath) _, err := c.ListTokens(context.Background()) if err == nil { @@ -142,7 +153,7 @@ func TestClient_ListTokens_ReturnsHTTPErrorOn401(t *testing.T) { } func TestClient_RevokeToken_SendsDeleteWithEscapedID(t *testing.T) { - t.Setenv(authTokensProviderVersionEnvVar, "") + t.Parallel() var gotMethod, gotEscapedPath, gotDecodedPath string @@ -155,8 +166,7 @@ func TestClient_RevokeToken_SendsDeleteWithEscapedID(t *testing.T) { })) defer server.Close() - c := NewClient("tok") - c.baseURL = server.URL + c := newAuthTokensTestClient(server.URL, testV1AuthTokensPath) // Use an id that needs URL escaping to verify we don't blindly concat. if err := c.RevokeToken(context.Background(), "abc/def 1"); err != nil { @@ -184,8 +194,7 @@ func TestClient_RevokeToken_ReturnsErrorBody(t *testing.T) { })) defer server.Close() - c := NewClient("tok") - c.baseURL = server.URL + c := newAuthTokensTestClient(server.URL, testV1AuthTokensPath) err := c.RevokeToken(context.Background(), "missing") if err == nil { @@ -199,39 +208,12 @@ func TestClient_RevokeToken_ReturnsErrorBody(t *testing.T) { } } -// TestAuthTokensBasePath_ProviderVersionRouting locks in the path -// switch so v2 doesn't silently regress to v1's path family. The whole -// reason the version env var exists is to route requests at this layer. -func TestAuthTokensBasePath_ProviderVersionRouting(t *testing.T) { - cases := []struct { - name string - version string - want string - }{ - {"unset defaults to v1", "", "/api/v1/auth/tokens"}, - {"v1 explicit", "v1", "/api/v1/auth/tokens"}, - {"v2", "v2", "/api/auth/tokens"}, - {"unrecognised defaults to v1", "v999", "/api/v1/auth/tokens"}, - // Whitespace trimming must match auth.currentProvider() — both - // trim, so the api and auth packages agree on what "v2" means. - // If either side stops trimming, these tests diverge first. - {"trims whitespace then matches v2", " v2 ", "/api/auth/tokens"}, - } - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - t.Setenv(authTokensProviderVersionEnvVar, tc.version) - if got := authTokensBasePath(); got != tc.want { - t.Fatalf("authTokensBasePath() = %q, want %q", got, tc.want) - } - }) - } -} - -// TestClient_ListTokens_RoutesV2Path is an end-to-end check that the -// version switch flows through the public Client API, not just the -// internal helper. -func TestClient_ListTokens_RoutesV2Path(t *testing.T) { - t.Setenv(authTokensProviderVersionEnvVar, "v2") +// TestClient_AuthTokens_RoutesV2Path verifies that whatever path the +// caller supplies via WithAuthTokensPath is what hits the wire. The +// provider table itself (which path corresponds to which version) is +// exercised by cmd/entire/cli/auth's resolveProvider tests. +func TestClient_AuthTokens_RoutesV2Path(t *testing.T) { + t.Parallel() var gotPath string server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -241,8 +223,7 @@ func TestClient_ListTokens_RoutesV2Path(t *testing.T) { })) defer server.Close() - c := NewClient("tok") - c.baseURL = server.URL + c := newAuthTokensTestClient(server.URL, testV2AuthTokensPath) if _, err := c.ListTokens(context.Background()); err != nil { t.Fatalf("ListTokens: %v", err) @@ -251,3 +232,22 @@ func TestClient_ListTokens_RoutesV2Path(t *testing.T) { t.Fatalf("path = %q, want /api/auth/tokens (v2)", gotPath) } } + +// TestClient_AuthTokens_UnsetPathErrors guards against silently +// shipping a request to "" — we want a clear error pointing at the +// missing WithAuthTokensPath wiring. +func TestClient_AuthTokens_UnsetPathErrors(t *testing.T) { + t.Parallel() + + c := NewClient("tok") // no WithAuthTokensPath + + if _, err := c.ListTokens(context.Background()); !errors.Is(err, errAuthTokensPathUnset) { + t.Errorf("ListTokens err = %v, want errAuthTokensPathUnset", err) + } + if err := c.RevokeCurrentToken(context.Background()); !errors.Is(err, errAuthTokensPathUnset) { + t.Errorf("RevokeCurrentToken err = %v, want errAuthTokensPathUnset", err) + } + if err := c.RevokeToken(context.Background(), "any"); !errors.Is(err, errAuthTokensPathUnset) { + t.Errorf("RevokeToken err = %v, want errAuthTokensPathUnset", err) + } +} diff --git a/cmd/entire/cli/api/client.go b/cmd/entire/cli/api/client.go index b568214a77..98a7d76ab1 100644 --- a/cmd/entire/cli/api/client.go +++ b/cmd/entire/cli/api/client.go @@ -21,6 +21,26 @@ const ( type Client struct { httpClient *http.Client baseURL string + + // authTokensPath is the base path for the auth-tokens management + // endpoints (list / revoke). Set via WithAuthTokensPath when the + // client targets the auth host. Empty for data-API-only clients; + // auth-tokens methods error out if called against an empty path. + authTokensPath string +} + +// WithAuthTokensPath sets the base path used by ListTokens, +// RevokeCurrentToken, and RevokeToken. The path is supplied by the +// auth shim from auth.CurrentProvider().AuthTokensPath, which is the +// single source of truth for provider-version routing — the api +// package no longer reads ENTIRE_AUTH_PROVIDER_VERSION itself. +// +// Returns the receiver for chaining at construction: +// +// c := api.NewClientWithBaseURL(token, base).WithAuthTokensPath(p) +func (c *Client) WithAuthTokensPath(path string) *Client { + c.authTokensPath = path + return c } // NewClient creates a new authenticated API client with an explicit bearer diff --git a/cmd/entire/cli/auth.go b/cmd/entire/cli/auth.go index e5ba946e2c..d5a08469eb 100644 --- a/cmd/entire/cli/auth.go +++ b/cmd/entire/cli/auth.go @@ -100,7 +100,9 @@ func newAuthStatusCmd() *cobra.Command { } func defaultListTokens(ctx context.Context, token string) ([]api.Token, error) { - return api.NewClientWithBaseURL(token, api.AuthBaseURL()).ListTokens(ctx) //nolint:wrapcheck // ListTokens already wraps with action context + client := api.NewClientWithBaseURL(token, api.AuthBaseURL()). + WithAuthTokensPath(auth.CurrentProvider().AuthTokensPath) + return client.ListTokens(ctx) //nolint:wrapcheck // ListTokens already wraps with action context } func runAuthStatus(ctx context.Context, w io.Writer, store tokenStore, list authTokenLister, baseURL string) error { @@ -426,7 +428,9 @@ func newAuthRevokeCmd() *cobra.Command { } func defaultRevokeTokenByID(ctx context.Context, callerToken, id string) error { - return api.NewClientWithBaseURL(callerToken, api.AuthBaseURL()).RevokeToken(ctx, id) //nolint:wrapcheck // RevokeToken already wraps with action context + client := api.NewClientWithBaseURL(callerToken, api.AuthBaseURL()). + WithAuthTokensPath(auth.CurrentProvider().AuthTokensPath) + return client.RevokeToken(ctx, id) //nolint:wrapcheck // RevokeToken already wraps with action context } func runAuthRevoke( diff --git a/cmd/entire/cli/auth/client.go b/cmd/entire/cli/auth/client.go index f6ee0da399..80e5579fd8 100644 --- a/cmd/entire/cli/auth/client.go +++ b/cmd/entire/cli/auth/client.go @@ -45,15 +45,15 @@ type Client struct { // NewClient constructs a Client targeting the active provider version. // httpClient is used directly when non-nil; otherwise http.DefaultClient. func NewClient(httpClient *http.Client) *Client { - p := currentProvider() + p := CurrentProvider() return &Client{inner: &deviceflow.Client{ HTTP: httpClient, BaseURL: api.AuthBaseURL(), - ClientID: p.clientID, + ClientID: p.ClientID, Scope: "cli", - UserAgent: p.clientID, - DeviceCodePath: p.deviceCodePath, - TokenPath: p.tokenPath, + UserAgent: p.ClientID, + DeviceCodePath: p.DeviceCodePath, + TokenPath: p.TokenPath, }} } diff --git a/cmd/entire/cli/auth/exchange.go b/cmd/entire/cli/auth/exchange.go index f888425fef..23962fde6d 100644 --- a/cmd/entire/cli/auth/exchange.go +++ b/cmd/entire/cli/auth/exchange.go @@ -50,13 +50,13 @@ func defaultManager() (*tokenmanager.Manager, error) { return managerForTest, nil } managerOnce.Do(func() { - provider := currentProvider() + provider := CurrentProvider() m, err := tokenmanager.New(tokenmanager.Config{ Issuer: api.AuthBaseURL(), - ClientID: provider.clientID, - STSPath: provider.stsPath, + ClientID: provider.ClientID, + STSPath: provider.STSPath, Store: NewStore(), - UserAgent: provider.clientID, + UserAgent: provider.ClientID, Scope: "cli", }) manager = m diff --git a/cmd/entire/cli/auth/provider.go b/cmd/entire/cli/auth/provider.go index 804fe22977..b30ec8b6b8 100644 --- a/cmd/entire/cli/auth/provider.go +++ b/cmd/entire/cli/auth/provider.go @@ -3,6 +3,7 @@ package auth import ( "os" "strings" + "sync" ) // ProviderVersionEnvVar selects which OAuth surface this CLI talks to. @@ -15,46 +16,112 @@ import ( // This is a transition-period switch: once v2 is the universal default // the env var goes away. Surfaces are otherwise reachable as RFC 8628 // device-flow endpoints; the only differences are paths and client_id. +// +// Read once at process startup via CurrentProvider; later flips within +// the same process are intentionally ignored. Tests inject via +// SetProviderForTest rather than mutating the env mid-run. const ProviderVersionEnvVar = "ENTIRE_AUTH_PROVIDER_VERSION" -// providerConfig captures the per-surface bits of OAuth wiring. +// Provider captures the per-surface bits of OAuth wiring. // -// stsPath is the RFC 8693 token-exchange endpoint. v1 is the legacy +// STSPath is the RFC 8693 token-exchange endpoint. v1 is the legacy // single-host surface where the auth and data API live at the same // origin (entire.io); the same-host shortcut in tokenmanager.Token -// always wins and STS is never invoked, so v1.stsPath is left empty. +// always wins and STS is never invoked, so v1.STSPath is left empty. // v2 exposes a dedicated STS path because it's used in split-host // deployments (e.g. us.auth.partial.to mints, partial.to consumes). -type providerConfig struct { - clientID string - deviceCodePath string - tokenPath string - stsPath string +// +// AuthTokensPath is the base path for the auth-tokens management +// endpoint family (list / revoke). Routed at the api.Client layer via +// (*api.Client).WithAuthTokensPath so the provider table is the single +// source of truth — no env-var duplication between auth/ and api/. +type Provider struct { + ClientID string + DeviceCodePath string + TokenPath string + STSPath string + AuthTokensPath string } -var providers = map[string]providerConfig{ +var providers = map[string]Provider{ "v1": { //nolint:gosec // OAuth client_id and endpoint paths, not credentials - clientID: "entire-cli", - deviceCodePath: "/oauth/device/code", - tokenPath: "/oauth/token", + ClientID: "entire-cli", + DeviceCodePath: "/oauth/device/code", + TokenPath: "/oauth/token", + AuthTokensPath: "/api/v1/auth/tokens", }, "v2": { //nolint:gosec // OAuth client_id and endpoint paths, not credentials - clientID: "entire-cli", - deviceCodePath: "/api/auth/oauth/device/code", - tokenPath: "/api/auth/token", - stsPath: "/api/authz/sts/token", + ClientID: "entire-cli", + DeviceCodePath: "/api/auth/oauth/device/code", + TokenPath: "/api/auth/token", + STSPath: "/api/authz/sts/token", + AuthTokensPath: "/api/auth/tokens", }, } -// currentProvider returns the active providerConfig, defaulting to v1 -// when ENTIRE_AUTH_PROVIDER_VERSION is unset or holds an unrecognised -// value. Defaulting (rather than erroring) keeps old binaries safe if -// a future v3 ever lands. -func currentProvider() providerConfig { - switch strings.TrimSpace(os.Getenv(ProviderVersionEnvVar)) { +// resolveProvider returns the Provider matching version. Defaulting +// (rather than erroring) on unrecognised values keeps old binaries safe +// if a future v3 ever lands. Pure function — no env reads — so unit +// tests can exercise the routing table without env-var gymnastics. +func resolveProvider(version string) Provider { + switch strings.TrimSpace(version) { case "v2": return providers["v2"] default: return providers["v1"] } } + +var ( + providerOnce sync.Once + resolvedProvider Provider + + // providerForTest, when non-nil, short-circuits CurrentProvider so + // tests can install a specific Provider without racing the + // process-wide sync.Once (which freezes the first observation + // forever). Mutated only via SetProviderForTest. Production code + // never reads this var. + providerForTest *Provider + providerTestMu sync.Mutex +) + +// CurrentProvider returns the active Provider for this process. +// +// Resolution: read ENTIRE_AUTH_PROVIDER_VERSION exactly once on first +// call, freeze the result, and return the same Provider on every +// subsequent call. Tests that need a different provider must use +// SetProviderForTest before any auth call constructs the singleton. +func CurrentProvider() Provider { + providerTestMu.Lock() + override := providerForTest + providerTestMu.Unlock() + if override != nil { + return *override + } + providerOnce.Do(func() { + resolvedProvider = resolveProvider(os.Getenv(ProviderVersionEnvVar)) + }) + return resolvedProvider +} + +// SetProviderForTest installs p as the Provider returned by +// CurrentProvider for the duration of the test, and registers a +// t.Cleanup to remove the override. Test-only. +// +// Takes a tiny interface rather than *testing.T so production builds +// don't import testing. +func SetProviderForTest(t interface { + Helper() + Cleanup(f func()) +}, p Provider) { + t.Helper() + providerTestMu.Lock() + prev := providerForTest + providerForTest = &p + providerTestMu.Unlock() + t.Cleanup(func() { + providerTestMu.Lock() + providerForTest = prev + providerTestMu.Unlock() + }) +} diff --git a/cmd/entire/cli/auth/provider_test.go b/cmd/entire/cli/auth/provider_test.go index a67409aab2..8a47f6894e 100644 --- a/cmd/entire/cli/auth/provider_test.go +++ b/cmd/entire/cli/auth/provider_test.go @@ -17,69 +17,94 @@ const ( wantClientIDV2 = "entire-cli" ) -func TestCurrentProvider_DefaultsToV1(t *testing.T) { - t.Setenv(ProviderVersionEnvVar, "") - - p := currentProvider() - if p.clientID != wantClientIDV1 || p.deviceCodePath != "/oauth/device/code" || p.tokenPath != "/oauth/token" { +// resolveProvider is a pure function — no env reads — so the routing +// table can be exercised without t.Setenv (and without the +// process-wide sync.Once in CurrentProvider freezing the first +// observation forever). + +func TestResolveProvider_DefaultsToV1(t *testing.T) { + t.Parallel() + p := resolveProvider("") + if p.ClientID != wantClientIDV1 || p.DeviceCodePath != "/oauth/device/code" || p.TokenPath != "/oauth/token" { t.Fatalf("default provider = %+v, want v1 config", p) } + if p.AuthTokensPath != "/api/v1/auth/tokens" { + t.Fatalf("default AuthTokensPath = %q, want /api/v1/auth/tokens", p.AuthTokensPath) + } } -func TestCurrentProvider_V1Explicit(t *testing.T) { - t.Setenv(ProviderVersionEnvVar, "v1") - - p := currentProvider() - if p.clientID != wantClientIDV1 { - t.Fatalf("v1 clientID = %q", p.clientID) +func TestResolveProvider_V1Explicit(t *testing.T) { + t.Parallel() + p := resolveProvider("v1") + if p.ClientID != wantClientIDV1 { + t.Fatalf("v1 ClientID = %q", p.ClientID) } // v1 is single-host (entire.io); no STS surface, same-host shortcut - // always wins. Empty stsPath is the contract. - if p.stsPath != "" { - t.Fatalf("v1 stsPath = %q, want empty (single-host, no STS)", p.stsPath) + // always wins. Empty STSPath is the contract. + if p.STSPath != "" { + t.Fatalf("v1 STSPath = %q, want empty (single-host, no STS)", p.STSPath) } } -func TestCurrentProvider_V2(t *testing.T) { - t.Setenv(ProviderVersionEnvVar, "v2") - - p := currentProvider() - if p.clientID != wantClientIDV2 { - t.Fatalf("v2 clientID = %q, want %s", p.clientID, wantClientIDV2) +func TestResolveProvider_V2(t *testing.T) { + t.Parallel() + p := resolveProvider("v2") + if p.ClientID != wantClientIDV2 { + t.Fatalf("v2 ClientID = %q, want %s", p.ClientID, wantClientIDV2) } - if p.deviceCodePath != "/api/auth/oauth/device/code" { - t.Fatalf("v2 deviceCodePath = %q", p.deviceCodePath) + if p.DeviceCodePath != "/api/auth/oauth/device/code" { + t.Fatalf("v2 DeviceCodePath = %q", p.DeviceCodePath) } - if p.tokenPath != "/api/auth/token" { - t.Fatalf("v2 tokenPath = %q", p.tokenPath) + if p.TokenPath != "/api/auth/token" { + t.Fatalf("v2 TokenPath = %q", p.TokenPath) } - if p.stsPath != "/api/authz/sts/token" { - t.Fatalf("v2 stsPath = %q", p.stsPath) + if p.STSPath != "/api/authz/sts/token" { + t.Fatalf("v2 STSPath = %q", p.STSPath) + } + if p.AuthTokensPath != "/api/auth/tokens" { + t.Fatalf("v2 AuthTokensPath = %q, want /api/auth/tokens", p.AuthTokensPath) } } -func TestCurrentProvider_UnknownDefaultsToV1(t *testing.T) { - t.Setenv(ProviderVersionEnvVar, "v999") +func TestResolveProvider_UnknownDefaultsToV1(t *testing.T) { + t.Parallel() + p := resolveProvider("v999") + if p.ClientID != wantClientIDV1 { + t.Fatalf("unknown version should default to v1; got ClientID = %q", p.ClientID) + } +} - p := currentProvider() - if p.clientID != wantClientIDV1 { - t.Fatalf("unknown version should default to v1; got clientID = %q", p.clientID) +func TestResolveProvider_TrimsWhitespace(t *testing.T) { + t.Parallel() + p := resolveProvider(" v2 ") + if p.ClientID != wantClientIDV2 { + t.Fatalf("whitespace-padded v2 ClientID = %q, want %s", p.ClientID, wantClientIDV2) } } -func TestCurrentProvider_TrimsWhitespace(t *testing.T) { - t.Setenv(ProviderVersionEnvVar, " v2 ") +// TestSetProviderForTest_OverridesCurrentProvider locks in the test +// seam: any test that pins a provider via SetProviderForTest must see +// it from CurrentProvider regardless of process-wide singleton state. +func TestSetProviderForTest_OverridesCurrentProvider(t *testing.T) { + pinned := Provider{ + ClientID: "test-client", + DeviceCodePath: "/test/device", + TokenPath: "/test/token", + STSPath: "/test/sts", + AuthTokensPath: "/test/tokens", + } + SetProviderForTest(t, pinned) - p := currentProvider() - if p.clientID != wantClientIDV2 { - t.Fatalf("whitespace-padded v2 clientID = %q, want %s", p.clientID, wantClientIDV2) + got := CurrentProvider() + if got != pinned { + t.Fatalf("CurrentProvider() = %+v, want %+v", got, pinned) } } -func TestNewClient_HonoursProviderVersion(t *testing.T) { +func TestNewClient_HonoursPinnedProvider(t *testing.T) { t.Setenv(api.BaseURLEnvVar, "https://example.test") t.Setenv(api.AuthBaseURLEnvVar, "") - t.Setenv(ProviderVersionEnvVar, "v2") + SetProviderForTest(t, resolveProvider("v2")) c := NewClient(&http.Client{}) if c.inner.ClientID != wantClientIDV2 { @@ -96,10 +121,10 @@ func TestNewClient_HonoursProviderVersion(t *testing.T) { } } -func TestNewClient_DefaultsToV1(t *testing.T) { +func TestNewClient_DefaultsToV1WhenPinned(t *testing.T) { t.Setenv(api.BaseURLEnvVar, "https://example.test") t.Setenv(api.AuthBaseURLEnvVar, "") - t.Setenv(ProviderVersionEnvVar, "") + SetProviderForTest(t, resolveProvider("")) c := NewClient(nil) if c.inner.ClientID != wantClientIDV1 { diff --git a/cmd/entire/cli/auth/store.go b/cmd/entire/cli/auth/store.go index 4e77ecd979..db80889726 100644 --- a/cmd/entire/cli/auth/store.go +++ b/cmd/entire/cli/auth/store.go @@ -84,6 +84,9 @@ func (s *Store) GetToken(baseURL string) (string, error) { if kerr != nil { return "", fmt.Errorf("get token from keyring: %w", kerr) } + if !looksLikeBareToken(raw) { + return "", nil + } return raw, nil } @@ -109,6 +112,12 @@ func (s *Store) SaveTokens(profile string, t tokens.TokenSet) error { // keyring errors (transport, permission denied) propagate; only // ErrMalformed triggers the fallback. ErrNotFound surfaces verbatim // so the manager's "not logged in" branch still works. +// +// The fallback also guards against well-formed-but-empty entries (e.g. +// "{}" or an unrelated CLI's blob keyed against the same service/profile): +// those surface as ErrMalformed from the lib, but using their raw bytes +// as a bearer token would be wrong. looksLikeBareToken filters them out, +// converting back to ErrNotFound so the caller sees "not logged in". func (s *Store) LoadTokens(profile string) (tokens.TokenSet, error) { t, err := s.inner.LoadTokens(profile) if err == nil { @@ -125,9 +134,29 @@ func (s *Store) LoadTokens(profile string) (tokens.TokenSet, error) { if kerr != nil { return tokens.TokenSet{}, fmt.Errorf("get token from keyring: %w", kerr) } + if !looksLikeBareToken(raw) { + return tokens.TokenSet{}, tokenstore.ErrNotFound + } return tokens.TokenSet{AccessToken: raw}, nil } +// looksLikeBareToken reports whether raw is plausibly a pre-shim bare +// access token rather than a JSON blob. Real bearer tokens are JWTs or +// opaque ASCII strings; they don't start with the JSON object/array +// delimiters. Trims whitespace first so a stray newline doesn't fool +// the check. +func looksLikeBareToken(raw string) bool { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return false + } + switch trimmed[0] { + case '{', '[': + return false + } + return true +} + // DeleteTokens implements tokenstore.Store. func (s *Store) DeleteTokens(profile string) error { return s.inner.DeleteTokens(profile) //nolint:wrapcheck // shim returns the lib error verbatim diff --git a/cmd/entire/cli/auth/store_test.go b/cmd/entire/cli/auth/store_test.go index 5bbfd93fbe..7a767b5602 100644 --- a/cmd/entire/cli/auth/store_test.go +++ b/cmd/entire/cli/auth/store_test.go @@ -178,6 +178,48 @@ func TestStoreLoadTokens_LegacyBareStringFallback(t *testing.T) { } } +// TestStoreLoadTokens_RejectsJSONShapedFallback guards against a +// well-formed-but-empty JSON entry being mistakenly treated as a +// bare-string token. The lib surfaces these as ErrMalformed; the +// shim's bare-string fallback must filter out anything starting with +// '{' or '[' so the user sees "not logged in" rather than getting +// "Authorization: Bearer {}" on the wire. +func TestStoreLoadTokens_RejectsJSONShapedFallback(t *testing.T) { + for _, body := range []string{`{}`, `{"foo":"bar"}`, `[]`} { + const profile = "https://json-shaped.example.com" + service := "test-json-fallback-" + body[:1] + if err := keyring.Set(service, profile, body); err != nil { + t.Fatalf("seed keyring: %v", err) + } + + got, err := NewStoreWithService(service).LoadTokens(profile) + // We expect ErrNotFound — JSON-shaped malformed entries must + // not be routed through the bare-string fallback. + if err == nil { + t.Fatalf("LoadTokens(%q) returned %+v; want ErrNotFound", body, got) + } + if got.AccessToken != "" { + t.Fatalf("LoadTokens(%q) AccessToken = %q, want empty", body, got.AccessToken) + } + } +} + +func TestStoreGetToken_RejectsJSONShapedFallback(t *testing.T) { + const service = "test-json-getoken" + const profile = "https://json-shaped.example.com" + if err := keyring.Set(service, profile, `{"unrelated":"blob"}`); err != nil { + t.Fatalf("seed keyring: %v", err) + } + + got, err := NewStoreWithService(service).GetToken(profile) + if err != nil { + t.Fatalf("GetToken: %v", err) + } + if got != "" { + t.Fatalf("GetToken = %q, want empty (JSON blob must not be shipped as a bearer)", got) + } +} + func TestLookupCurrentToken(t *testing.T) { t.Setenv(api.BaseURLEnvVar, "http://localhost:8787") t.Setenv(api.AuthBaseURLEnvVar, "") diff --git a/cmd/entire/cli/logout.go b/cmd/entire/cli/logout.go index b34b7496e2..0d44365035 100644 --- a/cmd/entire/cli/logout.go +++ b/cmd/entire/cli/logout.go @@ -41,7 +41,9 @@ func newLogoutCmd() *cobra.Command { } func defaultRevokeCurrentToken(ctx context.Context, token string) error { - return api.NewClientWithBaseURL(token, api.AuthBaseURL()).RevokeCurrentToken(ctx) //nolint:wrapcheck // RevokeCurrentToken already wraps with action context + client := api.NewClientWithBaseURL(token, api.AuthBaseURL()). + WithAuthTokensPath(auth.CurrentProvider().AuthTokensPath) + return client.RevokeCurrentToken(ctx) //nolint:wrapcheck // RevokeCurrentToken already wraps with action context } func runLogout(ctx context.Context, outW, errW io.Writer, store tokenStore, revoke revokeCurrentFunc, baseURL string) error {