From 7d2adb53c46b79df1191bb11f6cfd5efef540695 Mon Sep 17 00:00:00 2001 From: Alex Ong Date: Wed, 6 May 2026 11:48:44 +1000 Subject: [PATCH 01/21] Add auth/ library: device flow + tokens + token storage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduces github.com/entireio/cli/auth as a shared OAuth client library for the Entire CLI. Three subpackages ship in this commit: * auth/tokens — TokenSet bundle plus unverified JWT claim parsing * auth/tokenstore — Store interface plus an OS-keyring reference impl * auth/deviceflow — RFC 8628 OAuth Device Authorization Grant client The packages are deliberately provider-agnostic: every server-specific value (endpoint paths, client_id, scope) is supplied at construction. The library has no global state, no implicit URLs, and no provider detection. It is intended to be importable by any RFC 8628 / RFC 8693 caller. No existing callers are wired up in this commit; the cmd/entire/cli shim swap follows separately. Co-Authored-By: Claude Opus 4.7 (1M context) --- auth/deviceflow/deviceflow.go | 285 +++++++++++++++++++++++++++++ auth/deviceflow/deviceflow_test.go | 276 ++++++++++++++++++++++++++++ auth/doc.go | 15 ++ auth/tokens/tokens.go | 145 +++++++++++++++ auth/tokens/tokens_test.go | 196 ++++++++++++++++++++ auth/tokenstore/keyring.go | 129 +++++++++++++ auth/tokenstore/keyring_test.go | 136 ++++++++++++++ auth/tokenstore/tokenstore.go | 37 ++++ 8 files changed, 1219 insertions(+) create mode 100644 auth/deviceflow/deviceflow.go create mode 100644 auth/deviceflow/deviceflow_test.go create mode 100644 auth/doc.go create mode 100644 auth/tokens/tokens.go create mode 100644 auth/tokens/tokens_test.go create mode 100644 auth/tokenstore/keyring.go create mode 100644 auth/tokenstore/keyring_test.go create mode 100644 auth/tokenstore/tokenstore.go diff --git a/auth/deviceflow/deviceflow.go b/auth/deviceflow/deviceflow.go new file mode 100644 index 0000000000..be866306c6 --- /dev/null +++ b/auth/deviceflow/deviceflow.go @@ -0,0 +1,285 @@ +// Package deviceflow is an RFC 8628 OAuth 2.0 Device Authorization +// Grant client. +// +// Construct a Client with the issuer's BaseURL plus the paths and +// client_id it expects, then call StartDeviceAuth followed by repeated +// PollDeviceAuth calls until either a TokenSet comes back or a +// terminal error is returned. Caller drives the polling loop and +// adjusts the interval on ErrSlowDown per RFC 8628 §3.5. +// +// The client is provider-agnostic: every server-specific value +// (endpoint paths, client_id, optional scope) is configured at +// construction time. There is no provider detection. +package deviceflow + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/entireio/cli/auth/tokens" +) + +// nowFunc is the package's clock. Tests override it; production uses +// time.Now. +var nowFunc = time.Now + +// maxResponseBytes caps how much of an OAuth response body we read. +// Both endpoints return small JSON documents; larger bodies indicate +// either a misconfigured proxy or an attempt to exhaust client memory. +const maxResponseBytes = 1 << 20 + +// deviceCodeGrantType is the RFC 8628 token-endpoint grant_type for +// polling device-flow authorization. +const deviceCodeGrantType = "urn:ietf:params:oauth:grant-type:device_code" + +// DeviceCode is the response from the device authorization endpoint +// (RFC 8628 §3.2). Pass DeviceCode through to subsequent PollDeviceAuth +// calls and show UserCode + VerificationURI to the user. +type DeviceCode struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationURI string `json:"verification_uri"` + VerificationURIComplete string `json:"verification_uri_complete"` + ExpiresIn int `json:"expires_in"` + Interval int `json:"interval"` +} + +// Client polls an RFC 8628 device authorization grant. +// +// All configuration is explicit; the package has no global state and +// no implicit URLs. Provide BaseURL, ClientID, and the two endpoint +// paths; the rest is RFC 8628 mechanics. +type Client struct { + HTTP *http.Client + BaseURL string + ClientID string + Scope string + UserAgent string + DeviceCodePath string + TokenPath string +} + +// Sentinel errors returned by PollDeviceAuth when the token endpoint +// responds with a recognised RFC 8628 §3.5 error code. Callers branch +// on these with errors.Is and adjust their polling loop accordingly. +var ( + // ErrAuthorizationPending — user has not yet approved or denied. + // Caller polls again at the existing interval. + ErrAuthorizationPending = errors.New("authorization_pending") + + // ErrSlowDown — caller is polling too fast. Caller bumps the + // interval (per RFC 8628 §3.5, by at least 5 seconds) and tries + // again. + ErrSlowDown = errors.New("slow_down") + + // ErrAccessDenied — user denied the request. Terminal. + ErrAccessDenied = errors.New("access_denied") + + // ErrExpiredToken — device code expired before the user approved. + // Terminal; restart with a fresh StartDeviceAuth. + ErrExpiredToken = errors.New("expired_token") + + // ErrInvalidGrant — device code already redeemed, malformed, or + // otherwise rejected. Terminal. + ErrInvalidGrant = errors.New("invalid_grant") +) + +// errCodeToSentinel maps an RFC 8628 §3.5 error code string to the +// matching sentinel. Unknown codes fall through to a generic error. +func errCodeToSentinel(code string) error { + switch code { + case "authorization_pending": + return ErrAuthorizationPending + case "slow_down": + return ErrSlowDown + case "access_denied": + return ErrAccessDenied + case "expired_token": + return ErrExpiredToken + case "invalid_grant": + return ErrInvalidGrant + default: + return fmt.Errorf("oauth error: %s", code) + } +} + +// StartDeviceAuth requests a fresh device code from the authorization +// server. The returned DeviceCode is opaque to the client; pass it +// back unmodified on every PollDeviceAuth. +func (c *Client) StartDeviceAuth(ctx context.Context) (*DeviceCode, error) { + body := url.Values{} + body.Set("client_id", c.ClientID) + if c.Scope != "" { + body.Set("scope", c.Scope) + } + + resp, err := c.postForm(ctx, c.DeviceCodePath, body) + if err != nil { + return nil, fmt.Errorf("start device auth: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, readAPIError(resp, "start device auth") + } + + var result DeviceCode + if err := decodeJSON(resp.Body, &result, true); err != nil { + return nil, fmt.Errorf("decode device auth start response: %w", err) + } + return &result, nil +} + +// PollDeviceAuth exchanges deviceCode for a TokenSet at the token +// endpoint. +// +// On success, returns a TokenSet with absolute expiry derived from +// the server's expires_in. On any RFC 8628 §3.5 error code, returns +// the matching sentinel error from this package. Other failures +// (network, malformed responses) are wrapped with context. +func (c *Client) PollDeviceAuth(ctx context.Context, deviceCode string) (*tokens.TokenSet, error) { + body := url.Values{} + body.Set("grant_type", deviceCodeGrantType) + body.Set("client_id", c.ClientID) + body.Set("device_code", deviceCode) + + resp, err := c.postForm(ctx, c.TokenPath, body) + if err != nil { + return nil, fmt.Errorf("poll device auth: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + apiErr, parseErr := readAPIErrorResponse(resp) + if parseErr != nil { + return nil, fmt.Errorf("poll device auth: %w", parseErr) + } + return nil, errCodeToSentinel(apiErr.Error) + } + + var raw struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + Scope string `json:"scope"` + } + if err := decodeJSON(resp.Body, &raw, false); err != nil { + return nil, fmt.Errorf("decode device auth poll response: %w", err) + } + + if raw.AccessToken == "" { + return nil, fmt.Errorf("poll device auth: server returned 200 with no access token") + } + + t := &tokens.TokenSet{ + AccessToken: raw.AccessToken, + RefreshToken: raw.RefreshToken, + TokenType: raw.TokenType, + Scope: raw.Scope, + } + if raw.ExpiresIn > 0 { + t.ExpiresAt = nowFunc().Add(time.Duration(raw.ExpiresIn) * time.Second) + } + return t, nil +} + +// postForm POSTs body as application/x-www-form-urlencoded to a path +// resolved against the client's BaseURL. +func (c *Client) postForm(ctx context.Context, path string, body url.Values) (*http.Response, error) { + endpoint, err := resolveURL(c.BaseURL, path) + if err != nil { + return nil, fmt.Errorf("resolve URL %s: %w", path, err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(body.Encode())) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + if c.UserAgent != "" { + req.Header.Set("User-Agent", c.UserAgent) + } + + httpClient := c.HTTP + if httpClient == nil { + httpClient = http.DefaultClient + } + + resp, err := httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("request %s: %w", path, err) + } + return resp, nil +} + +func resolveURL(baseURL, path string) (string, error) { + base, err := url.Parse(baseURL) + if err != nil { + return "", fmt.Errorf("parse base URL: %w", err) + } + if base.Scheme != "http" && base.Scheme != "https" { + return "", fmt.Errorf("unsupported base URL scheme %q (must be http or https)", base.Scheme) + } + rel, err := url.Parse(path) + if err != nil { + return "", fmt.Errorf("parse path: %w", err) + } + return base.ResolveReference(rel).String(), nil +} + +type errorResponse struct { + Error string `json:"error"` +} + +func readAPIErrorResponse(resp *http.Response) (*errorResponse, error) { + body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) + if err != nil { + return nil, fmt.Errorf("status %d", resp.StatusCode) + } + + var apiErr errorResponse + if err := json.Unmarshal(body, &apiErr); err == nil && strings.TrimSpace(apiErr.Error) != "" { + return &apiErr, nil + } + + text := strings.TrimSpace(string(body)) + if text != "" { + return nil, fmt.Errorf("status %d: %s", resp.StatusCode, text) + } + return nil, fmt.Errorf("status %d", resp.StatusCode) +} + +func readAPIError(resp *http.Response, action string) error { + apiErr, err := readAPIErrorResponse(resp) + if err == nil { + return fmt.Errorf("%s: %s", action, apiErr.Error) + } + return fmt.Errorf("%s: %w", action, err) +} + +func decodeJSON(r io.Reader, dest any, strict bool) error { + body, err := io.ReadAll(io.LimitReader(r, maxResponseBytes)) + if err != nil { + return fmt.Errorf("read JSON response: %w", err) + } + + dec := json.NewDecoder(bytes.NewReader(body)) + if strict { + dec.DisallowUnknownFields() + } + if err := dec.Decode(dest); err != nil { + return fmt.Errorf("decode JSON response: %w", err) + } + return nil +} diff --git a/auth/deviceflow/deviceflow_test.go b/auth/deviceflow/deviceflow_test.go new file mode 100644 index 0000000000..1d217ed723 --- /dev/null +++ b/auth/deviceflow/deviceflow_test.go @@ -0,0 +1,276 @@ +package deviceflow + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +const ( + testClientID = "cli" + testDeviceCodePath = "/oauth/device/code" + testTokenPath = "/oauth/token" +) + +// freezeClock pins nowFunc for the duration of a test. +func freezeClock(t *testing.T, at time.Time) { + t.Helper() + prev := nowFunc + nowFunc = func() time.Time { return at } + t.Cleanup(func() { nowFunc = prev }) +} + +func newTestClient(t *testing.T, h http.HandlerFunc) (*Client, *httptest.Server) { + t.Helper() + srv := httptest.NewServer(h) + t.Cleanup(srv.Close) + + c := &Client{ + HTTP: srv.Client(), + BaseURL: srv.URL, + ClientID: testClientID, + Scope: "cli", + DeviceCodePath: testDeviceCodePath, + TokenPath: testTokenPath, + } + return c, srv +} + +func mustReadForm(t *testing.T, r *http.Request) { + t.Helper() + if err := r.ParseForm(); err != nil { + t.Fatalf("parse form: %v", err) + } +} + +func TestStartDeviceAuth_Success(t *testing.T) { + t.Parallel() + + c, _ := newTestClient(t, func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != testDeviceCodePath { + t.Errorf("path = %q", r.URL.Path) + } + mustReadForm(t, r) + if got := r.PostForm.Get("client_id"); got != testClientID { + t.Errorf("client_id = %q", got) + } + if got := r.PostForm.Get("scope"); got != "cli" { + t.Errorf("scope = %q", got) + } + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{ + "device_code": "dev-1", + "user_code": "ABCD-EFGH", + "verification_uri": "https://example.com/cli/auth", + "verification_uri_complete": "https://example.com/cli/auth?code=ABCD-EFGH", + "expires_in": 600, + "interval": 5 + }`) + }) + + got, err := c.StartDeviceAuth(context.Background()) + if err != nil { + t.Fatalf("StartDeviceAuth() error = %v", err) + } + if got.DeviceCode != "dev-1" || got.UserCode != "ABCD-EFGH" || got.ExpiresIn != 600 || got.Interval != 5 { + t.Fatalf("DeviceCode = %+v", got) + } +} + +func TestStartDeviceAuth_OmitsScopeWhenEmpty(t *testing.T) { + t.Parallel() + + c, _ := newTestClient(t, func(w http.ResponseWriter, r *http.Request) { + mustReadForm(t, r) + if r.PostForm.Has("scope") { + t.Errorf("scope should not be sent when Client.Scope is empty") + } + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"device_code":"d","user_code":"u","verification_uri":"x","expires_in":1,"interval":1}`) + }) + c.Scope = "" + + if _, err := c.StartDeviceAuth(context.Background()); err != nil { + t.Fatalf("StartDeviceAuth() error = %v", err) + } +} + +func TestStartDeviceAuth_RejectsUnknownFields(t *testing.T) { + t.Parallel() + + c, _ := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, `{ + "device_code":"d","user_code":"u","verification_uri":"x","expires_in":1,"interval":1, + "surprise":"field" + }`) + }) + + if _, err := c.StartDeviceAuth(context.Background()); err == nil { + t.Fatal("StartDeviceAuth() with unknown field should fail (strict decode)") + } +} + +func TestStartDeviceAuth_NonOK(t *testing.T) { + t.Parallel() + + c, _ := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = io.WriteString(w, `{"error":"invalid_client"}`) + }) + + if _, err := c.StartDeviceAuth(context.Background()); err == nil || + !strings.Contains(err.Error(), "invalid_client") { + t.Fatalf("StartDeviceAuth() error = %v, want invalid_client", err) + } +} + +func TestPollDeviceAuth_Success(t *testing.T) { + t.Parallel() + + freezeClock(t, time.Date(2026, 5, 6, 12, 0, 0, 0, time.UTC)) + + c, _ := newTestClient(t, func(w http.ResponseWriter, r *http.Request) { + mustReadForm(t, r) + if got := r.PostForm.Get("grant_type"); got != deviceCodeGrantType { + t.Errorf("grant_type = %q", got) + } + if got := r.PostForm.Get("device_code"); got != "dev-1" { + t.Errorf("device_code = %q", got) + } + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{ + "access_token":"acc", + "refresh_token":"ref", + "token_type":"Bearer", + "expires_in":3600, + "scope":"cli" + }`) + }) + + got, err := c.PollDeviceAuth(context.Background(), "dev-1") + if err != nil { + t.Fatalf("PollDeviceAuth() error = %v", err) + } + + if got.AccessToken != "acc" || got.RefreshToken != "ref" || got.TokenType != "Bearer" || got.Scope != "cli" { + t.Fatalf("TokenSet = %+v", got) + } + want := time.Date(2026, 5, 6, 13, 0, 0, 0, time.UTC) + if !got.ExpiresAt.Equal(want) { + t.Fatalf("ExpiresAt = %v, want %v", got.ExpiresAt, want) + } +} + +func TestPollDeviceAuth_TolerantToUnknownFields(t *testing.T) { + t.Parallel() + + c, _ := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, `{"access_token":"acc","extra":"ignored"}`) + }) + + got, err := c.PollDeviceAuth(context.Background(), "dev-1") + if err != nil { + t.Fatalf("PollDeviceAuth() error = %v", err) + } + if got.AccessToken != "acc" { + t.Fatalf("AccessToken = %q", got.AccessToken) + } +} + +func TestPollDeviceAuth_ErrorCodes(t *testing.T) { + t.Parallel() + + tests := []struct { + code string + want error + }{ + {"authorization_pending", ErrAuthorizationPending}, + {"slow_down", ErrSlowDown}, + {"access_denied", ErrAccessDenied}, + {"expired_token", ErrExpiredToken}, + {"invalid_grant", ErrInvalidGrant}, + } + + for _, tt := range tests { + t.Run(tt.code, func(t *testing.T) { + t.Parallel() + c, _ := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = fmt.Fprintf(w, `{"error":%q}`, tt.code) + }) + + _, err := c.PollDeviceAuth(context.Background(), "dev-1") + if !errors.Is(err, tt.want) { + t.Fatalf("PollDeviceAuth() error = %v, want %v", err, tt.want) + } + }) + } +} + +func TestPollDeviceAuth_UnknownErrorCode(t *testing.T) { + t.Parallel() + + c, _ := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = io.WriteString(w, `{"error":"weird_thing"}`) + }) + + _, err := c.PollDeviceAuth(context.Background(), "dev-1") + if err == nil || !strings.Contains(err.Error(), "weird_thing") { + t.Fatalf("PollDeviceAuth() error = %v, want unknown-code error", err) + } + for _, sentinel := range []error{ErrAuthorizationPending, ErrSlowDown, ErrAccessDenied, ErrExpiredToken, ErrInvalidGrant} { + if errors.Is(err, sentinel) { + t.Fatalf("unknown code matched sentinel %v", sentinel) + } + } +} + +func TestPollDeviceAuth_200WithNoAccessToken(t *testing.T) { + t.Parallel() + + c, _ := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, `{}`) + }) + + if _, err := c.PollDeviceAuth(context.Background(), "dev-1"); err == nil { + t.Fatal("PollDeviceAuth() should fail when access_token missing") + } +} + +func TestResolveURL(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + base string + path string + want string + wantErr bool + }{ + {"https + absolute path", "https://entire.io", "/oauth/device/code", "https://entire.io/oauth/device/code", false}, + {"trailing slash + absolute path", "https://entire.io/", "/oauth/token", "https://entire.io/oauth/token", false}, + {"http allowed", "http://localhost:8180", "/api/auth/token", "http://localhost:8180/api/auth/token", false}, + {"unsupported scheme", "ftp://x", "/y", "", true}, + {"malformed base", "://", "/y", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := resolveURL(tt.base, tt.path) + if (err != nil) != tt.wantErr { + t.Fatalf("resolveURL() err = %v, wantErr %v", err, tt.wantErr) + } + if got != tt.want { + t.Fatalf("resolveURL() = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/auth/doc.go b/auth/doc.go new file mode 100644 index 0000000000..1e3f0571ef --- /dev/null +++ b/auth/doc.go @@ -0,0 +1,15 @@ +// Package auth is the umbrella for the Entire CLI auth library. +// +// All real code lives in the subpackages: +// +// - deviceflow — RFC 8628 OAuth 2.0 Device Authorization Grant client +// - tokens — TokenSet plus unverified JWT claim parsing +// - tokenstore — pluggable persistence interface with reference impls +// - sts — RFC 8693 Token Exchange client +// +// The library is designed to talk RFC 8628 and RFC 8693 to any compliant +// OAuth 2.0 server. It contains no provider-specific behaviour; endpoint +// paths, client IDs, and token-type URIs are caller-supplied. Anything a +// caller learns about the server beyond what the server tells it in a +// public HTTP response is out of scope for this package. +package auth diff --git a/auth/tokens/tokens.go b/auth/tokens/tokens.go new file mode 100644 index 0000000000..d7c37d13ed --- /dev/null +++ b/auth/tokens/tokens.go @@ -0,0 +1,145 @@ +// Package tokens defines the post-protocol token shape and helpers for +// reading unverified claims out of JWT access tokens. +// +// The wire-shape responses from RFC 8628 / RFC 8693 endpoints are +// translated into a single TokenSet with absolute expiry. Clients that +// only ever see access tokens as opaque bearer strings need not import +// this package directly. +package tokens + +import ( + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "strings" + "time" +) + +// TokenSet is an OAuth token bundle returned from a device-flow or +// token-exchange endpoint, normalised to absolute expiry. +// +// RefreshToken is empty when the issuer didn't return one. ExpiresAt is +// zero for tokens that don't carry a wire-side expires_in. +type TokenSet struct { + AccessToken string + RefreshToken string + TokenType string + ExpiresAt time.Time + Scope string +} + +// HasRefresh reports whether the set carries a refresh token. +func (t TokenSet) HasRefresh() bool { return t.RefreshToken != "" } + +// Expired reports whether the access token's advertised lifetime has +// elapsed at now. Returns false for tokens with a zero ExpiresAt. +func (t TokenSet) Expired(now time.Time) bool { + if t.ExpiresAt.IsZero() { + return false + } + return !now.Before(t.ExpiresAt) +} + +// ShouldRefresh reports whether the access token is within skew of +// expiring (or has already expired). Tokens without an advertised +// expiry never need refreshing. +func (t TokenSet) ShouldRefresh(now time.Time, skew time.Duration) bool { + if t.ExpiresAt.IsZero() { + return false + } + return !now.Add(skew).Before(t.ExpiresAt) +} + +// Claims holds the fields parsed from a JWT access token's payload. +// +// Signature verification is the issuing server's responsibility; this +// package never validates signatures. Clients read claims for routing +// (which issuer, which audience) and UX (display the principal handle). +type Claims struct { + Issuer string + Subject string + Audience []string + Handle string + ExpiresAt time.Time + IssuedAt time.Time + NotBefore time.Time +} + +// ErrMalformedJWT is returned by ParseClaims when the input is not a +// well-formed JWT (three base64url-encoded segments separated by dots). +var ErrMalformedJWT = errors.New("malformed JWT") + +// ParseClaims decodes the payload segment of jwt without verifying the +// signature. Audience is normalised to a slice even when the wire form +// is a single string. +func ParseClaims(jwt string) (*Claims, error) { + parts := strings.Split(jwt, ".") + if len(parts) != 3 { + return nil, fmt.Errorf("%w: expected 3 segments, got %d", ErrMalformedJWT, len(parts)) + } + + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return nil, fmt.Errorf("decode JWT payload: %w", err) + } + + var raw struct { + Iss string `json:"iss"` + Sub string `json:"sub"` + Aud json.RawMessage `json:"aud"` + Exp int64 `json:"exp"` + Iat int64 `json:"iat"` + Nbf int64 `json:"nbf"` + Handle string `json:"handle"` + } + if err := json.Unmarshal(payload, &raw); err != nil { + return nil, fmt.Errorf("decode JWT claims: %w", err) + } + + c := &Claims{ + Issuer: raw.Iss, + Subject: raw.Sub, + Handle: raw.Handle, + } + + if raw.Exp != 0 { + c.ExpiresAt = time.Unix(raw.Exp, 0).UTC() + } + if raw.Iat != 0 { + c.IssuedAt = time.Unix(raw.Iat, 0).UTC() + } + if raw.Nbf != 0 { + c.NotBefore = time.Unix(raw.Nbf, 0).UTC() + } + + c.Audience, err = decodeAudience(raw.Aud) + if err != nil { + return nil, err + } + + return c, nil +} + +// decodeAudience handles both string and string-array forms of the JWT +// `aud` claim. +func decodeAudience(raw json.RawMessage) ([]string, error) { + if len(raw) == 0 { + return nil, nil + } + + var single string + if err := json.Unmarshal(raw, &single); err == nil { + if single == "" { + return nil, nil + } + return []string{single}, nil + } + + var multi []string + if err := json.Unmarshal(raw, &multi); err == nil { + return multi, nil + } + + return nil, fmt.Errorf("decode JWT aud claim: not a string or array") +} diff --git a/auth/tokens/tokens_test.go b/auth/tokens/tokens_test.go new file mode 100644 index 0000000000..8c702b8e3f --- /dev/null +++ b/auth/tokens/tokens_test.go @@ -0,0 +1,196 @@ +package tokens + +import ( + "encoding/base64" + "encoding/json" + "errors" + "testing" + "time" +) + +func TestTokenSet_Expired(t *testing.T) { + t.Parallel() + + now := time.Date(2026, 5, 6, 12, 0, 0, 0, time.UTC) + + tests := []struct { + name string + set TokenSet + want bool + }{ + {"zero expiry never expires", TokenSet{}, false}, + {"future expiry", TokenSet{ExpiresAt: now.Add(time.Hour)}, false}, + {"past expiry", TokenSet{ExpiresAt: now.Add(-time.Second)}, true}, + {"exact moment is expired", TokenSet{ExpiresAt: now}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := tt.set.Expired(now); got != tt.want { + t.Fatalf("Expired() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestTokenSet_ShouldRefresh(t *testing.T) { + t.Parallel() + + now := time.Date(2026, 5, 6, 12, 0, 0, 0, time.UTC) + skew := 30 * time.Second + + tests := []struct { + name string + set TokenSet + want bool + }{ + {"zero expiry never refreshes", TokenSet{}, false}, + {"comfortably future", TokenSet{ExpiresAt: now.Add(time.Hour)}, false}, + {"within skew window", TokenSet{ExpiresAt: now.Add(15 * time.Second)}, true}, + {"already expired", TokenSet{ExpiresAt: now.Add(-time.Second)}, true}, + {"exactly at skew boundary", TokenSet{ExpiresAt: now.Add(skew)}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := tt.set.ShouldRefresh(now, skew); got != tt.want { + t.Fatalf("ShouldRefresh() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestTokenSet_HasRefresh(t *testing.T) { + t.Parallel() + if (TokenSet{}).HasRefresh() { + t.Fatal("empty TokenSet should not have a refresh token") + } + if !(TokenSet{RefreshToken: "x"}).HasRefresh() { + t.Fatal("TokenSet with refresh token should report true") + } +} + +func makeJWT(t *testing.T, payload any) string { + t.Helper() + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`)) + body, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal payload: %v", err) + } + return header + "." + base64.RawURLEncoding.EncodeToString(body) + ".sig" +} + +func TestParseClaims_BasicFields(t *testing.T) { + t.Parallel() + + jwt := makeJWT(t, map[string]any{ + "iss": "https://entire.io", + "sub": "01HX...", + "aud": "entire-cli", + "exp": 1800000000, + "iat": 1799999000, + "handle": "alex", + }) + + got, err := ParseClaims(jwt) + if err != nil { + t.Fatalf("ParseClaims() error = %v", err) + } + + if got.Issuer != "https://entire.io" { + t.Errorf("Issuer = %q", got.Issuer) + } + if got.Subject != "01HX..." { + t.Errorf("Subject = %q", got.Subject) + } + if got.Handle != "alex" { + t.Errorf("Handle = %q", got.Handle) + } + if !got.ExpiresAt.Equal(time.Unix(1800000000, 0).UTC()) { + t.Errorf("ExpiresAt = %v", got.ExpiresAt) + } + if len(got.Audience) != 1 || got.Audience[0] != "entire-cli" { + t.Errorf("Audience = %v", got.Audience) + } +} + +func TestParseClaims_AudienceArray(t *testing.T) { + t.Parallel() + + jwt := makeJWT(t, map[string]any{ + "aud": []string{"entire-cli", "entire-server"}, + }) + + got, err := ParseClaims(jwt) + if err != nil { + t.Fatalf("ParseClaims() error = %v", err) + } + if len(got.Audience) != 2 || got.Audience[0] != "entire-cli" || got.Audience[1] != "entire-server" { + t.Fatalf("Audience = %v", got.Audience) + } +} + +func TestParseClaims_MissingFieldsAreZero(t *testing.T) { + t.Parallel() + + jwt := makeJWT(t, map[string]any{}) + got, err := ParseClaims(jwt) + if err != nil { + t.Fatalf("ParseClaims() error = %v", err) + } + + if got.Issuer != "" || got.Subject != "" || got.Handle != "" { + t.Errorf("expected zero strings, got %+v", got) + } + if !got.ExpiresAt.IsZero() { + t.Errorf("ExpiresAt should be zero, got %v", got.ExpiresAt) + } + if len(got.Audience) != 0 { + t.Errorf("Audience should be empty, got %v", got.Audience) + } +} + +func TestParseClaims_MalformedJWT(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + }{ + {"empty", ""}, + {"two segments", "header.payload"}, + {"four segments", "a.b.c.d"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := ParseClaims(tt.input) + if !errors.Is(err, ErrMalformedJWT) { + t.Fatalf("ParseClaims(%q) error = %v, want ErrMalformedJWT", tt.input, err) + } + }) + } +} + +func TestParseClaims_BadBase64(t *testing.T) { + t.Parallel() + + _, err := ParseClaims("header.!!!.sig") + if err == nil { + t.Fatal("ParseClaims() with bad base64 should fail") + } +} + +func TestParseClaims_BadJSON(t *testing.T) { + t.Parallel() + + header := base64.RawURLEncoding.EncodeToString([]byte(`{}`)) + payload := base64.RawURLEncoding.EncodeToString([]byte(`not json`)) + _, err := ParseClaims(header + "." + payload + ".sig") + if err == nil { + t.Fatal("ParseClaims() with bad JSON should fail") + } +} diff --git a/auth/tokenstore/keyring.go b/auth/tokenstore/keyring.go new file mode 100644 index 0000000000..cfba4d5c6c --- /dev/null +++ b/auth/tokenstore/keyring.go @@ -0,0 +1,129 @@ +package tokenstore + +import ( + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "github.com/entireio/cli/auth/tokens" + "github.com/zalando/go-keyring" +) + +// Keyring is a Store backed by the OS keyring. +// +// Each profile gets one entry under the configured Service name. The +// entry holds a JSON-encoded TokenSet so refresh tokens, expiry, and +// scope round-trip alongside the access token. +type Keyring struct { + Service string +} + +// NewKeyring returns a Keyring with the given service name. The service +// name namespaces entries in the OS keyring; pick something unique per +// CLI binary so two CLIs don't collide. +func NewKeyring(service string) *Keyring { + return &Keyring{Service: service} +} + +// SaveTokens marshals t to JSON and stores it under profile in the OS +// keyring. Empty access tokens are rejected. +func (k *Keyring) SaveTokens(profile string, t tokens.TokenSet) error { + t.AccessToken = strings.TrimSpace(t.AccessToken) + if t.AccessToken == "" { + return errors.New("refusing to save TokenSet with empty access token") + } + + encoded, err := encodeTokenSet(t) + if err != nil { + return err + } + + if err := keyring.Set(k.Service, profile, encoded); err != nil { + return fmt.Errorf("save tokens to keyring: %w", err) + } + return nil +} + +// LoadTokens returns the TokenSet stored for profile. Returns +// ErrNotFound when the profile has nothing stored. +func (k *Keyring) LoadTokens(profile string) (tokens.TokenSet, error) { + raw, err := keyring.Get(k.Service, profile) + if errors.Is(err, keyring.ErrNotFound) { + return tokens.TokenSet{}, ErrNotFound + } + if err != nil { + return tokens.TokenSet{}, fmt.Errorf("load tokens from keyring: %w", err) + } + + t, err := decodeTokenSet(raw) + if err != nil { + return tokens.TokenSet{}, err + } + return t, nil +} + +// DeleteTokens removes the TokenSet for profile. A missing entry is a +// no-op. +func (k *Keyring) DeleteTokens(profile string) error { + err := keyring.Delete(k.Service, profile) + if errors.Is(err, keyring.ErrNotFound) { + return nil + } + if err != nil { + return fmt.Errorf("delete tokens from keyring: %w", err) + } + return nil +} + +// keyringTokenSet is the on-keyring JSON shape. Time fields are +// serialised as RFC 3339 strings so the wire form survives keyring +// implementations that don't preserve byte-for-byte equality. +type keyringTokenSet struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token,omitempty"` + TokenType string `json:"token_type,omitempty"` + ExpiresAt string `json:"expires_at,omitempty"` + Scope string `json:"scope,omitempty"` +} + +func encodeTokenSet(t tokens.TokenSet) (string, error) { + wire := keyringTokenSet{ + AccessToken: t.AccessToken, + RefreshToken: t.RefreshToken, + TokenType: t.TokenType, + Scope: t.Scope, + } + if !t.ExpiresAt.IsZero() { + wire.ExpiresAt = t.ExpiresAt.UTC().Format(time.RFC3339) + } + + b, err := json.Marshal(wire) + if err != nil { + return "", fmt.Errorf("marshal TokenSet: %w", err) + } + return string(b), nil +} + +func decodeTokenSet(raw string) (tokens.TokenSet, error) { + var wire keyringTokenSet + if err := json.Unmarshal([]byte(raw), &wire); err != nil { + return tokens.TokenSet{}, fmt.Errorf("unmarshal TokenSet: %w", err) + } + + t := tokens.TokenSet{ + AccessToken: wire.AccessToken, + RefreshToken: wire.RefreshToken, + TokenType: wire.TokenType, + Scope: wire.Scope, + } + if wire.ExpiresAt != "" { + exp, err := time.Parse(time.RFC3339, wire.ExpiresAt) + if err != nil { + return tokens.TokenSet{}, fmt.Errorf("parse expires_at: %w", err) + } + t.ExpiresAt = exp.UTC() + } + return t, nil +} diff --git a/auth/tokenstore/keyring_test.go b/auth/tokenstore/keyring_test.go new file mode 100644 index 0000000000..4cc7a53925 --- /dev/null +++ b/auth/tokenstore/keyring_test.go @@ -0,0 +1,136 @@ +package tokenstore + +import ( + "errors" + "os" + "testing" + "time" + + "github.com/entireio/cli/auth/tokens" + "github.com/zalando/go-keyring" +) + +func TestMain(m *testing.M) { + keyring.MockInit() + os.Exit(m.Run()) +} + +func TestKeyring_SaveLoad_RoundTrip(t *testing.T) { + store := NewKeyring("test-roundtrip") + + want := tokens.TokenSet{ + AccessToken: "access", + RefreshToken: "refresh", + TokenType: "Bearer", + ExpiresAt: time.Date(2026, 5, 6, 12, 0, 0, 0, time.UTC), + Scope: "cli", + } + + if err := store.SaveTokens("https://entire.io", want); err != nil { + t.Fatalf("SaveTokens() error = %v", err) + } + + got, err := store.LoadTokens("https://entire.io") + if err != nil { + t.Fatalf("LoadTokens() error = %v", err) + } + if got.AccessToken != want.AccessToken || + got.RefreshToken != want.RefreshToken || + got.TokenType != want.TokenType || + !got.ExpiresAt.Equal(want.ExpiresAt) || + got.Scope != want.Scope { + t.Fatalf("LoadTokens() = %+v, want %+v", got, want) + } +} + +func TestKeyring_LoadTokens_NotFound(t *testing.T) { + store := NewKeyring("test-not-found") + + _, err := store.LoadTokens("https://missing.example") + if !errors.Is(err, ErrNotFound) { + t.Fatalf("LoadTokens() error = %v, want ErrNotFound", err) + } +} + +func TestKeyring_SaveTokens_RejectsEmptyAccessToken(t *testing.T) { + store := NewKeyring("test-empty") + + if err := store.SaveTokens("https://entire.io", tokens.TokenSet{}); err == nil { + t.Fatal("SaveTokens() with empty access token should fail") + } + if err := store.SaveTokens("https://entire.io", tokens.TokenSet{AccessToken: " "}); err == nil { + t.Fatal("SaveTokens() with whitespace access token should fail") + } +} + +func TestKeyring_SaveTokens_TrimsAccessToken(t *testing.T) { + store := NewKeyring("test-trim") + + if err := store.SaveTokens("https://entire.io", tokens.TokenSet{AccessToken: " tok "}); err != nil { + t.Fatalf("SaveTokens() error = %v", err) + } + got, err := store.LoadTokens("https://entire.io") + if err != nil { + t.Fatalf("LoadTokens() error = %v", err) + } + if got.AccessToken != "tok" { + t.Fatalf("AccessToken = %q, want %q", got.AccessToken, "tok") + } +} + +func TestKeyring_DeleteTokens(t *testing.T) { + store := NewKeyring("test-delete") + + if err := store.SaveTokens("https://entire.io", tokens.TokenSet{AccessToken: "tok"}); err != nil { + t.Fatalf("SaveTokens() error = %v", err) + } + if err := store.DeleteTokens("https://entire.io"); err != nil { + t.Fatalf("DeleteTokens() error = %v", err) + } + if _, err := store.LoadTokens("https://entire.io"); !errors.Is(err, ErrNotFound) { + t.Fatalf("LoadTokens() after delete error = %v, want ErrNotFound", err) + } +} + +func TestKeyring_DeleteTokens_MissingIsNoop(t *testing.T) { + store := NewKeyring("test-delete-missing") + + if err := store.DeleteTokens("https://nonexistent.example"); err != nil { + t.Fatalf("DeleteTokens() on missing entry error = %v", err) + } +} + +func TestKeyring_PreservesOtherProfiles(t *testing.T) { + store := NewKeyring("test-preserve") + + if err := store.SaveTokens("a", tokens.TokenSet{AccessToken: "tok-a"}); err != nil { + t.Fatalf("SaveTokens(a) error = %v", err) + } + if err := store.SaveTokens("b", tokens.TokenSet{AccessToken: "tok-b"}); err != nil { + t.Fatalf("SaveTokens(b) error = %v", err) + } + + a, err := store.LoadTokens("a") + if err != nil || a.AccessToken != "tok-a" { + t.Fatalf("LoadTokens(a) = %q (err %v), want tok-a", a.AccessToken, err) + } + b, err := store.LoadTokens("b") + if err != nil || b.AccessToken != "tok-b" { + t.Fatalf("LoadTokens(b) = %q (err %v), want tok-b", b.AccessToken, err) + } +} + +func TestKeyring_RoundTrip_NoExpiry(t *testing.T) { + store := NewKeyring("test-no-expiry") + + if err := store.SaveTokens("p", tokens.TokenSet{AccessToken: "tok"}); err != nil { + t.Fatalf("SaveTokens() error = %v", err) + } + got, err := store.LoadTokens("p") + if err != nil { + t.Fatalf("LoadTokens() error = %v", err) + } + if !got.ExpiresAt.IsZero() { + t.Fatalf("ExpiresAt = %v, want zero", got.ExpiresAt) + } +} diff --git a/auth/tokenstore/tokenstore.go b/auth/tokenstore/tokenstore.go new file mode 100644 index 0000000000..875308a37a --- /dev/null +++ b/auth/tokenstore/tokenstore.go @@ -0,0 +1,37 @@ +// Package tokenstore is the persistence interface for tokens, plus +// reference implementations. +// +// Callers pick a Store at startup. Two impls ship with this package: +// +// - Keyring stores one entry per profile in the OS keyring. Suitable +// for interactive single-user CLIs. +// - File stores entries in a JSON file on disk, with refresh tokens +// in the OS keyring. Suitable for CLIs that need to persist +// additional per-profile metadata (e.g. context bindings). +// +// Profile is whatever string the caller wants to key by — typically a +// base URL, a kubectl-style context name, or a principal handle. +package tokenstore + +import ( + "errors" + + "github.com/entireio/cli/auth/tokens" +) + +// ErrNotFound is returned when a profile has no stored tokens. Callers +// distinguish "not logged in" from genuine errors with errors.Is. +var ErrNotFound = errors.New("token not found") + +// Store persists token bundles keyed by an opaque profile string. +// +// Implementations must: +// - Return ErrNotFound (not a zero value, no error) when LoadTokens +// is called for an unknown profile. +// - Treat DeleteTokens of a missing profile as a no-op. +// - Not write empty access tokens; SaveTokens should reject them. +type Store interface { + SaveTokens(profile string, t tokens.TokenSet) error + LoadTokens(profile string) (tokens.TokenSet, error) + DeleteTokens(profile string) error +} From cd0648ab35cc2ac577e4d8d92d13ca7c89f8c996 Mon Sep 17 00:00:00 2001 From: Alex Ong Date: Wed, 6 May 2026 11:49:03 +1000 Subject: [PATCH 02/21] Wire cmd/entire/cli/auth shims to use the shared auth/ library MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the bespoke device-flow client and keyring store in cmd/entire/cli/auth with thin wrappers over auth/deviceflow and auth/tokenstore. The package's exported API (NewClient, NewStore, DeviceAuthStart, DeviceAuthPoll, LookupCurrentToken, etc.) is preserved field-for-field so login.go / logout.go / auth.go don't need to change. Two wrapper concerns worth noting: 1. PollDeviceAuth maps the shared library's RFC 8628 §3.5 sentinel errors back to the wire-side error code in DeviceAuthPoll.Error. This keeps the existing polling loop in login.go (which switches on result.Error) working unchanged. 2. Store.GetToken keeps a backward-compatibility fallback for keyring entries written before this commit, which stored bare access-token strings rather than JSON-encoded TokenSets. SaveToken always writes the new shape; GetToken transparently handles both. The legacy decodeJSON / decodeJSONStrict tests are removed; equivalent coverage now lives in auth/deviceflow tests. Co-Authored-By: Claude Opus 4.7 (1M context) --- cmd/entire/cli/auth/client.go | 233 ++++++++++------------------- cmd/entire/cli/auth/client_test.go | 42 ------ cmd/entire/cli/auth/store.go | 64 ++++---- 3 files changed, 117 insertions(+), 222 deletions(-) delete mode 100644 cmd/entire/cli/auth/client_test.go diff --git a/cmd/entire/cli/auth/client.go b/cmd/entire/cli/auth/client.go index 5ffd215003..c0dfad81ab 100644 --- a/cmd/entire/cli/auth/client.go +++ b/cmd/entire/cli/auth/client.go @@ -1,187 +1,118 @@ package auth import ( - "bytes" "context" - "encoding/json" - "fmt" - "io" + "errors" "net/http" - "net/url" - "strings" + "time" + "github.com/entireio/cli/auth/deviceflow" + "github.com/entireio/cli/auth/tokens" "github.com/entireio/cli/cmd/entire/cli/api" ) +// nowFunc is the package's clock. Override in tests. +var nowFunc = time.Now + +// clientID is the OAuth client_id this CLI registers under at the +// entire.io device-flow endpoint. +const clientID = "entire-cli" + +// Device-flow endpoint paths on entire.io. Held as constants so the +// shim is the only place that knows them; if the CLI ever needs to +// target a different surface, the change lives here. const ( - maxResponseBytes = 1 << 20 - clientID = "entire-cli" + deviceCodePath = "/oauth/device/code" + tokenPath = "/oauth/token" ) -type Client struct { - httpClient *http.Client - baseURL string -} - -type DeviceAuthStart struct { - DeviceCode string `json:"device_code"` - UserCode string `json:"user_code"` - VerificationURI string `json:"verification_uri"` - VerificationURIComplete string `json:"verification_uri_complete"` - ExpiresIn int `json:"expires_in"` - Interval int `json:"interval"` -} +// DeviceAuthStart preserves the historical type name; the shape now +// matches deviceflow.DeviceCode field-for-field. +type DeviceAuthStart = deviceflow.DeviceCode +// DeviceAuthPoll is the historical token-poll response shape. The shim +// flattens deviceflow's typed errors back into the Error field so +// existing login.go logic that switches on result.Error keeps working. type DeviceAuthPoll struct { - AccessToken string `json:"access_token,omitempty"` - TokenType string `json:"token_type,omitempty"` - ExpiresIn int `json:"expires_in,omitempty"` - Scope string `json:"scope,omitempty"` - Error string `json:"error,omitempty"` + AccessToken string + TokenType string + ExpiresIn int + Scope string + Error string } -type errorResponse struct { - Error string `json:"error"` +// Client wraps a deviceflow.Client preconfigured for entire.io. +type Client struct { + inner *deviceflow.Client } +// NewClient constructs a Client targeting entire.io. httpClient is +// used directly when non-nil; otherwise http.DefaultClient. func NewClient(httpClient *http.Client) *Client { - if httpClient == nil { - httpClient = &http.Client{} - } - - return &Client{ - httpClient: httpClient, - baseURL: api.BaseURL(), - } + return &Client{inner: &deviceflow.Client{ + HTTP: httpClient, + BaseURL: api.BaseURL(), + ClientID: clientID, + Scope: "cli", + UserAgent: clientID, + DeviceCodePath: deviceCodePath, + TokenPath: tokenPath, + }} } -func (c *Client) BaseURL() string { - return c.baseURL -} +// BaseURL returns the issuer base URL this client talks to. +func (c *Client) BaseURL() string { return c.inner.BaseURL } +// StartDeviceAuth requests a fresh device code. func (c *Client) StartDeviceAuth(ctx context.Context) (*DeviceAuthStart, error) { - body := url.Values{} - body.Set("client_id", clientID) - body.Set("scope", "cli") - - resp, err := c.postForm(ctx, "/oauth/device/code", body) - if err != nil { - return nil, fmt.Errorf("start device auth: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, readAPIError(resp, "start device auth") - } - - var result DeviceAuthStart - if err := decodeJSONStrict(resp.Body, &result); err != nil { - return nil, fmt.Errorf("decode device auth start response: %w", err) - } - - return &result, nil + return c.inner.StartDeviceAuth(ctx) } +// PollDeviceAuth polls the token endpoint. On any RFC 8628 §3.5 error, +// the wire-side error code is returned in DeviceAuthPoll.Error so the +// existing polling loop in login.go can branch on it. Non-RFC errors +// (network, decode) are returned as a real error. func (c *Client) PollDeviceAuth(ctx context.Context, deviceCode string) (*DeviceAuthPoll, error) { - body := url.Values{} - body.Set("grant_type", "urn:ietf:params:oauth:grant-type:device_code") - body.Set("client_id", clientID) - body.Set("device_code", deviceCode) - - resp, err := c.postForm(ctx, "/oauth/token", body) + t, err := c.inner.PollDeviceAuth(ctx, deviceCode) if err != nil { - return nil, fmt.Errorf("poll device auth: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - apiErr, err := readAPIErrorResponse(resp) - if err != nil { - return nil, fmt.Errorf("poll device auth: %w", err) + if code := oauthErrorCode(err); code != "" { + return &DeviceAuthPoll{Error: code}, nil } - return &DeviceAuthPoll{Error: apiErr.Error}, nil - } - - var result DeviceAuthPoll - if err := decodeJSON(resp.Body, &result); err != nil { - return nil, fmt.Errorf("decode device auth poll response: %w", err) + return nil, err } - return &result, nil + return &DeviceAuthPoll{ + AccessToken: t.AccessToken, + TokenType: t.TokenType, + ExpiresIn: secondsUntil(t), + Scope: t.Scope, + }, nil } -// postForm sends a POST request with form-encoded body to an API-relative path. -func (c *Client) postForm(ctx context.Context, path string, body url.Values) (*http.Response, error) { - endpoint, err := api.ResolveURLFromBase(c.baseURL, path) - if err != nil { - return nil, fmt.Errorf("resolve URL %s: %w", path, err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(body.Encode())) - if err != nil { - return nil, fmt.Errorf("create request: %w", err) - } - - req.Header.Set("Accept", "application/json") - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("User-Agent", clientID) - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("request %s: %w", path, err) - } - - return resp, nil +// oauthErrorCode returns the wire-side code for a recognised RFC 8628 +// sentinel error, or "" if err isn't one. +func oauthErrorCode(err error) string { + switch { + case errors.Is(err, deviceflow.ErrAuthorizationPending): + return "authorization_pending" + case errors.Is(err, deviceflow.ErrSlowDown): + return "slow_down" + case errors.Is(err, deviceflow.ErrAccessDenied): + return "access_denied" + case errors.Is(err, deviceflow.ErrExpiredToken): + return "expired_token" + case errors.Is(err, deviceflow.ErrInvalidGrant): + return "invalid_grant" + } + return "" } -func readAPIErrorResponse(resp *http.Response) (*errorResponse, error) { - body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) - if err != nil { - return nil, fmt.Errorf("status %d", resp.StatusCode) - } - - var apiErr errorResponse - if err := json.Unmarshal(body, &apiErr); err == nil && strings.TrimSpace(apiErr.Error) != "" { - return &apiErr, nil - } - - text := strings.TrimSpace(string(body)) - if text != "" { - return nil, fmt.Errorf("status %d: %s", resp.StatusCode, text) +// secondsUntil computes seconds-until-expiry for a TokenSet with an +// absolute ExpiresAt. Returns 0 when no expiry is set, mirroring the +// historical shape of DeviceAuthPoll.ExpiresIn. +func secondsUntil(t *tokens.TokenSet) int { + if t.ExpiresAt.IsZero() { + return 0 } - - return nil, fmt.Errorf("status %d", resp.StatusCode) -} - -func readAPIError(resp *http.Response, action string) error { - apiErr, err := readAPIErrorResponse(resp) - if err == nil { - return fmt.Errorf("%s: %s", action, apiErr.Error) - } - return fmt.Errorf("%s: %w", action, err) -} - -func decodeJSON(r io.Reader, dest any) error { - return decodeJSONWithOptions(r, dest, false) -} - -func decodeJSONStrict(r io.Reader, dest any) error { - return decodeJSONWithOptions(r, dest, true) -} - -func decodeJSONWithOptions(r io.Reader, dest any, strict bool) error { - body, err := io.ReadAll(io.LimitReader(r, maxResponseBytes)) - if err != nil { - return fmt.Errorf("read JSON response: %w", err) - } - - dec := json.NewDecoder(bytes.NewReader(body)) - if strict { - dec.DisallowUnknownFields() - } - if err := dec.Decode(dest); err != nil { - return fmt.Errorf("decode JSON response: %w", err) - } - - return nil + return int(t.ExpiresAt.Unix() - nowFunc().Unix()) } diff --git a/cmd/entire/cli/auth/client_test.go b/cmd/entire/cli/auth/client_test.go deleted file mode 100644 index 6a6a0bf229..0000000000 --- a/cmd/entire/cli/auth/client_test.go +++ /dev/null @@ -1,42 +0,0 @@ -package auth - -import ( - "strings" - "testing" -) - -func TestDecodeJSON_AllowsUnknownFields(t *testing.T) { - t.Parallel() - - var result DeviceAuthPoll - err := decodeJSON(strings.NewReader(`{ - "access_token": "token", - "token_type": "Bearer", - "refresh_token": "ignored" - }`), &result) - if err != nil { - t.Fatalf("decodeJSON() error = %v", err) - } - - if result.AccessToken != "token" { - t.Fatalf("AccessToken = %q, want %q", result.AccessToken, "token") - } -} - -func TestDecodeJSONStrict_RejectsUnknownFields(t *testing.T) { - t.Parallel() - - var result DeviceAuthStart - err := decodeJSONStrict(strings.NewReader(`{ - "device_code": "device", - "user_code": "ABCD-EFGH", - "verification_uri": "https://example.com/verify", - "verification_uri_complete": "https://example.com/verify?code=ABCD-EFGH", - "expires_in": 600, - "interval": 5, - "extra": true - }`), &result) - if err == nil { - t.Fatal("decodeJSONStrict() error = nil, want unknown-field error") - } -} diff --git a/cmd/entire/cli/auth/store.go b/cmd/entire/cli/auth/store.go index ee39eef926..a21bae568d 100644 --- a/cmd/entire/cli/auth/store.go +++ b/cmd/entire/cli/auth/store.go @@ -5,6 +5,8 @@ import ( "fmt" "strings" + "github.com/entireio/cli/auth/tokens" + "github.com/entireio/cli/auth/tokenstore" "github.com/entireio/cli/cmd/entire/cli/api" "github.com/zalando/go-keyring" ) @@ -12,18 +14,24 @@ import ( const keyringService = "entire-cli" // Store manages CLI authentication tokens in the OS keyring. +// +// Wraps tokenstore.Keyring with a backward-compatibility read path: +// pre-shim entries stored bare access-token strings, while the shared +// Keyring writes JSON-encoded TokenSets. GetToken transparently +// handles both shapes; SaveToken always writes the new shape. type Store struct { - service string + inner *tokenstore.Keyring } // NewStore returns a Store backed by the system keyring. func NewStore() *Store { - return &Store{service: keyringService} + return &Store{inner: tokenstore.NewKeyring(keyringService)} } -// NewStoreWithService returns a Store with a custom keyring service name (for testing). +// NewStoreWithService returns a Store with a custom keyring service +// name (for testing). func NewStoreWithService(service string) *Store { - return &Store{service: service} + return &Store{inner: tokenstore.NewKeyring(service)} } // SaveToken persists an access token for the given base URL. @@ -32,43 +40,41 @@ func (s *Store) SaveToken(baseURL, token string) error { if token == "" { return errors.New("refusing to save empty token") } - - if err := keyring.Set(s.service, baseURL, token); err != nil { - return fmt.Errorf("save token to keyring: %w", err) - } - - return nil + return s.inner.SaveTokens(baseURL, tokens.TokenSet{AccessToken: token}) } -// GetToken retrieves a stored token for the given base URL. -// Returns an empty string (and no error) if no token is stored. +// GetToken retrieves a stored token for the given base URL. Returns +// an empty string (and no error) if no token is stored. +// +// Falls back to a bare-string read to surface tokens written before +// the shim landed. func (s *Store) GetToken(baseURL string) (string, error) { - token, err := keyring.Get(s.service, baseURL) - if errors.Is(err, keyring.ErrNotFound) { - return "", nil + t, err := s.inner.LoadTokens(baseURL) + if err == nil { + return t.AccessToken, nil } - if err != nil { - return "", fmt.Errorf("get token from keyring: %w", err) + if errors.Is(err, tokenstore.ErrNotFound) { + return "", nil } - return token, nil + // Legacy fallback: pre-shim entries stored the raw access token + // rather than a JSON-encoded TokenSet. + raw, kerr := keyring.Get(s.inner.Service, baseURL) + if errors.Is(kerr, keyring.ErrNotFound) { + return "", nil + } + if kerr != nil { + return "", fmt.Errorf("get token from keyring: %w", kerr) + } + return raw, nil } // DeleteToken removes a stored token for the given base URL. func (s *Store) DeleteToken(baseURL string) error { - err := keyring.Delete(s.service, baseURL) - if errors.Is(err, keyring.ErrNotFound) { - return nil - } - if err != nil { - return fmt.Errorf("delete token from keyring: %w", err) - } - - return nil + return s.inner.DeleteTokens(baseURL) } // LookupCurrentToken retrieves the token for the current base URL. func LookupCurrentToken() (string, error) { - store := NewStore() - return store.GetToken(api.BaseURL()) + return NewStore().GetToken(api.BaseURL()) } From a94bf890732ec86ded6cb5629ac5e38b93c8b14c Mon Sep 17 00:00:00 2001 From: Alex Ong Date: Wed, 6 May 2026 11:58:11 +1000 Subject: [PATCH 03/21] Wire ENTIRE_AUTH_PROVIDER_VERSION switch to select v1 or v2 Adds a transition-period env-var switch that picks between two device-flow configurations: v1 (default): /oauth/device/code + /oauth/token, client_id="entire-cli" v2 : /api/auth/oauth/device/code + /api/auth/token, client_id="cli" Both surfaces speak the same RFC 8628 protocol; only the paths and client_id differ. Default behaviour is unchanged. Setting ENTIRE_AUTH_PROVIDER_VERSION=v2 (alongside an appropriate ENTIRE_API_BASE_URL) opts a user into the next-generation surface early. Unrecognised values fall back to v1 so old binaries stay safe if a future v3 ever ships. Co-Authored-By: Claude Opus 4.7 (1M context) --- cmd/entire/cli/auth/client.go | 29 +++------ cmd/entire/cli/auth/provider.go | 51 +++++++++++++++ cmd/entire/cli/auth/provider_test.go | 94 ++++++++++++++++++++++++++++ 3 files changed, 155 insertions(+), 19 deletions(-) create mode 100644 cmd/entire/cli/auth/provider.go create mode 100644 cmd/entire/cli/auth/provider_test.go diff --git a/cmd/entire/cli/auth/client.go b/cmd/entire/cli/auth/client.go index c0dfad81ab..eb7517d1dd 100644 --- a/cmd/entire/cli/auth/client.go +++ b/cmd/entire/cli/auth/client.go @@ -14,18 +14,6 @@ import ( // nowFunc is the package's clock. Override in tests. var nowFunc = time.Now -// clientID is the OAuth client_id this CLI registers under at the -// entire.io device-flow endpoint. -const clientID = "entire-cli" - -// Device-flow endpoint paths on entire.io. Held as constants so the -// shim is the only place that knows them; if the CLI ever needs to -// target a different surface, the change lives here. -const ( - deviceCodePath = "/oauth/device/code" - tokenPath = "/oauth/token" -) - // DeviceAuthStart preserves the historical type name; the shape now // matches deviceflow.DeviceCode field-for-field. type DeviceAuthStart = deviceflow.DeviceCode @@ -41,22 +29,25 @@ type DeviceAuthPoll struct { Error string } -// Client wraps a deviceflow.Client preconfigured for entire.io. +// Client wraps a deviceflow.Client preconfigured for whichever provider +// version is selected via ENTIRE_AUTH_PROVIDER_VERSION (defaulting to +// v1). type Client struct { inner *deviceflow.Client } -// NewClient constructs a Client targeting entire.io. httpClient is -// used directly when non-nil; otherwise http.DefaultClient. +// NewClient constructs a Client targeting the active provider version. +// httpClient is used directly when non-nil; otherwise http.DefaultClient. func NewClient(httpClient *http.Client) *Client { + p := currentProvider() return &Client{inner: &deviceflow.Client{ HTTP: httpClient, BaseURL: api.BaseURL(), - ClientID: clientID, + ClientID: p.clientID, Scope: "cli", - UserAgent: clientID, - DeviceCodePath: deviceCodePath, - TokenPath: tokenPath, + UserAgent: p.clientID, + DeviceCodePath: p.deviceCodePath, + TokenPath: p.tokenPath, }} } diff --git a/cmd/entire/cli/auth/provider.go b/cmd/entire/cli/auth/provider.go new file mode 100644 index 0000000000..c0eed72550 --- /dev/null +++ b/cmd/entire/cli/auth/provider.go @@ -0,0 +1,51 @@ +package auth + +import ( + "os" + "strings" +) + +// ProviderVersionEnvVar selects which OAuth surface this CLI talks to. +// +// Recognised values: +// +// - "v1" (or unset / unrecognised) — current device-flow surface +// - "v2" — next-generation device-flow surface +// +// This is a transition-period switch: once v2 is the universal default +// the env var goes away. Surfaces are otherwise reachable as RFC 8628 +// device-flow endpoints; the only differences are paths and client_id. +const ProviderVersionEnvVar = "ENTIRE_AUTH_PROVIDER_VERSION" + +// providerConfig captures the per-surface bits of OAuth wiring. +type providerConfig struct { + clientID string + deviceCodePath string + tokenPath string +} + +var providers = map[string]providerConfig{ + "v1": { + clientID: "entire-cli", + deviceCodePath: "/oauth/device/code", + tokenPath: "/oauth/token", + }, + "v2": { + clientID: "cli", + deviceCodePath: "/api/auth/oauth/device/code", + tokenPath: "/api/auth/token", + }, +} + +// currentProvider returns the active providerConfig, defaulting to v1 +// when ENTIRE_AUTH_PROVIDER_VERSION is unset or holds an unrecognised +// value. Defaulting (rather than erroring) keeps old binaries safe if +// a future v3 ever lands. +func currentProvider() providerConfig { + switch strings.TrimSpace(os.Getenv(ProviderVersionEnvVar)) { + case "v2": + return providers["v2"] + default: + return providers["v1"] + } +} diff --git a/cmd/entire/cli/auth/provider_test.go b/cmd/entire/cli/auth/provider_test.go new file mode 100644 index 0000000000..746af3f646 --- /dev/null +++ b/cmd/entire/cli/auth/provider_test.go @@ -0,0 +1,94 @@ +package auth + +import ( + "net/http" + "testing" + + "github.com/entireio/cli/cmd/entire/cli/api" +) + +func TestCurrentProvider_DefaultsToV1(t *testing.T) { + t.Setenv(ProviderVersionEnvVar, "") + + p := currentProvider() + if p.clientID != "entire-cli" || p.deviceCodePath != "/oauth/device/code" || p.tokenPath != "/oauth/token" { + t.Fatalf("default provider = %+v, want v1 config", p) + } +} + +func TestCurrentProvider_V1Explicit(t *testing.T) { + t.Setenv(ProviderVersionEnvVar, "v1") + + p := currentProvider() + if p.clientID != "entire-cli" { + t.Fatalf("v1 clientID = %q", p.clientID) + } +} + +func TestCurrentProvider_V2(t *testing.T) { + t.Setenv(ProviderVersionEnvVar, "v2") + + p := currentProvider() + if p.clientID != "cli" { + t.Fatalf("v2 clientID = %q, want cli", p.clientID) + } + if p.deviceCodePath != "/api/auth/oauth/device/code" { + t.Fatalf("v2 deviceCodePath = %q", p.deviceCodePath) + } + if p.tokenPath != "/api/auth/token" { + t.Fatalf("v2 tokenPath = %q", p.tokenPath) + } +} + +func TestCurrentProvider_UnknownDefaultsToV1(t *testing.T) { + t.Setenv(ProviderVersionEnvVar, "v999") + + p := currentProvider() + if p.clientID != "entire-cli" { + t.Fatalf("unknown version should default to v1; got clientID = %q", p.clientID) + } +} + +func TestCurrentProvider_TrimsWhitespace(t *testing.T) { + t.Setenv(ProviderVersionEnvVar, " v2 ") + + p := currentProvider() + if p.clientID != "cli" { + t.Fatalf("whitespace-padded v2 clientID = %q, want cli", p.clientID) + } +} + +func TestNewClient_HonoursProviderVersion(t *testing.T) { + t.Setenv(api.BaseURLEnvVar, "https://example.test") + t.Setenv(ProviderVersionEnvVar, "v2") + + c := NewClient(&http.Client{}) + if c.inner.ClientID != "cli" { + t.Errorf("ClientID = %q, want cli", c.inner.ClientID) + } + if c.inner.DeviceCodePath != "/api/auth/oauth/device/code" { + t.Errorf("DeviceCodePath = %q", c.inner.DeviceCodePath) + } + if c.inner.TokenPath != "/api/auth/token" { + t.Errorf("TokenPath = %q", c.inner.TokenPath) + } + if c.inner.BaseURL != "https://example.test" { + t.Errorf("BaseURL = %q", c.inner.BaseURL) + } +} + +func TestNewClient_DefaultsToV1(t *testing.T) { + t.Setenv(api.BaseURLEnvVar, "https://example.test") + t.Setenv(ProviderVersionEnvVar, "") + + c := NewClient(nil) + if c.inner.ClientID != "entire-cli" { + t.Errorf("ClientID = %q, want entire-cli", c.inner.ClientID) + } + if c.inner.DeviceCodePath != "/oauth/device/code" { + t.Errorf("DeviceCodePath = %q", c.inner.DeviceCodePath) + } + if c.inner.TokenPath != "/oauth/token" { + t.Errorf("TokenPath = %q", c.inner.TokenPath) + } +} From 6624e25e8e1303bd90875151552225588f43ca2c Mon Sep 17 00:00:00 2001 From: Alex Ong Date: Wed, 6 May 2026 16:19:43 +1000 Subject: [PATCH 04/21] Add auth/sts: RFC 8693 OAuth 2.0 Token Exchange client MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the fourth subpackage of the auth/ library: a small, provider- agnostic client for RFC 8693 token exchange. Caller supplies BaseURL, Path, and per-call ExchangeRequest fields (SubjectToken, SubjectTokenType, RequestedTokenType, plus optional Audience/Resource/Scope and an Extra url.Values for any non-standard form fields the server expects). The package defines constants only for RFC 8693's standard token-type URIs and the token-exchange grant_type — the requested-token-type URI is always caller-supplied. Returns *tokens.TokenSet on success with absolute ExpiresAt; wraps RFC 6749 / 8693 error responses with both code and description. Tests cover happy path, optional-field omission, Extra forwarding, standard-fields-override-Extra precedence, missing required fields, JSON and non-JSON server errors, missing access_token, and the no- expiry case. Co-Authored-By: Claude Opus 4.7 (1M context) --- auth/sts/sts.go | 219 +++++++++++++++++++++++++++++++++++ auth/sts/sts_test.go | 265 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 484 insertions(+) create mode 100644 auth/sts/sts.go create mode 100644 auth/sts/sts_test.go diff --git a/auth/sts/sts.go b/auth/sts/sts.go new file mode 100644 index 0000000000..993c0ca9a2 --- /dev/null +++ b/auth/sts/sts.go @@ -0,0 +1,219 @@ +// Package sts is an RFC 8693 OAuth 2.0 Token Exchange client. +// +// Construct a Client with the issuer's BaseURL and the token endpoint +// path, then call Exchange with a populated ExchangeRequest. The +// package is provider-agnostic: every server-specific value (endpoint +// path, requested-token-type URIs, custom form fields) is supplied at +// call time. There is no provider detection. +package sts + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/entireio/cli/auth/tokens" +) + +// nowFunc is the package's clock. Override in tests. +var nowFunc = time.Now + +// maxResponseBytes caps how much of an OAuth response body we read. +const maxResponseBytes = 1 << 20 + +// RFC 8693 grant_type and standard subject-token type URIs. Caller +// supplies RequestedTokenType (which is always implementation-specific +// outside of these RFC 8693 standard values). +const ( + GrantTypeTokenExchange = "urn:ietf:params:oauth:grant-type:token-exchange" + + SubjectTokenTypeJWT = "urn:ietf:params:oauth:token-type:jwt" + SubjectTokenTypeAccessToken = "urn:ietf:params:oauth:token-type:access_token" +) + +// ExchangeRequest is the input to a token exchange. +// +// SubjectToken, SubjectTokenType, and RequestedTokenType are required. +// Audience, Resource, and Scope map to RFC 8693 §2.1 parameters and +// are sent only when non-empty. Extra carries implementation-specific +// form fields (e.g. server-defined parameters not in RFC 8693) that +// the caller's server expects; the standard fields above always win +// if Extra also sets them. +type ExchangeRequest struct { + SubjectToken string + SubjectTokenType string + RequestedTokenType string + + Audience string + Resource string + Scope string + + Extra url.Values +} + +func (r ExchangeRequest) validate() error { + switch { + case r.SubjectToken == "": + return errors.New("SubjectToken is required") + case r.SubjectTokenType == "": + return errors.New("SubjectTokenType is required") + case r.RequestedTokenType == "": + return errors.New("RequestedTokenType is required") + } + return nil +} + +// Client exchanges subject tokens for tokens of a different type at an +// RFC 8693 token endpoint. +// +// All configuration is explicit; the package has no global state and +// no implicit URLs. Provide BaseURL and Path; the rest is RFC 8693. +type Client struct { + HTTP *http.Client + BaseURL string + Path string + UserAgent string +} + +// Exchange performs one RFC 8693 token exchange. +// +// Returns a TokenSet with absolute ExpiresAt derived from the server's +// expires_in. Returns an error wrapping the response body when the +// server responds with a non-2xx status; callers can match on the +// returned error message for known OAuth error codes. +func (c *Client) Exchange(ctx context.Context, req ExchangeRequest) (*tokens.TokenSet, error) { + if err := req.validate(); err != nil { + return nil, fmt.Errorf("token exchange: %w", err) + } + + form := buildForm(req) + + endpoint, err := resolveURL(c.BaseURL, c.Path) + if err != nil { + return nil, fmt.Errorf("token exchange: resolve URL: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode())) + if err != nil { + return nil, fmt.Errorf("token exchange: create request: %w", err) + } + httpReq.Header.Set("Accept", "application/json") + httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + if c.UserAgent != "" { + httpReq.Header.Set("User-Agent", c.UserAgent) + } + + httpClient := c.HTTP + if httpClient == nil { + httpClient = http.DefaultClient + } + + resp, err := httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("token exchange: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, readAPIError(resp) + } + + var raw struct { + AccessToken string `json:"access_token"` + IssuedTokenType string `json:"issued_token_type"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + Scope string `json:"scope"` + } + body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) + if err != nil { + return nil, fmt.Errorf("token exchange: read response: %w", err) + } + if err := json.Unmarshal(body, &raw); err != nil { + return nil, fmt.Errorf("token exchange: decode response: %w", err) + } + if raw.AccessToken == "" { + return nil, errors.New("token exchange: response missing access_token") + } + + t := &tokens.TokenSet{ + AccessToken: raw.AccessToken, + RefreshToken: raw.RefreshToken, + TokenType: raw.TokenType, + Scope: raw.Scope, + } + if raw.ExpiresIn > 0 { + t.ExpiresAt = nowFunc().Add(time.Duration(raw.ExpiresIn) * time.Second) + } + return t, nil +} + +// buildForm renders an ExchangeRequest into the wire form, layering +// the standard RFC 8693 fields on top of req.Extra so caller-supplied +// duplicates of standard fields are overwritten by the typed values. +func buildForm(req ExchangeRequest) url.Values { + form := url.Values{} + for k, v := range req.Extra { + form[k] = append(form[k], v...) + } + + form.Set("grant_type", GrantTypeTokenExchange) + form.Set("subject_token", req.SubjectToken) + form.Set("subject_token_type", req.SubjectTokenType) + form.Set("requested_token_type", req.RequestedTokenType) + + if req.Audience != "" { + form.Set("audience", req.Audience) + } + if req.Resource != "" { + form.Set("resource", req.Resource) + } + if req.Scope != "" { + form.Set("scope", req.Scope) + } + return form +} + +func resolveURL(baseURL, path string) (string, error) { + base, err := url.Parse(baseURL) + if err != nil { + return "", fmt.Errorf("parse base URL: %w", err) + } + if base.Scheme != "http" && base.Scheme != "https" { + return "", fmt.Errorf("unsupported base URL scheme %q (must be http or https)", base.Scheme) + } + rel, err := url.Parse(path) + if err != nil { + return "", fmt.Errorf("parse path: %w", err) + } + return base.ResolveReference(rel).String(), nil +} + +type errorResponse struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` +} + +func readAPIError(resp *http.Response) error { + body, _ := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) + var apiErr errorResponse + if err := json.Unmarshal(bytes.TrimSpace(body), &apiErr); err == nil && apiErr.Error != "" { + if apiErr.ErrorDescription != "" { + return fmt.Errorf("token exchange: status %d: %s: %s", resp.StatusCode, apiErr.Error, apiErr.ErrorDescription) + } + return fmt.Errorf("token exchange: status %d: %s", resp.StatusCode, apiErr.Error) + } + text := strings.TrimSpace(string(body)) + if text != "" { + return fmt.Errorf("token exchange: status %d: %s", resp.StatusCode, text) + } + return fmt.Errorf("token exchange: status %d", resp.StatusCode) +} diff --git a/auth/sts/sts_test.go b/auth/sts/sts_test.go new file mode 100644 index 0000000000..c52973375c --- /dev/null +++ b/auth/sts/sts_test.go @@ -0,0 +1,265 @@ +package sts + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" +) + +const testTokenPath = "/sts/token" + +func freezeClock(t *testing.T, at time.Time) { + t.Helper() + prev := nowFunc + nowFunc = func() time.Time { return at } + t.Cleanup(func() { nowFunc = prev }) +} + +func newTestClient(t *testing.T, h http.HandlerFunc) (*Client, *httptest.Server) { + t.Helper() + srv := httptest.NewServer(h) + t.Cleanup(srv.Close) + + return &Client{ + HTTP: srv.Client(), + BaseURL: srv.URL, + Path: testTokenPath, + }, srv +} + +func mustReadForm(t *testing.T, r *http.Request) { + t.Helper() + if err := r.ParseForm(); err != nil { + t.Fatalf("parse form: %v", err) + } +} + +func TestExchange_Success(t *testing.T) { + t.Parallel() + + freezeClock(t, time.Date(2026, 5, 6, 12, 0, 0, 0, time.UTC)) + + c, _ := newTestClient(t, func(w http.ResponseWriter, r *http.Request) { + mustReadForm(t, r) + if got := r.PostForm.Get("grant_type"); got != GrantTypeTokenExchange { + t.Errorf("grant_type = %q", got) + } + if got := r.PostForm.Get("subject_token_type"); got != SubjectTokenTypeJWT { + t.Errorf("subject_token_type = %q", got) + } + if got := r.PostForm.Get("requested_token_type"); got != "urn:example:token-type:thing" { + t.Errorf("requested_token_type = %q", got) + } + if got := r.PostForm.Get("subject_token"); got != "sub-jwt" { + t.Errorf("subject_token = %q", got) + } + if got := r.PostForm.Get("audience"); got != "audience-x" { + t.Errorf("audience = %q", got) + } + if got := r.PostForm.Get("resource"); got != "owner/repo" { + t.Errorf("resource = %q", got) + } + if got := r.PostForm.Get("scope"); got != "thing:do" { + t.Errorf("scope = %q", got) + } + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{ + "access_token":"acc", + "issued_token_type":"urn:example:token-type:thing", + "token_type":"Bearer", + "expires_in":3600, + "refresh_token":"ref", + "scope":"thing:do" + }`) + }) + + got, err := c.Exchange(context.Background(), ExchangeRequest{ + SubjectToken: "sub-jwt", + SubjectTokenType: SubjectTokenTypeJWT, + RequestedTokenType: "urn:example:token-type:thing", + Audience: "audience-x", + Resource: "owner/repo", + Scope: "thing:do", + }) + if err != nil { + t.Fatalf("Exchange() error = %v", err) + } + + if got.AccessToken != "acc" || got.RefreshToken != "ref" || got.TokenType != "Bearer" || got.Scope != "thing:do" { + t.Fatalf("TokenSet = %+v", got) + } + want := time.Date(2026, 5, 6, 13, 0, 0, 0, time.UTC) + if !got.ExpiresAt.Equal(want) { + t.Fatalf("ExpiresAt = %v, want %v", got.ExpiresAt, want) + } +} + +func TestExchange_OmitsOptionalFieldsWhenEmpty(t *testing.T) { + t.Parallel() + + c, _ := newTestClient(t, func(w http.ResponseWriter, r *http.Request) { + mustReadForm(t, r) + for _, k := range []string{"audience", "resource", "scope"} { + if r.PostForm.Has(k) { + t.Errorf("optional field %q should not be sent when empty", k) + } + } + _, _ = io.WriteString(w, `{"access_token":"acc"}`) + }) + + if _, err := c.Exchange(context.Background(), ExchangeRequest{ + SubjectToken: "sub", + SubjectTokenType: SubjectTokenTypeJWT, + RequestedTokenType: "urn:example:t", + }); err != nil { + t.Fatalf("Exchange() error = %v", err) + } +} + +func TestExchange_ExtraFieldsForwarded(t *testing.T) { + t.Parallel() + + c, _ := newTestClient(t, func(w http.ResponseWriter, r *http.Request) { + mustReadForm(t, r) + if got := r.PostForm.Get("custom_field"); got != "custom-value" { + t.Errorf("custom_field = %q", got) + } + _, _ = io.WriteString(w, `{"access_token":"acc"}`) + }) + + if _, err := c.Exchange(context.Background(), ExchangeRequest{ + SubjectToken: "sub", + SubjectTokenType: SubjectTokenTypeJWT, + RequestedTokenType: "urn:example:t", + Extra: url.Values{"custom_field": {"custom-value"}}, + }); err != nil { + t.Fatalf("Exchange() error = %v", err) + } +} + +func TestExchange_StandardFieldsOverrideExtra(t *testing.T) { + t.Parallel() + + c, _ := newTestClient(t, func(w http.ResponseWriter, r *http.Request) { + mustReadForm(t, r) + // Caller tried to set grant_type via Extra; standard wins. + if got := r.PostForm.Get("grant_type"); got != GrantTypeTokenExchange { + t.Errorf("Extra should not override standard grant_type; got %q", got) + } + _, _ = io.WriteString(w, `{"access_token":"acc"}`) + }) + + if _, err := c.Exchange(context.Background(), ExchangeRequest{ + SubjectToken: "sub", + SubjectTokenType: SubjectTokenTypeJWT, + RequestedTokenType: "urn:example:t", + Extra: url.Values{"grant_type": {"trojan"}}, + }); err != nil { + t.Fatalf("Exchange() error = %v", err) + } +} + +func TestExchange_RejectsMissingRequiredFields(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + req ExchangeRequest + }{ + {"no subject token", ExchangeRequest{SubjectTokenType: SubjectTokenTypeJWT, RequestedTokenType: "urn:example:t"}}, + {"no subject token type", ExchangeRequest{SubjectToken: "sub", RequestedTokenType: "urn:example:t"}}, + {"no requested token type", ExchangeRequest{SubjectToken: "sub", SubjectTokenType: SubjectTokenTypeJWT}}, + } + + c := &Client{HTTP: http.DefaultClient, BaseURL: "https://example.test", Path: testTokenPath} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if _, err := c.Exchange(context.Background(), tt.req); err == nil { + t.Fatal("Exchange() should fail on missing required field") + } + }) + } +} + +func TestExchange_ServerError(t *testing.T) { + t.Parallel() + + c, _ := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = io.WriteString(w, `{"error":"invalid_request","error_description":"bad subject"}`) + }) + + _, err := c.Exchange(context.Background(), ExchangeRequest{ + SubjectToken: "sub", + SubjectTokenType: SubjectTokenTypeJWT, + RequestedTokenType: "urn:example:t", + }) + if err == nil { + t.Fatal("Exchange() with 400 should fail") + } + if !strings.Contains(err.Error(), "invalid_request") || !strings.Contains(err.Error(), "bad subject") { + t.Fatalf("error = %v, want both code and description", err) + } +} + +func TestExchange_ServerErrorWithoutJSON(t *testing.T) { + t.Parallel() + + c, _ := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = io.WriteString(w, `something broke`) + }) + + _, err := c.Exchange(context.Background(), ExchangeRequest{ + SubjectToken: "sub", + SubjectTokenType: SubjectTokenTypeJWT, + RequestedTokenType: "urn:example:t", + }) + if err == nil || !strings.Contains(err.Error(), "500") || !strings.Contains(err.Error(), "something broke") { + t.Fatalf("error = %v, want status + body text", err) + } +} + +func TestExchange_MissingAccessToken(t *testing.T) { + t.Parallel() + + c, _ := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, `{"token_type":"Bearer"}`) + }) + + _, err := c.Exchange(context.Background(), ExchangeRequest{ + SubjectToken: "sub", + SubjectTokenType: SubjectTokenTypeJWT, + RequestedTokenType: "urn:example:t", + }) + if err == nil || !strings.Contains(err.Error(), "missing access_token") { + t.Fatalf("error = %v, want missing access_token", err) + } +} + +func TestExchange_NoExpiry(t *testing.T) { + t.Parallel() + + c, _ := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, `{"access_token":"acc","token_type":"Bearer"}`) + }) + + got, err := c.Exchange(context.Background(), ExchangeRequest{ + SubjectToken: "sub", + SubjectTokenType: SubjectTokenTypeJWT, + RequestedTokenType: "urn:example:t", + }) + if err != nil { + t.Fatalf("Exchange() error = %v", err) + } + if !got.ExpiresAt.IsZero() { + t.Fatalf("ExpiresAt = %v, want zero", got.ExpiresAt) + } +} From 9e9bc3f0728b7b11c3ad6ca79dffecc784e1ce7b Mon Sep 17 00:00:00 2001 From: Alex Ong Date: Thu, 7 May 2026 10:19:20 +1000 Subject: [PATCH 05/21] auth: surface friendly error when OAuth response is HTML, not JSON MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Captive portals, corporate proxies, and VPN firewalls (Cloudflare WARP, etc.) commonly intercept the OAuth endpoint and return a 200 OK with an HTML error page. Today the JSON decoder produces an opaque error like: start login: decode device auth start response: decode JSON response: invalid character '<' looking for beginning of value That tells the user nothing actionable. Now both auth/deviceflow and auth/sts surface: could not reach authentication server: server returned non-JSON response (check VPN, proxy, or firewall — e.g. Cloudflare WARP) Implementation lives in a new internal package auth/internal/oauthhttp. Both deviceflow and sts now run their successful-response bodies through oauthhttp.ReadAndDecodeJSON, which sniffs for a leading '<' (after trimming whitespace) and returns a typed ErrNonJSONResponse sentinel — callers can errors.Is when they want to branch, or just let the message bubble up. Tests cover the helper in isolation plus end-to-end paths through both StartDeviceAuth, PollDeviceAuth, and Exchange. Co-Authored-By: Claude Opus 4.7 (1M context) --- auth/deviceflow/deviceflow.go | 33 ++----- auth/deviceflow/deviceflow_test.go | 42 +++++++++ auth/internal/oauthhttp/jsonresp.go | 64 +++++++++++++ auth/internal/oauthhttp/jsonresp_test.go | 115 +++++++++++++++++++++++ auth/sts/sts.go | 28 +++--- auth/sts/sts_test.go | 24 +++++ 6 files changed, 262 insertions(+), 44 deletions(-) create mode 100644 auth/internal/oauthhttp/jsonresp.go create mode 100644 auth/internal/oauthhttp/jsonresp_test.go diff --git a/auth/deviceflow/deviceflow.go b/auth/deviceflow/deviceflow.go index be866306c6..6fa38ddfbf 100644 --- a/auth/deviceflow/deviceflow.go +++ b/auth/deviceflow/deviceflow.go @@ -13,7 +13,6 @@ package deviceflow import ( - "bytes" "context" "encoding/json" "errors" @@ -24,6 +23,7 @@ import ( "strings" "time" + "github.com/entireio/cli/auth/internal/oauthhttp" "github.com/entireio/cli/auth/tokens" ) @@ -31,11 +31,6 @@ import ( // time.Now. var nowFunc = time.Now -// maxResponseBytes caps how much of an OAuth response body we read. -// Both endpoints return small JSON documents; larger bodies indicate -// either a misconfigured proxy or an attempt to exhaust client memory. -const maxResponseBytes = 1 << 20 - // deviceCodeGrantType is the RFC 8628 token-endpoint grant_type for // polling device-flow authorization. const deviceCodeGrantType = "urn:ietf:params:oauth:grant-type:device_code" @@ -132,8 +127,8 @@ func (c *Client) StartDeviceAuth(ctx context.Context) (*DeviceCode, error) { } var result DeviceCode - if err := decodeJSON(resp.Body, &result, true); err != nil { - return nil, fmt.Errorf("decode device auth start response: %w", err) + if err := oauthhttp.ReadAndDecodeJSON(resp.Body, &result, true); err != nil { + return nil, fmt.Errorf("start device auth: %w", err) } return &result, nil } @@ -172,8 +167,8 @@ func (c *Client) PollDeviceAuth(ctx context.Context, deviceCode string) (*tokens RefreshToken string `json:"refresh_token"` Scope string `json:"scope"` } - if err := decodeJSON(resp.Body, &raw, false); err != nil { - return nil, fmt.Errorf("decode device auth poll response: %w", err) + if err := oauthhttp.ReadAndDecodeJSON(resp.Body, &raw, false); err != nil { + return nil, fmt.Errorf("poll device auth: %w", err) } if raw.AccessToken == "" { @@ -243,7 +238,7 @@ type errorResponse struct { } func readAPIErrorResponse(resp *http.Response) (*errorResponse, error) { - body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) + body, err := io.ReadAll(io.LimitReader(resp.Body, oauthhttp.MaxResponseBytes)) if err != nil { return nil, fmt.Errorf("status %d", resp.StatusCode) } @@ -267,19 +262,3 @@ func readAPIError(resp *http.Response, action string) error { } return fmt.Errorf("%s: %w", action, err) } - -func decodeJSON(r io.Reader, dest any, strict bool) error { - body, err := io.ReadAll(io.LimitReader(r, maxResponseBytes)) - if err != nil { - return fmt.Errorf("read JSON response: %w", err) - } - - dec := json.NewDecoder(bytes.NewReader(body)) - if strict { - dec.DisallowUnknownFields() - } - if err := dec.Decode(dest); err != nil { - return fmt.Errorf("decode JSON response: %w", err) - } - return nil -} diff --git a/auth/deviceflow/deviceflow_test.go b/auth/deviceflow/deviceflow_test.go index 1d217ed723..ebcf626bf3 100644 --- a/auth/deviceflow/deviceflow_test.go +++ b/auth/deviceflow/deviceflow_test.go @@ -244,6 +244,48 @@ func TestPollDeviceAuth_200WithNoAccessToken(t *testing.T) { } } +func TestStartDeviceAuth_HTMLBodySurfacesFriendlyError(t *testing.T) { + t.Parallel() + + // Captive portal / firewall (Cloudflare WARP, corp proxy) returns + // 200 OK with an HTML error page. Surface a network-actionable + // message instead of the opaque JSON-decode complaint. + c, _ := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/html") + _, _ = io.WriteString(w, `Access blocked`) + }) + + _, err := c.StartDeviceAuth(context.Background()) + if err == nil { + t.Fatal("StartDeviceAuth() with HTML body should error") + } + for _, want := range []string{"non-JSON", "VPN", "proxy", "firewall"} { + if !strings.Contains(err.Error(), want) { + t.Errorf("error missing %q hint: %s", want, err) + } + } + if strings.Contains(err.Error(), "invalid character") { + t.Errorf("raw JSON-decoder error leaked through: %s", err) + } +} + +func TestPollDeviceAuth_HTMLBodySurfacesFriendlyError(t *testing.T) { + t.Parallel() + + c, _ := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/html") + _, _ = io.WriteString(w, `Access blocked by WARP`) + }) + + _, err := c.PollDeviceAuth(context.Background(), "dev-1") + if err == nil { + t.Fatal("PollDeviceAuth() with HTML body should error") + } + if strings.Contains(err.Error(), "invalid character") { + t.Errorf("raw JSON-decoder error leaked through: %s", err) + } +} + func TestResolveURL(t *testing.T) { t.Parallel() diff --git a/auth/internal/oauthhttp/jsonresp.go b/auth/internal/oauthhttp/jsonresp.go new file mode 100644 index 0000000000..768858d2d1 --- /dev/null +++ b/auth/internal/oauthhttp/jsonresp.go @@ -0,0 +1,64 @@ +// Package oauthhttp holds shared HTTP-response helpers used by the +// auth subpackages. Internal: only auth/* may import. +package oauthhttp + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" +) + +// MaxResponseBytes caps how much of an OAuth response body we read. +// Both device-flow and token-exchange endpoints return small JSON +// documents; larger bodies indicate a misconfigured proxy or an +// attempt to exhaust client memory. +const MaxResponseBytes = 1 << 20 + +// ErrNonJSONResponse is returned by ReadAndDecodeJSON when a 200 OK +// from the authorization server's body looks like HTML rather than +// JSON — typically a captive portal, corporate proxy, or VPN firewall +// (Cloudflare WARP, etc.) intercepting the request and returning an +// error page. +// +// Callers can match with errors.Is and surface a UX message; the +// default Error() string is already user-readable. +var ErrNonJSONResponse = errors.New( + "could not reach authentication server: server returned non-JSON " + + "response (check VPN, proxy, or firewall — e.g. Cloudflare WARP)", +) + +// ReadAndDecodeJSON reads up to MaxResponseBytes from r and decodes +// the body as JSON into dest. When strict is true, unknown fields are +// rejected. +// +// Returns ErrNonJSONResponse when the body is HTML — the captive- +// portal / proxy-interceptor case. Other read or decode failures are +// wrapped with context. +func ReadAndDecodeJSON(r io.Reader, dest any, strict bool) error { + body, err := io.ReadAll(io.LimitReader(r, MaxResponseBytes)) + if err != nil { + return fmt.Errorf("read JSON response: %w", err) + } + if looksLikeHTML(body) { + return ErrNonJSONResponse + } + + dec := json.NewDecoder(bytes.NewReader(body)) + if strict { + dec.DisallowUnknownFields() + } + if err := dec.Decode(dest); err != nil { + return fmt.Errorf("decode JSON response: %w", err) + } + return nil +} + +// looksLikeHTML reports whether body's first non-whitespace byte is +// '<'. That covers HTML, XHTML, XML, and most captive-portal error +// pages without trying to fully sniff the document. +func looksLikeHTML(body []byte) bool { + trimmed := bytes.TrimSpace(body) + return len(trimmed) > 0 && trimmed[0] == '<' +} diff --git a/auth/internal/oauthhttp/jsonresp_test.go b/auth/internal/oauthhttp/jsonresp_test.go new file mode 100644 index 0000000000..da0ede757a --- /dev/null +++ b/auth/internal/oauthhttp/jsonresp_test.go @@ -0,0 +1,115 @@ +package oauthhttp + +import ( + "errors" + "strings" + "testing" +) + +func TestReadAndDecodeJSON_Success(t *testing.T) { + t.Parallel() + + var got struct { + A string `json:"a"` + B int `json:"b"` + } + err := ReadAndDecodeJSON(strings.NewReader(`{"a":"x","b":42}`), &got, false) + if err != nil { + t.Fatalf("error = %v", err) + } + if got.A != "x" || got.B != 42 { + t.Fatalf("got = %+v", got) + } +} + +func TestReadAndDecodeJSON_StrictRejectsUnknown(t *testing.T) { + t.Parallel() + + var got struct { + A string `json:"a"` + } + err := ReadAndDecodeJSON(strings.NewReader(`{"a":"x","extra":1}`), &got, true) + if err == nil { + t.Fatal("strict mode should reject unknown fields") + } + if errors.Is(err, ErrNonJSONResponse) { + t.Fatal("decode failure misclassified as non-JSON") + } +} + +func TestReadAndDecodeJSON_TolerantUnknown(t *testing.T) { + t.Parallel() + + var got struct { + A string `json:"a"` + } + err := ReadAndDecodeJSON(strings.NewReader(`{"a":"x","extra":1}`), &got, false) + if err != nil { + t.Fatalf("non-strict should accept unknown fields, got %v", err) + } +} + +func TestReadAndDecodeJSON_DetectsHTMLBody(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + body string + }{ + {"plain HTML", `Access denied`}, + {"DOCTYPE", ``}, + {"leading whitespace + HTML", " \n\t"}, + {"XML", ``}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + var dest map[string]any + err := ReadAndDecodeJSON(strings.NewReader(tt.body), &dest, false) + if !errors.Is(err, ErrNonJSONResponse) { + t.Fatalf("error = %v, want ErrNonJSONResponse", err) + } + }) + } +} + +func TestReadAndDecodeJSON_SurfacesGenuineDecodeErrors(t *testing.T) { + t.Parallel() + + var dest map[string]any + err := ReadAndDecodeJSON(strings.NewReader(`{"a": not json}`), &dest, false) + if err == nil { + t.Fatal("malformed JSON should error") + } + if errors.Is(err, ErrNonJSONResponse) { + t.Fatal("malformed-but-not-HTML should not be flagged as non-JSON response") + } + if !strings.Contains(err.Error(), "decode JSON response") { + t.Fatalf("error = %v, want wrapped decode error", err) + } +} + +func TestReadAndDecodeJSON_EmptyBody(t *testing.T) { + t.Parallel() + + var dest map[string]any + err := ReadAndDecodeJSON(strings.NewReader(""), &dest, false) + if err == nil { + t.Fatal("empty body should error") + } + if errors.Is(err, ErrNonJSONResponse) { + t.Fatal("empty body shouldn't be flagged as HTML") + } +} + +func TestErrNonJSONResponse_MessageIsActionable(t *testing.T) { + t.Parallel() + + msg := ErrNonJSONResponse.Error() + for _, want := range []string{"non-JSON", "VPN", "proxy", "firewall"} { + if !strings.Contains(msg, want) { + t.Errorf("message missing %q: %s", want, msg) + } + } +} diff --git a/auth/sts/sts.go b/auth/sts/sts.go index 993c0ca9a2..0c3487eaab 100644 --- a/auth/sts/sts.go +++ b/auth/sts/sts.go @@ -19,15 +19,13 @@ import ( "strings" "time" + "github.com/entireio/cli/auth/internal/oauthhttp" "github.com/entireio/cli/auth/tokens" ) // nowFunc is the package's clock. Override in tests. var nowFunc = time.Now -// maxResponseBytes caps how much of an OAuth response body we read. -const maxResponseBytes = 1 << 20 - // RFC 8693 grant_type and standard subject-token type URIs. Caller // supplies RequestedTokenType (which is always implementation-specific // outside of these RFC 8693 standard values). @@ -126,19 +124,15 @@ func (c *Client) Exchange(ctx context.Context, req ExchangeRequest) (*tokens.Tok } var raw struct { - AccessToken string `json:"access_token"` - IssuedTokenType string `json:"issued_token_type"` - TokenType string `json:"token_type"` - ExpiresIn int `json:"expires_in"` - RefreshToken string `json:"refresh_token"` - Scope string `json:"scope"` - } - body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) - if err != nil { - return nil, fmt.Errorf("token exchange: read response: %w", err) - } - if err := json.Unmarshal(body, &raw); err != nil { - return nil, fmt.Errorf("token exchange: decode response: %w", err) + AccessToken string `json:"access_token"` + IssuedTokenType string `json:"issued_token_type"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + Scope string `json:"scope"` + } + if err := oauthhttp.ReadAndDecodeJSON(resp.Body, &raw, false); err != nil { + return nil, fmt.Errorf("token exchange: %w", err) } if raw.AccessToken == "" { return nil, errors.New("token exchange: response missing access_token") @@ -203,7 +197,7 @@ type errorResponse struct { } func readAPIError(resp *http.Response) error { - body, _ := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) + body, _ := io.ReadAll(io.LimitReader(resp.Body, oauthhttp.MaxResponseBytes)) var apiErr errorResponse if err := json.Unmarshal(bytes.TrimSpace(body), &apiErr); err == nil && apiErr.Error != "" { if apiErr.ErrorDescription != "" { diff --git a/auth/sts/sts_test.go b/auth/sts/sts_test.go index c52973375c..d851afd0f4 100644 --- a/auth/sts/sts_test.go +++ b/auth/sts/sts_test.go @@ -244,6 +244,30 @@ func TestExchange_MissingAccessToken(t *testing.T) { } } +func TestExchange_HTMLBodySurfacesFriendlyError(t *testing.T) { + t.Parallel() + + c, _ := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/html") + _, _ = io.WriteString(w, `Blocked by firewall`) + }) + + _, err := c.Exchange(context.Background(), ExchangeRequest{ + SubjectToken: "sub", + SubjectTokenType: SubjectTokenTypeJWT, + RequestedTokenType: "urn:example:t", + }) + if err == nil { + t.Fatal("Exchange() with HTML body should error") + } + if !strings.Contains(err.Error(), "non-JSON") { + t.Errorf("error missing non-JSON hint: %s", err) + } + if strings.Contains(err.Error(), "invalid character") { + t.Errorf("raw JSON-decoder error leaked through: %s", err) + } +} + func TestExchange_NoExpiry(t *testing.T) { t.Parallel() From 0f05c157577335428293c3bf5d647a3e9fdc6bb6 Mon Sep 17 00:00:00 2001 From: Alex Ong Date: Thu, 7 May 2026 11:27:45 +1000 Subject: [PATCH 06/21] =?UTF-8?q?auth/deviceflow:=20surface=20error=5Fdesc?= =?UTF-8?q?ription=20from=20RFC=208628=20=C2=A73.5=20errors?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The token endpoint's error response carries an optional human-readable error_description alongside the standard error code (RFC 6749 §5.2, inherited by RFC 8628 §3.5). The lib was decoding only the code, which collapsed several distinct invalid_grant flavours — "device_code unknown" vs "client_id does not match grant" vs already-consumed replay — into a single opaque "device authorization failed: invalid_grant" at the CLI. Pull through: * auth/deviceflow now decodes both fields on a non-2xx response and wraps the sentinel error as fmt.Errorf("%w: %s", sentinel, desc) so errors.Is(err, ErrInvalidGrant) keeps matching while the message retains the description. * cmd/entire/cli/auth.DeviceAuthPoll grows an ErrorDescription field; the shim extracts it from the wrapped sentinel. * cmd/entire/cli/login.go appends ": " to the user-facing failure message when the server provided one. Two new deviceflow tests cover the description-present and description-absent paths (no trailing colon-space when absent). Co-Authored-By: Claude Opus 4.7 (1M context) --- auth/deviceflow/deviceflow.go | 12 ++++++++-- auth/deviceflow/deviceflow_test.go | 34 ++++++++++++++++++++++++++++ cmd/entire/cli/auth/client.go | 36 +++++++++++++++++++++++++----- cmd/entire/cli/login.go | 3 +++ 4 files changed, 77 insertions(+), 8 deletions(-) diff --git a/auth/deviceflow/deviceflow.go b/auth/deviceflow/deviceflow.go index 6fa38ddfbf..1ae7413e7a 100644 --- a/auth/deviceflow/deviceflow.go +++ b/auth/deviceflow/deviceflow.go @@ -157,7 +157,14 @@ func (c *Client) PollDeviceAuth(ctx context.Context, deviceCode string) (*tokens if parseErr != nil { return nil, fmt.Errorf("poll device auth: %w", parseErr) } - return nil, errCodeToSentinel(apiErr.Error) + err := errCodeToSentinel(apiErr.Error) + if apiErr.ErrorDescription != "" { + // Wrap so callers using errors.Is(err, ErrInvalidGrant) keep + // working while the description is still surfaced via + // err.Error(). Format: ": ". + err = fmt.Errorf("%w: %s", err, apiErr.ErrorDescription) + } + return nil, err } var raw struct { @@ -234,7 +241,8 @@ func resolveURL(baseURL, path string) (string, error) { } type errorResponse struct { - Error string `json:"error"` + Error string `json:"error"` + ErrorDescription string `json:"error_description"` } func readAPIErrorResponse(resp *http.Response) (*errorResponse, error) { diff --git a/auth/deviceflow/deviceflow_test.go b/auth/deviceflow/deviceflow_test.go index ebcf626bf3..1d0cd99df2 100644 --- a/auth/deviceflow/deviceflow_test.go +++ b/auth/deviceflow/deviceflow_test.go @@ -213,6 +213,40 @@ func TestPollDeviceAuth_ErrorCodes(t *testing.T) { } } +func TestPollDeviceAuth_ErrorDescription_AppendedToSentinel(t *testing.T) { + t.Parallel() + + c, _ := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = io.WriteString(w, `{"error":"invalid_grant","error_description":"device_code unknown"}`) + }) + + _, err := c.PollDeviceAuth(context.Background(), "dev-1") + if !errors.Is(err, ErrInvalidGrant) { + t.Fatalf("PollDeviceAuth() error = %v, want ErrInvalidGrant chain", err) + } + if !strings.Contains(err.Error(), "device_code unknown") { + t.Fatalf("error = %q, want it to include the description", err) + } +} + +func TestPollDeviceAuth_NoDescription_NoTrailingColon(t *testing.T) { + t.Parallel() + + c, _ := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = io.WriteString(w, `{"error":"invalid_grant"}`) + }) + + _, err := c.PollDeviceAuth(context.Background(), "dev-1") + if !errors.Is(err, ErrInvalidGrant) { + t.Fatalf("error = %v", err) + } + if strings.HasSuffix(err.Error(), ": ") { + t.Fatalf("error trailing colon-space: %q", err) + } +} + func TestPollDeviceAuth_UnknownErrorCode(t *testing.T) { t.Parallel() diff --git a/cmd/entire/cli/auth/client.go b/cmd/entire/cli/auth/client.go index eb7517d1dd..6743608f61 100644 --- a/cmd/entire/cli/auth/client.go +++ b/cmd/entire/cli/auth/client.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net/http" + "strings" "time" "github.com/entireio/cli/auth/deviceflow" @@ -21,12 +22,17 @@ type DeviceAuthStart = deviceflow.DeviceCode // DeviceAuthPoll is the historical token-poll response shape. The shim // flattens deviceflow's typed errors back into the Error field so // existing login.go logic that switches on result.Error keeps working. +// +// ErrorDescription carries the optional `error_description` from the +// server's RFC 8628 §3.5 error response, when present. Used to give +// callers a more actionable message than the bare error code. type DeviceAuthPoll struct { - AccessToken string - TokenType string - ExpiresIn int - Scope string - Error string + AccessToken string + TokenType string + ExpiresIn int + Scope string + Error string + ErrorDescription string } // Client wraps a deviceflow.Client preconfigured for whichever provider @@ -67,7 +73,10 @@ func (c *Client) PollDeviceAuth(ctx context.Context, deviceCode string) (*Device t, err := c.inner.PollDeviceAuth(ctx, deviceCode) if err != nil { if code := oauthErrorCode(err); code != "" { - return &DeviceAuthPoll{Error: code}, nil + return &DeviceAuthPoll{ + Error: code, + ErrorDescription: descriptionFromSentinel(err, code), + }, nil } return nil, err } @@ -98,6 +107,21 @@ func oauthErrorCode(err error) string { return "" } +// descriptionFromSentinel pulls the description suffix out of a wrapped +// sentinel error. The deviceflow lib uses fmt.Errorf("%w: %s", sentinel, +// description) when the server included an error_description, so the +// formatted error reads ": ". Stripping the +// ": " prefix yields the description; absent prefix means the +// server didn't supply one. +func descriptionFromSentinel(err error, code string) string { + msg := err.Error() + prefix := code + ": " + if rest, ok := strings.CutPrefix(msg, prefix); ok { + return rest + } + return "" +} + // secondsUntil computes seconds-until-expiry for a TokenSet with an // absolute ExpiresAt. Returns 0 when no expiry is set, mirroring the // historical shape of DeviceAuthPoll.ExpiresIn. diff --git a/cmd/entire/cli/login.go b/cmd/entire/cli/login.go index 88618c522e..de5ad84f8c 100644 --- a/cmd/entire/cli/login.go +++ b/cmd/entire/cli/login.go @@ -146,6 +146,9 @@ func waitForApproval(ctx context.Context, poller deviceAuthClient, deviceCode st case "expired_token": return "", errors.New("device authorization expired") default: + if result.ErrorDescription != "" { + return "", fmt.Errorf("device authorization failed: %s: %s", result.Error, result.ErrorDescription) + } return "", fmt.Errorf("device authorization failed: %s", result.Error) } From bb13cbd4decc38cc8dee6c111f34ee7e92d9a7c5 Mon Sep 17 00:00:00 2001 From: Alex Ong Date: Thu, 7 May 2026 14:07:40 +1000 Subject: [PATCH 07/21] api/auth_tokens: route to /api/v1/auth/tokens or /api/auth/tokens by version MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The auth-tokens endpoint family lives at different paths on the two backends — historical /api/v1/auth/tokens vs the consolidated /api/auth/tokens. ENTIRE_AUTH_PROVIDER_VERSION already gates the device-flow path split; auth_tokens.go now reads the same env var to pick its base path. ListTokens, RevokeToken, and RevokeCurrentToken all flow through one authTokensBasePath() helper so future paths land in one place. The env-var name is duplicated as a constant rather than imported from cmd/entire/cli/auth: api/ is a leaf package and shouldn't take a dependency on auth/ for routing. Both reads must stay in sync; flagged in a comment. Co-Authored-By: Claude Opus 4.7 (1M context) --- cmd/entire/cli/api/auth_tokens.go | 35 ++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/cmd/entire/cli/api/auth_tokens.go b/cmd/entire/cli/api/auth_tokens.go index ad9d53e5ba..a34445d07e 100644 --- a/cmd/entire/cli/api/auth_tokens.go +++ b/cmd/entire/cli/api/auth_tokens.go @@ -4,9 +4,11 @@ import ( "context" "fmt" "net/url" + "os" + "strings" ) -// Token is a single API token row returned by GET /api/v1/auth/tokens. +// Token is a single API token row returned by the auth-tokens endpoint. // Plaintext token values are never returned by the server — only metadata. type Token struct { ID string `json:"id"` @@ -18,15 +20,30 @@ type Token struct { CreatedAt string `json:"created_at"` } -// TokensResponse is the envelope returned by GET /api/v1/auth/tokens. +// TokensResponse is the envelope returned by the list endpoint. type TokensResponse struct { Tokens []Token `json:"tokens"` } +// authTokensProviderVersionEnvVar must match the env var read by +// cmd/entire/cli/auth's currentProvider(). Duplicated here rather than +// imported because api/ is a leaf package and shouldn't take a +// dependency on auth/ for routing. +const authTokensProviderVersionEnvVar = "ENTIRE_AUTH_PROVIDER_VERSION" + +// authTokensBasePath returns the auth-tokens endpoint family base path +// for the active provider version. v1 (default) hits /api/v1/auth/tokens; +// v2 hits /api/auth/tokens (no version segment). +func authTokensBasePath() string { + if strings.TrimSpace(os.Getenv(authTokensProviderVersionEnvVar)) == "v2" { + return "/api/auth/tokens" + } + return "/api/v1/auth/tokens" +} + // ListTokens returns the authenticated user's non-expired API tokens. -// Backed by GET /api/v1/auth/tokens. func (c *Client) ListTokens(ctx context.Context) ([]Token, error) { - resp, err := c.Get(ctx, "/api/v1/auth/tokens") + resp, err := c.Get(ctx, authTokensBasePath()) if err != nil { return nil, fmt.Errorf("list tokens: %w", err) } @@ -44,9 +61,12 @@ func (c *Client) ListTokens(ctx context.Context) ([]Token, error) { } // RevokeCurrentToken revokes the bearer token used to authenticate this client. -// Backed by DELETE /api/v1/auth/tokens/current. +// +// v1 has a dedicated /current endpoint. v2 doesn't expose one yet +// (would require a family_id claim on the JWT — tracked separately); +// callers can find the active family via ListTokens and revoke by ID. func (c *Client) RevokeCurrentToken(ctx context.Context) error { - resp, err := c.Delete(ctx, "/api/v1/auth/tokens/current") + resp, err := c.Delete(ctx, authTokensBasePath()+"/current") if err != nil { return fmt.Errorf("revoke current token: %w", err) } @@ -59,9 +79,8 @@ func (c *Client) RevokeCurrentToken(ctx context.Context) error { } // RevokeToken revokes the API token with the given id. -// Backed by DELETE /api/v1/auth/tokens/{id}. func (c *Client) RevokeToken(ctx context.Context, id string) error { - resp, err := c.Delete(ctx, "/api/v1/auth/tokens/"+url.PathEscape(id)) + resp, err := c.Delete(ctx, authTokensBasePath()+"/"+url.PathEscape(id)) if err != nil { return fmt.Errorf("revoke token %s: %w", id, err) } From d5737db78a3535e7b1a88c4ab604cdda356cb275 Mon Sep 17 00:00:00 2001 From: Alex Ong Date: Thu, 7 May 2026 14:27:53 +1000 Subject: [PATCH 08/21] auth: clear lint findings (errcheck, gosec G101/G117, unparam, goconst) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Across the auth/ library and customer-CLI shim, golangci-lint flagged a fistful of routine findings that the existing files inherited or that my recent commits introduced. None are correctness bugs; just noise that the repo's strict configuration wants explicit suppression for. * auth/deviceflow/deviceflow_test.go and auth/sts/sts_test.go grow a shared writeBody(t, w, body) helper, replacing every `_, _ = io.WriteString` in test fixtures. errcheck-clean without per-callsite nolints. newTestClient drops its unused *httptest.Server return (unparam). * auth/sts/sts.go suppresses gosec G101 on the three RFC 8693 standard URI constants (GrantTypeTokenExchange, SubjectTokenType*) and errcheck on the best-effort body read in readAPIError. * auth/tokenstore/keyring.go suppresses gosec G117 on the json.Marshal call that intentionally serialises the access token into the OS-keyring entry (encrypted at rest by the OS). * cmd/entire/cli/api/auth_tokens.go suppresses G101 on authTokensProviderVersionEnvVar — env-var name, not a credential. * cmd/entire/cli/auth/provider.go suppresses G101 on the v1/v2 entries in the providers map (OAuth client_id and endpoint paths, not credentials). * cmd/entire/cli/auth/provider_test.go extracts wantClientIDV1 / wantClientIDV2 test-local constants to satisfy goconst, then uses them in every comparison. cmd/entire/cli/auth/{client,store}.go also need nolint:wrapcheck comments on the four shim returns — those changes sit in the working tree alongside the in-progress AuthBaseURL refactor and will go in together with that commit. Co-Authored-By: Claude Opus 4.7 (1M context) --- auth/deviceflow/deviceflow_test.go | 65 ++++++++++++++++------------ auth/sts/sts.go | 8 ++-- auth/sts/sts_test.go | 51 +++++++++++++--------- auth/tokenstore/keyring.go | 5 ++- cmd/entire/cli/api/auth_tokens.go | 2 +- cmd/entire/cli/auth/provider.go | 4 +- cmd/entire/cli/auth/provider_test.go | 29 ++++++++----- 7 files changed, 98 insertions(+), 66 deletions(-) diff --git a/auth/deviceflow/deviceflow_test.go b/auth/deviceflow/deviceflow_test.go index 1d0cd99df2..a12b3d3fe0 100644 --- a/auth/deviceflow/deviceflow_test.go +++ b/auth/deviceflow/deviceflow_test.go @@ -12,6 +12,17 @@ import ( "time" ) + +// writeBody writes body to w from a test handler. Wraps io.WriteString +// with a t.Fatal on error so test fixtures stay readable without +// per-callsite nolint comments. +func writeBody(t *testing.T, w io.Writer, body string) { + t.Helper() + if _, err := io.WriteString(w, body); err != nil { + t.Fatalf("write body: %v", err) + } +} + const ( testClientID = "cli" testDeviceCodePath = "/oauth/device/code" @@ -26,7 +37,7 @@ func freezeClock(t *testing.T, at time.Time) { t.Cleanup(func() { nowFunc = prev }) } -func newTestClient(t *testing.T, h http.HandlerFunc) (*Client, *httptest.Server) { +func newTestClient(t *testing.T, h http.HandlerFunc) *Client { t.Helper() srv := httptest.NewServer(h) t.Cleanup(srv.Close) @@ -39,7 +50,7 @@ func newTestClient(t *testing.T, h http.HandlerFunc) (*Client, *httptest.Server) DeviceCodePath: testDeviceCodePath, TokenPath: testTokenPath, } - return c, srv + return c } func mustReadForm(t *testing.T, r *http.Request) { @@ -52,7 +63,7 @@ func mustReadForm(t *testing.T, r *http.Request) { func TestStartDeviceAuth_Success(t *testing.T) { t.Parallel() - c, _ := newTestClient(t, func(w http.ResponseWriter, r *http.Request) { + c := newTestClient(t, func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != testDeviceCodePath { t.Errorf("path = %q", r.URL.Path) } @@ -64,7 +75,7 @@ func TestStartDeviceAuth_Success(t *testing.T) { t.Errorf("scope = %q", got) } w.Header().Set("Content-Type", "application/json") - _, _ = io.WriteString(w, `{ + writeBody(t, w, `{ "device_code": "dev-1", "user_code": "ABCD-EFGH", "verification_uri": "https://example.com/cli/auth", @@ -86,13 +97,13 @@ func TestStartDeviceAuth_Success(t *testing.T) { func TestStartDeviceAuth_OmitsScopeWhenEmpty(t *testing.T) { t.Parallel() - c, _ := newTestClient(t, func(w http.ResponseWriter, r *http.Request) { + c := newTestClient(t, func(w http.ResponseWriter, r *http.Request) { mustReadForm(t, r) if r.PostForm.Has("scope") { t.Errorf("scope should not be sent when Client.Scope is empty") } w.Header().Set("Content-Type", "application/json") - _, _ = io.WriteString(w, `{"device_code":"d","user_code":"u","verification_uri":"x","expires_in":1,"interval":1}`) + writeBody(t, w, `{"device_code":"d","user_code":"u","verification_uri":"x","expires_in":1,"interval":1}`) }) c.Scope = "" @@ -104,8 +115,8 @@ func TestStartDeviceAuth_OmitsScopeWhenEmpty(t *testing.T) { func TestStartDeviceAuth_RejectsUnknownFields(t *testing.T) { t.Parallel() - c, _ := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { - _, _ = io.WriteString(w, `{ + c := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { + writeBody(t, w, `{ "device_code":"d","user_code":"u","verification_uri":"x","expires_in":1,"interval":1, "surprise":"field" }`) @@ -119,9 +130,9 @@ func TestStartDeviceAuth_RejectsUnknownFields(t *testing.T) { func TestStartDeviceAuth_NonOK(t *testing.T) { t.Parallel() - c, _ := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { + c := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusBadRequest) - _, _ = io.WriteString(w, `{"error":"invalid_client"}`) + writeBody(t, w, `{"error":"invalid_client"}`) }) if _, err := c.StartDeviceAuth(context.Background()); err == nil || @@ -135,7 +146,7 @@ func TestPollDeviceAuth_Success(t *testing.T) { freezeClock(t, time.Date(2026, 5, 6, 12, 0, 0, 0, time.UTC)) - c, _ := newTestClient(t, func(w http.ResponseWriter, r *http.Request) { + c := newTestClient(t, func(w http.ResponseWriter, r *http.Request) { mustReadForm(t, r) if got := r.PostForm.Get("grant_type"); got != deviceCodeGrantType { t.Errorf("grant_type = %q", got) @@ -144,7 +155,7 @@ func TestPollDeviceAuth_Success(t *testing.T) { t.Errorf("device_code = %q", got) } w.Header().Set("Content-Type", "application/json") - _, _ = io.WriteString(w, `{ + writeBody(t, w, `{ "access_token":"acc", "refresh_token":"ref", "token_type":"Bearer", @@ -170,8 +181,8 @@ func TestPollDeviceAuth_Success(t *testing.T) { func TestPollDeviceAuth_TolerantToUnknownFields(t *testing.T) { t.Parallel() - c, _ := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { - _, _ = io.WriteString(w, `{"access_token":"acc","extra":"ignored"}`) + c := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { + writeBody(t, w, `{"access_token":"acc","extra":"ignored"}`) }) got, err := c.PollDeviceAuth(context.Background(), "dev-1") @@ -200,7 +211,7 @@ func TestPollDeviceAuth_ErrorCodes(t *testing.T) { for _, tt := range tests { t.Run(tt.code, func(t *testing.T) { t.Parallel() - c, _ := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { + c := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusBadRequest) _, _ = fmt.Fprintf(w, `{"error":%q}`, tt.code) }) @@ -216,9 +227,9 @@ func TestPollDeviceAuth_ErrorCodes(t *testing.T) { func TestPollDeviceAuth_ErrorDescription_AppendedToSentinel(t *testing.T) { t.Parallel() - c, _ := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { + c := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusBadRequest) - _, _ = io.WriteString(w, `{"error":"invalid_grant","error_description":"device_code unknown"}`) + writeBody(t, w, `{"error":"invalid_grant","error_description":"device_code unknown"}`) }) _, err := c.PollDeviceAuth(context.Background(), "dev-1") @@ -233,9 +244,9 @@ func TestPollDeviceAuth_ErrorDescription_AppendedToSentinel(t *testing.T) { func TestPollDeviceAuth_NoDescription_NoTrailingColon(t *testing.T) { t.Parallel() - c, _ := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { + c := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusBadRequest) - _, _ = io.WriteString(w, `{"error":"invalid_grant"}`) + writeBody(t, w, `{"error":"invalid_grant"}`) }) _, err := c.PollDeviceAuth(context.Background(), "dev-1") @@ -250,9 +261,9 @@ func TestPollDeviceAuth_NoDescription_NoTrailingColon(t *testing.T) { func TestPollDeviceAuth_UnknownErrorCode(t *testing.T) { t.Parallel() - c, _ := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { + c := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusBadRequest) - _, _ = io.WriteString(w, `{"error":"weird_thing"}`) + writeBody(t, w, `{"error":"weird_thing"}`) }) _, err := c.PollDeviceAuth(context.Background(), "dev-1") @@ -269,8 +280,8 @@ func TestPollDeviceAuth_UnknownErrorCode(t *testing.T) { func TestPollDeviceAuth_200WithNoAccessToken(t *testing.T) { t.Parallel() - c, _ := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { - _, _ = io.WriteString(w, `{}`) + c := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { + writeBody(t, w, `{}`) }) if _, err := c.PollDeviceAuth(context.Background(), "dev-1"); err == nil { @@ -284,9 +295,9 @@ func TestStartDeviceAuth_HTMLBodySurfacesFriendlyError(t *testing.T) { // Captive portal / firewall (Cloudflare WARP, corp proxy) returns // 200 OK with an HTML error page. Surface a network-actionable // message instead of the opaque JSON-decode complaint. - c, _ := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { + c := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/html") - _, _ = io.WriteString(w, `Access blocked`) + writeBody(t, w, `Access blocked`) }) _, err := c.StartDeviceAuth(context.Background()) @@ -306,9 +317,9 @@ func TestStartDeviceAuth_HTMLBodySurfacesFriendlyError(t *testing.T) { func TestPollDeviceAuth_HTMLBodySurfacesFriendlyError(t *testing.T) { t.Parallel() - c, _ := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { + c := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/html") - _, _ = io.WriteString(w, `Access blocked by WARP`) + writeBody(t, w, `Access blocked by WARP`) }) _, err := c.PollDeviceAuth(context.Background(), "dev-1") diff --git a/auth/sts/sts.go b/auth/sts/sts.go index 0c3487eaab..856aa0aab2 100644 --- a/auth/sts/sts.go +++ b/auth/sts/sts.go @@ -30,10 +30,10 @@ var nowFunc = time.Now // supplies RequestedTokenType (which is always implementation-specific // outside of these RFC 8693 standard values). const ( - GrantTypeTokenExchange = "urn:ietf:params:oauth:grant-type:token-exchange" + GrantTypeTokenExchange = "urn:ietf:params:oauth:grant-type:token-exchange" //nolint:gosec // RFC 8693 grant_type URI, not a credential - SubjectTokenTypeJWT = "urn:ietf:params:oauth:token-type:jwt" - SubjectTokenTypeAccessToken = "urn:ietf:params:oauth:token-type:access_token" + SubjectTokenTypeJWT = "urn:ietf:params:oauth:token-type:jwt" //nolint:gosec // RFC 8693 token-type URI, not a credential + SubjectTokenTypeAccessToken = "urn:ietf:params:oauth:token-type:access_token" //nolint:gosec // RFC 8693 token-type URI, not a credential ) // ExchangeRequest is the input to a token exchange. @@ -197,7 +197,7 @@ type errorResponse struct { } func readAPIError(resp *http.Response) error { - body, _ := io.ReadAll(io.LimitReader(resp.Body, oauthhttp.MaxResponseBytes)) + body, _ := io.ReadAll(io.LimitReader(resp.Body, oauthhttp.MaxResponseBytes)) //nolint:errcheck // best-effort body read for error message var apiErr errorResponse if err := json.Unmarshal(bytes.TrimSpace(body), &apiErr); err == nil && apiErr.Error != "" { if apiErr.ErrorDescription != "" { diff --git a/auth/sts/sts_test.go b/auth/sts/sts_test.go index d851afd0f4..3a02b584d9 100644 --- a/auth/sts/sts_test.go +++ b/auth/sts/sts_test.go @@ -11,6 +11,17 @@ import ( "time" ) + +// writeBody writes body to w from a test handler. Wraps io.WriteString +// with a t.Fatal on error so test fixtures stay readable without +// per-callsite nolint comments. +func writeBody(t *testing.T, w io.Writer, body string) { + t.Helper() + if _, err := io.WriteString(w, body); err != nil { + t.Fatalf("write body: %v", err) + } +} + const testTokenPath = "/sts/token" func freezeClock(t *testing.T, at time.Time) { @@ -20,7 +31,7 @@ func freezeClock(t *testing.T, at time.Time) { t.Cleanup(func() { nowFunc = prev }) } -func newTestClient(t *testing.T, h http.HandlerFunc) (*Client, *httptest.Server) { +func newTestClient(t *testing.T, h http.HandlerFunc) *Client { t.Helper() srv := httptest.NewServer(h) t.Cleanup(srv.Close) @@ -29,7 +40,7 @@ func newTestClient(t *testing.T, h http.HandlerFunc) (*Client, *httptest.Server) HTTP: srv.Client(), BaseURL: srv.URL, Path: testTokenPath, - }, srv + } } func mustReadForm(t *testing.T, r *http.Request) { @@ -44,7 +55,7 @@ func TestExchange_Success(t *testing.T) { freezeClock(t, time.Date(2026, 5, 6, 12, 0, 0, 0, time.UTC)) - c, _ := newTestClient(t, func(w http.ResponseWriter, r *http.Request) { + c := newTestClient(t, func(w http.ResponseWriter, r *http.Request) { mustReadForm(t, r) if got := r.PostForm.Get("grant_type"); got != GrantTypeTokenExchange { t.Errorf("grant_type = %q", got) @@ -68,7 +79,7 @@ func TestExchange_Success(t *testing.T) { t.Errorf("scope = %q", got) } w.Header().Set("Content-Type", "application/json") - _, _ = io.WriteString(w, `{ + writeBody(t, w, `{ "access_token":"acc", "issued_token_type":"urn:example:token-type:thing", "token_type":"Bearer", @@ -102,14 +113,14 @@ func TestExchange_Success(t *testing.T) { func TestExchange_OmitsOptionalFieldsWhenEmpty(t *testing.T) { t.Parallel() - c, _ := newTestClient(t, func(w http.ResponseWriter, r *http.Request) { + c := newTestClient(t, func(w http.ResponseWriter, r *http.Request) { mustReadForm(t, r) for _, k := range []string{"audience", "resource", "scope"} { if r.PostForm.Has(k) { t.Errorf("optional field %q should not be sent when empty", k) } } - _, _ = io.WriteString(w, `{"access_token":"acc"}`) + writeBody(t, w, `{"access_token":"acc"}`) }) if _, err := c.Exchange(context.Background(), ExchangeRequest{ @@ -124,12 +135,12 @@ func TestExchange_OmitsOptionalFieldsWhenEmpty(t *testing.T) { func TestExchange_ExtraFieldsForwarded(t *testing.T) { t.Parallel() - c, _ := newTestClient(t, func(w http.ResponseWriter, r *http.Request) { + c := newTestClient(t, func(w http.ResponseWriter, r *http.Request) { mustReadForm(t, r) if got := r.PostForm.Get("custom_field"); got != "custom-value" { t.Errorf("custom_field = %q", got) } - _, _ = io.WriteString(w, `{"access_token":"acc"}`) + writeBody(t, w, `{"access_token":"acc"}`) }) if _, err := c.Exchange(context.Background(), ExchangeRequest{ @@ -145,13 +156,13 @@ func TestExchange_ExtraFieldsForwarded(t *testing.T) { func TestExchange_StandardFieldsOverrideExtra(t *testing.T) { t.Parallel() - c, _ := newTestClient(t, func(w http.ResponseWriter, r *http.Request) { + c := newTestClient(t, func(w http.ResponseWriter, r *http.Request) { mustReadForm(t, r) // Caller tried to set grant_type via Extra; standard wins. if got := r.PostForm.Get("grant_type"); got != GrantTypeTokenExchange { t.Errorf("Extra should not override standard grant_type; got %q", got) } - _, _ = io.WriteString(w, `{"access_token":"acc"}`) + writeBody(t, w, `{"access_token":"acc"}`) }) if _, err := c.Exchange(context.Background(), ExchangeRequest{ @@ -191,9 +202,9 @@ func TestExchange_RejectsMissingRequiredFields(t *testing.T) { func TestExchange_ServerError(t *testing.T) { t.Parallel() - c, _ := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { + c := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusBadRequest) - _, _ = io.WriteString(w, `{"error":"invalid_request","error_description":"bad subject"}`) + writeBody(t, w, `{"error":"invalid_request","error_description":"bad subject"}`) }) _, err := c.Exchange(context.Background(), ExchangeRequest{ @@ -212,9 +223,9 @@ func TestExchange_ServerError(t *testing.T) { func TestExchange_ServerErrorWithoutJSON(t *testing.T) { t.Parallel() - c, _ := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { + c := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusInternalServerError) - _, _ = io.WriteString(w, `something broke`) + writeBody(t, w, `something broke`) }) _, err := c.Exchange(context.Background(), ExchangeRequest{ @@ -230,8 +241,8 @@ func TestExchange_ServerErrorWithoutJSON(t *testing.T) { func TestExchange_MissingAccessToken(t *testing.T) { t.Parallel() - c, _ := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { - _, _ = io.WriteString(w, `{"token_type":"Bearer"}`) + c := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { + writeBody(t, w, `{"token_type":"Bearer"}`) }) _, err := c.Exchange(context.Background(), ExchangeRequest{ @@ -247,9 +258,9 @@ func TestExchange_MissingAccessToken(t *testing.T) { func TestExchange_HTMLBodySurfacesFriendlyError(t *testing.T) { t.Parallel() - c, _ := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { + c := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/html") - _, _ = io.WriteString(w, `Blocked by firewall`) + writeBody(t, w, `Blocked by firewall`) }) _, err := c.Exchange(context.Background(), ExchangeRequest{ @@ -271,8 +282,8 @@ func TestExchange_HTMLBodySurfacesFriendlyError(t *testing.T) { func TestExchange_NoExpiry(t *testing.T) { t.Parallel() - c, _ := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { - _, _ = io.WriteString(w, `{"access_token":"acc","token_type":"Bearer"}`) + c := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { + writeBody(t, w, `{"access_token":"acc","token_type":"Bearer"}`) }) got, err := c.Exchange(context.Background(), ExchangeRequest{ diff --git a/auth/tokenstore/keyring.go b/auth/tokenstore/keyring.go index cfba4d5c6c..318c76f92e 100644 --- a/auth/tokenstore/keyring.go +++ b/auth/tokenstore/keyring.go @@ -80,6 +80,9 @@ func (k *Keyring) DeleteTokens(profile string) error { // keyringTokenSet is the on-keyring JSON shape. Time fields are // serialised as RFC 3339 strings so the wire form survives keyring // implementations that don't preserve byte-for-byte equality. +// keyringTokenSet is the wire shape; access_token is intentionally +// serialised so the OS keyring (encrypted at rest) holds the full +// TokenSet for round-tripping. The G117 lint flag is suppressed below. type keyringTokenSet struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token,omitempty"` @@ -99,7 +102,7 @@ func encodeTokenSet(t tokens.TokenSet) (string, error) { wire.ExpiresAt = t.ExpiresAt.UTC().Format(time.RFC3339) } - b, err := json.Marshal(wire) + b, err := json.Marshal(wire) //nolint:gosec // intentional: serialise the access token for OS-keyring storage (encrypted at rest) if err != nil { return "", fmt.Errorf("marshal TokenSet: %w", err) } diff --git a/cmd/entire/cli/api/auth_tokens.go b/cmd/entire/cli/api/auth_tokens.go index a34445d07e..f3e40c2cf8 100644 --- a/cmd/entire/cli/api/auth_tokens.go +++ b/cmd/entire/cli/api/auth_tokens.go @@ -29,7 +29,7 @@ type TokensResponse struct { // cmd/entire/cli/auth's currentProvider(). Duplicated here rather than // imported because api/ is a leaf package and shouldn't take a // dependency on auth/ for routing. -const authTokensProviderVersionEnvVar = "ENTIRE_AUTH_PROVIDER_VERSION" +const authTokensProviderVersionEnvVar = "ENTIRE_AUTH_PROVIDER_VERSION" //nolint:gosec // env var name, not a credential // authTokensBasePath returns the auth-tokens endpoint family base path // for the active provider version. v1 (default) hits /api/v1/auth/tokens; diff --git a/cmd/entire/cli/auth/provider.go b/cmd/entire/cli/auth/provider.go index c0eed72550..fb47733933 100644 --- a/cmd/entire/cli/auth/provider.go +++ b/cmd/entire/cli/auth/provider.go @@ -25,12 +25,12 @@ type providerConfig struct { } var providers = map[string]providerConfig{ - "v1": { + "v1": { //nolint:gosec // OAuth client_id and endpoint paths, not credentials clientID: "entire-cli", deviceCodePath: "/oauth/device/code", tokenPath: "/oauth/token", }, - "v2": { + "v2": { //nolint:gosec // OAuth client_id and endpoint paths, not credentials clientID: "cli", deviceCodePath: "/api/auth/oauth/device/code", tokenPath: "/api/auth/token", diff --git a/cmd/entire/cli/auth/provider_test.go b/cmd/entire/cli/auth/provider_test.go index 746af3f646..e95bda444f 100644 --- a/cmd/entire/cli/auth/provider_test.go +++ b/cmd/entire/cli/auth/provider_test.go @@ -7,11 +7,18 @@ import ( "github.com/entireio/cli/cmd/entire/cli/api" ) +// Test-local mirrors of the v1 / v2 client_id values, so assertions +// don't repeat the same string literal across multiple tests (goconst). +const ( + wantClientIDV1 = "entire-cli" + wantClientIDV2 = "cli" +) + func TestCurrentProvider_DefaultsToV1(t *testing.T) { t.Setenv(ProviderVersionEnvVar, "") p := currentProvider() - if p.clientID != "entire-cli" || p.deviceCodePath != "/oauth/device/code" || p.tokenPath != "/oauth/token" { + if p.clientID != wantClientIDV1 || p.deviceCodePath != "/oauth/device/code" || p.tokenPath != "/oauth/token" { t.Fatalf("default provider = %+v, want v1 config", p) } } @@ -20,7 +27,7 @@ func TestCurrentProvider_V1Explicit(t *testing.T) { t.Setenv(ProviderVersionEnvVar, "v1") p := currentProvider() - if p.clientID != "entire-cli" { + if p.clientID != wantClientIDV1 { t.Fatalf("v1 clientID = %q", p.clientID) } } @@ -29,8 +36,8 @@ func TestCurrentProvider_V2(t *testing.T) { t.Setenv(ProviderVersionEnvVar, "v2") p := currentProvider() - if p.clientID != "cli" { - t.Fatalf("v2 clientID = %q, want cli", p.clientID) + if p.clientID != wantClientIDV2 { + t.Fatalf("v2 clientID = %q, want %s", p.clientID, wantClientIDV2) } if p.deviceCodePath != "/api/auth/oauth/device/code" { t.Fatalf("v2 deviceCodePath = %q", p.deviceCodePath) @@ -44,7 +51,7 @@ func TestCurrentProvider_UnknownDefaultsToV1(t *testing.T) { t.Setenv(ProviderVersionEnvVar, "v999") p := currentProvider() - if p.clientID != "entire-cli" { + if p.clientID != wantClientIDV1 { t.Fatalf("unknown version should default to v1; got clientID = %q", p.clientID) } } @@ -53,8 +60,8 @@ func TestCurrentProvider_TrimsWhitespace(t *testing.T) { t.Setenv(ProviderVersionEnvVar, " v2 ") p := currentProvider() - if p.clientID != "cli" { - t.Fatalf("whitespace-padded v2 clientID = %q, want cli", p.clientID) + if p.clientID != wantClientIDV2 { + t.Fatalf("whitespace-padded v2 clientID = %q, want %s", p.clientID, wantClientIDV2) } } @@ -63,8 +70,8 @@ func TestNewClient_HonoursProviderVersion(t *testing.T) { t.Setenv(ProviderVersionEnvVar, "v2") c := NewClient(&http.Client{}) - if c.inner.ClientID != "cli" { - t.Errorf("ClientID = %q, want cli", c.inner.ClientID) + if c.inner.ClientID != wantClientIDV2 { + t.Errorf("ClientID = %q, want %s", c.inner.ClientID, wantClientIDV2) } if c.inner.DeviceCodePath != "/api/auth/oauth/device/code" { t.Errorf("DeviceCodePath = %q", c.inner.DeviceCodePath) @@ -82,8 +89,8 @@ func TestNewClient_DefaultsToV1(t *testing.T) { t.Setenv(ProviderVersionEnvVar, "") c := NewClient(nil) - if c.inner.ClientID != "entire-cli" { - t.Errorf("ClientID = %q, want entire-cli", c.inner.ClientID) + if c.inner.ClientID != wantClientIDV1 { + t.Errorf("ClientID = %q, want %s", c.inner.ClientID, wantClientIDV1) } if c.inner.DeviceCodePath != "/oauth/device/code" { t.Errorf("DeviceCodePath = %q", c.inner.DeviceCodePath) From c492a54b1adc41dcac6022bdfc0544685e14376b Mon Sep 17 00:00:00 2001 From: Alex Ong Date: Thu, 7 May 2026 17:49:44 +1000 Subject: [PATCH 09/21] auth: split-host config + RFC 8693 token exchange (auth/tokenmanager) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Lets the CLI talk to deployments where the auth issuer and data API live on different origins (e.g. us.console.partial.to mints tokens that are then exchanged for partial.to-scoped tokens before each data-API call). Split-host plumbing: - New ENTIRE_AUTH_BASE_URL env var; AuthBaseURL() falls back to BaseURL() so single-host deployments are unchanged. - Tokens are keyed in the keyring by the auth issuer (the host that minted them), not by the data API URL. - Auth-management commands (auth list/revoke/status/logout) hit the auth host via NewClientWithBaseURL since their endpoints live there. - Align v2 client_id to "entire-cli" to match v1. New shareable library auth/tokenmanager: - Provider-agnostic orchestration over auth/sts: cache, JWT-aud shortcut, exchange dispatch. - Config struct takes Issuer, ClientID, STSPath, Store, plus defaults and test hooks. No globals, no env-var reads, no implicit URLs — ready to share with other internal CLIs. - TokenForResource/Token resolve to: 1) ErrNotLoggedIn when the store is empty, 2) core token verbatim when issuer == resource, 3) core token verbatim when its aud claim already includes the resource (multi-audience tokens skip exchange), 4) RFC 8693 exchange otherwise, cached per (core, resource, audience, requested-token-type, scope) until expiry. CLI wiring: - NewAuthenticatedAPIClient now takes ctx and routes through tokenmanager so data-API calls carry the right-audience bearer. All 7 callers updated to pass ctx. - cmd/entire/cli/auth/exchange.go is a thin shim that builds a package-level Manager from the active provider + NewStore() and exposes TokenForResource / Token / ErrNotLoggedIn. - *Store now implements tokenstore.Store so it can be passed to the Manager, preserving the legacy bare-string keyring fallback. Fix discovered along the way: - search defaulted to a hardcoded entire.io serviceURL; now defaults to api.BaseURL() when ENTIRE_SEARCH_URL is unset. Misc gofmt/lint autofixes in auth/deviceflow, auth/sts, auth/tokens that the linter applied while iterating. Co-Authored-By: Claude Opus 4.7 (1M context) --- auth/deviceflow/deviceflow.go | 2 +- auth/deviceflow/deviceflow_test.go | 1 - auth/sts/sts.go | 2 +- auth/sts/sts_test.go | 1 - auth/tokenmanager/tokenmanager.go | 332 ++++++++++++++++++++++ auth/tokenmanager/tokenmanager_test.go | 365 +++++++++++++++++++++++++ auth/tokens/tokens.go | 2 +- cmd/entire/cli/activity_cmd.go | 2 +- cmd/entire/cli/api/base_url.go | 19 ++ cmd/entire/cli/api/base_url_test.go | 22 ++ cmd/entire/cli/api/client.go | 13 +- cmd/entire/cli/api_client.go | 39 ++- cmd/entire/cli/auth.go | 18 +- cmd/entire/cli/auth/client.go | 6 +- cmd/entire/cli/auth/exchange.go | 90 ++++++ cmd/entire/cli/auth/exchange_test.go | 80 ++++++ cmd/entire/cli/auth/provider.go | 2 +- cmd/entire/cli/auth/provider_test.go | 5 +- cmd/entire/cli/auth/store.go | 43 ++- cmd/entire/cli/dispatch_wizard.go | 2 +- cmd/entire/cli/logout.go | 4 +- cmd/entire/cli/recap.go | 2 +- cmd/entire/cli/search/search.go | 5 +- cmd/entire/cli/search_cmd.go | 7 +- cmd/entire/cli/trail_cmd.go | 8 +- 25 files changed, 1027 insertions(+), 45 deletions(-) create mode 100644 auth/tokenmanager/tokenmanager.go create mode 100644 auth/tokenmanager/tokenmanager_test.go create mode 100644 cmd/entire/cli/auth/exchange.go create mode 100644 cmd/entire/cli/auth/exchange_test.go diff --git a/auth/deviceflow/deviceflow.go b/auth/deviceflow/deviceflow.go index 1ae7413e7a..9a72f53340 100644 --- a/auth/deviceflow/deviceflow.go +++ b/auth/deviceflow/deviceflow.go @@ -179,7 +179,7 @@ func (c *Client) PollDeviceAuth(ctx context.Context, deviceCode string) (*tokens } if raw.AccessToken == "" { - return nil, fmt.Errorf("poll device auth: server returned 200 with no access token") + return nil, errors.New("poll device auth: server returned 200 with no access token") } t := &tokens.TokenSet{ diff --git a/auth/deviceflow/deviceflow_test.go b/auth/deviceflow/deviceflow_test.go index a12b3d3fe0..8166be0341 100644 --- a/auth/deviceflow/deviceflow_test.go +++ b/auth/deviceflow/deviceflow_test.go @@ -12,7 +12,6 @@ import ( "time" ) - // writeBody writes body to w from a test handler. Wraps io.WriteString // with a t.Fatal on error so test fixtures stay readable without // per-callsite nolint comments. diff --git a/auth/sts/sts.go b/auth/sts/sts.go index 856aa0aab2..72a53cb578 100644 --- a/auth/sts/sts.go +++ b/auth/sts/sts.go @@ -32,7 +32,7 @@ var nowFunc = time.Now const ( GrantTypeTokenExchange = "urn:ietf:params:oauth:grant-type:token-exchange" //nolint:gosec // RFC 8693 grant_type URI, not a credential - SubjectTokenTypeJWT = "urn:ietf:params:oauth:token-type:jwt" //nolint:gosec // RFC 8693 token-type URI, not a credential + SubjectTokenTypeJWT = "urn:ietf:params:oauth:token-type:jwt" //nolint:gosec // RFC 8693 token-type URI, not a credential SubjectTokenTypeAccessToken = "urn:ietf:params:oauth:token-type:access_token" //nolint:gosec // RFC 8693 token-type URI, not a credential ) diff --git a/auth/sts/sts_test.go b/auth/sts/sts_test.go index 3a02b584d9..04570cb446 100644 --- a/auth/sts/sts_test.go +++ b/auth/sts/sts_test.go @@ -11,7 +11,6 @@ import ( "time" ) - // writeBody writes body to w from a test handler. Wraps io.WriteString // with a t.Fatal on error so test fixtures stay readable without // per-callsite nolint comments. diff --git a/auth/tokenmanager/tokenmanager.go b/auth/tokenmanager/tokenmanager.go new file mode 100644 index 0000000000..205a7bf958 --- /dev/null +++ b/auth/tokenmanager/tokenmanager.go @@ -0,0 +1,332 @@ +// Package tokenmanager orchestrates core-token storage and RFC 8693 +// token exchanges for an OAuth 2.0 device-flow client. +// +// One Manager per CLI process. Construct it once from the embedding +// CLI's identity (Issuer, ClientID, STSPath, Store) and call +// TokenForResource / Token from data-API call sites. +// +// The package is provider-agnostic: every endpoint, identifier, and +// default value comes from Config. It has no env-var reads, no +// implicit URLs, and no embedded provider tables. Tests inject +// Config.Exchange and Config.Now to avoid hitting the network and to +// freeze the clock. +package tokenmanager + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + "slices" + "strings" + "sync" + "time" + + "github.com/entireio/cli/auth/sts" + "github.com/entireio/cli/auth/tokens" + "github.com/entireio/cli/auth/tokenstore" +) + +// DefaultRequestedTokenType is the RFC 8693 §3 URI used when neither +// Config.RequestedTokenType nor TokenRequest.RequestedTokenType is set. +// :access_token is the canonical "give me an OAuth access token" URI; +// the wire format is the server's choice. +// +//nolint:gosec // RFC 8693 token-type URI, not a credential +const DefaultRequestedTokenType = "urn:ietf:params:oauth:token-type:access_token" + +// exchangeSkew is the safety margin applied when deciding whether a +// cached exchanged token is still usable. Set conservatively because +// the worst case (re-exchange one extra time per command) is cheap. +const exchangeSkew = 30 * time.Second + +// ErrNotLoggedIn is returned by Token / TokenForResource when no core +// token is present in the store. Callers can match on it to render a +// "run " message. +var ErrNotLoggedIn = errors.New("not logged in") + +// Config configures a Manager. +type Config struct { + // Issuer is the auth host base URL where the device-flow login + // happened and STS exchanges are POSTed. Required. Doubles as the + // Store profile key, so a user can be logged into multiple issuers + // (e.g. regions / staging) without conflict. + Issuer string + + // ClientID identifies the public client per RFC 6749 §2.3.1 / §3.2.1. + // Sent on STS exchanges via the client_id form field. Required. + ClientID string + + // STSPath is the path on Issuer where token-exchange requests are + // POSTed. Typically the OAuth token endpoint (RFC 8693 convention). + // Required. + STSPath string + + // Store persists the core token. Required. Use any tokenstore.Store + // implementation; a per-CLI service name keeps credentials isolated + // from other CLIs sharing this library. + Store tokenstore.Store + + // RequestedTokenType is the default RFC 8693 requested_token_type + // URI. Empty → DefaultRequestedTokenType. + RequestedTokenType string + + // Scope is the default scope sent on exchanges. Empty → omitted. + Scope string + + // UserAgent for HTTP requests. Empty → none. + UserAgent string + + // HTTPClient overrides the http.Client used for exchange calls. + // Useful for installing a debug transport. nil → http.DefaultClient. + HTTPClient *http.Client + + // Now overrides time.Now for cache-expiry decisions. Tests only. + Now func() time.Time + + // Exchange overrides the STS call. Tests only. + Exchange func(ctx context.Context, req sts.ExchangeRequest) (*tokens.TokenSet, error) +} + +func (c Config) validate() error { + switch { + case strings.TrimSpace(c.Issuer) == "": + return errors.New("Config.Issuer is required") + case strings.TrimSpace(c.ClientID) == "": + return errors.New("Config.ClientID is required") + case strings.TrimSpace(c.STSPath) == "": + return errors.New("Config.STSPath is required") + case c.Store == nil: + return errors.New("Config.Store is required") + } + return nil +} + +// Manager orchestrates core-token storage and STS exchanges. Safe for +// concurrent use. +type Manager struct { + cfg Config + + mu sync.Mutex + cache map[string]cachedToken +} + +// New builds a Manager from cfg. Returns an error when required +// fields are missing. +func New(cfg Config) (*Manager, error) { + if err := cfg.validate(); err != nil { + return nil, err + } + if cfg.RequestedTokenType == "" { + cfg.RequestedTokenType = DefaultRequestedTokenType + } + if cfg.Now == nil { + cfg.Now = time.Now + } + return &Manager{cfg: cfg, cache: map[string]cachedToken{}}, nil +} + +// Issuer returns the configured issuer URL. +func (m *Manager) Issuer() string { return m.cfg.Issuer } + +// SaveCoreToken persists the device-flow access token under the +// configured Issuer. +func (m *Manager) SaveCoreToken(accessToken string) error { + return m.cfg.Store.SaveTokens(m.cfg.Issuer, tokens.TokenSet{AccessToken: accessToken}) //nolint:wrapcheck // backend error already names the operation +} + +// LookupCoreToken returns the stored core token, or "" if none is +// stored. A nil-return-no-error mirrors how callers expect +// "not-logged-in" to look. +func (m *Manager) LookupCoreToken() (string, error) { + t, err := m.cfg.Store.LoadTokens(m.cfg.Issuer) + if errors.Is(err, tokenstore.ErrNotFound) { + return "", nil + } + if err != nil { + return "", fmt.Errorf("load core token: %w", err) + } + return t.AccessToken, nil +} + +// DeleteCoreToken removes the stored core token (and any cached +// exchanges derived from it). +func (m *Manager) DeleteCoreToken() error { + m.mu.Lock() + m.cache = map[string]cachedToken{} + m.mu.Unlock() + return m.cfg.Store.DeleteTokens(m.cfg.Issuer) //nolint:wrapcheck // backend error already names the operation +} + +// TokenRequest customises one Token call. Empty fields fall back to +// Config defaults. +type TokenRequest struct { + // Resource is the origin where the bearer will be presented. + // Required. Used for the same-host shortcut, the JWT-aud shortcut, + // and as part of the cache key. + Resource string + + // Audience is the wire-level RFC 8693 audience parameter. Empty → + // omitted (the AS picks). Independent of Resource — most callers + // leave Audience empty. + Audience string + + // RequestedTokenType overrides Config.RequestedTokenType for this + // call. Empty → Config default. + RequestedTokenType string + + // Scope overrides Config.Scope for this call. Empty → Config default. + Scope string +} + +// TokenForResource is a convenience for Token using only Resource. +func (m *Manager) TokenForResource(ctx context.Context, resourceBaseURL string) (string, error) { + return m.Token(ctx, TokenRequest{Resource: resourceBaseURL}) +} + +// Token resolves a bearer token for use against req.Resource, +// performing an RFC 8693 exchange when needed. +// +// Resolution rules: +// +// 1. No core token in the store → ErrNotLoggedIn. +// 2. m.Issuer() == req.Resource (and req.Audience is empty) → use +// the core token directly. Single-host deployments hit this path. +// 3. Core token's `aud` claim already includes req.Resource → use +// the core token directly. Multi-audience tokens skip exchange. +// 4. Otherwise → RFC 8693 token exchange. +// +// Successful exchanges are cached in-memory keyed by (core token, +// resource, audience, requested-token-type, scope) until expiry. +func (m *Manager) Token(ctx context.Context, req TokenRequest) (string, error) { + if strings.TrimSpace(req.Resource) == "" { + return "", errors.New("TokenRequest.Resource is required") + } + + core, err := m.LookupCoreToken() + if err != nil { + return "", err + } + if core == "" { + return "", ErrNotLoggedIn + } + + if req.Audience == "" && m.cfg.Issuer == req.Resource { + return core, nil + } + if coreTokenAudienceIncludes(core, req.Resource) { + return core, nil + } + + resolved := m.resolve(req) + key := cacheKey(core, resolved) + if hit, ok := m.cacheLookup(key); ok { + return hit, nil + } + + exchanged, err := m.runExchange(ctx, core, resolved) + if err != nil { + return "", err + } + m.cacheStore(key, exchanged) + return exchanged.AccessToken, nil +} + +// resolve fills empty TokenRequest fields with Config defaults. +func (m *Manager) resolve(req TokenRequest) TokenRequest { + if req.RequestedTokenType == "" { + req.RequestedTokenType = m.cfg.RequestedTokenType + } + if req.Scope == "" { + req.Scope = m.cfg.Scope + } + return req +} + +func coreTokenAudienceIncludes(coreJWT, target string) bool { + claims, err := tokens.ParseClaims(coreJWT) + if err != nil { + return false + } + return slices.Contains(claims.Audience, target) +} + +// cachedToken is one entry in the per-process exchange cache. +type cachedToken struct { + accessToken string + expiresAt time.Time +} + +func (c cachedToken) usable(now time.Time) bool { + if c.accessToken == "" { + return false + } + if c.expiresAt.IsZero() { + return true + } + return now.Add(exchangeSkew).Before(c.expiresAt) +} + +// cacheKey derives a stable cache key from the (resolved) request. +// Includes every wire-affecting field so different combinations don't +// shadow each other. +func cacheKey(coreToken string, req TokenRequest) string { + return strings.Join([]string{ + coreToken, + req.Resource, + req.Audience, + req.RequestedTokenType, + req.Scope, + }, "|") +} + +func (m *Manager) cacheLookup(key string) (string, bool) { + m.mu.Lock() + defer m.mu.Unlock() + entry, ok := m.cache[key] + if !ok { + return "", false + } + if !entry.usable(m.cfg.Now()) { + delete(m.cache, key) + return "", false + } + return entry.accessToken, true +} + +func (m *Manager) cacheStore(key string, t *tokens.TokenSet) { + m.mu.Lock() + defer m.mu.Unlock() + m.cache[key] = cachedToken{ + accessToken: t.AccessToken, + expiresAt: t.ExpiresAt, + } +} + +// runExchange dispatches to either Config.Exchange (test override) or +// a freshly built sts.Client pointing at m.cfg.Issuer + m.cfg.STSPath. +func (m *Manager) runExchange(ctx context.Context, coreToken string, req TokenRequest) (*tokens.TokenSet, error) { + stsReq := sts.ExchangeRequest{ + SubjectToken: coreToken, + SubjectTokenType: sts.SubjectTokenTypeJWT, + RequestedTokenType: req.RequestedTokenType, + Audience: req.Audience, + Scope: req.Scope, + // Public-client identification per RFC 6749 §2.3.1 / §3.2.1. + // Carried via Extra because the sts package is provider-agnostic. + Extra: url.Values{"client_id": {m.cfg.ClientID}}, + } + + if m.cfg.Exchange != nil { + return m.cfg.Exchange(ctx, stsReq) + } + + stsClient := &sts.Client{ + HTTP: m.cfg.HTTPClient, + BaseURL: m.cfg.Issuer, + Path: m.cfg.STSPath, + UserAgent: m.cfg.UserAgent, + } + return stsClient.Exchange(ctx, stsReq) //nolint:wrapcheck // sts.Exchange already prefixes "token exchange:" +} diff --git a/auth/tokenmanager/tokenmanager_test.go b/auth/tokenmanager/tokenmanager_test.go new file mode 100644 index 0000000000..e87679ef90 --- /dev/null +++ b/auth/tokenmanager/tokenmanager_test.go @@ -0,0 +1,365 @@ +package tokenmanager + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "strings" + "testing" + "time" + + "github.com/entireio/cli/auth/sts" + "github.com/entireio/cli/auth/tokens" + "github.com/entireio/cli/auth/tokenstore" +) + +// memStore is an in-memory tokenstore.Store for tests. Avoids pulling +// the keyring backend into tokenmanager's test surface. +type memStore struct { + data map[string]tokens.TokenSet +} + +func newMemStore() *memStore { return &memStore{data: map[string]tokens.TokenSet{}} } + +func (s *memStore) SaveTokens(profile string, t tokens.TokenSet) error { + s.data[profile] = t + return nil +} + +func (s *memStore) LoadTokens(profile string) (tokens.TokenSet, error) { + t, ok := s.data[profile] + if !ok { + return tokens.TokenSet{}, tokenstore.ErrNotFound + } + return t, nil +} + +func (s *memStore) DeleteTokens(profile string) error { + delete(s.data, profile) + return nil +} + +const ( + testIssuer = "https://auth.example.com" + testResource = "https://api.example.com" + testClientID = "test-cli" + testSTSPath = "/sts/token" +) + +func makeJWTWithAudience(t *testing.T, aud []string) string { + t.Helper() + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`)) + payload, err := json.Marshal(map[string]any{"aud": aud, "sub": "test"}) + if err != nil { + t.Fatalf("marshal payload: %v", err) + } + body := base64.RawURLEncoding.EncodeToString(payload) + sig := base64.RawURLEncoding.EncodeToString([]byte("not-real")) + return header + "." + body + "." + sig +} + +func newTestManager(t *testing.T, store tokenstore.Store, exchange func(context.Context, sts.ExchangeRequest) (*tokens.TokenSet, error)) *Manager { + t.Helper() + m, err := New(Config{ + Issuer: testIssuer, + ClientID: testClientID, + STSPath: testSTSPath, + Store: store, + Exchange: exchange, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + return m +} + +func TestNew_RequiresFields(t *testing.T) { + t.Parallel() + cases := []struct { + name string + cfg Config + }{ + {"missing issuer", Config{ClientID: "x", STSPath: "/p", Store: newMemStore()}}, + {"missing clientID", Config{Issuer: "https://x", STSPath: "/p", Store: newMemStore()}}, + {"missing STSPath", Config{Issuer: "https://x", ClientID: "x", Store: newMemStore()}}, + {"missing Store", Config{Issuer: "https://x", ClientID: "x", STSPath: "/p"}}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if _, err := New(tc.cfg); err == nil { + t.Fatal("expected error") + } + }) + } +} + +func TestNew_DefaultRequestedTokenType(t *testing.T) { + t.Parallel() + m, err := New(Config{Issuer: testIssuer, ClientID: testClientID, STSPath: testSTSPath, Store: newMemStore()}) + if err != nil { + t.Fatalf("New: %v", err) + } + if m.cfg.RequestedTokenType != DefaultRequestedTokenType { + t.Fatalf("RequestedTokenType default = %q, want %q", m.cfg.RequestedTokenType, DefaultRequestedTokenType) + } +} + +func TestToken_NotLoggedIn(t *testing.T) { + t.Parallel() + m := newTestManager(t, newMemStore(), nil) + _, err := m.TokenForResource(context.Background(), testResource) + if !errors.Is(err, ErrNotLoggedIn) { + t.Fatalf("err = %v, want ErrNotLoggedIn", err) + } +} + +func TestToken_SameHostShortcut(t *testing.T) { + t.Parallel() + store := newMemStore() + store.data[testIssuer] = tokens.TokenSet{AccessToken: "core-tok"} + + m, err := New(Config{ + Issuer: testIssuer, ClientID: testClientID, STSPath: testSTSPath, Store: store, + Exchange: func(_ context.Context, _ sts.ExchangeRequest) (*tokens.TokenSet, error) { + t.Fatal("exchange must not run when issuer == resource") + return nil, errors.New("unreachable") + }, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + got, err := m.TokenForResource(context.Background(), testIssuer) + if err != nil { + t.Fatalf("TokenForResource: %v", err) + } + if got != "core-tok" { + t.Fatalf("got %q, want core token verbatim", got) + } +} + +func TestToken_AudienceShortcut(t *testing.T) { + t.Parallel() + core := makeJWTWithAudience(t, []string{testIssuer, testResource}) + store := newMemStore() + store.data[testIssuer] = tokens.TokenSet{AccessToken: core} + + m, err := New(Config{ + Issuer: testIssuer, ClientID: testClientID, STSPath: testSTSPath, Store: store, + Exchange: func(_ context.Context, _ sts.ExchangeRequest) (*tokens.TokenSet, error) { + t.Fatal("exchange must not run when core token's aud already covers resource") + return nil, errors.New("unreachable") + }, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + got, err := m.TokenForResource(context.Background(), testResource) + if err != nil { + t.Fatalf("TokenForResource: %v", err) + } + if got != core { + t.Fatal("expected core token verbatim when aud already matches") + } +} + +func TestToken_ExchangesAndCaches(t *testing.T) { + t.Parallel() + core := makeJWTWithAudience(t, []string{testIssuer}) + store := newMemStore() + store.data[testIssuer] = tokens.TokenSet{AccessToken: core} + + var calls int + var lastReq sts.ExchangeRequest + m, err := New(Config{ + Issuer: testIssuer, ClientID: testClientID, STSPath: testSTSPath, Store: store, + Exchange: func(_ context.Context, req sts.ExchangeRequest) (*tokens.TokenSet, error) { + calls++ + lastReq = req + return &tokens.TokenSet{AccessToken: "exchanged-1", ExpiresAt: time.Now().Add(10 * time.Minute)}, nil + }, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + first, err := m.TokenForResource(context.Background(), testResource) + if err != nil { + t.Fatalf("first: %v", err) + } + if first != "exchanged-1" { + t.Fatalf("first = %q", first) + } + second, err := m.TokenForResource(context.Background(), testResource) + if err != nil { + t.Fatalf("second: %v", err) + } + if second != "exchanged-1" || calls != 1 { + t.Fatalf("expected cache hit, got calls=%d second=%q", calls, second) + } + + // Wire shape: default RequestedTokenType, empty audience, client_id. + if lastReq.RequestedTokenType != DefaultRequestedTokenType { + t.Errorf("RequestedTokenType = %q", lastReq.RequestedTokenType) + } + if lastReq.Audience != "" { + t.Errorf("Audience = %q, want empty", lastReq.Audience) + } + if got := lastReq.Extra.Get("client_id"); got != testClientID { + t.Errorf("client_id = %q", got) + } +} + +func TestToken_OverridesAudienceAndType(t *testing.T) { + t.Parallel() + core := makeJWTWithAudience(t, []string{testIssuer}) + store := newMemStore() + store.data[testIssuer] = tokens.TokenSet{AccessToken: core} + + const customAud = "https://elsewhere.example.com" + const customType = "urn:ietf:params:oauth:token-type:jwt" + const customScope = "narrower" + + var got sts.ExchangeRequest + m := newTestManager(t, store, func(_ context.Context, req sts.ExchangeRequest) (*tokens.TokenSet, error) { + got = req + return &tokens.TokenSet{AccessToken: "ok"}, nil + }) + + if _, err := m.Token(context.Background(), TokenRequest{ + Resource: testResource, + Audience: customAud, + RequestedTokenType: customType, + Scope: customScope, + }); err != nil { + t.Fatalf("Token: %v", err) + } + + if got.Audience != customAud { + t.Errorf("Audience = %q", got.Audience) + } + if got.RequestedTokenType != customType { + t.Errorf("RequestedTokenType = %q", got.RequestedTokenType) + } + if got.Scope != customScope { + t.Errorf("Scope = %q", got.Scope) + } +} + +func TestToken_DifferentAudiencesCacheIndependently(t *testing.T) { + t.Parallel() + core := makeJWTWithAudience(t, []string{testIssuer}) + store := newMemStore() + store.data[testIssuer] = tokens.TokenSet{AccessToken: core} + + var calls int + m := newTestManager(t, store, func(_ context.Context, req sts.ExchangeRequest) (*tokens.TokenSet, error) { + calls++ + return &tokens.TokenSet{AccessToken: "tok-" + req.Audience}, nil + }) + + a, err := m.Token(context.Background(), TokenRequest{Resource: testResource, Audience: "aud-a"}) + if err != nil { + t.Fatalf("a: %v", err) + } + b, err := m.Token(context.Background(), TokenRequest{Resource: testResource, Audience: "aud-b"}) + if err != nil { + t.Fatalf("b: %v", err) + } + if a == b || calls != 2 { + t.Fatalf("expected separate cache entries, got a=%q b=%q calls=%d", a, b, calls) + } + + // Repeat A — cache hit. + if _, err := m.Token(context.Background(), TokenRequest{Resource: testResource, Audience: "aud-a"}); err != nil { + t.Fatalf("a repeat: %v", err) + } + if calls != 2 { + t.Fatalf("expected cache hit on repeat, got %d calls", calls) + } +} + +func TestToken_CacheExpires(t *testing.T) { + t.Parallel() + core := makeJWTWithAudience(t, []string{testIssuer}) + store := newMemStore() + store.data[testIssuer] = tokens.TokenSet{AccessToken: core} + + now := time.Date(2026, 5, 7, 12, 0, 0, 0, time.UTC) + + var calls int + m, err := New(Config{ + Issuer: testIssuer, ClientID: testClientID, STSPath: testSTSPath, Store: store, + Now: func() time.Time { return now }, + Exchange: func(_ context.Context, _ sts.ExchangeRequest) (*tokens.TokenSet, error) { + calls++ + return &tokens.TokenSet{AccessToken: "exchanged", ExpiresAt: now.Add(time.Minute)}, nil + }, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + if _, err := m.TokenForResource(context.Background(), testResource); err != nil { + t.Fatalf("first: %v", err) + } + now = now.Add(2 * time.Minute) // past expiry + if _, err := m.TokenForResource(context.Background(), testResource); err != nil { + t.Fatalf("second: %v", err) + } + if calls != 2 { + t.Fatalf("calls = %d, want 2 (cache must miss after expiry)", calls) + } +} + +func TestToken_RequiresResource(t *testing.T) { + t.Parallel() + m := newTestManager(t, newMemStore(), nil) + _, err := m.Token(context.Background(), TokenRequest{}) + if err == nil { + t.Fatal("expected error for empty Resource") + } +} + +func TestToken_ExchangeFailureSurfaces(t *testing.T) { + t.Parallel() + core := makeJWTWithAudience(t, []string{testIssuer}) + store := newMemStore() + store.data[testIssuer] = tokens.TokenSet{AccessToken: core} + + m := newTestManager(t, store, func(_ context.Context, _ sts.ExchangeRequest) (*tokens.TokenSet, error) { + return nil, errors.New("token exchange: status 400: invalid_target") + }) + + _, err := m.TokenForResource(context.Background(), testResource) + if err == nil || !strings.Contains(err.Error(), "invalid_target") { + t.Fatalf("err = %v, want underlying message", err) + } +} + +func TestSaveLookupDeleteCoreToken(t *testing.T) { + t.Parallel() + m := newTestManager(t, newMemStore(), nil) + + if got, err := m.LookupCoreToken(); err != nil || got != "" { + t.Fatalf("initial lookup: got=%q err=%v, want empty/nil", got, err) + } + + if err := m.SaveCoreToken("new-core"); err != nil { + t.Fatalf("SaveCoreToken: %v", err) + } + got, err := m.LookupCoreToken() + if err != nil || got != "new-core" { + t.Fatalf("after save: got=%q err=%v", got, err) + } + + if err := m.DeleteCoreToken(); err != nil { + t.Fatalf("DeleteCoreToken: %v", err) + } + if got, err := m.LookupCoreToken(); err != nil || got != "" { + t.Fatalf("after delete: got=%q err=%v", got, err) + } +} diff --git a/auth/tokens/tokens.go b/auth/tokens/tokens.go index d7c37d13ed..c886db78b0 100644 --- a/auth/tokens/tokens.go +++ b/auth/tokens/tokens.go @@ -141,5 +141,5 @@ func decodeAudience(raw json.RawMessage) ([]string, error) { return multi, nil } - return nil, fmt.Errorf("decode JWT aud claim: not a string or array") + return nil, errors.New("decode JWT aud claim: not a string or array") } diff --git a/cmd/entire/cli/activity_cmd.go b/cmd/entire/cli/activity_cmd.go index 35b49018fa..c43af88df0 100644 --- a/cmd/entire/cli/activity_cmd.go +++ b/cmd/entire/cli/activity_cmd.go @@ -56,7 +56,7 @@ func newActivityCmd() *cobra.Command { } func runActivity(ctx context.Context, w, errW io.Writer) error { - client, err := NewAuthenticatedAPIClient(false) + client, err := NewAuthenticatedAPIClient(ctx, false) if err != nil { fmt.Fprintln(errW, "Not logged in. Run 'entire login' to authenticate.") return NewSilentError(err) diff --git a/cmd/entire/cli/api/base_url.go b/cmd/entire/cli/api/base_url.go index 9e684c04a1..0f2109093e 100644 --- a/cmd/entire/cli/api/base_url.go +++ b/cmd/entire/cli/api/base_url.go @@ -17,6 +17,13 @@ const ( // BaseURLEnvVar overrides the Entire API origin for local development. BaseURLEnvVar = "ENTIRE_API_BASE_URL" + + // AuthBaseURLEnvVar overrides only the auth/login origin (device flow, + // auth-tokens management, keyring key). Falls back to BaseURLEnvVar when + // unset, which is the right behavior for single-host deployments. Split + // hosts (e.g. auth on us.console.partial.to, data on partial.to) set + // both. + AuthBaseURLEnvVar = "ENTIRE_AUTH_BASE_URL" ) // BaseURL returns the effective Entire API base URL. @@ -29,6 +36,18 @@ func BaseURL() string { return DefaultBaseURL } +// AuthBaseURL returns the origin used for the device-flow login, auth-token +// management endpoints, and the keyring key under which the bearer token is +// stored. ENTIRE_AUTH_BASE_URL takes precedence; otherwise it falls back to +// BaseURL() so single-host deployments keep working unchanged. +func AuthBaseURL() string { + if raw := strings.TrimSpace(os.Getenv(AuthBaseURLEnvVar)); raw != "" { + return normalizeBaseURL(raw) + } + + return BaseURL() +} + // ResolveURL joins an API-relative path against the effective base URL. func ResolveURL(path string) (string, error) { return ResolveURLFromBase(BaseURL(), path) diff --git a/cmd/entire/cli/api/base_url_test.go b/cmd/entire/cli/api/base_url_test.go index 474de2026a..b094a97a11 100644 --- a/cmd/entire/cli/api/base_url_test.go +++ b/cmd/entire/cli/api/base_url_test.go @@ -53,6 +53,28 @@ func TestRequireSecureURL_RejectsHTTP(t *testing.T) { } } +func TestAuthBaseURL_FallsBackToBaseURL(t *testing.T) { + t.Setenv(BaseURLEnvVar, "https://partial.to") + t.Setenv(AuthBaseURLEnvVar, "") + + if got := AuthBaseURL(); got != "https://partial.to" { + t.Fatalf("AuthBaseURL() = %q, want fallback to BaseURL %q", got, "https://partial.to") + } +} + +func TestAuthBaseURL_OverridesBaseURL(t *testing.T) { + t.Setenv(BaseURLEnvVar, "https://partial.to") + t.Setenv(AuthBaseURLEnvVar, " https://us.console.partial.to/ ") + + if got := AuthBaseURL(); got != "https://us.console.partial.to" { + t.Fatalf("AuthBaseURL() = %q, want %q", got, "https://us.console.partial.to") + } + + if got := BaseURL(); got != "https://partial.to" { + t.Fatalf("BaseURL() = %q, want unchanged %q", got, "https://partial.to") + } +} + func TestResolveURL(t *testing.T) { t.Setenv(BaseURLEnvVar, "http://localhost:8787/") diff --git a/cmd/entire/cli/api/client.go b/cmd/entire/cli/api/client.go index 5fe7a6b012..c594b0c261 100644 --- a/cmd/entire/cli/api/client.go +++ b/cmd/entire/cli/api/client.go @@ -23,8 +23,17 @@ type Client struct { baseURL string } -// NewClient creates a new authenticated API client with an explicit bearer token. +// NewClient creates a new authenticated API client with an explicit bearer +// token, targeting the data API base URL (BaseURL()). func NewClient(token string) *Client { + return NewClientWithBaseURL(token, BaseURL()) +} + +// NewClientWithBaseURL creates a new authenticated API client targeting an +// explicit base URL. Use this for endpoints that live on the auth host (e.g. +// auth-token management) when ENTIRE_AUTH_BASE_URL splits the auth origin +// from the data API origin. +func NewClientWithBaseURL(token, baseURL string) *Client { return &Client{ httpClient: &http.Client{ Transport: &bearerTransport{ @@ -32,7 +41,7 @@ func NewClient(token string) *Client { base: http.DefaultTransport, }, }, - baseURL: BaseURL(), + baseURL: baseURL, } } diff --git a/cmd/entire/cli/api_client.go b/cmd/entire/cli/api_client.go index 65f36968f5..06a1edc2ce 100644 --- a/cmd/entire/cli/api_client.go +++ b/cmd/entire/cli/api_client.go @@ -1,6 +1,7 @@ package cli import ( + "context" "errors" "fmt" @@ -8,22 +9,36 @@ import ( "github.com/entireio/cli/cmd/entire/cli/auth" ) -// NewAuthenticatedAPIClient creates an API client using the bearer token -// from the CLI login flow. Returns an error if the user is not logged in. -// Pass insecureHTTP=true to allow plain HTTP base URLs (for local development). -func NewAuthenticatedAPIClient(insecureHTTP bool) (*api.Client, error) { - token, err := auth.LookupCurrentToken() - if err != nil { - return nil, fmt.Errorf("lookup auth token: %w", err) - } - if token == "" { - return nil, errors.New("not logged in (run 'entire login' first)") - } - +// NewAuthenticatedAPIClient creates an API client targeting api.BaseURL() +// (the data API origin) carrying a token valid for that audience. +// +// Resolution: looks up the core token from the keyring, then either uses +// it directly (single-host setup, or when the core token's `aud` already +// covers api.BaseURL()) or performs an RFC 8693 token exchange against +// the auth host to obtain a token scoped to the data API. Exchanged +// tokens are cached in-memory per (core-token, resource) pair. +// +// Pass insecureHTTP=true to allow plain HTTP base URLs for local +// development. Both api.BaseURL() and api.AuthBaseURL() are validated: +// the bearer travels to the data host on resource requests, and the +// core token travels to the auth host during the exchange step. +func NewAuthenticatedAPIClient(ctx context.Context, insecureHTTP bool) (*api.Client, error) { if !insecureHTTP { if err := api.RequireSecureURL(api.BaseURL()); err != nil { return nil, fmt.Errorf("base URL check: %w", err) } + if err := api.RequireSecureURL(api.AuthBaseURL()); err != nil { + return nil, fmt.Errorf("auth base URL check: %w", err) + } } + + token, err := auth.TokenForResource(ctx, api.BaseURL()) + if err != nil { + if errors.Is(err, auth.ErrNotLoggedIn) { + return nil, errors.New("not logged in (run 'entire login' first)") + } + return nil, fmt.Errorf("resolve API token: %w", err) + } + return api.NewClient(token), nil } diff --git a/cmd/entire/cli/auth.go b/cmd/entire/cli/auth.go index 0a441aef07..e5ba946e2c 100644 --- a/cmd/entire/cli/auth.go +++ b/cmd/entire/cli/auth.go @@ -35,6 +35,11 @@ const ( // command that sends a bearer token over the network (login, logout, // auth status/list/revoke) must call this so credentials don't leak over // plaintext HTTP without explicit opt-in. +// +// Both the auth and data API origins are checked: the bearer travels to the +// auth host for login + auth-token management, and to the data host for +// search/activity/dispatch/etc. When ENTIRE_AUTH_BASE_URL is unset they +// resolve to the same URL and the second check is a no-op. func requireSecureBaseURL(insecureHTTPAuth bool) error { if insecureHTTPAuth { return nil @@ -42,6 +47,9 @@ func requireSecureBaseURL(insecureHTTPAuth bool) error { if err := api.RequireSecureURL(api.BaseURL()); err != nil { return fmt.Errorf("base URL check: %w", err) } + if err := api.RequireSecureURL(api.AuthBaseURL()); err != nil { + return fmt.Errorf("auth base URL check: %w", err) + } return nil } @@ -84,7 +92,7 @@ func newAuthStatusCmd() *cobra.Command { return err } return runAuthStatus(cmd.Context(), cmd.OutOrStdout(), - auth.NewStore(), defaultListTokens, api.BaseURL()) + auth.NewStore(), defaultListTokens, api.AuthBaseURL()) }, } addInsecureHTTPAuthFlag(cmd, &insecureHTTPAuth) @@ -92,7 +100,7 @@ func newAuthStatusCmd() *cobra.Command { } func defaultListTokens(ctx context.Context, token string) ([]api.Token, error) { - return api.NewClient(token).ListTokens(ctx) //nolint:wrapcheck // ListTokens already wraps with action context + return api.NewClientWithBaseURL(token, api.AuthBaseURL()).ListTokens(ctx) //nolint:wrapcheck // ListTokens already wraps with action context } func runAuthStatus(ctx context.Context, w io.Writer, store tokenStore, list authTokenLister, baseURL string) error { @@ -135,7 +143,7 @@ func newAuthListCmd() *cobra.Command { return err } return runAuthList(cmd.Context(), cmd.OutOrStdout(), - auth.NewStore(), defaultListTokens, api.BaseURL(), jsonOut) + auth.NewStore(), defaultListTokens, api.AuthBaseURL(), jsonOut) }, } cmd.Flags().BoolVar(&jsonOut, "json", false, "Print tokens as JSON") @@ -409,7 +417,7 @@ func newAuthRevokeCmd() *cobra.Command { } return runAuthRevoke(cmd.Context(), cmd.OutOrStdout(), cmd.ErrOrStderr(), auth.NewStore(), defaultListTokens, defaultRevokeTokenByID, defaultRevokeCurrentToken, - api.BaseURL(), id, revokeCurrent) + api.AuthBaseURL(), id, revokeCurrent) }, } cmd.Flags().BoolVar(&revokeCurrent, "current", false, "Revoke the token used by this CLI and remove the local copy") @@ -418,7 +426,7 @@ func newAuthRevokeCmd() *cobra.Command { } func defaultRevokeTokenByID(ctx context.Context, callerToken, id string) error { - return api.NewClient(callerToken).RevokeToken(ctx, id) //nolint:wrapcheck // RevokeToken already wraps with action context + return api.NewClientWithBaseURL(callerToken, api.AuthBaseURL()).RevokeToken(ctx, id) //nolint:wrapcheck // RevokeToken already wraps with action context } func runAuthRevoke( diff --git a/cmd/entire/cli/auth/client.go b/cmd/entire/cli/auth/client.go index 6743608f61..1e6697a244 100644 --- a/cmd/entire/cli/auth/client.go +++ b/cmd/entire/cli/auth/client.go @@ -48,7 +48,7 @@ func NewClient(httpClient *http.Client) *Client { p := currentProvider() return &Client{inner: &deviceflow.Client{ HTTP: httpClient, - BaseURL: api.BaseURL(), + BaseURL: api.AuthBaseURL(), ClientID: p.clientID, Scope: "cli", UserAgent: p.clientID, @@ -62,7 +62,7 @@ func (c *Client) BaseURL() string { return c.inner.BaseURL } // StartDeviceAuth requests a fresh device code. func (c *Client) StartDeviceAuth(ctx context.Context) (*DeviceAuthStart, error) { - return c.inner.StartDeviceAuth(ctx) + return c.inner.StartDeviceAuth(ctx) //nolint:wrapcheck // shim preserves the lib's wrapped errors verbatim } // PollDeviceAuth polls the token endpoint. On any RFC 8628 §3.5 error, @@ -78,7 +78,7 @@ func (c *Client) PollDeviceAuth(ctx context.Context, deviceCode string) (*Device ErrorDescription: descriptionFromSentinel(err, code), }, nil } - return nil, err + return nil, err //nolint:wrapcheck // shim returns deviceflow errors verbatim so callers can errors.Is on sentinels } return &DeviceAuthPoll{ diff --git a/cmd/entire/cli/auth/exchange.go b/cmd/entire/cli/auth/exchange.go new file mode 100644 index 0000000000..573b4a330a --- /dev/null +++ b/cmd/entire/cli/auth/exchange.go @@ -0,0 +1,90 @@ +package auth + +import ( + "context" + "fmt" + "sync" + + "github.com/entireio/cli/auth/tokenmanager" + "github.com/entireio/cli/cmd/entire/cli/api" +) + +// TokenRequest is the entire-CLI alias of tokenmanager.TokenRequest so +// callers don't have to import the underlying package for the common +// case. The two types are interchangeable. +type TokenRequest = tokenmanager.TokenRequest + +// ErrNotLoggedIn re-exports tokenmanager.ErrNotLoggedIn so callers in +// the cli package can errors.Is against it without an extra import. +var ErrNotLoggedIn = tokenmanager.ErrNotLoggedIn + +var ( + managerOnce sync.Once + manager *tokenmanager.Manager + errManager error + + // managerForTest, when non-nil, is returned by defaultManager() + // instead of constructing the production manager. Tests use + // SetManagerForTest to inject a manager that hits a test STS + // server / in-memory store. Production code never reads this var. + managerForTest *tokenmanager.Manager +) + +// SetManagerForTest installs mgr as the manager returned by +// defaultManager() and returns a cleanup function. Test-only. +func SetManagerForTest(t interface{ Helper() }, mgr *tokenmanager.Manager) func() { + t.Helper() + prev := managerForTest + managerForTest = mgr + return func() { managerForTest = prev } +} + +// defaultManager returns the package-level Manager built from this +// CLI's identity (current provider, AuthBaseURL, NewStore service +// name). Constructed lazily on first use so env-var changes between +// tests are honoured by the first non-test caller in any given +// process. +func defaultManager() (*tokenmanager.Manager, error) { + if managerForTest != nil { + return managerForTest, nil + } + managerOnce.Do(func() { + provider := currentProvider() + m, err := tokenmanager.New(tokenmanager.Config{ + Issuer: api.AuthBaseURL(), + ClientID: provider.clientID, + STSPath: provider.tokenPath, + Store: NewStore(), + UserAgent: provider.clientID, + Scope: "cli", + }) + manager = m + if err != nil { + errManager = fmt.Errorf("build token manager: %w", err) + } + }) + return manager, errManager +} + +// TokenForResource returns a bearer token suitable for use against +// resourceBaseURL, performing an RFC 8693 token exchange when the +// stored core token's audience doesn't already cover that resource. +// See tokenmanager.Manager.TokenForResource for the resolution rules. +func TokenForResource(ctx context.Context, resourceBaseURL string) (string, error) { + m, err := defaultManager() + if err != nil { + return "", err + } + return m.TokenForResource(ctx, resourceBaseURL) //nolint:wrapcheck // shim returns the lib error verbatim +} + +// Token is the full-control entry point. Use TokenForResource for the +// common case; this exists so callers can override the wire-level +// Audience, RequestedTokenType, or Scope per call. +func Token(ctx context.Context, req TokenRequest) (string, error) { + m, err := defaultManager() + if err != nil { + return "", err + } + return m.Token(ctx, req) //nolint:wrapcheck // shim returns the lib error verbatim +} diff --git a/cmd/entire/cli/auth/exchange_test.go b/cmd/entire/cli/auth/exchange_test.go new file mode 100644 index 0000000000..d55c37516f --- /dev/null +++ b/cmd/entire/cli/auth/exchange_test.go @@ -0,0 +1,80 @@ +package auth + +import ( + "context" + "errors" + "testing" + + "github.com/entireio/cli/auth/sts" + "github.com/entireio/cli/auth/tokenmanager" + "github.com/entireio/cli/auth/tokens" + "github.com/entireio/cli/auth/tokenstore" +) + +// memStoreForExchange mirrors the tokenmanager test helper but local +// to this package — the cmd-side test only exercises wiring, so we +// don't import the manager's test fixtures. +type memStoreForExchange struct { + data map[string]tokens.TokenSet +} + +func (s *memStoreForExchange) SaveTokens(profile string, t tokens.TokenSet) error { + s.data[profile] = t + return nil +} + +func (s *memStoreForExchange) LoadTokens(profile string) (tokens.TokenSet, error) { + t, ok := s.data[profile] + if !ok { + return tokens.TokenSet{}, tokenstore.ErrNotFound + } + return t, nil +} + +func (s *memStoreForExchange) DeleteTokens(profile string) error { + delete(s.data, profile) + return nil +} + +// TestTokenForResource_DelegatesToManager verifies the cmd-side shim +// forwards calls to whatever Manager SetManagerForTest installs. The +// underlying behaviour (cache, exchange, audience checks) is covered +// by the tokenmanager package tests. +func TestTokenForResource_DelegatesToManager(t *testing.T) { + // Not parallel: SetManagerForTest mutates package-level state. + store := &memStoreForExchange{data: map[string]tokens.TokenSet{ + "https://auth.example.com": {AccessToken: "core"}, + }} + mgr, err := tokenmanager.New(tokenmanager.Config{ + Issuer: "https://auth.example.com", + ClientID: "test-cli", + STSPath: "/sts/token", + Store: store, + Exchange: func(_ context.Context, _ sts.ExchangeRequest) (*tokens.TokenSet, error) { + return &tokens.TokenSet{AccessToken: "exchanged"}, nil + }, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + cleanup := SetManagerForTest(t, mgr) + t.Cleanup(cleanup) + + got, err := TokenForResource(context.Background(), "https://api.example.com") + if err != nil { + t.Fatalf("TokenForResource: %v", err) + } + if got != "exchanged" { + t.Fatalf("got %q, want exchanged", got) + } +} + +// TestTokenForResource_ReExportsErrNotLoggedIn ensures the cmd-side +// alias matches the underlying sentinel so callers can errors.Is +// against either. +func TestTokenForResource_ReExportsErrNotLoggedIn(t *testing.T) { + t.Parallel() + if !errors.Is(ErrNotLoggedIn, tokenmanager.ErrNotLoggedIn) { + t.Fatal("auth.ErrNotLoggedIn must alias tokenmanager.ErrNotLoggedIn") + } +} diff --git a/cmd/entire/cli/auth/provider.go b/cmd/entire/cli/auth/provider.go index fb47733933..d5b681521c 100644 --- a/cmd/entire/cli/auth/provider.go +++ b/cmd/entire/cli/auth/provider.go @@ -31,7 +31,7 @@ var providers = map[string]providerConfig{ tokenPath: "/oauth/token", }, "v2": { //nolint:gosec // OAuth client_id and endpoint paths, not credentials - clientID: "cli", + clientID: "entire-cli", deviceCodePath: "/api/auth/oauth/device/code", tokenPath: "/api/auth/token", }, diff --git a/cmd/entire/cli/auth/provider_test.go b/cmd/entire/cli/auth/provider_test.go index e95bda444f..8a29667ed6 100644 --- a/cmd/entire/cli/auth/provider_test.go +++ b/cmd/entire/cli/auth/provider_test.go @@ -9,9 +9,12 @@ import ( // Test-local mirrors of the v1 / v2 client_id values, so assertions // don't repeat the same string literal across multiple tests (goconst). +// Both providers now share the same client_id; the constants are kept +// distinct so a future divergence (or a regression that re-splits them) +// shows up at a single edit site. const ( wantClientIDV1 = "entire-cli" - wantClientIDV2 = "cli" + wantClientIDV2 = "entire-cli" ) func TestCurrentProvider_DefaultsToV1(t *testing.T) { diff --git a/cmd/entire/cli/auth/store.go b/cmd/entire/cli/auth/store.go index a21bae568d..7391c9ef32 100644 --- a/cmd/entire/cli/auth/store.go +++ b/cmd/entire/cli/auth/store.go @@ -40,7 +40,7 @@ func (s *Store) SaveToken(baseURL, token string) error { if token == "" { return errors.New("refusing to save empty token") } - return s.inner.SaveTokens(baseURL, tokens.TokenSet{AccessToken: token}) + return s.inner.SaveTokens(baseURL, tokens.TokenSet{AccessToken: token}) //nolint:wrapcheck // shim returns the lib error verbatim } // GetToken retrieves a stored token for the given base URL. Returns @@ -71,10 +71,45 @@ func (s *Store) GetToken(baseURL string) (string, error) { // DeleteToken removes a stored token for the given base URL. func (s *Store) DeleteToken(baseURL string) error { - return s.inner.DeleteTokens(baseURL) + return s.inner.DeleteTokens(baseURL) //nolint:wrapcheck // shim returns the lib error verbatim } -// LookupCurrentToken retrieves the token for the current base URL. +// SaveTokens implements tokenstore.Store. Used by the tokenmanager. +func (s *Store) SaveTokens(profile string, t tokens.TokenSet) error { + return s.inner.SaveTokens(profile, t) //nolint:wrapcheck // shim returns the lib error verbatim +} + +// LoadTokens implements tokenstore.Store, preserving the legacy bare-string +// fallback path so users with pre-shim keyring entries don't appear logged +// out after upgrading. +func (s *Store) LoadTokens(profile string) (tokens.TokenSet, error) { + t, err := s.inner.LoadTokens(profile) + if err == nil { + return t, nil + } + if !errors.Is(err, tokenstore.ErrNotFound) { + return tokens.TokenSet{}, err //nolint:wrapcheck // shim returns the lib error verbatim + } + + raw, kerr := keyring.Get(s.inner.Service, profile) + if errors.Is(kerr, keyring.ErrNotFound) { + return tokens.TokenSet{}, tokenstore.ErrNotFound + } + if kerr != nil { + return tokens.TokenSet{}, fmt.Errorf("get token from keyring: %w", kerr) + } + return tokens.TokenSet{AccessToken: raw}, nil +} + +// DeleteTokens implements tokenstore.Store. +func (s *Store) DeleteTokens(profile string) error { + return s.inner.DeleteTokens(profile) //nolint:wrapcheck // shim returns the lib error verbatim +} + +// LookupCurrentToken retrieves the token for the current auth base URL. +// Tokens are keyed by the auth issuer (api.AuthBaseURL()) since that's the +// host that minted them; in single-host deployments AuthBaseURL falls back +// to BaseURL so behaviour is unchanged. func LookupCurrentToken() (string, error) { - return NewStore().GetToken(api.BaseURL()) + return NewStore().GetToken(api.AuthBaseURL()) } diff --git a/cmd/entire/cli/dispatch_wizard.go b/cmd/entire/cli/dispatch_wizard.go index dadd60ca40..88e5551cab 100644 --- a/cmd/entire/cli/dispatch_wizard.go +++ b/cmd/entire/cli/dispatch_wizard.go @@ -29,7 +29,7 @@ var getDispatchWizardCurrentBranch = GetCurrentBranch var runDispatchWizardForm = func(form *huh.Form) error { return form.Run() } func defaultListDispatchWizardRepoResources(ctx context.Context) ([]api.Repository, error) { - client, err := NewAuthenticatedAPIClient(false) + client, err := NewAuthenticatedAPIClient(ctx, false) if err != nil { return nil, err } diff --git a/cmd/entire/cli/logout.go b/cmd/entire/cli/logout.go index 77c8533976..b34b7496e2 100644 --- a/cmd/entire/cli/logout.go +++ b/cmd/entire/cli/logout.go @@ -33,7 +33,7 @@ func newLogoutCmd() *cobra.Command { return err } return runLogout(cmd.Context(), cmd.OutOrStdout(), cmd.ErrOrStderr(), - auth.NewStore(), defaultRevokeCurrentToken, api.BaseURL()) + auth.NewStore(), defaultRevokeCurrentToken, api.AuthBaseURL()) }, } addInsecureHTTPAuthFlag(cmd, &insecureHTTPAuth) @@ -41,7 +41,7 @@ func newLogoutCmd() *cobra.Command { } func defaultRevokeCurrentToken(ctx context.Context, token string) error { - return api.NewClient(token).RevokeCurrentToken(ctx) //nolint:wrapcheck // RevokeCurrentToken already wraps with action context + return api.NewClientWithBaseURL(token, api.AuthBaseURL()).RevokeCurrentToken(ctx) //nolint:wrapcheck // RevokeCurrentToken already wraps with action context } func runLogout(ctx context.Context, outW, errW io.Writer, store tokenStore, revoke revokeCurrentFunc, baseURL string) error { diff --git a/cmd/entire/cli/recap.go b/cmd/entire/cli/recap.go index cd84b448a1..6b05f89f0b 100644 --- a/cmd/entire/cli/recap.go +++ b/cmd/entire/cli/recap.go @@ -120,7 +120,7 @@ func runRecap(ctx context.Context, w io.Writer, f *recapFlags) error { if err != nil { return err } - client, err := NewAuthenticatedAPIClient(f.insecureHTTP) + client, err := NewAuthenticatedAPIClient(ctx, f.insecureHTTP) if err != nil { fmt.Fprintln(w, "Sign in with `entire login` to use `entire recap`.") return NewSilentError(err) diff --git a/cmd/entire/cli/search/search.go b/cmd/entire/cli/search/search.go index 2e58557da8..b9cf8c8adf 100644 --- a/cmd/entire/cli/search/search.go +++ b/cmd/entire/cli/search/search.go @@ -66,7 +66,10 @@ type Response struct { // Config holds the configuration for a search request. type Config struct { - ServiceURL string // Base URL of the search service + ServiceURL string // Base URL of the search service + // GitHubToken is a misnomer kept for backwards compatibility: callers + // populate it with the OAuth bearer from auth.LookupCurrentToken(). + // The wire format is unchanged (Authorization: Bearer ). GitHubToken string Owner string Repo string diff --git a/cmd/entire/cli/search_cmd.go b/cmd/entire/cli/search_cmd.go index 299f46f186..d8d1245845 100644 --- a/cmd/entire/cli/search_cmd.go +++ b/cmd/entire/cli/search_cmd.go @@ -108,7 +108,10 @@ branch:, repo:, and repo:* to search all accessible repos.`, serviceURL := os.Getenv("ENTIRE_SEARCH_URL") if serviceURL == "" { - serviceURL = search.DefaultServiceURL + // Honour ENTIRE_API_BASE_URL: search lives on the data API + // host. Fall back to search.DefaultServiceURL only when no + // API base URL is configured (production default). + serviceURL = api.BaseURL() } searchCfg := search.Config{ @@ -207,7 +210,7 @@ branch:, repo:, and repo:* to search all accessible repos.`, // must never pollute the user's prompt with error output. func completeRepoFlag(cmd *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) { suggestions := []string{"*"} - client, err := NewAuthenticatedAPIClient(false) + client, err := NewAuthenticatedAPIClient(cmd.Context(), false) if err != nil { return suggestions, cobra.ShellCompDirectiveNoFileComp } diff --git a/cmd/entire/cli/trail_cmd.go b/cmd/entire/cli/trail_cmd.go index a6eee48a20..39c9454bbf 100644 --- a/cmd/entire/cli/trail_cmd.go +++ b/cmd/entire/cli/trail_cmd.go @@ -67,7 +67,7 @@ func runTrailShow(ctx context.Context, w io.Writer, insecureHTTP bool) error { return runTrailListAll(ctx, w, "", false, false, insecureHTTP) } - client, err := NewAuthenticatedAPIClient(insecureHTTP) + client, err := NewAuthenticatedAPIClient(ctx, insecureHTTP) if err != nil { return fmt.Errorf("authentication required: %w", err) } @@ -130,7 +130,7 @@ func newTrailListCmd() *cobra.Command { } func runTrailListAll(ctx context.Context, w io.Writer, statusFilter string, jsonOutput, showAll, insecureHTTP bool) error { - client, err := NewAuthenticatedAPIClient(insecureHTTP) + client, err := NewAuthenticatedAPIClient(ctx, insecureHTTP) if err != nil { return fmt.Errorf("authentication required: %w", err) } @@ -317,7 +317,7 @@ func runTrailCreate(cmd *cobra.Command, title, body, base, branch, statusStr str // --- Phase 2: API operations --- - client, err := NewAuthenticatedAPIClient(trailInsecureHTTP(cmd)) + client, err := NewAuthenticatedAPIClient(cmd.Context(), trailInsecureHTTP(cmd)) if err != nil { return fmt.Errorf("authentication required: %w", err) } @@ -404,7 +404,7 @@ func newTrailUpdateCmd() *cobra.Command { func runTrailUpdate(ctx context.Context, w, errW io.Writer, insecureHTTP bool, statusStr, title, body, branch string, labelAdd, labelRemove []string) error { _ = errW // reserved for future warnings - client, err := NewAuthenticatedAPIClient(insecureHTTP) + client, err := NewAuthenticatedAPIClient(ctx, insecureHTTP) if err != nil { return fmt.Errorf("authentication required: %w", err) } From d9322bc2ac556cd6fd69232e40ccd037282240b6 Mon Sep 17 00:00:00 2001 From: Alex Ong Date: Fri, 8 May 2026 16:02:21 +1000 Subject: [PATCH 10/21] auth: route STS to provider.stsPath; make STSPath optional in tokenmanager Two fixes that came out of getting `entire trail list` working against partial.to's split-host deployment: - Provider config now carries an stsPath alongside the OAuth token endpoint. v2's STS lives at /api/authz/sts/token, distinct from the /api/auth/token OAuth endpoint that rejects token-exchange grants with unsupported_grant_type. cmd/entire/cli/auth/exchange.go now passes provider.stsPath (rather than provider.tokenPath) into the tokenmanager. - v1 is the legacy single-host surface (entire.io for both auth and data API), so the same-host shortcut in tokenmanager.Token always wins and STS is never invoked. v1.stsPath is left empty. - tokenmanager.Config.STSPath is now optional. New() no longer rejects empty STSPath; runExchange() returns the new ErrNoSTSPath sentinel if an exchange is actually attempted with no path configured. Single-host setups (incl. v1) need no STS endpoint; split-host misconfigurations fail loudly at the right layer. Tests updated to cover empty-STSPath construction, the ErrNoSTSPath path, and v1's empty stsPath contract. Co-Authored-By: Claude Opus 4.7 (1M context) --- auth/tokenmanager/tokenmanager.go | 17 ++++++++--- auth/tokenmanager/tokenmanager_test.go | 40 +++++++++++++++++++++++++- cmd/entire/cli/auth/exchange.go | 2 +- cmd/entire/cli/auth/provider.go | 9 ++++++ cmd/entire/cli/auth/provider_test.go | 8 ++++++ 5 files changed, 70 insertions(+), 6 deletions(-) diff --git a/auth/tokenmanager/tokenmanager.go b/auth/tokenmanager/tokenmanager.go index 205a7bf958..f410eb203a 100644 --- a/auth/tokenmanager/tokenmanager.go +++ b/auth/tokenmanager/tokenmanager.go @@ -46,6 +46,11 @@ const exchangeSkew = 30 * time.Second // "run " message. var ErrNotLoggedIn = errors.New("not logged in") +// ErrNoSTSPath is returned when an exchange is needed but Config.STSPath +// is empty. Single-host deployments hit the same-host shortcut and never +// reach this; split-host deployments must configure STSPath. +var ErrNoSTSPath = errors.New("token exchange required but Config.STSPath is empty") + // Config configures a Manager. type Config struct { // Issuer is the auth host base URL where the device-flow login @@ -59,8 +64,10 @@ type Config struct { ClientID string // STSPath is the path on Issuer where token-exchange requests are - // POSTed. Typically the OAuth token endpoint (RFC 8693 convention). - // Required. + // POSTed. Optional: single-host deployments never trigger an + // exchange (the same-host shortcut wins) so they can leave it + // empty. When empty and an exchange is attempted, runExchange + // returns ErrNoSTSPath rather than POSTing to a bogus URL. STSPath string // Store persists the core token. Required. Use any tokenstore.Store @@ -95,8 +102,6 @@ func (c Config) validate() error { return errors.New("Config.Issuer is required") case strings.TrimSpace(c.ClientID) == "": return errors.New("Config.ClientID is required") - case strings.TrimSpace(c.STSPath) == "": - return errors.New("Config.STSPath is required") case c.Store == nil: return errors.New("Config.Store is required") } @@ -322,6 +327,10 @@ func (m *Manager) runExchange(ctx context.Context, coreToken string, req TokenRe return m.cfg.Exchange(ctx, stsReq) } + if strings.TrimSpace(m.cfg.STSPath) == "" { + return nil, ErrNoSTSPath + } + stsClient := &sts.Client{ HTTP: m.cfg.HTTPClient, BaseURL: m.cfg.Issuer, diff --git a/auth/tokenmanager/tokenmanager_test.go b/auth/tokenmanager/tokenmanager_test.go index e87679ef90..f3710d9656 100644 --- a/auth/tokenmanager/tokenmanager_test.go +++ b/auth/tokenmanager/tokenmanager_test.go @@ -82,7 +82,6 @@ func TestNew_RequiresFields(t *testing.T) { }{ {"missing issuer", Config{ClientID: "x", STSPath: "/p", Store: newMemStore()}}, {"missing clientID", Config{Issuer: "https://x", STSPath: "/p", Store: newMemStore()}}, - {"missing STSPath", Config{Issuer: "https://x", ClientID: "x", Store: newMemStore()}}, {"missing Store", Config{Issuer: "https://x", ClientID: "x", STSPath: "/p"}}, } for _, tc := range cases { @@ -95,6 +94,45 @@ func TestNew_RequiresFields(t *testing.T) { } } +// TestNew_AllowsEmptySTSPath documents that single-host configs can +// omit STSPath because the same-host shortcut always wins. The error +// surfaces only if an exchange is actually attempted. +func TestNew_AllowsEmptySTSPath(t *testing.T) { + t.Parallel() + if _, err := New(Config{ + Issuer: testIssuer, + ClientID: testClientID, + Store: newMemStore(), + }); err != nil { + t.Fatalf("New: %v", err) + } +} + +// TestExchange_FailsWithoutSTSPath checks that triggering an exchange +// against a manager configured without an STS path returns ErrNoSTSPath +// (rather than POSTing to a bogus URL). +func TestExchange_FailsWithoutSTSPath(t *testing.T) { + t.Parallel() + core := makeJWTWithAudience(t, []string{testIssuer}) + store := newMemStore() + store.data[testIssuer] = tokens.TokenSet{AccessToken: core} + + m, err := New(Config{ + Issuer: testIssuer, + ClientID: testClientID, + Store: store, + // STSPath intentionally empty + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + _, err = m.TokenForResource(context.Background(), testResource) + if !errors.Is(err, ErrNoSTSPath) { + t.Fatalf("err = %v, want ErrNoSTSPath", err) + } +} + func TestNew_DefaultRequestedTokenType(t *testing.T) { t.Parallel() m, err := New(Config{Issuer: testIssuer, ClientID: testClientID, STSPath: testSTSPath, Store: newMemStore()}) diff --git a/cmd/entire/cli/auth/exchange.go b/cmd/entire/cli/auth/exchange.go index 573b4a330a..bddf13496b 100644 --- a/cmd/entire/cli/auth/exchange.go +++ b/cmd/entire/cli/auth/exchange.go @@ -53,7 +53,7 @@ func defaultManager() (*tokenmanager.Manager, error) { m, err := tokenmanager.New(tokenmanager.Config{ Issuer: api.AuthBaseURL(), ClientID: provider.clientID, - STSPath: provider.tokenPath, + STSPath: provider.stsPath, Store: NewStore(), UserAgent: provider.clientID, Scope: "cli", diff --git a/cmd/entire/cli/auth/provider.go b/cmd/entire/cli/auth/provider.go index d5b681521c..804fe22977 100644 --- a/cmd/entire/cli/auth/provider.go +++ b/cmd/entire/cli/auth/provider.go @@ -18,10 +18,18 @@ import ( const ProviderVersionEnvVar = "ENTIRE_AUTH_PROVIDER_VERSION" // providerConfig captures the per-surface bits of OAuth wiring. +// +// stsPath is the RFC 8693 token-exchange endpoint. v1 is the legacy +// single-host surface where the auth and data API live at the same +// origin (entire.io); the same-host shortcut in tokenmanager.Token +// always wins and STS is never invoked, so v1.stsPath is left empty. +// v2 exposes a dedicated STS path because it's used in split-host +// deployments (e.g. us.auth.partial.to mints, partial.to consumes). type providerConfig struct { clientID string deviceCodePath string tokenPath string + stsPath string } var providers = map[string]providerConfig{ @@ -34,6 +42,7 @@ var providers = map[string]providerConfig{ clientID: "entire-cli", deviceCodePath: "/api/auth/oauth/device/code", tokenPath: "/api/auth/token", + stsPath: "/api/authz/sts/token", }, } diff --git a/cmd/entire/cli/auth/provider_test.go b/cmd/entire/cli/auth/provider_test.go index 8a29667ed6..ddf7550aa9 100644 --- a/cmd/entire/cli/auth/provider_test.go +++ b/cmd/entire/cli/auth/provider_test.go @@ -33,6 +33,11 @@ func TestCurrentProvider_V1Explicit(t *testing.T) { if p.clientID != wantClientIDV1 { t.Fatalf("v1 clientID = %q", p.clientID) } + // v1 is single-host (entire.io); no STS surface, same-host shortcut + // always wins. Empty stsPath is the contract. + if p.stsPath != "" { + t.Fatalf("v1 stsPath = %q, want empty (single-host, no STS)", p.stsPath) + } } func TestCurrentProvider_V2(t *testing.T) { @@ -48,6 +53,9 @@ func TestCurrentProvider_V2(t *testing.T) { if p.tokenPath != "/api/auth/token" { t.Fatalf("v2 tokenPath = %q", p.tokenPath) } + if p.stsPath != "/api/authz/sts/token" { + t.Fatalf("v2 stsPath = %q", p.stsPath) + } } func TestCurrentProvider_UnknownDefaultsToV1(t *testing.T) { From ead027cf9d5580bb428098a53f50c17fe985e4a0 Mon Sep 17 00:00:00 2001 From: Alex Ong Date: Fri, 8 May 2026 16:02:33 +1000 Subject: [PATCH 11/21] search: route bearer through auth.TokenForResource `entire search` was sending the raw core token to the search service, which on split-host deployments has the wrong audience (auth host issuer, not the data API). Switch to auth.TokenForResource(ctx, serviceURL) so the bearer is exchange-resolved against the search service URL: same-host shortcut keeps single-host setups unchanged, split-host setups now get an exchanged token with aud=entire-api. Also moves the auth lookup after the git/repo plumbing so the resource URL (which can come from ENTIRE_SEARCH_URL or api.BaseURL()) is known at the time we resolve the bearer. Co-Authored-By: Claude Opus 4.7 (1M context) --- cmd/entire/cli/search_cmd.go | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/cmd/entire/cli/search_cmd.go b/cmd/entire/cli/search_cmd.go index d8d1245845..a654c7ddba 100644 --- a/cmd/entire/cli/search_cmd.go +++ b/cmd/entire/cli/search_cmd.go @@ -76,14 +76,6 @@ branch:, repo:, and repo:* to search all accessible repos.`, return errors.New("query required when using --json, accessible mode, or piped output. Usage: entire search ") } - ghToken, err := auth.LookupCurrentToken() - if err != nil { - return fmt.Errorf("reading credentials: %w", err) - } - if ghToken == "" { - return errors.New("not authenticated. Run 'entire login' to authenticate") - } - // Get the repo's GitHub remote URL repo, err := strategy.OpenRepository(ctx) if err != nil { @@ -114,6 +106,19 @@ branch:, repo:, and repo:* to search all accessible repos.`, serviceURL = api.BaseURL() } + // Resolve a bearer scoped to the search service host. In split-host + // deployments this triggers an RFC 8693 exchange so the bearer + // carries the data-API audience rather than the auth-host one; + // single-host setups hit the same-host shortcut and return the + // core token unchanged. + ghToken, err := auth.TokenForResource(ctx, serviceURL) + if errors.Is(err, auth.ErrNotLoggedIn) { + return errors.New("not authenticated. Run 'entire login' to authenticate") + } + if err != nil { + return fmt.Errorf("reading credentials: %w", err) + } + searchCfg := search.Config{ ServiceURL: serviceURL, GitHubToken: ghToken, From 16746fd66068ff85eec953c4448830820a2f3703 Mon Sep 17 00:00:00 2001 From: Alex Ong Date: Fri, 8 May 2026 16:18:15 +1000 Subject: [PATCH 12/21] auth: fix legacy keyring fallback + cover gaps surfaced by review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Code review surfaced a real correctness bug plus a handful of clarity/coverage gaps. This commit fixes them in one pass. Critical fix — Store.LoadTokens legacy bare-string fallback was dead code: - tokenstore.Keyring.LoadTokens returned "unmarshal TokenSet: ..." for pre-shim bare-string entries, not ErrNotFound. The cmd-side shim's fallback only fired on ErrNotFound, so users with pre-shim keyring entries appeared logged out after upgrading to the manager-backed code path (entire trail/search/etc.). The legacy GetToken path was separately over-permissive: it fell back on any error, masking real keyring errors. - Add tokenstore.ErrMalformed sentinel returned (wrapped) by decodeTokenSet on JSON unmarshal or expires_at parse failures. - Update Store.LoadTokens / Store.GetToken to fall back precisely on ErrMalformed (legacy path) and surface ErrNotFound + real keyring errors verbatim. Regression tests pre-seed bare-string keyring entries and assert the round-trip. api.bearerTransport: reject empty bearer at first request rather than sending Authorization: Bearer on the wire (which produces a confusing 401). New errEmptyBearerToken sentinel. api/auth_tokens: add table-driven test that pins the ENTIRE_AUTH_PROVIDER_VERSION → path mapping (v1/v2/unrecognised/ whitespace) plus an end-to-end ListTokens routing check. The path switch is the whole point of the version env var; it had no test. Doc fixes (review found these stale or misleading): - auth/doc.go: list tokenmanager subpackage (was missing). - auth/tokenstore/tokenstore.go: drop the "File impl" claim — only Keyring ships today. - auth/tokenstore/keyring.go: collapse the duplicated keyringTokenSet comment paragraph, drop the dangling G117 reference. - cmd/entire/cli/auth/exchange.go: defaultManager rationale corrected (sync.Once means later env-var changes are ignored, not honoured); TokenForResource doc points at Manager.Token (the rules live there, not on TokenForResource). - cmd/entire/cli/api_client.go: cache-key undercount — list all wire-affecting fields rather than just (core-token, resource). - cmd/entire/cli/search_cmd.go: rewrite the misleading "fall back to search.DefaultServiceURL" comment (the fallback is api.BaseURL()). Co-Authored-By: Claude Opus 4.7 (1M context) --- auth/doc.go | 10 +++-- auth/tokenstore/keyring.go | 11 +++--- auth/tokenstore/tokenstore.go | 17 ++++----- cmd/entire/cli/api/auth_tokens_test.go | 53 ++++++++++++++++++++++++++ cmd/entire/cli/api/client.go | 8 ++++ cmd/entire/cli/api/client_test.go | 21 ++++++++++ cmd/entire/cli/api_client.go | 3 +- cmd/entire/cli/auth/exchange.go | 9 +++-- cmd/entire/cli/auth/store.go | 22 ++++++++--- cmd/entire/cli/auth/store_test.go | 46 ++++++++++++++++++++++ cmd/entire/cli/search_cmd.go | 7 ++-- 11 files changed, 174 insertions(+), 33 deletions(-) diff --git a/auth/doc.go b/auth/doc.go index 1e3f0571ef..25c175fe5c 100644 --- a/auth/doc.go +++ b/auth/doc.go @@ -2,10 +2,12 @@ // // All real code lives in the subpackages: // -// - deviceflow — RFC 8628 OAuth 2.0 Device Authorization Grant client -// - tokens — TokenSet plus unverified JWT claim parsing -// - tokenstore — pluggable persistence interface with reference impls -// - sts — RFC 8693 Token Exchange client +// - deviceflow — RFC 8628 OAuth 2.0 Device Authorization Grant client +// - sts — RFC 8693 Token Exchange client +// - tokens — TokenSet plus unverified JWT claim parsing +// - tokenstore — pluggable persistence interface with reference impls +// - tokenmanager — orchestrates core-token storage + STS exchanges, +// with caching and a JWT-audience shortcut // // The library is designed to talk RFC 8628 and RFC 8693 to any compliant // OAuth 2.0 server. It contains no provider-specific behaviour; endpoint diff --git a/auth/tokenstore/keyring.go b/auth/tokenstore/keyring.go index 318c76f92e..f80a2e9ddf 100644 --- a/auth/tokenstore/keyring.go +++ b/auth/tokenstore/keyring.go @@ -79,10 +79,9 @@ func (k *Keyring) DeleteTokens(profile string) error { // keyringTokenSet is the on-keyring JSON shape. Time fields are // serialised as RFC 3339 strings so the wire form survives keyring -// implementations that don't preserve byte-for-byte equality. -// keyringTokenSet is the wire shape; access_token is intentionally -// serialised so the OS keyring (encrypted at rest) holds the full -// TokenSet for round-tripping. The G117 lint flag is suppressed below. +// implementations that don't preserve byte-for-byte equality. The +// access_token is intentionally serialised so the OS keyring +// (encrypted at rest) holds the full TokenSet for round-tripping. type keyringTokenSet struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token,omitempty"` @@ -112,7 +111,7 @@ func encodeTokenSet(t tokens.TokenSet) (string, error) { func decodeTokenSet(raw string) (tokens.TokenSet, error) { var wire keyringTokenSet if err := json.Unmarshal([]byte(raw), &wire); err != nil { - return tokens.TokenSet{}, fmt.Errorf("unmarshal TokenSet: %w", err) + return tokens.TokenSet{}, fmt.Errorf("%w: unmarshal TokenSet: %w", ErrMalformed, err) } t := tokens.TokenSet{ @@ -124,7 +123,7 @@ func decodeTokenSet(raw string) (tokens.TokenSet, error) { if wire.ExpiresAt != "" { exp, err := time.Parse(time.RFC3339, wire.ExpiresAt) if err != nil { - return tokens.TokenSet{}, fmt.Errorf("parse expires_at: %w", err) + return tokens.TokenSet{}, fmt.Errorf("%w: parse expires_at: %w", ErrMalformed, err) } t.ExpiresAt = exp.UTC() } diff --git a/auth/tokenstore/tokenstore.go b/auth/tokenstore/tokenstore.go index 875308a37a..40e5a6619e 100644 --- a/auth/tokenstore/tokenstore.go +++ b/auth/tokenstore/tokenstore.go @@ -1,13 +1,5 @@ // Package tokenstore is the persistence interface for tokens, plus -// reference implementations. -// -// Callers pick a Store at startup. Two impls ship with this package: -// -// - Keyring stores one entry per profile in the OS keyring. Suitable -// for interactive single-user CLIs. -// - File stores entries in a JSON file on disk, with refresh tokens -// in the OS keyring. Suitable for CLIs that need to persist -// additional per-profile metadata (e.g. context bindings). +// the Keyring reference implementation. // // Profile is whatever string the caller wants to key by — typically a // base URL, a kubectl-style context name, or a principal handle. @@ -23,6 +15,13 @@ import ( // distinguish "not logged in" from genuine errors with errors.Is. var ErrNotFound = errors.New("token not found") +// ErrMalformed is returned (wrapped) when a stored entry exists but +// can't be decoded into a TokenSet. Used by callers that want to treat +// a malformed entry as a legacy/upgrade path (e.g. pre-shim bare-string +// entries from older binaries) without confusing it with transport +// errors from the underlying keyring. +var ErrMalformed = errors.New("malformed token entry") + // Store persists token bundles keyed by an opaque profile string. // // Implementations must: diff --git a/cmd/entire/cli/api/auth_tokens_test.go b/cmd/entire/cli/api/auth_tokens_test.go index 8fc354c4ab..c23d5155a6 100644 --- a/cmd/entire/cli/api/auth_tokens_test.go +++ b/cmd/entire/cli/api/auth_tokens_test.go @@ -198,3 +198,56 @@ func TestClient_RevokeToken_ReturnsErrorBody(t *testing.T) { t.Errorf("IsHTTPErrorStatus(err, 404) = false; err = %v", err) } } + +// TestAuthTokensBasePath_ProviderVersionRouting locks in the path +// switch so v2 doesn't silently regress to v1's path family. The whole +// reason the version env var exists is to route requests at this layer. +func TestAuthTokensBasePath_ProviderVersionRouting(t *testing.T) { + cases := []struct { + name string + version string + want string + }{ + {"unset defaults to v1", "", "/api/v1/auth/tokens"}, + {"v1 explicit", "v1", "/api/v1/auth/tokens"}, + {"v2", "v2", "/api/auth/tokens"}, + {"unrecognised defaults to v1", "v999", "/api/v1/auth/tokens"}, + // Whitespace trimming must match auth.currentProvider() — both + // trim, so the api and auth packages agree on what "v2" means. + // If either side stops trimming, these tests diverge first. + {"trims whitespace then matches v2", " v2 ", "/api/auth/tokens"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Setenv(authTokensProviderVersionEnvVar, tc.version) + if got := authTokensBasePath(); got != tc.want { + t.Fatalf("authTokensBasePath() = %q, want %q", got, tc.want) + } + }) + } +} + +// TestClient_ListTokens_RoutesV2Path is an end-to-end check that the +// version switch flows through the public Client API, not just the +// internal helper. +func TestClient_ListTokens_RoutesV2Path(t *testing.T) { + t.Setenv(authTokensProviderVersionEnvVar, "v2") + + var gotPath string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"tokens":[]}`)) //nolint:errcheck // test handler + })) + defer server.Close() + + c := NewClient("tok") + c.baseURL = server.URL + + if _, err := c.ListTokens(context.Background()); err != nil { + t.Fatalf("ListTokens: %v", err) + } + if gotPath != "/api/auth/tokens" { + t.Fatalf("path = %q, want /api/auth/tokens (v2)", gotPath) + } +} diff --git a/cmd/entire/cli/api/client.go b/cmd/entire/cli/api/client.go index c594b0c261..b568214a77 100644 --- a/cmd/entire/cli/api/client.go +++ b/cmd/entire/cli/api/client.go @@ -51,7 +51,15 @@ type bearerTransport struct { base http.RoundTripper } +// errEmptyBearerToken surfaces at first request rather than at construction +// because NewClient* don't return errors. An empty bearer otherwise becomes +// "Authorization: Bearer " on the wire and produces a confusing 401. +var errEmptyBearerToken = errors.New("api: refusing to send request with empty bearer token (construct via NewAuthenticatedAPIClient)") + func (t *bearerTransport) RoundTrip(req *http.Request) (*http.Response, error) { + if t.token == "" { + return nil, errEmptyBearerToken + } // Clone the request to avoid mutating the caller's request. r := req.Clone(req.Context()) r.Header.Set("Authorization", "Bearer "+t.token) diff --git a/cmd/entire/cli/api/client_test.go b/cmd/entire/cli/api/client_test.go index f5977b7a72..de105409f1 100644 --- a/cmd/entire/cli/api/client_test.go +++ b/cmd/entire/cli/api/client_test.go @@ -3,6 +3,7 @@ package api import ( "context" "encoding/json" + "errors" "io" "net/http" "net/http/httptest" @@ -300,3 +301,23 @@ func TestDecodeJSONResponse(t *testing.T) { t.Errorf("Status = %q, want %q", result.Status, "ok") } } + +// TestBearerTransport_RejectsEmptyToken locks in the early-failure +// behaviour for an accidentally-empty bearer. Without this guard the +// CLI would put "Authorization: Bearer " on the wire and the server +// would respond with a confusing 401. +func TestBearerTransport_RejectsEmptyToken(t *testing.T) { + t.Parallel() + + c := NewClientWithBaseURL("", "https://example.test") + resp, err := c.Get(context.Background(), "/probe") + if resp != nil { + _ = resp.Body.Close() + } + if err == nil { + t.Fatal("Get with empty token must error") + } + if !errors.Is(err, errEmptyBearerToken) { + t.Fatalf("err = %v, want errEmptyBearerToken sentinel", err) + } +} diff --git a/cmd/entire/cli/api_client.go b/cmd/entire/cli/api_client.go index 06a1edc2ce..453c050e3b 100644 --- a/cmd/entire/cli/api_client.go +++ b/cmd/entire/cli/api_client.go @@ -16,7 +16,8 @@ import ( // it directly (single-host setup, or when the core token's `aud` already // covers api.BaseURL()) or performs an RFC 8693 token exchange against // the auth host to obtain a token scoped to the data API. Exchanged -// tokens are cached in-memory per (core-token, resource) pair. +// tokens are cached in-memory keyed off the wire-affecting fields of +// the request — see tokenmanager.cacheKey for the precise key shape. // // Pass insecureHTTP=true to allow plain HTTP base URLs for local // development. Both api.BaseURL() and api.AuthBaseURL() are validated: diff --git a/cmd/entire/cli/auth/exchange.go b/cmd/entire/cli/auth/exchange.go index bddf13496b..f888425fef 100644 --- a/cmd/entire/cli/auth/exchange.go +++ b/cmd/entire/cli/auth/exchange.go @@ -41,9 +41,10 @@ func SetManagerForTest(t interface{ Helper() }, mgr *tokenmanager.Manager) func( // defaultManager returns the package-level Manager built from this // CLI's identity (current provider, AuthBaseURL, NewStore service -// name). Constructed lazily on first use so env-var changes between -// tests are honoured by the first non-test caller in any given -// process. +// name). Constructed lazily on first use so any env-var setup +// (ENTIRE_AUTH_BASE_URL, ENTIRE_AUTH_PROVIDER_VERSION) lands before +// construction. sync.Once means later env-var changes within the same +// process are ignored; tests bypass the singleton via SetManagerForTest. func defaultManager() (*tokenmanager.Manager, error) { if managerForTest != nil { return managerForTest, nil @@ -69,7 +70,7 @@ func defaultManager() (*tokenmanager.Manager, error) { // TokenForResource returns a bearer token suitable for use against // resourceBaseURL, performing an RFC 8693 token exchange when the // stored core token's audience doesn't already cover that resource. -// See tokenmanager.Manager.TokenForResource for the resolution rules. +// See tokenmanager.Manager.Token for the full resolution rules. func TokenForResource(ctx context.Context, resourceBaseURL string) (string, error) { m, err := defaultManager() if err != nil { diff --git a/cmd/entire/cli/auth/store.go b/cmd/entire/cli/auth/store.go index 7391c9ef32..5d344bb481 100644 --- a/cmd/entire/cli/auth/store.go +++ b/cmd/entire/cli/auth/store.go @@ -46,8 +46,11 @@ func (s *Store) SaveToken(baseURL, token string) error { // GetToken retrieves a stored token for the given base URL. Returns // an empty string (and no error) if no token is stored. // -// Falls back to a bare-string read to surface tokens written before -// the shim landed. +// Falls back to a bare-string read when the stored entry is malformed +// JSON, to handle pre-shim entries that stored the raw access token +// rather than a JSON-encoded TokenSet. Real keyring errors (transport, +// permission denied) propagate; only ErrNotFound and ErrMalformed +// trigger the fallback. func (s *Store) GetToken(baseURL string) (string, error) { t, err := s.inner.LoadTokens(baseURL) if err == nil { @@ -56,9 +59,10 @@ func (s *Store) GetToken(baseURL string) (string, error) { if errors.Is(err, tokenstore.ErrNotFound) { return "", nil } + if !errors.Is(err, tokenstore.ErrMalformed) { + return "", fmt.Errorf("load token from keyring: %w", err) + } - // Legacy fallback: pre-shim entries stored the raw access token - // rather than a JSON-encoded TokenSet. raw, kerr := keyring.Get(s.inner.Service, baseURL) if errors.Is(kerr, keyring.ErrNotFound) { return "", nil @@ -82,13 +86,19 @@ func (s *Store) SaveTokens(profile string, t tokens.TokenSet) error { // LoadTokens implements tokenstore.Store, preserving the legacy bare-string // fallback path so users with pre-shim keyring entries don't appear logged // out after upgrading. +// +// Falls back to a bare-string read when the stored entry is malformed +// JSON (pre-shim entries stored the raw access token verbatim). Real +// keyring errors (transport, permission denied) propagate; only +// ErrMalformed triggers the fallback. ErrNotFound surfaces verbatim +// so the manager's "not logged in" branch still works. func (s *Store) LoadTokens(profile string) (tokens.TokenSet, error) { t, err := s.inner.LoadTokens(profile) if err == nil { return t, nil } - if !errors.Is(err, tokenstore.ErrNotFound) { - return tokens.TokenSet{}, err //nolint:wrapcheck // shim returns the lib error verbatim + if !errors.Is(err, tokenstore.ErrMalformed) { + return tokens.TokenSet{}, err //nolint:wrapcheck // surface ErrNotFound and real keyring errors verbatim } raw, kerr := keyring.Get(s.inner.Service, profile) diff --git a/cmd/entire/cli/auth/store_test.go b/cmd/entire/cli/auth/store_test.go index 8ccf122372..7eca339b14 100644 --- a/cmd/entire/cli/auth/store_test.go +++ b/cmd/entire/cli/auth/store_test.go @@ -132,6 +132,52 @@ func TestStoreDeleteToken_NotFoundIsNoop(t *testing.T) { } } +// TestStoreGetToken_LegacyBareStringFallback verifies that a pre-shim +// keyring entry (raw access-token string, not a JSON-encoded TokenSet) +// is still readable via GetToken after the shim landed. Without the +// fallback, pre-shim users would appear logged out after upgrading. +func TestStoreGetToken_LegacyBareStringFallback(t *testing.T) { + // Not parallel: go-keyring's mock provider uses an unprotected map. + const service = "test-legacy-getoken" + const profile = "https://legacy.example.com" + const bareToken = "ent_pre_shim_raw_token" + + if err := keyring.Set(service, profile, bareToken); err != nil { + t.Fatalf("seed keyring: %v", err) + } + + got, err := NewStoreWithService(service).GetToken(profile) + if err != nil { + t.Fatalf("GetToken: %v", err) + } + if got != bareToken { + t.Fatalf("GetToken() = %q, want bare token %q", got, bareToken) + } +} + +// TestStoreLoadTokens_LegacyBareStringFallback is the tokenstore.Store +// counterpart of the above. The tokenmanager calls LoadTokens, so this +// path is what determines whether the manager-backed code path +// recognises pre-shim entries. +func TestStoreLoadTokens_LegacyBareStringFallback(t *testing.T) { + // Not parallel: go-keyring's mock provider uses an unprotected map. + const service = "test-legacy-loadtokens" + const profile = "https://legacy.example.com" + const bareToken = "ent_pre_shim_raw_token" + + if err := keyring.Set(service, profile, bareToken); err != nil { + t.Fatalf("seed keyring: %v", err) + } + + got, err := NewStoreWithService(service).LoadTokens(profile) + if err != nil { + t.Fatalf("LoadTokens: %v", err) + } + if got.AccessToken != bareToken { + t.Fatalf("LoadTokens AccessToken = %q, want %q", got.AccessToken, bareToken) + } +} + func TestLookupCurrentToken(t *testing.T) { t.Setenv(api.BaseURLEnvVar, "http://localhost:8787") diff --git a/cmd/entire/cli/search_cmd.go b/cmd/entire/cli/search_cmd.go index a654c7ddba..04c5aca666 100644 --- a/cmd/entire/cli/search_cmd.go +++ b/cmd/entire/cli/search_cmd.go @@ -100,9 +100,10 @@ branch:, repo:, and repo:* to search all accessible repos.`, serviceURL := os.Getenv("ENTIRE_SEARCH_URL") if serviceURL == "" { - // Honour ENTIRE_API_BASE_URL: search lives on the data API - // host. Fall back to search.DefaultServiceURL only when no - // API base URL is configured (production default). + // Search lives on the data API host. Fall back to + // api.BaseURL() so ENTIRE_API_BASE_URL applies; the search + // package's DefaultServiceURL is only consulted by callers + // that bypass this entry point. serviceURL = api.BaseURL() } From 5173d30fc424c0bc5dae7f4355be46d51a98ac97 Mon Sep 17 00:00:00 2001 From: Alex Ong Date: Fri, 8 May 2026 16:43:01 +1000 Subject: [PATCH 13/21] auth: round-2 review fixes (DeleteCoreToken order, coverage, deprecations) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Behaviour: - tokenmanager.DeleteCoreToken now deletes the keyring entry first and only clears the in-memory exchange cache on success. Pre-emptively clearing would leave a window where the CLI thinks it's logged out but the keyring still hands out the core token to the next process. Surfaces the store error wrapped as "delete core token: ...". Coverage: - tokenmanager: regression tests for the cache-clear (and its inverse — cache survives a failed delete), cache-key independence for RequestedTokenType and Scope (matching the existing Audience test), malformed-JWT fallthrough on the audience shortcut (security contract — corrupt cores must not be returned verbatim), and surface-don't-collapse for non-ErrNotFound store errors. Adds an erroringStore test helper for failure-path tests. - tokenstore.Keyring: pin the ErrMalformed contract — malformed JSON, legacy bare-string entries, and bad expires_at all surface as ErrMalformed (wrapped), not ErrNotFound. cmd-side legacy fallback depends on this distinction. Deprecations / docs: - Mark cmd/entire/cli/auth.Store.SaveToken/GetToken/DeleteToken as // Deprecated so godoc and IDE hover steer new callers to the tokenstore.Store interface methods. Legacy direct-bearer call sites (login, logout, auth status/list/revoke) keep using them; login.go carries a //nolint:staticcheck with a pointer to the doc. - Document keyringService = "entire-cli" as immutable — renaming would orphan every existing user's stored credentials. Co-Authored-By: Claude Opus 4.7 (1M context) --- auth/tokenmanager/tokenmanager.go | 15 +- auth/tokenmanager/tokenmanager_test.go | 216 +++++++++++++++++++++++++ auth/tokenstore/keyring_test.go | 61 +++++++ cmd/entire/cli/auth/store.go | 17 ++ cmd/entire/cli/login.go | 5 +- 5 files changed, 310 insertions(+), 4 deletions(-) diff --git a/auth/tokenmanager/tokenmanager.go b/auth/tokenmanager/tokenmanager.go index f410eb203a..fef1ed1e9e 100644 --- a/auth/tokenmanager/tokenmanager.go +++ b/auth/tokenmanager/tokenmanager.go @@ -155,13 +155,22 @@ func (m *Manager) LookupCoreToken() (string, error) { return t.AccessToken, nil } -// DeleteCoreToken removes the stored core token (and any cached -// exchanges derived from it). +// DeleteCoreToken removes the stored core token and any cached +// exchanges derived from it. +// +// Order matters: the keyring delete runs first, then the in-memory +// cache is cleared. If the keyring delete fails the cache is left +// alone — clearing it pre-emptively would create a window where the +// CLI thinks it's logged out (no cache entries) but the keyring +// still hands out the core token to the next process. func (m *Manager) DeleteCoreToken() error { + if err := m.cfg.Store.DeleteTokens(m.cfg.Issuer); err != nil { + return fmt.Errorf("delete core token: %w", err) + } m.mu.Lock() m.cache = map[string]cachedToken{} m.mu.Unlock() - return m.cfg.Store.DeleteTokens(m.cfg.Issuer) //nolint:wrapcheck // backend error already names the operation + return nil } // TokenRequest customises one Token call. Empty fields fall back to diff --git a/auth/tokenmanager/tokenmanager_test.go b/auth/tokenmanager/tokenmanager_test.go index f3710d9656..5e7fcedae4 100644 --- a/auth/tokenmanager/tokenmanager_test.go +++ b/auth/tokenmanager/tokenmanager_test.go @@ -401,3 +401,219 @@ func TestSaveLookupDeleteCoreToken(t *testing.T) { t.Fatalf("after delete: got=%q err=%v", got, err) } } + +// TestDeleteCoreToken_ClearsExchangeCache exercises the cache-clear +// side of DeleteCoreToken. Without it, a subsequent Token() call after +// re-login could return a stale exchanged token derived from the old +// core token (currently safe because cacheKey includes the core token, +// but the manager promises a clean slate on delete and tests should +// pin that). +func TestDeleteCoreToken_ClearsExchangeCache(t *testing.T) { + t.Parallel() + core := makeJWTWithAudience(t, []string{testIssuer}) + store := newMemStore() + store.data[testIssuer] = tokens.TokenSet{AccessToken: core} + + var exchangeCalls int + m := newTestManager(t, store, func(_ context.Context, _ sts.ExchangeRequest) (*tokens.TokenSet, error) { + exchangeCalls++ + return &tokens.TokenSet{AccessToken: "exchanged-old", ExpiresAt: time.Now().Add(time.Hour)}, nil + }) + + // Prime the cache. + if _, err := m.TokenForResource(context.Background(), testResource); err != nil { + t.Fatalf("prime: %v", err) + } + if exchangeCalls != 1 { + t.Fatalf("prime exchanges = %d, want 1", exchangeCalls) + } + + if err := m.DeleteCoreToken(); err != nil { + t.Fatalf("DeleteCoreToken: %v", err) + } + + // Re-login with a fresh core token; the next Token() must not + // surface the stale cached entry. + freshCore := makeJWTWithAudience(t, []string{testIssuer}) + if err := m.SaveCoreToken(freshCore); err != nil { + t.Fatalf("SaveCoreToken: %v", err) + } + if _, err := m.TokenForResource(context.Background(), testResource); err != nil { + t.Fatalf("post-relogin: %v", err) + } + if exchangeCalls != 2 { + t.Fatalf("post-relogin exchanges = %d, want 2 (cache must miss after delete)", exchangeCalls) + } +} + +// TestDeleteCoreToken_PreservesCacheOnStoreFailure pins the order-of- +// operations: if Store.DeleteTokens fails, the in-memory cache must +// stay populated. Clearing pre-emptively would create a window where +// the CLI thinks it's logged out but the keyring still hands out the +// core token to the next process. +func TestDeleteCoreToken_PreservesCacheOnStoreFailure(t *testing.T) { + t.Parallel() + core := makeJWTWithAudience(t, []string{testIssuer}) + store := &erroringStore{inner: newMemStore(), deleteErr: errors.New("keyring locked")} + store.inner.data[testIssuer] = tokens.TokenSet{AccessToken: core} + + var exchangeCalls int + m := newTestManager(t, store, func(_ context.Context, _ sts.ExchangeRequest) (*tokens.TokenSet, error) { + exchangeCalls++ + return &tokens.TokenSet{AccessToken: "exchanged-1", ExpiresAt: time.Now().Add(time.Hour)}, nil + }) + + if _, err := m.TokenForResource(context.Background(), testResource); err != nil { + t.Fatalf("prime: %v", err) + } + if exchangeCalls != 1 { + t.Fatalf("prime exchanges = %d, want 1", exchangeCalls) + } + + if err := m.DeleteCoreToken(); err == nil { + t.Fatal("DeleteCoreToken must surface store error") + } + + // Cache must still hand out the previously exchanged token — + // no exchange call should fire on the second Token(). + if _, err := m.TokenForResource(context.Background(), testResource); err != nil { + t.Fatalf("post-failed-delete: %v", err) + } + if exchangeCalls != 1 { + t.Fatalf("post-failed-delete exchanges = %d, want 1 (cache must survive failed delete)", exchangeCalls) + } +} + +// erroringStore wraps memStore and lets a test force a specific store +// op to fail, so we can exercise failure paths without a flaky real +// keyring. +type erroringStore struct { + inner *memStore + loadErr error + deleteErr error +} + +func (s *erroringStore) SaveTokens(profile string, t tokens.TokenSet) error { + return s.inner.SaveTokens(profile, t) +} + +func (s *erroringStore) LoadTokens(profile string) (tokens.TokenSet, error) { + if s.loadErr != nil { + return tokens.TokenSet{}, s.loadErr + } + return s.inner.LoadTokens(profile) +} + +func (s *erroringStore) DeleteTokens(profile string) error { + if s.deleteErr != nil { + return s.deleteErr + } + return s.inner.DeleteTokens(profile) +} + +// TestToken_CacheKeyDistinguishesRequestedTokenType complements the +// existing audience-independence test: different requested_token_type +// URIs must not shadow each other in the cache. +func TestToken_CacheKeyDistinguishesRequestedTokenType(t *testing.T) { + t.Parallel() + core := makeJWTWithAudience(t, []string{testIssuer}) + store := newMemStore() + store.data[testIssuer] = tokens.TokenSet{AccessToken: core} + + var calls int + m := newTestManager(t, store, func(_ context.Context, req sts.ExchangeRequest) (*tokens.TokenSet, error) { + calls++ + return &tokens.TokenSet{AccessToken: "tok-" + req.RequestedTokenType}, nil + }) + + const otherType = "urn:ietf:params:oauth:token-type:jwt" + a, err := m.Token(context.Background(), TokenRequest{Resource: testResource}) + if err != nil { + t.Fatalf("Token(default type): %v", err) + } + b, err := m.Token(context.Background(), TokenRequest{Resource: testResource, RequestedTokenType: otherType}) + if err != nil { + t.Fatalf("Token(otherType): %v", err) + } + if a == b || calls != 2 { + t.Fatalf("expected separate cache entries per requested_token_type, got a=%q b=%q calls=%d", a, b, calls) + } +} + +// TestToken_CacheKeyDistinguishesScope same shape, locks scope into +// the cache key. +func TestToken_CacheKeyDistinguishesScope(t *testing.T) { + t.Parallel() + core := makeJWTWithAudience(t, []string{testIssuer}) + store := newMemStore() + store.data[testIssuer] = tokens.TokenSet{AccessToken: core} + + var calls int + m := newTestManager(t, store, func(_ context.Context, req sts.ExchangeRequest) (*tokens.TokenSet, error) { + calls++ + return &tokens.TokenSet{AccessToken: "tok-" + req.Scope}, nil + }) + + a, err := m.Token(context.Background(), TokenRequest{Resource: testResource, Scope: "scope-a"}) + if err != nil { + t.Fatalf("Token(scope-a): %v", err) + } + b, err := m.Token(context.Background(), TokenRequest{Resource: testResource, Scope: "scope-b"}) + if err != nil { + t.Fatalf("Token(scope-b): %v", err) + } + if a == b || calls != 2 { + t.Fatalf("expected separate cache entries per scope, got a=%q b=%q calls=%d", a, b, calls) + } +} + +// TestCoreTokenAudienceShortcut_FallsThroughOnMalformedJWT pins a +// security-sensitive contract: a non-JWT (or malformed JWT) core token +// must NOT be silently treated as audience-matching the resource. +// Otherwise a corrupt/forged-but-undecodeable token could bypass the +// exchange path. The "fallthrough to exchange" behaviour is what makes +// signature-skipping ParseClaims safe here. +func TestCoreTokenAudienceShortcut_FallsThroughOnMalformedJWT(t *testing.T) { + t.Parallel() + store := newMemStore() + store.data[testIssuer] = tokens.TokenSet{AccessToken: "not-a-jwt"} + + var exchangeCalls int + m := newTestManager(t, store, func(_ context.Context, _ sts.ExchangeRequest) (*tokens.TokenSet, error) { + exchangeCalls++ + return &tokens.TokenSet{AccessToken: "exchanged"}, nil + }) + + got, err := m.TokenForResource(context.Background(), testResource) + if err != nil { + t.Fatalf("TokenForResource: %v", err) + } + if got == "not-a-jwt" { + t.Fatal("malformed core token must not be returned verbatim — exchange path must fire") + } + if exchangeCalls != 1 { + t.Fatalf("exchanges = %d, want 1 (exchange must run on unparseable JWT)", exchangeCalls) + } +} + +// TestToken_StoreErrorSurfacesNotAsErrNotLoggedIn pins the contract +// that a non-ErrNotFound store error is *not* collapsed to +// ErrNotLoggedIn. Doing so would mask real keyring failures behind a +// "run entire login" message that does nothing. +func TestToken_StoreErrorSurfacesNotAsErrNotLoggedIn(t *testing.T) { + t.Parallel() + store := &erroringStore{inner: newMemStore(), loadErr: errors.New("keyring permission denied")} + + m := newTestManager(t, store, nil) + + _, err := m.TokenForResource(context.Background(), testResource) + if err == nil { + t.Fatal("expected store error to surface") + } + if errors.Is(err, ErrNotLoggedIn) { + t.Fatalf("err = %v, must NOT be ErrNotLoggedIn (real failures must not be silenced)", err) + } + if !strings.Contains(err.Error(), "keyring permission denied") { + t.Fatalf("err = %v, want underlying store error", err) + } +} diff --git a/auth/tokenstore/keyring_test.go b/auth/tokenstore/keyring_test.go index 4cc7a53925..c30ea5469f 100644 --- a/auth/tokenstore/keyring_test.go +++ b/auth/tokenstore/keyring_test.go @@ -134,3 +134,64 @@ func TestKeyring_RoundTrip_NoExpiry(t *testing.T) { t.Fatalf("ExpiresAt = %v, want zero", got.ExpiresAt) } } + +// TestKeyring_LoadTokens_MalformedJSONReturnsErrMalformed pins the +// contract that decode failures surface as ErrMalformed (wrapped), not +// ErrNotFound. Callers (e.g. cmd/entire/cli/auth.Store) use this to +// distinguish "no entry" from "entry exists but can't be parsed", +// which is the hook for the legacy bare-string upgrade fallback. +func TestKeyring_LoadTokens_MalformedJSONReturnsErrMalformed(t *testing.T) { + const service = "test-malformed" + const profile = "https://example.com" + + if err := keyring.Set(service, profile, "{not-valid-json"); err != nil { + t.Fatalf("seed keyring: %v", err) + } + + _, err := NewKeyring(service).LoadTokens(profile) + if err == nil { + t.Fatal("expected error for malformed JSON") + } + if errors.Is(err, ErrNotFound) { + t.Fatalf("err = %v, must NOT be ErrNotFound (entry exists, just malformed)", err) + } + if !errors.Is(err, ErrMalformed) { + t.Fatalf("err = %v, want ErrMalformed sentinel for callers to detect legacy entries", err) + } +} + +// TestKeyring_LoadTokens_BareStringReturnsErrMalformed is the contract +// the cmd-side legacy fallback depends on: a pre-shim raw access-token +// entry must surface as ErrMalformed so the shim knows to fall through +// to a bare-string read. +func TestKeyring_LoadTokens_BareStringReturnsErrMalformed(t *testing.T) { + const service = "test-barestring" + const profile = "https://example.com" + + if err := keyring.Set(service, profile, "ent_pre_shim_raw_token"); err != nil { + t.Fatalf("seed keyring: %v", err) + } + + _, err := NewKeyring(service).LoadTokens(profile) + if !errors.Is(err, ErrMalformed) { + t.Fatalf("err = %v, want ErrMalformed", err) + } +} + +// TestKeyring_LoadTokens_BadExpiresAtReturnsErrMalformed covers the +// other branch in decodeTokenSet: well-formed JSON with a malformed +// expires_at also surfaces as ErrMalformed so the same fallback +// machinery applies. +func TestKeyring_LoadTokens_BadExpiresAtReturnsErrMalformed(t *testing.T) { + const service = "test-bad-expires" + const profile = "https://example.com" + + if err := keyring.Set(service, profile, `{"access_token":"a","expires_at":"not-a-date"}`); err != nil { + t.Fatalf("seed keyring: %v", err) + } + + _, err := NewKeyring(service).LoadTokens(profile) + if !errors.Is(err, ErrMalformed) { + t.Fatalf("err = %v, want ErrMalformed", err) + } +} diff --git a/cmd/entire/cli/auth/store.go b/cmd/entire/cli/auth/store.go index 5d344bb481..1120bea86d 100644 --- a/cmd/entire/cli/auth/store.go +++ b/cmd/entire/cli/auth/store.go @@ -11,6 +11,10 @@ import ( "github.com/zalando/go-keyring" ) +// keyringService is the OS-keyring service name for this CLI. Renaming +// would orphan every existing user's stored credentials — they'd appear +// logged out until they ran `entire login` again. Don't change this +// without a migration path. const keyringService = "entire-cli" // Store manages CLI authentication tokens in the OS keyring. @@ -35,6 +39,11 @@ func NewStoreWithService(service string) *Store { } // SaveToken persists an access token for the given base URL. +// +// Deprecated: prefer SaveTokens (the tokenstore.Store interface method) +// for new callers. SaveToken is kept for the legacy direct-bearer call +// sites (login, logout, auth status/list/revoke) that don't go through +// the tokenmanager. func (s *Store) SaveToken(baseURL, token string) error { token = strings.TrimSpace(token) if token == "" { @@ -51,6 +60,11 @@ func (s *Store) SaveToken(baseURL, token string) error { // rather than a JSON-encoded TokenSet. Real keyring errors (transport, // permission denied) propagate; only ErrNotFound and ErrMalformed // trigger the fallback. +// +// Deprecated: prefer LoadTokens (the tokenstore.Store interface method) +// for new callers — it returns the full TokenSet so refresh tokens and +// expiry survive the round trip. GetToken is retained for the direct- +// bearer call sites that only need the access token string. func (s *Store) GetToken(baseURL string) (string, error) { t, err := s.inner.LoadTokens(baseURL) if err == nil { @@ -74,6 +88,9 @@ func (s *Store) GetToken(baseURL string) (string, error) { } // DeleteToken removes a stored token for the given base URL. +// +// Deprecated: prefer DeleteTokens (the tokenstore.Store interface +// method). DeleteToken is retained for direct-bearer call sites. func (s *Store) DeleteToken(baseURL string) error { return s.inner.DeleteTokens(baseURL) //nolint:wrapcheck // shim returns the lib error verbatim } diff --git a/cmd/entire/cli/login.go b/cmd/entire/cli/login.go index de5ad84f8c..0de1bdeeb4 100644 --- a/cmd/entire/cli/login.go +++ b/cmd/entire/cli/login.go @@ -86,7 +86,10 @@ func runLogin(ctx context.Context, outW, errW io.Writer, client deviceAuthClient store := auth.NewStore() - if err := store.SaveToken(client.BaseURL(), token); err != nil { + // Login deliberately uses the legacy SaveToken (string, string) + // surface — we only have an access-token string at this point; + // the deviceflow client doesn't return a TokenSet here. + if err := store.SaveToken(client.BaseURL(), token); err != nil { //nolint:staticcheck // SA1019: legacy direct-bearer call site, see Store.SaveToken doc return fmt.Errorf("save auth token: %w", err) } From f33b79dfcbe431176d465b04fc8a095beff8c2c9 Mon Sep 17 00:00:00 2001 From: Alex Ong Date: Fri, 8 May 2026 17:36:42 +1000 Subject: [PATCH 14/21] Fix token exchange resource routing --- auth/tokenmanager/tokenmanager.go | 3 +- auth/tokenmanager/tokenmanager_test.go | 49 ++++++++++++++++++++++++++ cmd/entire/cli/search_cmd.go | 12 ++++++- cmd/entire/cli/search_cmd_test.go | 25 +++++++++++++ 4 files changed, 87 insertions(+), 2 deletions(-) diff --git a/auth/tokenmanager/tokenmanager.go b/auth/tokenmanager/tokenmanager.go index fef1ed1e9e..40501a7637 100644 --- a/auth/tokenmanager/tokenmanager.go +++ b/auth/tokenmanager/tokenmanager.go @@ -229,7 +229,7 @@ func (m *Manager) Token(ctx context.Context, req TokenRequest) (string, error) { if req.Audience == "" && m.cfg.Issuer == req.Resource { return core, nil } - if coreTokenAudienceIncludes(core, req.Resource) { + if req.Audience == "" && coreTokenAudienceIncludes(core, req.Resource) { return core, nil } @@ -326,6 +326,7 @@ func (m *Manager) runExchange(ctx context.Context, coreToken string, req TokenRe SubjectTokenType: sts.SubjectTokenTypeJWT, RequestedTokenType: req.RequestedTokenType, Audience: req.Audience, + Resource: req.Resource, Scope: req.Scope, // Public-client identification per RFC 6749 §2.3.1 / §3.2.1. // Carried via Extra because the sts package is provider-agnostic. diff --git a/auth/tokenmanager/tokenmanager_test.go b/auth/tokenmanager/tokenmanager_test.go index 5e7fcedae4..916280c7d1 100644 --- a/auth/tokenmanager/tokenmanager_test.go +++ b/auth/tokenmanager/tokenmanager_test.go @@ -204,6 +204,34 @@ func TestToken_AudienceShortcut(t *testing.T) { } } +func TestToken_ExplicitAudienceBypassesAudienceShortcut(t *testing.T) { + t.Parallel() + core := makeJWTWithAudience(t, []string{testIssuer, testResource}) + store := newMemStore() + store.data[testIssuer] = tokens.TokenSet{AccessToken: core} + + const requestedAudience = "https://tokens.example.com" + var got sts.ExchangeRequest + var calls int + m := newTestManager(t, store, func(_ context.Context, req sts.ExchangeRequest) (*tokens.TokenSet, error) { + calls++ + got = req + return &tokens.TokenSet{AccessToken: "exchanged"}, nil + }) + + token, err := m.Token(context.Background(), TokenRequest{Resource: testResource, Audience: requestedAudience}) + if err != nil { + t.Fatalf("Token: %v", err) + } + + if token != "exchanged" || calls != 1 { + t.Fatalf("Token returned %q with %d exchange calls, want exchanged token from one exchange", token, calls) + } + if got.Audience != requestedAudience { + t.Fatalf("exchange Audience = %q, want %q", got.Audience, requestedAudience) + } +} + func TestToken_ExchangesAndCaches(t *testing.T) { t.Parallel() core := makeJWTWithAudience(t, []string{testIssuer}) @@ -251,6 +279,27 @@ func TestToken_ExchangesAndCaches(t *testing.T) { } } +func TestToken_ExchangeIncludesResource(t *testing.T) { + t.Parallel() + core := makeJWTWithAudience(t, []string{testIssuer}) + store := newMemStore() + store.data[testIssuer] = tokens.TokenSet{AccessToken: core} + + var got sts.ExchangeRequest + m := newTestManager(t, store, func(_ context.Context, req sts.ExchangeRequest) (*tokens.TokenSet, error) { + got = req + return &tokens.TokenSet{AccessToken: "exchanged"}, nil + }) + + if _, err := m.TokenForResource(context.Background(), testResource); err != nil { + t.Fatalf("TokenForResource: %v", err) + } + + if got.Resource != testResource { + t.Fatalf("exchange Resource = %q, want %q", got.Resource, testResource) + } +} + func TestToken_OverridesAudienceAndType(t *testing.T) { t.Parallel() core := makeJWTWithAudience(t, []string{testIssuer}) diff --git a/cmd/entire/cli/search_cmd.go b/cmd/entire/cli/search_cmd.go index 04c5aca666..55b2cbf00d 100644 --- a/cmd/entire/cli/search_cmd.go +++ b/cmd/entire/cli/search_cmd.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "io" + "net/url" "os" "strings" @@ -112,7 +113,7 @@ branch:, repo:, and repo:* to search all accessible repos.`, // carries the data-API audience rather than the auth-host one; // single-host setups hit the same-host shortcut and return the // core token unchanged. - ghToken, err := auth.TokenForResource(ctx, serviceURL) + ghToken, err := auth.TokenForResource(ctx, searchTokenResourceURL(serviceURL)) if errors.Is(err, auth.ErrNotLoggedIn) { return errors.New("not authenticated. Run 'entire login' to authenticate") } @@ -209,6 +210,15 @@ branch:, repo:, and repo:* to search all accessible repos.`, return cmd } +func searchTokenResourceURL(serviceURL string) string { + raw := strings.TrimSpace(serviceURL) + u, err := url.Parse(raw) + if err != nil || u.Scheme == "" || u.Host == "" { + return raw + } + return (&url.URL{Scheme: u.Scheme, Host: u.Host}).String() +} + // completeRepoFlag returns shell-completion suggestions for the search // command's --repo flag. "*" is always offered so the wildcard works // regardless of auth state. Errors are swallowed (rather than surfaced via diff --git a/cmd/entire/cli/search_cmd_test.go b/cmd/entire/cli/search_cmd_test.go index 484627655c..e0f3a7e6ba 100644 --- a/cmd/entire/cli/search_cmd_test.go +++ b/cmd/entire/cli/search_cmd_test.go @@ -76,3 +76,28 @@ func TestWriteSearchJSON_ZeroLimitFallsBackToDefaultPageSize(t *testing.T) { t.Fatalf("output missing total_pages:\n%s", output) } } + +func TestSearchTokenResourceURL_NormalizesToOrigin(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + serviceURL string + want string + }{ + {"plain origin", "https://entire.io", "https://entire.io"}, + {"trailing slash", "https://entire.io/", "https://entire.io"}, + {"pathful search URL", "https://entire.io/custom/search", "https://entire.io"}, + {"localhost port", "http://localhost:8787/search", "http://localhost:8787"}, + {"parse fallback", "://not-a-url", "://not-a-url"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := searchTokenResourceURL(tt.serviceURL); got != tt.want { + t.Fatalf("searchTokenResourceURL(%q) = %q, want %q", tt.serviceURL, got, tt.want) + } + }) + } +} From d8ccd264a97366f685d088e890a576434a0a3bf2 Mon Sep 17 00:00:00 2001 From: Alex Ong Date: Fri, 8 May 2026 17:44:48 +1000 Subject: [PATCH 15/21] Make auth tests independent of provider env --- cmd/entire/cli/api/auth_tokens_test.go | 6 +++--- cmd/entire/cli/auth/provider_test.go | 2 ++ cmd/entire/cli/auth/store_test.go | 1 + cmd/entire/cli/integration_test/login_test.go | 2 ++ 4 files changed, 8 insertions(+), 3 deletions(-) diff --git a/cmd/entire/cli/api/auth_tokens_test.go b/cmd/entire/cli/api/auth_tokens_test.go index c23d5155a6..12a9c4628f 100644 --- a/cmd/entire/cli/api/auth_tokens_test.go +++ b/cmd/entire/cli/api/auth_tokens_test.go @@ -10,7 +10,7 @@ import ( ) func TestClient_RevokeCurrentToken_SendsDeleteWithBearer(t *testing.T) { - t.Parallel() + t.Setenv(authTokensProviderVersionEnvVar, "") var gotMethod, gotPath, gotAuth string @@ -71,7 +71,7 @@ func TestClient_RevokeCurrentToken_ReturnsHTTPErrorOn401(t *testing.T) { } func TestClient_ListTokens_DecodesResponse(t *testing.T) { - t.Parallel() + t.Setenv(authTokensProviderVersionEnvVar, "") var gotMethod, gotPath, gotAuth string @@ -142,7 +142,7 @@ func TestClient_ListTokens_ReturnsHTTPErrorOn401(t *testing.T) { } func TestClient_RevokeToken_SendsDeleteWithEscapedID(t *testing.T) { - t.Parallel() + t.Setenv(authTokensProviderVersionEnvVar, "") var gotMethod, gotEscapedPath, gotDecodedPath string diff --git a/cmd/entire/cli/auth/provider_test.go b/cmd/entire/cli/auth/provider_test.go index ddf7550aa9..a67409aab2 100644 --- a/cmd/entire/cli/auth/provider_test.go +++ b/cmd/entire/cli/auth/provider_test.go @@ -78,6 +78,7 @@ func TestCurrentProvider_TrimsWhitespace(t *testing.T) { func TestNewClient_HonoursProviderVersion(t *testing.T) { t.Setenv(api.BaseURLEnvVar, "https://example.test") + t.Setenv(api.AuthBaseURLEnvVar, "") t.Setenv(ProviderVersionEnvVar, "v2") c := NewClient(&http.Client{}) @@ -97,6 +98,7 @@ func TestNewClient_HonoursProviderVersion(t *testing.T) { func TestNewClient_DefaultsToV1(t *testing.T) { t.Setenv(api.BaseURLEnvVar, "https://example.test") + t.Setenv(api.AuthBaseURLEnvVar, "") t.Setenv(ProviderVersionEnvVar, "") c := NewClient(nil) diff --git a/cmd/entire/cli/auth/store_test.go b/cmd/entire/cli/auth/store_test.go index 7eca339b14..5bbfd93fbe 100644 --- a/cmd/entire/cli/auth/store_test.go +++ b/cmd/entire/cli/auth/store_test.go @@ -180,6 +180,7 @@ func TestStoreLoadTokens_LegacyBareStringFallback(t *testing.T) { func TestLookupCurrentToken(t *testing.T) { t.Setenv(api.BaseURLEnvVar, "http://localhost:8787") + t.Setenv(api.AuthBaseURLEnvVar, "") store := NewStore() if err := store.SaveToken("http://localhost:8787", "local-token"); err != nil { diff --git a/cmd/entire/cli/integration_test/login_test.go b/cmd/entire/cli/integration_test/login_test.go index 23932049d9..e64ea4151f 100644 --- a/cmd/entire/cli/integration_test/login_test.go +++ b/cmd/entire/cli/integration_test/login_test.go @@ -202,6 +202,8 @@ func runLoginProcess(t *testing.T, apiBaseURL string) *loginProcess { "ENTIRE_TEST_GEMINI_PROJECT_DIR="+env.GeminiProjectDir, "ENTIRE_TEST_OPENCODE_PROJECT_DIR="+env.OpenCodeProjectDir, "ENTIRE_API_BASE_URL="+apiBaseURL, + "ENTIRE_AUTH_BASE_URL="+apiBaseURL, + "ENTIRE_AUTH_PROVIDER_VERSION=v1", ) stdoutPipe, err := cmd.StdoutPipe() From a9aeb9e6fdee4bbedb19dce8a4d116db3feaa7f3 Mon Sep 17 00:00:00 2001 From: Alex Ong Date: Fri, 8 May 2026 18:31:17 +1000 Subject: [PATCH 16/21] dispatch: route bearer through tokenmanager + document the auth pattern MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `entire dispatch` was sending the raw core token (audience = auth host) to the data API and getting back a 401 that cloud.go mapped to "dispatch requires login — run \`entire login\`" — misleading on split-host deployments where the user IS logged in but with the wrong-audience bearer. Same trap search hit before; same fix. - mode_cloud.go now resolves the bearer via auth.TokenForResource so the tokenmanager's same-host shortcut / JWT-aud shortcut / RFC 8693 exchange all apply. ErrNotLoggedIn is mapped to the friendly "dispatch requires login" message; other errors surface verbatim. - mode_local.go grows a lookupResourceToken seam (defaulted to auth.TokenForResource) for test injection; the existing lookupCurrentToken seam is retained for back-compat with tests that haven't migrated. - Test stubs (stubCloudDispatchAuth + per-test cleanups) updated to swap both seams so the assertions still cover what they used to. Sweep confirmed no other data-API caller bypasses the manager: search, trail, recap, dispatch_wizard, and activity all flow through NewAuthenticatedAPIClient → tokenmanager. Auth-host commands (auth list/revoke/status, logout) correctly retain LookupCurrentToken since they need the auth-audience bearer. Docs: - New CLAUDE.md "Auth and token resolution" section flags the two blessed entry points (NewAuthenticatedAPIClient, TokenForResource), the resolution rules, and that LookupCurrentToken is for auth-host callers only. - New auth/README.md positions the library as shareable across internal CLIs: subpackage map, embedding checklist, design principles (no globals, no env-var reads, provider-agnostic), non-goals (OIDC discovery, server-side, code-flow PKCE), quick- start snippets for login / data-API call / logout. - search.Config.GitHubToken doc now points at TokenForResource (was LookupCurrentToken). Co-Authored-By: Claude Opus 4.7 (1M context) --- CLAUDE.md | 26 +++++ auth/README.md | 120 +++++++++++++++++++++ cmd/entire/cli/dispatch/dispatch_test.go | 10 +- cmd/entire/cli/dispatch/mode_cloud.go | 22 ++-- cmd/entire/cli/dispatch/mode_cloud_test.go | 19 +++- cmd/entire/cli/dispatch/mode_local.go | 14 ++- cmd/entire/cli/search/search.go | 7 +- 7 files changed, 202 insertions(+), 16 deletions(-) create mode 100644 auth/README.md diff --git a/CLAUDE.md b/CLAUDE.md index 93c02c27d3..9b42a9385d 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -339,6 +339,32 @@ if settings.IsSummarizeEnabled() { - `settings/settings.go` - `EntireSettings` struct, `Load()`, and helper methods - `config.go` - Higher-level config functions that use settings (for `cli` package consumers) +### Auth and token resolution + +The CLI uses a shareable auth library at `auth/` (subpackages: `deviceflow`, `sts`, `tokens`, `tokenstore`, `tokenmanager`). The `cmd/entire/cli/auth/` package wraps it with entire-specific config (provider table, keyring service name) and exposes the call surface that command code should use. + +**For every data-API call, get the bearer through one of these two entry points — never read the keyring directly:** + +```go +// Preferred — for callers that need an *api.Client. +client, err := cli.NewAuthenticatedAPIClient(ctx, insecureHTTP) + +// Direct — for callers that hand the bearer to a non-api package +// (e.g. the search service client, dispatch CloudClient). +bearer, err := auth.TokenForResource(ctx, serviceURL) +``` + +Both route through `tokenmanager.Manager.Token`, which: + +1. Returns `auth.ErrNotLoggedIn` when the keyring is empty. +2. Hits the same-host shortcut when `api.AuthBaseURL() == resourceURL`. +3. Hits the JWT-`aud`-includes-resource shortcut when the core token is already valid for `resourceURL` (caller didn't request an explicit `Audience`). +4. Otherwise runs an RFC 8693 token exchange against the auth host's STS endpoint and caches the result per `(core token, resource, audience, requested-token-type, scope)`. + +**Don't use `auth.LookupCurrentToken` for data-API calls.** It returns the raw core token (audience = auth host). On split-host deployments (`ENTIRE_AUTH_BASE_URL` set) the data API will reject it with 401. `LookupCurrentToken` is correct only for auth-host-targeted commands (`auth list/revoke/status`, `logout`) — they intentionally hold the auth-audience bearer. + +**Test injection:** at the cmd layer use `auth.SetManagerForTest(t, mgr)` with a `tokenmanager.Manager` constructed via `tokenmanager.New(Config{Exchange: ...})`. The manager's `Config.Exchange` and `Config.Now` fields are test seams — production callers leave them nil. + ### Logging vs User Output - **Internal/debug logging**: Use `logging.Debug/Info/Warn/Error(ctx, msg, attrs...)` from `cmd/entire/cli/logging/`. Writes to `.entire/logs/`. diff --git a/auth/README.md b/auth/README.md new file mode 100644 index 0000000000..7619582661 --- /dev/null +++ b/auth/README.md @@ -0,0 +1,120 @@ +# auth — shareable OAuth 2.0 client library for internal CLIs + +Provider-agnostic Go library for CLIs that authenticate end-users via OAuth 2.0 device flow (RFC 8628), present resource-scoped bearer tokens to data APIs, and (when the auth host and data API live on different origins) exchange tokens via RFC 8693 STS. + +The library has no global state, no env-var reads, and no implicit URLs. Every endpoint, identifier, and default value is supplied by the embedding CLI through a `Config` struct. That keeps it usable by any CLI in the org without forking. + +## Subpackages + +| Package | What it does | +|---|---| +| [`deviceflow`](./deviceflow/) | RFC 8628 OAuth 2.0 Device Authorization Grant client. Polls the token endpoint, surfaces RFC 8628 §3.5 error codes (`authorization_pending`, `slow_down`, `access_denied`, `expired_token`, `invalid_grant`) as Go sentinels with optional `error_description`. | +| [`sts`](./sts/) | RFC 8693 OAuth 2.0 Token Exchange client. Provider-agnostic — caller supplies endpoint path, `subject_token_type`, `requested_token_type`, optional `audience` / `resource` / `scope`, and any provider-specific `Extra` form fields (e.g. `client_id`). | +| [`tokens`](./tokens/) | `TokenSet` value type plus unverified JWT claim parsing. The package never validates signatures — that's the issuing server's responsibility. CLIs use `Claims` for routing decisions (which issuer, which audience) and UX (display the principal handle), not as a security boundary. | +| [`tokenstore`](./tokenstore/) | `Store` interface for token persistence + `Keyring` reference impl backed by `github.com/zalando/go-keyring`. Each CLI passes its own service name so credentials are isolated across CLIs sharing this library. Returns `ErrNotFound` for unknown profiles and `ErrMalformed` (wrapped) when a stored entry exists but can't be decoded — used by upgrade fallbacks. | +| [`tokenmanager`](./tokenmanager/) | Orchestration: stores the device-flow core token, runs RFC 8693 exchanges when needed to obtain resource-scoped bearers, caches the results until expiry, and short-circuits when no exchange is needed (same-host or core-token's `aud` already covers the resource). Most CLIs only need to interact with this package directly. | + +Internal helper: + +| Package | What it does | +|---|---| +| [`internal/oauthhttp`](./internal/oauthhttp/) | Shared HTTP body-reading + JSON-decoding helpers. Detects HTML responses (captive portal / proxy intercept) and surfaces them as actionable errors instead of unmarshal failures. Not exported. | + +## Quick start + +The typical embedding CLI does roughly this at startup: + +```go +import ( + "github.com/entireio/cli/auth/deviceflow" + "github.com/entireio/cli/auth/tokenmanager" + "github.com/entireio/cli/auth/tokenstore" +) + +const ( + issuer = "https://auth.example.com" // auth host base URL + clientID = "my-cli" // public OAuth client_id +) + +store := tokenstore.NewKeyring("my-cli") // service name = your CLI's name + +// One Manager per CLI process. Construct from your CLI's identity. +mgr, err := tokenmanager.New(tokenmanager.Config{ + Issuer: issuer, + ClientID: clientID, + STSPath: "/oauth/token", // RFC 8693 endpoint; usually the OAuth token endpoint + Store: store, + Scope: "cli", +}) +if err != nil { /* misconfiguration */ } +``` + +### Login + +```go +dfc := &deviceflow.Client{ + BaseURL: issuer, + ClientID: clientID, + Scope: "cli", + DeviceCodePath: "/oauth/device/code", + TokenPath: "/oauth/token", +} + +dc, err := dfc.StartDeviceAuth(ctx) +// ... show dc.UserCode + dc.VerificationURI to user, then poll ... +ts, err := dfc.PollDeviceAuth(ctx, dc.DeviceCode) +if err != nil { /* surface RFC 8628 §3.5 sentinel as needed */ } + +if err := mgr.SaveCoreToken(ts.AccessToken); err != nil { /* keyring failed */ } +``` + +### Calling a data API + +```go +bearer, err := mgr.TokenForResource(ctx, "https://api.example.com") +if errors.Is(err, tokenmanager.ErrNotLoggedIn) { + // prompt user to run `mycli login` +} +// bearer is valid for https://api.example.com +req.Header.Set("Authorization", "Bearer "+bearer) +``` + +The manager picks the right strategy automatically: + +- Same-host (`Issuer == resource`): hands back the core token verbatim. +- JWT-`aud`-includes shortcut: same, when the core token's audience already covers the resource (e.g. multi-audience tokens). +- Otherwise: runs an RFC 8693 exchange against `Issuer + STSPath`, caches the exchanged token by `(core, resource, audience, requested_token_type, scope)` until expiry. + +### Logout + +```go +if err := mgr.DeleteCoreToken(); err != nil { /* keyring failed */ } +``` + +Deletes the keyring entry first; only clears the in-memory exchange cache on success, so a failed delete doesn't leave the CLI thinking it's logged out while the keyring still holds the token. + +## Design principles + +- **No globals, no env-var reads, no implicit URLs.** Everything ships through `Config`. The library should compile and run identically inside any CLI. +- **Provider-agnostic.** `deviceflow.Client` and `sts.Client` are field-bag structs; neither knows about your provider's endpoint paths or token-type URIs. Pass them in. +- **Bearer-presenter, not bearer-validator.** This library is for CLIs that *receive* tokens from an auth server and *present* them to a resource server. JWT signature verification is intentionally not done — the resource server validates. `tokens.ParseClaims` is documented as unverified and used only for routing decisions. +- **Per-CLI keyring isolation.** Each CLI passes a unique service name to `tokenstore.NewKeyring`. OS keyrings key by `(service, account)`, so different CLIs naturally get separate credential stores. +- **Caller controls the wire shape.** Default values (RFC 8693 `requested_token_type`, `scope`, audience-empty) live in the embedding CLI's wiring, not in this library. + +## Embedding checklist for a new CLI + +1. Pick a stable service name for `tokenstore.NewKeyring(...)`. **Don't change it later** — renaming orphans every existing user's stored credentials. +2. Pick a `client_id` that the auth server recognises. +3. Decide your `STSPath`: typically the OAuth token endpoint per RFC 8693 convention, or a dedicated path if your auth server exposes one. +4. Construct the `tokenmanager.Manager` once at startup; pass it to your data-API call sites. +5. For multi-environment users (regions, staging), key the keyring by issuer URL — `Manager.Issuer()` returns the configured value. + +## Non-goals + +- **OIDC discovery / ID tokens.** This library is OAuth 2.0 only. If you need OIDC `/.well-known/openid-configuration` + ID-token verification, layer `coreos/go-oidc` on top. +- **PKCE / authorization code flow.** Device flow only; CLIs almost never need code flow. +- **Server-side OIDC.** If you're building an *issuer*, look at `zitadel/oidc`'s `op` package. + +## Status + +Used in production by [`entireio/cli`](https://github.com/entireio/cli). Open to additional internal CLI consumers — file an issue if you hit a gap. diff --git a/cmd/entire/cli/dispatch/dispatch_test.go b/cmd/entire/cli/dispatch/dispatch_test.go index 431e66d2dd..5427141342 100644 --- a/cmd/entire/cli/dispatch/dispatch_test.go +++ b/cmd/entire/cli/dispatch/dispatch_test.go @@ -4,13 +4,17 @@ import ( "context" "strings" "testing" + + "github.com/entireio/cli/cmd/entire/cli/auth" ) func TestRun_ServerAllowsRepos(t *testing.T) { - oldLookup := lookupCurrentToken - lookupCurrentToken = func() (string, error) { return "", nil } + oldResource := lookupResourceToken + lookupResourceToken = func(_ context.Context, _ string) (string, error) { + return "", auth.ErrNotLoggedIn + } t.Cleanup(func() { - lookupCurrentToken = oldLookup + lookupResourceToken = oldResource }) _, err := Run(context.Background(), Options{ diff --git a/cmd/entire/cli/dispatch/mode_cloud.go b/cmd/entire/cli/dispatch/mode_cloud.go index 3c5fa93b3e..21aae53a9a 100644 --- a/cmd/entire/cli/dispatch/mode_cloud.go +++ b/cmd/entire/cli/dispatch/mode_cloud.go @@ -8,6 +8,7 @@ import ( "time" "github.com/entireio/cli/cmd/entire/cli/api" + "github.com/entireio/cli/cmd/entire/cli/auth" "github.com/entireio/cli/cmd/entire/cli/paths" "github.com/go-git/go-git/v6" ) @@ -19,14 +20,6 @@ import ( var requireSecureDispatchURL = api.RequireSecureURL func runServer(ctx context.Context, opts Options) (*Dispatch, error) { - token, err := lookupCurrentToken() - if err != nil { - return nil, fmt.Errorf("reading credentials: %w", err) - } - if token == "" { - return nil, errors.New("dispatch requires login — run `entire login`") - } - baseURL := api.BaseURL() if !opts.InsecureHTTPAuth { if err := requireSecureDispatchURL(baseURL); err != nil { @@ -34,6 +27,19 @@ func runServer(ctx context.Context, opts Options) (*Dispatch, error) { } } + // Resolve a bearer scoped to the dispatch service host. In split-host + // deployments the tokenmanager runs an RFC 8693 exchange so the + // bearer carries the data-API audience rather than the auth-host + // one; single-host setups hit the same-host shortcut and return the + // core token unchanged. + token, err := lookupResourceToken(ctx, baseURL) + if errors.Is(err, auth.ErrNotLoggedIn) { + return nil, errors.New("dispatch requires login — run `entire login`") + } + if err != nil { + return nil, fmt.Errorf("reading credentials: %w", err) + } + now := nowUTC() sinceInput := strings.TrimSpace(opts.Since) if sinceInput == "" { diff --git a/cmd/entire/cli/dispatch/mode_cloud_test.go b/cmd/entire/cli/dispatch/mode_cloud_test.go index c8af1b42c4..521f519e1a 100644 --- a/cmd/entire/cli/dispatch/mode_cloud_test.go +++ b/cmd/entire/cli/dispatch/mode_cloud_test.go @@ -20,11 +20,16 @@ import ( func stubCloudDispatchAuth(t *testing.T) { t.Helper() oldLookup := lookupCurrentToken + oldResource := lookupResourceToken oldRequire := requireSecureDispatchURL lookupCurrentToken = func() (string, error) { return testCloudDispatchToken, nil } + lookupResourceToken = func(_ context.Context, _ string) (string, error) { + return testCloudDispatchToken, nil + } requireSecureDispatchURL = func(string) error { return nil } t.Cleanup(func() { lookupCurrentToken = oldLookup + lookupResourceToken = oldResource requireSecureDispatchURL = oldRequire }) } @@ -368,11 +373,16 @@ func TestServerMode_InsecureHTTPAuthBypassesSecureURLCheck(t *testing.T) { defer mock.Close() oldLookup := lookupCurrentToken + oldResource := lookupResourceToken oldNow := nowUTC lookupCurrentToken = func() (string, error) { return testCloudDispatchToken, nil } + lookupResourceToken = func(_ context.Context, _ string) (string, error) { + return testCloudDispatchToken, nil + } nowUTC = func() time.Time { return time.Date(2026, 4, 16, 0, 0, 0, 0, time.UTC) } t.Cleanup(func() { lookupCurrentToken = oldLookup + lookupResourceToken = oldResource nowUTC = oldNow }) @@ -399,8 +409,15 @@ func TestServerMode_InsecureHTTPAuthBypassesSecureURLCheck(t *testing.T) { // leak reaches users. func TestServerMode_RejectsPlainHTTPBaseURL(t *testing.T) { oldLookup := lookupCurrentToken + oldResource := lookupResourceToken lookupCurrentToken = func() (string, error) { return testCloudDispatchToken, nil } - t.Cleanup(func() { lookupCurrentToken = oldLookup }) + lookupResourceToken = func(_ context.Context, _ string) (string, error) { + return testCloudDispatchToken, nil + } + t.Cleanup(func() { + lookupCurrentToken = oldLookup + lookupResourceToken = oldResource + }) t.Setenv("ENTIRE_API_BASE_URL", "http://dispatch.example.invalid") diff --git a/cmd/entire/cli/dispatch/mode_local.go b/cmd/entire/cli/dispatch/mode_local.go index 58186ac6e1..34ea268f32 100644 --- a/cmd/entire/cli/dispatch/mode_local.go +++ b/cmd/entire/cli/dispatch/mode_local.go @@ -22,8 +22,20 @@ import ( ) var ( + // lookupCurrentToken is retained for test-injection back-compat. The + // cloud-mode runner now resolves its bearer via lookupResourceToken + // so it picks up the RFC 8693 exchange in split-host deployments; + // existing tests that swap lookupCurrentToken keep working because + // the default lookupResourceToken delegates to it. lookupCurrentToken = auth.LookupCurrentToken - nowUTC = func() time.Time { return time.Now().UTC() } + + // lookupResourceToken returns a bearer scoped to the given resource + // origin. Production wiring goes through auth.TokenForResource so + // the tokenmanager's same-host shortcut, JWT-aud shortcut, and + // exchange dispatch all apply. Tests swap to a fixed-token closure. + lookupResourceToken = auth.TokenForResource + + nowUTC = func() time.Time { return time.Now().UTC() } ) func runLocal(ctx context.Context, opts Options) (*Dispatch, error) { diff --git a/cmd/entire/cli/search/search.go b/cmd/entire/cli/search/search.go index b9cf8c8adf..6f93bbb7fd 100644 --- a/cmd/entire/cli/search/search.go +++ b/cmd/entire/cli/search/search.go @@ -67,9 +67,10 @@ type Response struct { // Config holds the configuration for a search request. type Config struct { ServiceURL string // Base URL of the search service - // GitHubToken is a misnomer kept for backwards compatibility: callers - // populate it with the OAuth bearer from auth.LookupCurrentToken(). - // The wire format is unchanged (Authorization: Bearer ). + // GitHubToken is a misnomer kept for backwards compatibility: + // callers populate it with the resource-scoped OAuth bearer from + // auth.TokenForResource(ctx, ServiceURL). The wire format is + // unchanged (Authorization: Bearer ). GitHubToken string Owner string Repo string From b410d5017b32383e933eda6d0f72e28b0a943c2e Mon Sep 17 00:00:00 2001 From: Alex Ong Date: Fri, 8 May 2026 18:54:05 +1000 Subject: [PATCH 17/21] auth: PR review fixes (PollDeviceAuth retry, doc accuracy) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - PollDeviceAuth: unknown OAuth error codes (invalid_request, invalid_client, server_error, unsupported_grant_type, etc.) used to fall through to login.go's transient-retry path, burning ~25-150s on permanent server failures before producing a confusing "after N consecutive failures" message. Replace oauthErrorCode + descriptionFromSentinel with a single oauthErrorParts that also matches deviceflow's generic "oauth error: " wrapper. Unknown codes now land in DeviceAuthPoll.Error so the polling loop's default switch arm fails fast with "device authorization failed: ". Tests cover known sentinels, sentinel-with-description, unknown- passthrough, unknown-with-description, and non-OAuth (transient) errors. - Store.GetToken: doc said "only ErrNotFound and ErrMalformed trigger the fallback" but ErrNotFound short-circuits to the empty-string return without a keyring read; only ErrMalformed actually triggers the bare-string fallback. Fixed the doc to match. - api.RevokeCurrentToken: dropped the stale comment about v2 not exposing /current — server-side fix is incoming. Co-Authored-By: Claude Opus 4.7 (1M context) --- cmd/entire/cli/api/auth_tokens.go | 4 - cmd/entire/cli/auth/client.go | 73 +++++++++++------ cmd/entire/cli/auth/client_test.go | 122 +++++++++++++++++++++++++++++ cmd/entire/cli/auth/store.go | 10 +-- 4 files changed, 178 insertions(+), 31 deletions(-) create mode 100644 cmd/entire/cli/auth/client_test.go diff --git a/cmd/entire/cli/api/auth_tokens.go b/cmd/entire/cli/api/auth_tokens.go index f3e40c2cf8..e54c69bc13 100644 --- a/cmd/entire/cli/api/auth_tokens.go +++ b/cmd/entire/cli/api/auth_tokens.go @@ -61,10 +61,6 @@ func (c *Client) ListTokens(ctx context.Context) ([]Token, error) { } // RevokeCurrentToken revokes the bearer token used to authenticate this client. -// -// v1 has a dedicated /current endpoint. v2 doesn't expose one yet -// (would require a family_id claim on the JWT — tracked separately); -// callers can find the active family via ListTokens and revoke by ID. func (c *Client) RevokeCurrentToken(ctx context.Context) error { resp, err := c.Delete(ctx, authTokensBasePath()+"/current") if err != nil { diff --git a/cmd/entire/cli/auth/client.go b/cmd/entire/cli/auth/client.go index 1e6697a244..f6ee0da399 100644 --- a/cmd/entire/cli/auth/client.go +++ b/cmd/entire/cli/auth/client.go @@ -65,17 +65,21 @@ func (c *Client) StartDeviceAuth(ctx context.Context) (*DeviceAuthStart, error) return c.inner.StartDeviceAuth(ctx) //nolint:wrapcheck // shim preserves the lib's wrapped errors verbatim } -// PollDeviceAuth polls the token endpoint. On any RFC 8628 §3.5 error, -// the wire-side error code is returned in DeviceAuthPoll.Error so the -// existing polling loop in login.go can branch on it. Non-RFC errors -// (network, decode) are returned as a real error. +// PollDeviceAuth polls the token endpoint. On any OAuth-protocol error +// (recognised RFC 8628 §3.5 sentinel or unknown but spec-shaped code +// like invalid_request / invalid_client / server_error), the wire-side +// code is returned in DeviceAuthPoll.Error so the existing polling +// loop in login.go can branch on it — known codes hit the dedicated +// switch arms, unknown codes fall through to the default arm and fail +// fast. Non-protocol errors (network, decode) are returned as a real +// error and treated as transient by the polling loop. func (c *Client) PollDeviceAuth(ctx context.Context, deviceCode string) (*DeviceAuthPoll, error) { t, err := c.inner.PollDeviceAuth(ctx, deviceCode) if err != nil { - if code := oauthErrorCode(err); code != "" { + if code, description, ok := oauthErrorParts(err); ok { return &DeviceAuthPoll{ Error: code, - ErrorDescription: descriptionFromSentinel(err, code), + ErrorDescription: description, }, nil } return nil, err //nolint:wrapcheck // shim returns deviceflow errors verbatim so callers can errors.Is on sentinels @@ -89,31 +93,56 @@ func (c *Client) PollDeviceAuth(ctx context.Context, deviceCode string) (*Device }, nil } -// oauthErrorCode returns the wire-side code for a recognised RFC 8628 -// sentinel error, or "" if err isn't one. -func oauthErrorCode(err error) string { +// oauthErrorParts inspects err for either a recognised RFC 8628 §3.5 +// sentinel or the generic "oauth error: " wrapper deviceflow uses +// for unrecognised but spec-shaped codes (RFC 6749 §5.2: invalid_request, +// invalid_client, server_error, …). +// +// On a match, returns the wire-side code, any error_description the +// server included, and ok=true. Otherwise returns "", "", false — the +// caller should treat the error as a transport/decode failure. +// +// Surfacing unknown codes as ok=true is what keeps login.go's polling +// loop fast-failing on terminal OAuth rejections instead of treating +// them as transient and retrying ~5 times. +func oauthErrorParts(err error) (code, description string, ok bool) { switch { case errors.Is(err, deviceflow.ErrAuthorizationPending): - return "authorization_pending" + code = "authorization_pending" case errors.Is(err, deviceflow.ErrSlowDown): - return "slow_down" + code = "slow_down" case errors.Is(err, deviceflow.ErrAccessDenied): - return "access_denied" + code = "access_denied" case errors.Is(err, deviceflow.ErrExpiredToken): - return "expired_token" + code = "expired_token" case errors.Is(err, deviceflow.ErrInvalidGrant): - return "invalid_grant" + code = "invalid_grant" + default: + // Unknown but legitimate OAuth codes come back from + // deviceflow.errCodeToSentinel as fmt.Errorf("oauth error: %s", + // code), optionally wrapped a second time with ": " + // when the server supplied error_description. + const oauthPrefix = "oauth error: " + rest, hadPrefix := strings.CutPrefix(err.Error(), oauthPrefix) + if !hadPrefix { + return "", "", false + } + if c, d, hasDesc := strings.Cut(rest, ": "); hasDesc { + return c, d, true + } + return rest, "", true } - return "" + description = descriptionFromSentinelError(err, code) + return code, description, true } -// descriptionFromSentinel pulls the description suffix out of a wrapped -// sentinel error. The deviceflow lib uses fmt.Errorf("%w: %s", sentinel, -// description) when the server included an error_description, so the -// formatted error reads ": ". Stripping the -// ": " prefix yields the description; absent prefix means the -// server didn't supply one. -func descriptionFromSentinel(err error, code string) string { +// descriptionFromSentinelError pulls the description suffix out of a +// wrapped sentinel error. The deviceflow lib uses +// fmt.Errorf("%w: %s", sentinel, description) when the server included +// an error_description, so the formatted error reads +// ": ". Stripping the ": " prefix yields the +// description; absent prefix means the server didn't supply one. +func descriptionFromSentinelError(err error, code string) string { msg := err.Error() prefix := code + ": " if rest, ok := strings.CutPrefix(msg, prefix); ok { diff --git a/cmd/entire/cli/auth/client_test.go b/cmd/entire/cli/auth/client_test.go new file mode 100644 index 0000000000..bffef7ee9f --- /dev/null +++ b/cmd/entire/cli/auth/client_test.go @@ -0,0 +1,122 @@ +package auth + +import ( + "errors" + "fmt" + "testing" + + "github.com/entireio/cli/auth/deviceflow" +) + +// TestOAuthErrorParts_KnownSentinels covers the five RFC 8628 §3.5 +// codes the polling loop in login.go switches on by name. Without a +// match here, the loop's switch never fires for these terminal cases. +func TestOAuthErrorParts_KnownSentinels(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + err error + want string + }{ + {"authorization_pending", deviceflow.ErrAuthorizationPending, "authorization_pending"}, + {"slow_down", deviceflow.ErrSlowDown, "slow_down"}, + {"access_denied", deviceflow.ErrAccessDenied, "access_denied"}, + {"expired_token", deviceflow.ErrExpiredToken, "expired_token"}, + {"invalid_grant", deviceflow.ErrInvalidGrant, "invalid_grant"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + code, desc, ok := oauthErrorParts(tc.err) + if !ok { + t.Fatalf("oauthErrorParts returned ok=false for %v", tc.err) + } + if code != tc.want { + t.Errorf("code = %q, want %q", code, tc.want) + } + if desc != "" { + t.Errorf("desc = %q, want empty (no description supplied)", desc) + } + }) + } +} + +// TestOAuthErrorParts_KnownSentinelWithDescription pins the description +// extraction for the wrapped form: fmt.Errorf("%w: %s", sentinel, desc). +func TestOAuthErrorParts_KnownSentinelWithDescription(t *testing.T) { + t.Parallel() + + wrapped := fmt.Errorf("%w: device approval window closed", deviceflow.ErrExpiredToken) + code, desc, ok := oauthErrorParts(wrapped) + if !ok { + t.Fatalf("oauthErrorParts ok=false for wrapped sentinel") + } + if code != "expired_token" { + t.Errorf("code = %q, want expired_token", code) + } + if desc != "device approval window closed" { + t.Errorf("desc = %q, want %q", desc, "device approval window closed") + } +} + +// TestOAuthErrorParts_UnknownCodePassesThrough is the regression for +// the bug surfaced in PR review: unknown OAuth codes (e.g. +// invalid_request, invalid_client, server_error) coming back from +// deviceflow as fmt.Errorf("oauth error: %s", code) used to fall +// through to the transient-retry path in login.go's waitForApproval, +// burning ~25-150s on permanent server errors. They now land in +// DeviceAuthPoll.Error so the polling loop's default switch arm fails +// fast with "device authorization failed: ". +func TestOAuthErrorParts_UnknownCodePassesThrough(t *testing.T) { + t.Parallel() + + cases := []string{"invalid_request", "invalid_client", "server_error", "unsupported_grant_type"} + for _, want := range cases { + t.Run(want, func(t *testing.T) { + t.Parallel() + err := fmt.Errorf("oauth error: %s", want) + code, desc, ok := oauthErrorParts(err) + if !ok { + t.Fatalf("oauthErrorParts ok=false for unknown OAuth code %q", want) + } + if code != want { + t.Errorf("code = %q, want %q", code, want) + } + if desc != "" { + t.Errorf("desc = %q, want empty (no description supplied)", desc) + } + }) + } +} + +// TestOAuthErrorParts_UnknownCodeWithDescription matches the wire +// shape deviceflow produces when the server returns both an unknown +// error code and a non-empty error_description. +func TestOAuthErrorParts_UnknownCodeWithDescription(t *testing.T) { + t.Parallel() + + err := errors.New("oauth error: invalid_client: client authentication failed") + code, desc, ok := oauthErrorParts(err) + if !ok { + t.Fatalf("oauthErrorParts ok=false") + } + if code != "invalid_client" { + t.Errorf("code = %q, want invalid_client", code) + } + if desc != "client authentication failed" { + t.Errorf("desc = %q, want %q", desc, "client authentication failed") + } +} + +// TestOAuthErrorParts_NonOAuthError confirms transport/decode errors +// (no "oauth error:" prefix and not an RFC 8628 sentinel) are reported +// as ok=false so the polling loop treats them as transient. +func TestOAuthErrorParts_NonOAuthError(t *testing.T) { + t.Parallel() + + err := errors.New("connection reset by peer") + if _, _, ok := oauthErrorParts(err); ok { + t.Fatal("oauthErrorParts ok=true for non-OAuth error; would mask transient transport failure") + } +} diff --git a/cmd/entire/cli/auth/store.go b/cmd/entire/cli/auth/store.go index 1120bea86d..4e77ecd979 100644 --- a/cmd/entire/cli/auth/store.go +++ b/cmd/entire/cli/auth/store.go @@ -55,11 +55,11 @@ func (s *Store) SaveToken(baseURL, token string) error { // GetToken retrieves a stored token for the given base URL. Returns // an empty string (and no error) if no token is stored. // -// Falls back to a bare-string read when the stored entry is malformed -// JSON, to handle pre-shim entries that stored the raw access token -// rather than a JSON-encoded TokenSet. Real keyring errors (transport, -// permission denied) propagate; only ErrNotFound and ErrMalformed -// trigger the fallback. +// ErrNotFound short-circuits to the empty-string return without a +// further keyring read. ErrMalformed (a stored entry exists but can't +// be decoded as a TokenSet) triggers a bare-string fallback to handle +// pre-shim entries that stored the raw access token verbatim. Real +// keyring errors (transport, permission denied) propagate. // // Deprecated: prefer LoadTokens (the tokenstore.Store interface method) // for new callers — it returns the full TokenSet so refresh tokens and From 9c2b0703ba1630a3c27bd818be994ff014a2bf16 Mon Sep 17 00:00:00 2001 From: Alex Ong Date: Fri, 8 May 2026 20:56:33 +1000 Subject: [PATCH 18/21] auth: PR review fixes (parallel-safe clock pin, struct cache key) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two cursor bug-bot findings, both Low severity but worth closing. - auth/deviceflow + auth/sts: TestPollDeviceAuth_Success and TestExchange_Success call freezeClock, which mutates the package- level nowFunc. Both tests were marked t.Parallel(), creating a latent race against any future parallel test that reads nowFunc through a real Exchange/PollDeviceAuth call. Drop the t.Parallel() on those two tests with a comment explaining why; the rest of the package keeps parallelism. -race confirms no race remains. - auth/tokenmanager: cacheKey was a delimiter-joined string, structurally vulnerable to collisions if any field embedded the "|" separator (none do today, but no guarantee for future callers). Replace with a struct map key — Go's map can use comparable structs directly, so there's no string encoding to misbehave. Co-Authored-By: Claude Opus 4.7 (1M context) --- auth/deviceflow/deviceflow_test.go | 4 +-- auth/sts/sts_test.go | 4 +-- auth/tokenmanager/tokenmanager.go | 47 +++++++++++++++++++----------- 3 files changed, 34 insertions(+), 21 deletions(-) diff --git a/auth/deviceflow/deviceflow_test.go b/auth/deviceflow/deviceflow_test.go index 8166be0341..05fd091c57 100644 --- a/auth/deviceflow/deviceflow_test.go +++ b/auth/deviceflow/deviceflow_test.go @@ -141,8 +141,8 @@ func TestStartDeviceAuth_NonOK(t *testing.T) { } func TestPollDeviceAuth_Success(t *testing.T) { - t.Parallel() - + // Not parallel: freezeClock mutates the package-level nowFunc. + // Any other parallel test calling PollDeviceAuth would race against it. freezeClock(t, time.Date(2026, 5, 6, 12, 0, 0, 0, time.UTC)) c := newTestClient(t, func(w http.ResponseWriter, r *http.Request) { diff --git a/auth/sts/sts_test.go b/auth/sts/sts_test.go index 04570cb446..5f8282aa20 100644 --- a/auth/sts/sts_test.go +++ b/auth/sts/sts_test.go @@ -50,8 +50,8 @@ func mustReadForm(t *testing.T, r *http.Request) { } func TestExchange_Success(t *testing.T) { - t.Parallel() - + // Not parallel: freezeClock mutates the package-level nowFunc. + // Any other parallel test calling Exchange would race against it. freezeClock(t, time.Date(2026, 5, 6, 12, 0, 0, 0, time.UTC)) c := newTestClient(t, func(w http.ResponseWriter, r *http.Request) { diff --git a/auth/tokenmanager/tokenmanager.go b/auth/tokenmanager/tokenmanager.go index 40501a7637..82aa426068 100644 --- a/auth/tokenmanager/tokenmanager.go +++ b/auth/tokenmanager/tokenmanager.go @@ -114,7 +114,7 @@ type Manager struct { cfg Config mu sync.Mutex - cache map[string]cachedToken + cache map[cacheKey]cachedToken } // New builds a Manager from cfg. Returns an error when required @@ -129,7 +129,7 @@ func New(cfg Config) (*Manager, error) { if cfg.Now == nil { cfg.Now = time.Now } - return &Manager{cfg: cfg, cache: map[string]cachedToken{}}, nil + return &Manager{cfg: cfg, cache: map[cacheKey]cachedToken{}}, nil } // Issuer returns the configured issuer URL. @@ -168,7 +168,7 @@ func (m *Manager) DeleteCoreToken() error { return fmt.Errorf("delete core token: %w", err) } m.mu.Lock() - m.cache = map[string]cachedToken{} + m.cache = map[cacheKey]cachedToken{} m.mu.Unlock() return nil } @@ -234,7 +234,7 @@ func (m *Manager) Token(ctx context.Context, req TokenRequest) (string, error) { } resolved := m.resolve(req) - key := cacheKey(core, resolved) + key := makeCacheKey(core, resolved) if hit, ok := m.cacheLookup(key); ok { return hit, nil } @@ -282,20 +282,33 @@ func (c cachedToken) usable(now time.Time) bool { return now.Add(exchangeSkew).Before(c.expiresAt) } -// cacheKey derives a stable cache key from the (resolved) request. -// Includes every wire-affecting field so different combinations don't -// shadow each other. -func cacheKey(coreToken string, req TokenRequest) string { - return strings.Join([]string{ - coreToken, - req.Resource, - req.Audience, - req.RequestedTokenType, - req.Scope, - }, "|") +// cacheKey is a structurally-keyed exchange-cache key. Using a struct +// rather than a delimiter-joined string sidesteps any chance of two +// distinct (core token, resource, audience, requested-token-type, +// scope) tuples hashing to the same map slot via embedded delimiters +// in any field. +type cacheKey struct { + CoreToken string + Resource string + Audience string + RequestedTokenType string + Scope string +} + +// makeCacheKey builds a cacheKey from the (resolved) request. Includes +// every wire-affecting field so different combinations don't shadow +// each other. +func makeCacheKey(coreToken string, req TokenRequest) cacheKey { + return cacheKey{ + CoreToken: coreToken, + Resource: req.Resource, + Audience: req.Audience, + RequestedTokenType: req.RequestedTokenType, + Scope: req.Scope, + } } -func (m *Manager) cacheLookup(key string) (string, bool) { +func (m *Manager) cacheLookup(key cacheKey) (string, bool) { m.mu.Lock() defer m.mu.Unlock() entry, ok := m.cache[key] @@ -309,7 +322,7 @@ func (m *Manager) cacheLookup(key string) (string, bool) { return entry.accessToken, true } -func (m *Manager) cacheStore(key string, t *tokens.TokenSet) { +func (m *Manager) cacheStore(key cacheKey, t *tokens.TokenSet) { m.mu.Lock() defer m.mu.Unlock() m.cache[key] = cachedToken{ From 6a9e6016a5cbd4f7cb662e1c307a78955c30d3e6 Mon Sep 17 00:00:00 2001 From: Stefan Haubold Date: Fri, 8 May 2026 14:02:27 +0200 Subject: [PATCH 19/21] auth: review follow-ups (provider routing, URL normalization, expiry preflight, timeouts) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Six fixes called out across the security/architecture/correctness reviews of #cli-auth-consolidation. Each fix lands with focused tests; no behaviour changes outside the auth path. 1. Eliminate duplicate ENTIRE_AUTH_PROVIDER_VERSION read in cmd/entire/cli/api/auth_tokens.go. The provider table in cmd/entire/cli/auth.Provider now owns AuthTokensPath; api.Client takes it via WithAuthTokensPath. The api package no longer reads the env var. 2. Read provider version once at startup. CurrentProvider() resolves via sync.Once and freezes; tests inject via SetProviderForTest. resolveProvider is a pure function so the routing table is exercisable without env-var gymnastics. 3. Normalize URLs in tokenmanager same-host / aud / cache-key compares. normalizeOriginURL handles trailing slash, scheme/host case, and default ports (RFC 3986 §6.2.2.1 / §6.2.3); non-URL audiences pass through unchanged for byte-exact compare. 4. Preflight core-token expiry and clear the exchange cache on SaveCoreToken. Long-expired tokens surface as ErrNotLoggedIn (so "run login" UX kicks in) instead of confusing STS / 401 errors; a re-login can't return the previous user's exchanged tokens. 5. Tighten tokenstore malformed-JSON detection. Well-formed JSON without an access_token now surfaces as ErrMalformed. The shim's bare-string fallback rejects JSON-shaped content via looksLikeBareToken so "Authorization: Bearer {}" can't ship. 6. Add per-request timeouts (DefaultRequestTimeout = 30s) to deviceflow.Client and sts.Client via context.WithTimeout. The wrap lives at the method level so the deadline covers the body read, not just the dial. Tests pin both the firing path and the default/override resolution. Entire-Checkpoint: c79c0ff7d6c1 --- auth/deviceflow/deviceflow.go | 46 ++++- auth/deviceflow/deviceflow_test.go | 80 +++++++++ auth/sts/sts.go | 33 ++++ auth/sts/sts_test.go | 59 +++++++ auth/tokenmanager/tokenmanager.go | 102 ++++++++++- auth/tokenmanager/tokenmanager_test.go | 228 ++++++++++++++++++++++++- auth/tokenstore/keyring.go | 11 ++ auth/tokenstore/keyring_test.go | 31 ++++ cmd/entire/cli/api/auth_tokens.go | 43 +++-- cmd/entire/cli/api/auth_tokens_test.go | 100 +++++------ cmd/entire/cli/api/client.go | 20 +++ cmd/entire/cli/auth.go | 8 +- cmd/entire/cli/auth/client.go | 10 +- cmd/entire/cli/auth/exchange.go | 8 +- cmd/entire/cli/auth/provider.go | 111 +++++++++--- cmd/entire/cli/auth/provider_test.go | 105 +++++++----- cmd/entire/cli/auth/store.go | 29 ++++ cmd/entire/cli/auth/store_test.go | 42 +++++ cmd/entire/cli/logout.go | 4 +- 19 files changed, 910 insertions(+), 160 deletions(-) diff --git a/auth/deviceflow/deviceflow.go b/auth/deviceflow/deviceflow.go index 9a72f53340..2827e053e6 100644 --- a/auth/deviceflow/deviceflow.go +++ b/auth/deviceflow/deviceflow.go @@ -47,6 +47,15 @@ type DeviceCode struct { Interval int `json:"interval"` } +// DefaultRequestTimeout caps a single device-flow HTTP round-trip +// (StartDeviceAuth or one PollDeviceAuth call). Set conservatively: +// healthy device-flow endpoints respond in sub-seconds, so the cap +// mainly defends against slow-loris responses dripping bytes within +// MaxResponseBytes — see Client.RequestTimeout for the per-Client +// override. The polling-loop interval is the caller's concern; this +// timeout governs only the individual HTTP request. +const DefaultRequestTimeout = 30 * time.Second + // Client polls an RFC 8628 device authorization grant. // // All configuration is explicit; the package has no global state and @@ -60,6 +69,26 @@ type Client struct { UserAgent string DeviceCodePath string TokenPath string + + // RequestTimeout is the per-request deadline applied via + // context.WithTimeout on top of the caller's context. Zero falls + // back to DefaultRequestTimeout. Negative disables the cap (useful + // for tests that want to drive timing via the caller's ctx alone). + RequestTimeout time.Duration +} + +// requestTimeout resolves the effective per-request timeout: the +// configured RequestTimeout if positive, the package default if zero, +// or zero (no cap) if negative. +func (c *Client) requestTimeout() time.Duration { + switch { + case c.RequestTimeout < 0: + return 0 + case c.RequestTimeout == 0: + return DefaultRequestTimeout + default: + return c.RequestTimeout + } } // Sentinel errors returned by PollDeviceAuth when the token endpoint @@ -110,6 +139,12 @@ func errCodeToSentinel(code string) error { // server. The returned DeviceCode is opaque to the client; pass it // back unmodified on every PollDeviceAuth. func (c *Client) StartDeviceAuth(ctx context.Context) (*DeviceCode, error) { + if timeout := c.requestTimeout(); timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + body := url.Values{} body.Set("client_id", c.ClientID) if c.Scope != "" { @@ -141,6 +176,12 @@ func (c *Client) StartDeviceAuth(ctx context.Context) (*DeviceCode, error) { // the matching sentinel error from this package. Other failures // (network, malformed responses) are wrapped with context. func (c *Client) PollDeviceAuth(ctx context.Context, deviceCode string) (*tokens.TokenSet, error) { + if timeout := c.requestTimeout(); timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + body := url.Values{} body.Set("grant_type", deviceCodeGrantType) body.Set("client_id", c.ClientID) @@ -195,7 +236,10 @@ func (c *Client) PollDeviceAuth(ctx context.Context, deviceCode string) (*tokens } // postForm POSTs body as application/x-www-form-urlencoded to a path -// resolved against the client's BaseURL. +// resolved against the client's BaseURL. The caller is responsible +// for applying any per-request timeout via context.WithTimeout — the +// timeout must cover the body-read that happens after postForm +// returns, so cancel-on-return here would interrupt that read. func (c *Client) postForm(ctx context.Context, path string, body url.Values) (*http.Response, error) { endpoint, err := resolveURL(c.BaseURL, path) if err != nil { diff --git a/auth/deviceflow/deviceflow_test.go b/auth/deviceflow/deviceflow_test.go index 05fd091c57..a43dfd4709 100644 --- a/auth/deviceflow/deviceflow_test.go +++ b/auth/deviceflow/deviceflow_test.go @@ -360,3 +360,83 @@ func TestResolveURL(t *testing.T) { }) } } + +// TestPollDeviceAuth_RequestTimeoutFires pins the slow-loris defence: +// a handler that never finishes writing a response must surface as a +// context deadline error rather than blocking the polling loop forever. +func TestPollDeviceAuth_RequestTimeoutFires(t *testing.T) { + t.Parallel() + // Cleanup is LIFO and httptest.Server.Close waits for active + // handler goroutines, so close(hung) is registered AFTER + // newTestClient to fire first and let the handler exit before + // srv.Close runs. + hung := make(chan struct{}) + c := newTestClient(t, func(_ http.ResponseWriter, r *http.Request) { + select { + case <-hung: + case <-r.Context().Done(): + } + }) + t.Cleanup(func() { close(hung) }) + c.RequestTimeout = 50 * time.Millisecond + + _, err := c.PollDeviceAuth(context.Background(), "dev-1") + if err == nil { + t.Fatal("expected timeout error, got nil") + } + if !strings.Contains(err.Error(), "context deadline exceeded") { + t.Fatalf("err = %v, want context deadline exceeded", err) + } +} + +// TestStartDeviceAuth_RequestTimeoutFires mirrors the poll-side test +// for the device-code endpoint. +func TestStartDeviceAuth_RequestTimeoutFires(t *testing.T) { + t.Parallel() + // Cleanup is LIFO and httptest.Server.Close waits for active + // handler goroutines, so close(hung) is registered AFTER + // newTestClient to fire first and let the handler exit before + // srv.Close runs. + hung := make(chan struct{}) + c := newTestClient(t, func(_ http.ResponseWriter, r *http.Request) { + select { + case <-hung: + case <-r.Context().Done(): + } + }) + t.Cleanup(func() { close(hung) }) + c.RequestTimeout = 50 * time.Millisecond + + _, err := c.StartDeviceAuth(context.Background()) + if err == nil { + t.Fatal("expected timeout error, got nil") + } + if !strings.Contains(err.Error(), "context deadline exceeded") { + t.Fatalf("err = %v, want context deadline exceeded", err) + } +} + +// TestRequestTimeout_DefaultAndOverride exercises the timeout policy +// without doing IO — pure resolution of the (zero / negative / +// positive) input contract. +func TestRequestTimeout_DefaultAndOverride(t *testing.T) { + t.Parallel() + cases := []struct { + name string + in time.Duration + want time.Duration + }{ + {"zero -> default", 0, DefaultRequestTimeout}, + {"negative -> disabled", -1, 0}, + {"positive -> verbatim", 5 * time.Second, 5 * time.Second}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + c := &Client{RequestTimeout: tc.in} + if got := c.requestTimeout(); got != tc.want { + t.Fatalf("requestTimeout() = %v, want %v", got, tc.want) + } + }) + } +} diff --git a/auth/sts/sts.go b/auth/sts/sts.go index 72a53cb578..d8233cd285 100644 --- a/auth/sts/sts.go +++ b/auth/sts/sts.go @@ -68,6 +68,13 @@ func (r ExchangeRequest) validate() error { return nil } +// DefaultRequestTimeout caps a single token-exchange round-trip. Set +// conservatively: even with a slow auth host plus TLS handshake, a +// healthy exchange completes in sub-seconds. The cap mainly defends +// against slow-loris responses dripping bytes within MaxResponseBytes +// — see Client.RequestTimeout for the per-Client override. +const DefaultRequestTimeout = 30 * time.Second + // Client exchanges subject tokens for tokens of a different type at an // RFC 8693 token endpoint. // @@ -78,6 +85,12 @@ type Client struct { BaseURL string Path string UserAgent string + + // RequestTimeout is the per-Exchange deadline applied via + // context.WithTimeout on top of the caller's context. Zero falls + // back to DefaultRequestTimeout. Negative disables the cap (useful + // for tests that want to drive timing via the caller's ctx alone). + RequestTimeout time.Duration } // Exchange performs one RFC 8693 token exchange. @@ -98,6 +111,12 @@ func (c *Client) Exchange(ctx context.Context, req ExchangeRequest) (*tokens.Tok return nil, fmt.Errorf("token exchange: resolve URL: %w", err) } + if timeout := c.requestTimeout(); timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode())) if err != nil { return nil, fmt.Errorf("token exchange: create request: %w", err) @@ -176,6 +195,20 @@ func buildForm(req ExchangeRequest) url.Values { return form } +// requestTimeout resolves the effective per-request timeout: the +// configured RequestTimeout if positive, the package default if zero, +// or zero (no cap) if negative. +func (c *Client) requestTimeout() time.Duration { + switch { + case c.RequestTimeout < 0: + return 0 + case c.RequestTimeout == 0: + return DefaultRequestTimeout + default: + return c.RequestTimeout + } +} + func resolveURL(baseURL, path string) (string, error) { base, err := url.Parse(baseURL) if err != nil { diff --git a/auth/sts/sts_test.go b/auth/sts/sts_test.go index 5f8282aa20..5e4f348d5c 100644 --- a/auth/sts/sts_test.go +++ b/auth/sts/sts_test.go @@ -297,3 +297,62 @@ func TestExchange_NoExpiry(t *testing.T) { t.Fatalf("ExpiresAt = %v, want zero", got.ExpiresAt) } } + +// TestExchange_RequestTimeoutFires pins the slow-loris defence: a +// handler that never writes a response body must surface as a context +// deadline error rather than blocking the caller indefinitely. +// +// Cleanup order matters: t.Cleanup is LIFO, and httptest.Server.Close +// waits for in-flight handler goroutines to return. We register +// `close(hung)` AFTER newTestClient so it fires first and lets the +// handler exit before srv.Close runs. +func TestExchange_RequestTimeoutFires(t *testing.T) { + t.Parallel() + hung := make(chan struct{}) + + c := newTestClient(t, func(_ http.ResponseWriter, r *http.Request) { + select { + case <-hung: + case <-r.Context().Done(): + } + }) + t.Cleanup(func() { close(hung) }) + c.RequestTimeout = 50 * time.Millisecond + + _, err := c.Exchange(context.Background(), ExchangeRequest{ + SubjectToken: "sub", + SubjectTokenType: SubjectTokenTypeJWT, + RequestedTokenType: "urn:example:t", + }) + if err == nil { + t.Fatal("expected timeout error, got nil") + } + if !strings.Contains(err.Error(), "context deadline exceeded") { + t.Fatalf("err = %v, want context deadline exceeded", err) + } +} + +// TestRequestTimeout_DefaultAndOverride exercises the timeout policy +// without doing IO — pure resolution of the (zero / negative / +// positive) input contract. +func TestRequestTimeout_DefaultAndOverride(t *testing.T) { + t.Parallel() + cases := []struct { + name string + in time.Duration + want time.Duration + }{ + {"zero -> default", 0, DefaultRequestTimeout}, + {"negative -> disabled", -1, 0}, + {"positive -> verbatim", 5 * time.Second, 5 * time.Second}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + c := &Client{RequestTimeout: tc.in} + if got := c.requestTimeout(); got != tc.want { + t.Fatalf("requestTimeout() = %v, want %v", got, tc.want) + } + }) + } +} diff --git a/auth/tokenmanager/tokenmanager.go b/auth/tokenmanager/tokenmanager.go index 82aa426068..2d1af6c6af 100644 --- a/auth/tokenmanager/tokenmanager.go +++ b/auth/tokenmanager/tokenmanager.go @@ -18,7 +18,6 @@ import ( "fmt" "net/http" "net/url" - "slices" "strings" "sync" "time" @@ -137,8 +136,20 @@ func (m *Manager) Issuer() string { return m.cfg.Issuer } // SaveCoreToken persists the device-flow access token under the // configured Issuer. +// +// On successful save the in-memory exchange cache is cleared so a +// re-login under a different identity can't return the previous user's +// exchanged tokens. The cacheKey already binds entries to CoreToken so +// this is defence-in-depth against a future refactor that drops the +// core token from the cache key — see TestSaveCoreToken_ClearsExchangeCache. func (m *Manager) SaveCoreToken(accessToken string) error { - return m.cfg.Store.SaveTokens(m.cfg.Issuer, tokens.TokenSet{AccessToken: accessToken}) //nolint:wrapcheck // backend error already names the operation + if err := m.cfg.Store.SaveTokens(m.cfg.Issuer, tokens.TokenSet{AccessToken: accessToken}); err != nil { + return fmt.Errorf("save core token: %w", err) + } + m.mu.Lock() + m.cache = map[cacheKey]cachedToken{} + m.mu.Unlock() + return nil } // LookupCoreToken returns the stored core token, or "" if none is @@ -225,16 +236,26 @@ func (m *Manager) Token(ctx context.Context, req TokenRequest) (string, error) { if core == "" { return "", ErrNotLoggedIn } + // Preflight expiry: a long-stored core token would otherwise hit the + // resource (or STS) and surface as a confusing "invalid_grant" / + // "401". Parse-failure is intentionally not treated as expired — + // opaque (non-JWT) access tokens have no client-visible expiry, so + // we let them flow and trust the server to reject if necessary. + if coreTokenExpired(core, m.cfg.Now()) { + return "", ErrNotLoggedIn + } + + normResource := normalizeOriginURL(req.Resource) - if req.Audience == "" && m.cfg.Issuer == req.Resource { + if req.Audience == "" && normalizeOriginURL(m.cfg.Issuer) == normResource { return core, nil } - if req.Audience == "" && coreTokenAudienceIncludes(core, req.Resource) { + if req.Audience == "" && coreTokenAudienceIncludes(core, normResource) { return core, nil } resolved := m.resolve(req) - key := makeCacheKey(core, resolved) + key := makeCacheKey(core, resolved, normResource) if hit, ok := m.cacheLookup(key); ok { return hit, nil } @@ -258,12 +279,73 @@ func (m *Manager) resolve(req TokenRequest) TokenRequest { return req } +// coreTokenExpired reports whether the core token has an `exp` claim +// in the past at now. JWT parse failures (and tokens without an `exp` +// claim) are reported as not-expired so opaque access tokens flow +// through the rest of the resolution rules unchanged. +func coreTokenExpired(coreJWT string, now time.Time) bool { + claims, err := tokens.ParseClaims(coreJWT) + if err != nil { + return false + } + if claims.ExpiresAt.IsZero() { + return false + } + return !now.Before(claims.ExpiresAt) +} + +// coreTokenAudienceIncludes reports whether the core JWT's `aud` claim +// covers target. target is expected to already be in normalised form +// (see normalizeOriginURL); aud entries are normalised here so a +// trailing-slash / case difference between the AS and the caller +// doesn't force a needless STS exchange. func coreTokenAudienceIncludes(coreJWT, target string) bool { claims, err := tokens.ParseClaims(coreJWT) if err != nil { return false } - return slices.Contains(claims.Audience, target) + for _, aud := range claims.Audience { + if normalizeOriginURL(aud) == target { + return true + } + } + return false +} + +// normalizeOriginURL canonicalises an origin URL for equality +// comparisons. RFC 3986 §6.2.2.1 makes scheme and host case-insensitive +// and §6.2.3 makes the empty path equivalent to "/" — we collapse to +// no-trailing-slash. Default ports (80/http, 443/https) are stripped. +// +// On parse failure (or when the input lacks a scheme or host — common +// for non-URL audiences) the input is returned unchanged so callers +// fall back to byte-exact comparison. +func normalizeOriginURL(raw string) string { + u, err := url.Parse(raw) + if err != nil || u.Scheme == "" || u.Host == "" { + return raw + } + u.Scheme = strings.ToLower(u.Scheme) + + hostname := strings.ToLower(u.Hostname()) + port := u.Port() + dropPort := (u.Scheme == "http" && port == "80") || + (u.Scheme == "https" && port == "443") || + port == "" + + switch { + case dropPort && strings.Contains(hostname, ":"): // IPv6 without port + u.Host = "[" + hostname + "]" + case dropPort: + u.Host = hostname + case strings.Contains(hostname, ":"): // IPv6 with non-default port + u.Host = "[" + hostname + "]:" + port + default: + u.Host = hostname + ":" + port + } + + u.Path = strings.TrimRight(u.Path, "/") + return u.String() } // cachedToken is one entry in the per-process exchange cache. @@ -297,11 +379,13 @@ type cacheKey struct { // makeCacheKey builds a cacheKey from the (resolved) request. Includes // every wire-affecting field so different combinations don't shadow -// each other. -func makeCacheKey(coreToken string, req TokenRequest) cacheKey { +// each other. normalizedResource is the caller-supplied Resource after +// passing through normalizeOriginURL, so https://api.example.com and +// https://api.example.com/ share a single cache entry. +func makeCacheKey(coreToken string, req TokenRequest, normalizedResource string) cacheKey { return cacheKey{ CoreToken: coreToken, - Resource: req.Resource, + Resource: normalizedResource, Audience: req.Audience, RequestedTokenType: req.RequestedTokenType, Scope: req.Scope, diff --git a/auth/tokenmanager/tokenmanager_test.go b/auth/tokenmanager/tokenmanager_test.go index 916280c7d1..4529509312 100644 --- a/auth/tokenmanager/tokenmanager_test.go +++ b/auth/tokenmanager/tokenmanager_test.go @@ -41,10 +41,11 @@ func (s *memStore) DeleteTokens(profile string) error { } const ( - testIssuer = "https://auth.example.com" - testResource = "https://api.example.com" - testClientID = "test-cli" - testSTSPath = "/sts/token" + testIssuer = "https://auth.example.com" + testResource = "https://api.example.com" + testClientID = "test-cli" + testSTSPath = "/sts/token" + testExchangedTok = "exchanged" ) func makeJWTWithAudience(t *testing.T, aud []string) string { @@ -59,6 +60,25 @@ func makeJWTWithAudience(t *testing.T, aud []string) string { return header + "." + body + "." + sig } +// makeJWTWithExp builds an unsigned JWT carrying `exp` (and optionally +// `aud`). The signature segment is junk — tokenmanager never verifies +// it, ParseClaims is documented as unverified. +func makeJWTWithExp(t *testing.T, exp time.Time, aud []string) string { + t.Helper() + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`)) + claims := map[string]any{"sub": "test", "exp": exp.Unix()} + if len(aud) > 0 { + claims["aud"] = aud + } + payload, err := json.Marshal(claims) + if err != nil { + t.Fatalf("marshal payload: %v", err) + } + body := base64.RawURLEncoding.EncodeToString(payload) + sig := base64.RawURLEncoding.EncodeToString([]byte("not-real")) + return header + "." + body + "." + sig +} + func newTestManager(t *testing.T, store tokenstore.Store, exchange func(context.Context, sts.ExchangeRequest) (*tokens.TokenSet, error)) *Manager { t.Helper() m, err := New(Config{ @@ -216,7 +236,7 @@ func TestToken_ExplicitAudienceBypassesAudienceShortcut(t *testing.T) { m := newTestManager(t, store, func(_ context.Context, req sts.ExchangeRequest) (*tokens.TokenSet, error) { calls++ got = req - return &tokens.TokenSet{AccessToken: "exchanged"}, nil + return &tokens.TokenSet{AccessToken: testExchangedTok}, nil }) token, err := m.Token(context.Background(), TokenRequest{Resource: testResource, Audience: requestedAudience}) @@ -224,7 +244,7 @@ func TestToken_ExplicitAudienceBypassesAudienceShortcut(t *testing.T) { t.Fatalf("Token: %v", err) } - if token != "exchanged" || calls != 1 { + if token != testExchangedTok || calls != 1 { t.Fatalf("Token returned %q with %d exchange calls, want exchanged token from one exchange", token, calls) } if got.Audience != requestedAudience { @@ -288,7 +308,7 @@ func TestToken_ExchangeIncludesResource(t *testing.T) { var got sts.ExchangeRequest m := newTestManager(t, store, func(_ context.Context, req sts.ExchangeRequest) (*tokens.TokenSet, error) { got = req - return &tokens.TokenSet{AccessToken: "exchanged"}, nil + return &tokens.TokenSet{AccessToken: testExchangedTok}, nil }) if _, err := m.TokenForResource(context.Background(), testResource); err != nil { @@ -383,7 +403,7 @@ func TestToken_CacheExpires(t *testing.T) { Now: func() time.Time { return now }, Exchange: func(_ context.Context, _ sts.ExchangeRequest) (*tokens.TokenSet, error) { calls++ - return &tokens.TokenSet{AccessToken: "exchanged", ExpiresAt: now.Add(time.Minute)}, nil + return &tokens.TokenSet{AccessToken: testExchangedTok, ExpiresAt: now.Add(time.Minute)}, nil }, }) if err != nil { @@ -630,7 +650,7 @@ func TestCoreTokenAudienceShortcut_FallsThroughOnMalformedJWT(t *testing.T) { var exchangeCalls int m := newTestManager(t, store, func(_ context.Context, _ sts.ExchangeRequest) (*tokens.TokenSet, error) { exchangeCalls++ - return &tokens.TokenSet{AccessToken: "exchanged"}, nil + return &tokens.TokenSet{AccessToken: testExchangedTok}, nil }) got, err := m.TokenForResource(context.Background(), testResource) @@ -666,3 +686,193 @@ func TestToken_StoreErrorSurfacesNotAsErrNotLoggedIn(t *testing.T) { t.Fatalf("err = %v, want underlying store error", err) } } + +// TestToken_ExpiredCoreReturnsNotLoggedIn pins the preflight behaviour: +// a core token whose JWT `exp` is in the past surfaces ErrNotLoggedIn +// before the request reaches STS or the resource. Without preflight, +// users see a confusing "invalid_grant" / "401" instead of "run login". +func TestToken_ExpiredCoreReturnsNotLoggedIn(t *testing.T) { + t.Parallel() + now := time.Date(2026, 1, 1, 12, 0, 0, 0, time.UTC) + expired := makeJWTWithExp(t, now.Add(-time.Hour), nil) + + store := newMemStore() + store.data[testIssuer] = tokens.TokenSet{AccessToken: expired} + + m, err := New(Config{ + Issuer: testIssuer, ClientID: testClientID, STSPath: testSTSPath, Store: store, + Now: func() time.Time { return now }, + Exchange: func(_ context.Context, _ sts.ExchangeRequest) (*tokens.TokenSet, error) { + t.Fatal("exchange must not run for an expired core token") + return nil, errors.New("unreachable") + }, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + _, err = m.TokenForResource(context.Background(), testResource) + if !errors.Is(err, ErrNotLoggedIn) { + t.Fatalf("err = %v, want ErrNotLoggedIn", err) + } +} + +// TestToken_OpaqueCorePassesPreflight guards the parse-failure branch: +// non-JWT (opaque) access tokens have no client-visible expiry, so +// they must NOT be classified as expired by the preflight check. +func TestToken_OpaqueCorePassesPreflight(t *testing.T) { + t.Parallel() + store := newMemStore() + store.data[testIssuer] = tokens.TokenSet{AccessToken: "opaque-not-a-jwt"} + + m, err := New(Config{ + Issuer: testIssuer, ClientID: testClientID, Store: store, + Exchange: func(_ context.Context, _ sts.ExchangeRequest) (*tokens.TokenSet, error) { + t.Fatal("same-host shortcut should win for opaque core token == issuer") + return nil, errors.New("unreachable") + }, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + got, err := m.TokenForResource(context.Background(), testIssuer) + if err != nil { + t.Fatalf("TokenForResource: %v", err) + } + if got != "opaque-not-a-jwt" { + t.Fatalf("got %q, want opaque core verbatim", got) + } +} + +// TestSaveCoreToken_ClearsExchangeCache pins the cache-invalidation +// contract on save: a re-login under a different identity must not +// return the previous user's exchanged tokens, even if a future +// refactor accidentally drops CoreToken from the cache key. +func TestSaveCoreToken_ClearsExchangeCache(t *testing.T) { + t.Parallel() + store := newMemStore() + store.data[testIssuer] = tokens.TokenSet{AccessToken: "user-a-core"} + + calls := 0 + m := newTestManager(t, store, func(_ context.Context, _ sts.ExchangeRequest) (*tokens.TokenSet, error) { + calls++ + return &tokens.TokenSet{AccessToken: "user-a-exchanged"}, nil + }) + + if _, err := m.TokenForResource(context.Background(), testResource); err != nil { + t.Fatalf("first call: %v", err) + } + if calls != 1 { + t.Fatalf("calls after first Token = %d, want 1", calls) + } + + if err := m.SaveCoreToken("user-b-core"); err != nil { + t.Fatalf("SaveCoreToken: %v", err) + } + + if _, err := m.TokenForResource(context.Background(), testResource); err != nil { + t.Fatalf("post-save call: %v", err) + } + if calls != 2 { + t.Fatalf("exchange calls after save = %d, want 2 (cache must be cleared on save)", calls) + } +} + +// TestNormalizeOriginURL covers the cases where same-host / aud-shortcut +// equality has historically misfired: trailing slash, scheme/host case, +// default-port presence. Inputs that don't parse as origin URLs must +// pass through unchanged so non-URL audiences keep byte-exact compare. +func TestNormalizeOriginURL(t *testing.T) { + t.Parallel() + cases := []struct { + name string + in string + want string + }{ + {"empty", "", ""}, + {"plain", "https://api.example.com", "https://api.example.com"}, + {"trailing slash", "https://api.example.com/", "https://api.example.com"}, + {"upper scheme", "HTTPS://api.example.com", "https://api.example.com"}, + {"upper host", "https://API.Example.COM", "https://api.example.com"}, + {"default https port", "https://api.example.com:443", "https://api.example.com"}, + {"default http port", "http://api.example.com:80/", "http://api.example.com"}, + {"non-default port preserved", "https://api.example.com:8443", "https://api.example.com:8443"}, + {"path preserved (sans trailing slash)", "https://api.example.com/v2/", "https://api.example.com/v2"}, + {"non-URL audience passes through", "urn:example:cli", "urn:example:cli"}, + {"bare string passes through", "some-audience", "some-audience"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if got := normalizeOriginURL(tc.in); got != tc.want { + t.Errorf("normalizeOriginURL(%q) = %q, want %q", tc.in, got, tc.want) + } + }) + } +} + +// TestToken_SameHostShortcut_NormalisesURLs guards against a regression +// where a trailing-slash or case difference between Issuer and Resource +// forces a needless STS exchange (or fails outright on single-host +// deployments where STSPath is empty). +func TestToken_SameHostShortcut_NormalisesURLs(t *testing.T) { + t.Parallel() + store := newMemStore() + store.data[testIssuer] = tokens.TokenSet{AccessToken: "core-tok"} + + m, err := New(Config{ + Issuer: testIssuer, ClientID: testClientID, // STSPath intentionally empty + Store: store, + Exchange: func(_ context.Context, _ sts.ExchangeRequest) (*tokens.TokenSet, error) { + t.Fatal("exchange must not run when issuer == resource modulo trailing slash / case") + return nil, errors.New("unreachable") + }, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + for _, resource := range []string{ + testIssuer + "/", // trailing slash + strings.ToUpper(testIssuer[:8]) + testIssuer[8:], // uppercase scheme + } { + got, err := m.TokenForResource(context.Background(), resource) + if err != nil { + t.Fatalf("TokenForResource(%q): %v", resource, err) + } + if got != "core-tok" { + t.Fatalf("TokenForResource(%q) = %q, want core token verbatim", resource, got) + } + } +} + +// TestToken_CacheCollapsesURLEquivalents pins the cache key being +// computed off the normalised resource: two equivalent forms must +// share a single entry rather than each driving its own STS round-trip. +func TestToken_CacheCollapsesURLEquivalents(t *testing.T) { + t.Parallel() + store := newMemStore() + store.data[testIssuer] = tokens.TokenSet{AccessToken: "core-tok"} + + var calls int + m := newTestManager(t, store, func(_ context.Context, _ sts.ExchangeRequest) (*tokens.TokenSet, error) { + calls++ + return &tokens.TokenSet{AccessToken: testExchangedTok}, nil + }) + + first, err := m.TokenForResource(context.Background(), testResource) + if err != nil { + t.Fatalf("first call: %v", err) + } + second, err := m.TokenForResource(context.Background(), testResource+"/") + if err != nil { + t.Fatalf("second call: %v", err) + } + if first != testExchangedTok || second != testExchangedTok { + t.Fatalf("tokens = (%q, %q), want both exchanged", first, second) + } + if calls != 1 { + t.Fatalf("exchange calls = %d, want 1 (trailing-slash variant must hit cache)", calls) + } +} diff --git a/auth/tokenstore/keyring.go b/auth/tokenstore/keyring.go index f80a2e9ddf..8ec3ddfceb 100644 --- a/auth/tokenstore/keyring.go +++ b/auth/tokenstore/keyring.go @@ -114,6 +114,17 @@ func decodeTokenSet(raw string) (tokens.TokenSet, error) { return tokens.TokenSet{}, fmt.Errorf("%w: unmarshal TokenSet: %w", ErrMalformed, err) } + // json.Unmarshal happily decodes any well-formed JSON object to a + // zero keyringTokenSet — for example {} or an unrelated CLI's blob + // that happens to be keyed against the same (service, profile). We + // surface that as ErrMalformed so the embedding shim can route to + // its legacy/upgrade path rather than returning a TokenSet with an + // empty AccessToken (which the caller can't distinguish from a + // successful load of a freshly-cleared entry). + if strings.TrimSpace(wire.AccessToken) == "" { + return tokens.TokenSet{}, fmt.Errorf("%w: stored entry has no access_token", ErrMalformed) + } + t := tokens.TokenSet{ AccessToken: wire.AccessToken, RefreshToken: wire.RefreshToken, diff --git a/auth/tokenstore/keyring_test.go b/auth/tokenstore/keyring_test.go index c30ea5469f..f21402e48b 100644 --- a/auth/tokenstore/keyring_test.go +++ b/auth/tokenstore/keyring_test.go @@ -195,3 +195,34 @@ func TestKeyring_LoadTokens_BadExpiresAtReturnsErrMalformed(t *testing.T) { t.Fatalf("err = %v, want ErrMalformed", err) } } + +// TestKeyring_LoadTokens_EmptyAccessTokenReturnsErrMalformed pins the +// guard against well-formed JSON that decodes to a zero TokenSet. An +// unrelated CLI's blob keyed against the same service/profile, or a +// "{}" entry from a buggy save, would otherwise produce a TokenSet +// with empty AccessToken indistinguishable from a successful load — +// and the shim would then ship "Authorization: Bearer " on the wire. +func TestKeyring_LoadTokens_EmptyAccessTokenReturnsErrMalformed(t *testing.T) { + cases := []struct { + name string + body string + }{ + {"empty object", `{}`}, + {"unrelated fields only", `{"foo":"bar","count":3}`}, + {"explicit empty access_token", `{"access_token":""}`}, + {"whitespace access_token", `{"access_token":" "}`}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + service := "test-empty-" + tc.name + const profile = "https://example.com" + if err := keyring.Set(service, profile, tc.body); err != nil { + t.Fatalf("seed keyring: %v", err) + } + _, err := NewKeyring(service).LoadTokens(profile) + if !errors.Is(err, ErrMalformed) { + t.Fatalf("err = %v, want ErrMalformed", err) + } + }) + } +} diff --git a/cmd/entire/cli/api/auth_tokens.go b/cmd/entire/cli/api/auth_tokens.go index e54c69bc13..557c9e6ab3 100644 --- a/cmd/entire/cli/api/auth_tokens.go +++ b/cmd/entire/cli/api/auth_tokens.go @@ -2,10 +2,9 @@ package api import ( "context" + "errors" "fmt" "net/url" - "os" - "strings" ) // Token is a single API token row returned by the auth-tokens endpoint. @@ -25,25 +24,27 @@ type TokensResponse struct { Tokens []Token `json:"tokens"` } -// authTokensProviderVersionEnvVar must match the env var read by -// cmd/entire/cli/auth's currentProvider(). Duplicated here rather than -// imported because api/ is a leaf package and shouldn't take a -// dependency on auth/ for routing. -const authTokensProviderVersionEnvVar = "ENTIRE_AUTH_PROVIDER_VERSION" //nolint:gosec // env var name, not a credential +// errAuthTokensPathUnset surfaces when an auth-tokens method is called +// on a Client that wasn't given a base path. Construct via +// NewClientWithBaseURL(...).WithAuthTokensPath(...) — the active path +// lives in cmd/entire/cli/auth.CurrentProvider().AuthTokensPath, the +// single source of truth for provider-version routing. +var errAuthTokensPathUnset = errors.New("api: auth-tokens path is unset (call (*Client).WithAuthTokensPath before list/revoke)") -// authTokensBasePath returns the auth-tokens endpoint family base path -// for the active provider version. v1 (default) hits /api/v1/auth/tokens; -// v2 hits /api/auth/tokens (no version segment). -func authTokensBasePath() string { - if strings.TrimSpace(os.Getenv(authTokensProviderVersionEnvVar)) == "v2" { - return "/api/auth/tokens" +func (c *Client) authTokensBasePath() (string, error) { + if c.authTokensPath == "" { + return "", errAuthTokensPathUnset } - return "/api/v1/auth/tokens" + return c.authTokensPath, nil } // ListTokens returns the authenticated user's non-expired API tokens. func (c *Client) ListTokens(ctx context.Context) ([]Token, error) { - resp, err := c.Get(ctx, authTokensBasePath()) + base, err := c.authTokensBasePath() + if err != nil { + return nil, fmt.Errorf("list tokens: %w", err) + } + resp, err := c.Get(ctx, base) if err != nil { return nil, fmt.Errorf("list tokens: %w", err) } @@ -62,7 +63,11 @@ func (c *Client) ListTokens(ctx context.Context) ([]Token, error) { // RevokeCurrentToken revokes the bearer token used to authenticate this client. func (c *Client) RevokeCurrentToken(ctx context.Context) error { - resp, err := c.Delete(ctx, authTokensBasePath()+"/current") + base, err := c.authTokensBasePath() + if err != nil { + return fmt.Errorf("revoke current token: %w", err) + } + resp, err := c.Delete(ctx, base+"/current") if err != nil { return fmt.Errorf("revoke current token: %w", err) } @@ -76,7 +81,11 @@ func (c *Client) RevokeCurrentToken(ctx context.Context) error { // RevokeToken revokes the API token with the given id. func (c *Client) RevokeToken(ctx context.Context, id string) error { - resp, err := c.Delete(ctx, authTokensBasePath()+"/"+url.PathEscape(id)) + base, err := c.authTokensBasePath() + if err != nil { + return fmt.Errorf("revoke token %s: %w", id, err) + } + resp, err := c.Delete(ctx, base+"/"+url.PathEscape(id)) if err != nil { return fmt.Errorf("revoke token %s: %w", id, err) } diff --git a/cmd/entire/cli/api/auth_tokens_test.go b/cmd/entire/cli/api/auth_tokens_test.go index 12a9c4628f..07ccd694d2 100644 --- a/cmd/entire/cli/api/auth_tokens_test.go +++ b/cmd/entire/cli/api/auth_tokens_test.go @@ -9,8 +9,23 @@ import ( "testing" ) +const ( + testV1AuthTokensPath = "/api/v1/auth/tokens" + testV2AuthTokensPath = "/api/auth/tokens" +) + +// newAuthTokensTestClient builds a Client pointed at server.URL with +// the given auth-tokens base path. Used by all auth-tokens tests so +// the wiring matches production: callers chain WithAuthTokensPath at +// construction time. +func newAuthTokensTestClient(serverURL, authTokensPath string) *Client { + c := NewClient("tok").WithAuthTokensPath(authTokensPath) + c.baseURL = serverURL + return c +} + func TestClient_RevokeCurrentToken_SendsDeleteWithBearer(t *testing.T) { - t.Setenv(authTokensProviderVersionEnvVar, "") + t.Parallel() var gotMethod, gotPath, gotAuth string @@ -23,8 +38,7 @@ func TestClient_RevokeCurrentToken_SendsDeleteWithBearer(t *testing.T) { })) defer server.Close() - c := NewClient("tok") - c.baseURL = server.URL + c := newAuthTokensTestClient(server.URL, testV1AuthTokensPath) if err := c.RevokeCurrentToken(context.Background()); err != nil { t.Fatalf("RevokeCurrentToken() error = %v", err) @@ -51,8 +65,7 @@ func TestClient_RevokeCurrentToken_ReturnsHTTPErrorOn401(t *testing.T) { })) defer server.Close() - c := NewClient("tok") - c.baseURL = server.URL + c := newAuthTokensTestClient(server.URL, testV1AuthTokensPath) err := c.RevokeCurrentToken(context.Background()) if err == nil { @@ -71,7 +84,7 @@ func TestClient_RevokeCurrentToken_ReturnsHTTPErrorOn401(t *testing.T) { } func TestClient_ListTokens_DecodesResponse(t *testing.T) { - t.Setenv(authTokensProviderVersionEnvVar, "") + t.Parallel() var gotMethod, gotPath, gotAuth string @@ -87,8 +100,7 @@ func TestClient_ListTokens_DecodesResponse(t *testing.T) { })) defer server.Close() - c := NewClient("tok") - c.baseURL = server.URL + c := newAuthTokensTestClient(server.URL, testV1AuthTokensPath) tokens, err := c.ListTokens(context.Background()) if err != nil { @@ -129,8 +141,7 @@ func TestClient_ListTokens_ReturnsHTTPErrorOn401(t *testing.T) { })) defer server.Close() - c := NewClient("tok") - c.baseURL = server.URL + c := newAuthTokensTestClient(server.URL, testV1AuthTokensPath) _, err := c.ListTokens(context.Background()) if err == nil { @@ -142,7 +153,7 @@ func TestClient_ListTokens_ReturnsHTTPErrorOn401(t *testing.T) { } func TestClient_RevokeToken_SendsDeleteWithEscapedID(t *testing.T) { - t.Setenv(authTokensProviderVersionEnvVar, "") + t.Parallel() var gotMethod, gotEscapedPath, gotDecodedPath string @@ -155,8 +166,7 @@ func TestClient_RevokeToken_SendsDeleteWithEscapedID(t *testing.T) { })) defer server.Close() - c := NewClient("tok") - c.baseURL = server.URL + c := newAuthTokensTestClient(server.URL, testV1AuthTokensPath) // Use an id that needs URL escaping to verify we don't blindly concat. if err := c.RevokeToken(context.Background(), "abc/def 1"); err != nil { @@ -184,8 +194,7 @@ func TestClient_RevokeToken_ReturnsErrorBody(t *testing.T) { })) defer server.Close() - c := NewClient("tok") - c.baseURL = server.URL + c := newAuthTokensTestClient(server.URL, testV1AuthTokensPath) err := c.RevokeToken(context.Background(), "missing") if err == nil { @@ -199,39 +208,12 @@ func TestClient_RevokeToken_ReturnsErrorBody(t *testing.T) { } } -// TestAuthTokensBasePath_ProviderVersionRouting locks in the path -// switch so v2 doesn't silently regress to v1's path family. The whole -// reason the version env var exists is to route requests at this layer. -func TestAuthTokensBasePath_ProviderVersionRouting(t *testing.T) { - cases := []struct { - name string - version string - want string - }{ - {"unset defaults to v1", "", "/api/v1/auth/tokens"}, - {"v1 explicit", "v1", "/api/v1/auth/tokens"}, - {"v2", "v2", "/api/auth/tokens"}, - {"unrecognised defaults to v1", "v999", "/api/v1/auth/tokens"}, - // Whitespace trimming must match auth.currentProvider() — both - // trim, so the api and auth packages agree on what "v2" means. - // If either side stops trimming, these tests diverge first. - {"trims whitespace then matches v2", " v2 ", "/api/auth/tokens"}, - } - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - t.Setenv(authTokensProviderVersionEnvVar, tc.version) - if got := authTokensBasePath(); got != tc.want { - t.Fatalf("authTokensBasePath() = %q, want %q", got, tc.want) - } - }) - } -} - -// TestClient_ListTokens_RoutesV2Path is an end-to-end check that the -// version switch flows through the public Client API, not just the -// internal helper. -func TestClient_ListTokens_RoutesV2Path(t *testing.T) { - t.Setenv(authTokensProviderVersionEnvVar, "v2") +// TestClient_AuthTokens_RoutesV2Path verifies that whatever path the +// caller supplies via WithAuthTokensPath is what hits the wire. The +// provider table itself (which path corresponds to which version) is +// exercised by cmd/entire/cli/auth's resolveProvider tests. +func TestClient_AuthTokens_RoutesV2Path(t *testing.T) { + t.Parallel() var gotPath string server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -241,8 +223,7 @@ func TestClient_ListTokens_RoutesV2Path(t *testing.T) { })) defer server.Close() - c := NewClient("tok") - c.baseURL = server.URL + c := newAuthTokensTestClient(server.URL, testV2AuthTokensPath) if _, err := c.ListTokens(context.Background()); err != nil { t.Fatalf("ListTokens: %v", err) @@ -251,3 +232,22 @@ func TestClient_ListTokens_RoutesV2Path(t *testing.T) { t.Fatalf("path = %q, want /api/auth/tokens (v2)", gotPath) } } + +// TestClient_AuthTokens_UnsetPathErrors guards against silently +// shipping a request to "" — we want a clear error pointing at the +// missing WithAuthTokensPath wiring. +func TestClient_AuthTokens_UnsetPathErrors(t *testing.T) { + t.Parallel() + + c := NewClient("tok") // no WithAuthTokensPath + + if _, err := c.ListTokens(context.Background()); !errors.Is(err, errAuthTokensPathUnset) { + t.Errorf("ListTokens err = %v, want errAuthTokensPathUnset", err) + } + if err := c.RevokeCurrentToken(context.Background()); !errors.Is(err, errAuthTokensPathUnset) { + t.Errorf("RevokeCurrentToken err = %v, want errAuthTokensPathUnset", err) + } + if err := c.RevokeToken(context.Background(), "any"); !errors.Is(err, errAuthTokensPathUnset) { + t.Errorf("RevokeToken err = %v, want errAuthTokensPathUnset", err) + } +} diff --git a/cmd/entire/cli/api/client.go b/cmd/entire/cli/api/client.go index b568214a77..98a7d76ab1 100644 --- a/cmd/entire/cli/api/client.go +++ b/cmd/entire/cli/api/client.go @@ -21,6 +21,26 @@ const ( type Client struct { httpClient *http.Client baseURL string + + // authTokensPath is the base path for the auth-tokens management + // endpoints (list / revoke). Set via WithAuthTokensPath when the + // client targets the auth host. Empty for data-API-only clients; + // auth-tokens methods error out if called against an empty path. + authTokensPath string +} + +// WithAuthTokensPath sets the base path used by ListTokens, +// RevokeCurrentToken, and RevokeToken. The path is supplied by the +// auth shim from auth.CurrentProvider().AuthTokensPath, which is the +// single source of truth for provider-version routing — the api +// package no longer reads ENTIRE_AUTH_PROVIDER_VERSION itself. +// +// Returns the receiver for chaining at construction: +// +// c := api.NewClientWithBaseURL(token, base).WithAuthTokensPath(p) +func (c *Client) WithAuthTokensPath(path string) *Client { + c.authTokensPath = path + return c } // NewClient creates a new authenticated API client with an explicit bearer diff --git a/cmd/entire/cli/auth.go b/cmd/entire/cli/auth.go index e5ba946e2c..d5a08469eb 100644 --- a/cmd/entire/cli/auth.go +++ b/cmd/entire/cli/auth.go @@ -100,7 +100,9 @@ func newAuthStatusCmd() *cobra.Command { } func defaultListTokens(ctx context.Context, token string) ([]api.Token, error) { - return api.NewClientWithBaseURL(token, api.AuthBaseURL()).ListTokens(ctx) //nolint:wrapcheck // ListTokens already wraps with action context + client := api.NewClientWithBaseURL(token, api.AuthBaseURL()). + WithAuthTokensPath(auth.CurrentProvider().AuthTokensPath) + return client.ListTokens(ctx) //nolint:wrapcheck // ListTokens already wraps with action context } func runAuthStatus(ctx context.Context, w io.Writer, store tokenStore, list authTokenLister, baseURL string) error { @@ -426,7 +428,9 @@ func newAuthRevokeCmd() *cobra.Command { } func defaultRevokeTokenByID(ctx context.Context, callerToken, id string) error { - return api.NewClientWithBaseURL(callerToken, api.AuthBaseURL()).RevokeToken(ctx, id) //nolint:wrapcheck // RevokeToken already wraps with action context + client := api.NewClientWithBaseURL(callerToken, api.AuthBaseURL()). + WithAuthTokensPath(auth.CurrentProvider().AuthTokensPath) + return client.RevokeToken(ctx, id) //nolint:wrapcheck // RevokeToken already wraps with action context } func runAuthRevoke( diff --git a/cmd/entire/cli/auth/client.go b/cmd/entire/cli/auth/client.go index f6ee0da399..80e5579fd8 100644 --- a/cmd/entire/cli/auth/client.go +++ b/cmd/entire/cli/auth/client.go @@ -45,15 +45,15 @@ type Client struct { // NewClient constructs a Client targeting the active provider version. // httpClient is used directly when non-nil; otherwise http.DefaultClient. func NewClient(httpClient *http.Client) *Client { - p := currentProvider() + p := CurrentProvider() return &Client{inner: &deviceflow.Client{ HTTP: httpClient, BaseURL: api.AuthBaseURL(), - ClientID: p.clientID, + ClientID: p.ClientID, Scope: "cli", - UserAgent: p.clientID, - DeviceCodePath: p.deviceCodePath, - TokenPath: p.tokenPath, + UserAgent: p.ClientID, + DeviceCodePath: p.DeviceCodePath, + TokenPath: p.TokenPath, }} } diff --git a/cmd/entire/cli/auth/exchange.go b/cmd/entire/cli/auth/exchange.go index f888425fef..23962fde6d 100644 --- a/cmd/entire/cli/auth/exchange.go +++ b/cmd/entire/cli/auth/exchange.go @@ -50,13 +50,13 @@ func defaultManager() (*tokenmanager.Manager, error) { return managerForTest, nil } managerOnce.Do(func() { - provider := currentProvider() + provider := CurrentProvider() m, err := tokenmanager.New(tokenmanager.Config{ Issuer: api.AuthBaseURL(), - ClientID: provider.clientID, - STSPath: provider.stsPath, + ClientID: provider.ClientID, + STSPath: provider.STSPath, Store: NewStore(), - UserAgent: provider.clientID, + UserAgent: provider.ClientID, Scope: "cli", }) manager = m diff --git a/cmd/entire/cli/auth/provider.go b/cmd/entire/cli/auth/provider.go index 804fe22977..b30ec8b6b8 100644 --- a/cmd/entire/cli/auth/provider.go +++ b/cmd/entire/cli/auth/provider.go @@ -3,6 +3,7 @@ package auth import ( "os" "strings" + "sync" ) // ProviderVersionEnvVar selects which OAuth surface this CLI talks to. @@ -15,46 +16,112 @@ import ( // This is a transition-period switch: once v2 is the universal default // the env var goes away. Surfaces are otherwise reachable as RFC 8628 // device-flow endpoints; the only differences are paths and client_id. +// +// Read once at process startup via CurrentProvider; later flips within +// the same process are intentionally ignored. Tests inject via +// SetProviderForTest rather than mutating the env mid-run. const ProviderVersionEnvVar = "ENTIRE_AUTH_PROVIDER_VERSION" -// providerConfig captures the per-surface bits of OAuth wiring. +// Provider captures the per-surface bits of OAuth wiring. // -// stsPath is the RFC 8693 token-exchange endpoint. v1 is the legacy +// STSPath is the RFC 8693 token-exchange endpoint. v1 is the legacy // single-host surface where the auth and data API live at the same // origin (entire.io); the same-host shortcut in tokenmanager.Token -// always wins and STS is never invoked, so v1.stsPath is left empty. +// always wins and STS is never invoked, so v1.STSPath is left empty. // v2 exposes a dedicated STS path because it's used in split-host // deployments (e.g. us.auth.partial.to mints, partial.to consumes). -type providerConfig struct { - clientID string - deviceCodePath string - tokenPath string - stsPath string +// +// AuthTokensPath is the base path for the auth-tokens management +// endpoint family (list / revoke). Routed at the api.Client layer via +// (*api.Client).WithAuthTokensPath so the provider table is the single +// source of truth — no env-var duplication between auth/ and api/. +type Provider struct { + ClientID string + DeviceCodePath string + TokenPath string + STSPath string + AuthTokensPath string } -var providers = map[string]providerConfig{ +var providers = map[string]Provider{ "v1": { //nolint:gosec // OAuth client_id and endpoint paths, not credentials - clientID: "entire-cli", - deviceCodePath: "/oauth/device/code", - tokenPath: "/oauth/token", + ClientID: "entire-cli", + DeviceCodePath: "/oauth/device/code", + TokenPath: "/oauth/token", + AuthTokensPath: "/api/v1/auth/tokens", }, "v2": { //nolint:gosec // OAuth client_id and endpoint paths, not credentials - clientID: "entire-cli", - deviceCodePath: "/api/auth/oauth/device/code", - tokenPath: "/api/auth/token", - stsPath: "/api/authz/sts/token", + ClientID: "entire-cli", + DeviceCodePath: "/api/auth/oauth/device/code", + TokenPath: "/api/auth/token", + STSPath: "/api/authz/sts/token", + AuthTokensPath: "/api/auth/tokens", }, } -// currentProvider returns the active providerConfig, defaulting to v1 -// when ENTIRE_AUTH_PROVIDER_VERSION is unset or holds an unrecognised -// value. Defaulting (rather than erroring) keeps old binaries safe if -// a future v3 ever lands. -func currentProvider() providerConfig { - switch strings.TrimSpace(os.Getenv(ProviderVersionEnvVar)) { +// resolveProvider returns the Provider matching version. Defaulting +// (rather than erroring) on unrecognised values keeps old binaries safe +// if a future v3 ever lands. Pure function — no env reads — so unit +// tests can exercise the routing table without env-var gymnastics. +func resolveProvider(version string) Provider { + switch strings.TrimSpace(version) { case "v2": return providers["v2"] default: return providers["v1"] } } + +var ( + providerOnce sync.Once + resolvedProvider Provider + + // providerForTest, when non-nil, short-circuits CurrentProvider so + // tests can install a specific Provider without racing the + // process-wide sync.Once (which freezes the first observation + // forever). Mutated only via SetProviderForTest. Production code + // never reads this var. + providerForTest *Provider + providerTestMu sync.Mutex +) + +// CurrentProvider returns the active Provider for this process. +// +// Resolution: read ENTIRE_AUTH_PROVIDER_VERSION exactly once on first +// call, freeze the result, and return the same Provider on every +// subsequent call. Tests that need a different provider must use +// SetProviderForTest before any auth call constructs the singleton. +func CurrentProvider() Provider { + providerTestMu.Lock() + override := providerForTest + providerTestMu.Unlock() + if override != nil { + return *override + } + providerOnce.Do(func() { + resolvedProvider = resolveProvider(os.Getenv(ProviderVersionEnvVar)) + }) + return resolvedProvider +} + +// SetProviderForTest installs p as the Provider returned by +// CurrentProvider for the duration of the test, and registers a +// t.Cleanup to remove the override. Test-only. +// +// Takes a tiny interface rather than *testing.T so production builds +// don't import testing. +func SetProviderForTest(t interface { + Helper() + Cleanup(f func()) +}, p Provider) { + t.Helper() + providerTestMu.Lock() + prev := providerForTest + providerForTest = &p + providerTestMu.Unlock() + t.Cleanup(func() { + providerTestMu.Lock() + providerForTest = prev + providerTestMu.Unlock() + }) +} diff --git a/cmd/entire/cli/auth/provider_test.go b/cmd/entire/cli/auth/provider_test.go index a67409aab2..8a47f6894e 100644 --- a/cmd/entire/cli/auth/provider_test.go +++ b/cmd/entire/cli/auth/provider_test.go @@ -17,69 +17,94 @@ const ( wantClientIDV2 = "entire-cli" ) -func TestCurrentProvider_DefaultsToV1(t *testing.T) { - t.Setenv(ProviderVersionEnvVar, "") - - p := currentProvider() - if p.clientID != wantClientIDV1 || p.deviceCodePath != "/oauth/device/code" || p.tokenPath != "/oauth/token" { +// resolveProvider is a pure function — no env reads — so the routing +// table can be exercised without t.Setenv (and without the +// process-wide sync.Once in CurrentProvider freezing the first +// observation forever). + +func TestResolveProvider_DefaultsToV1(t *testing.T) { + t.Parallel() + p := resolveProvider("") + if p.ClientID != wantClientIDV1 || p.DeviceCodePath != "/oauth/device/code" || p.TokenPath != "/oauth/token" { t.Fatalf("default provider = %+v, want v1 config", p) } + if p.AuthTokensPath != "/api/v1/auth/tokens" { + t.Fatalf("default AuthTokensPath = %q, want /api/v1/auth/tokens", p.AuthTokensPath) + } } -func TestCurrentProvider_V1Explicit(t *testing.T) { - t.Setenv(ProviderVersionEnvVar, "v1") - - p := currentProvider() - if p.clientID != wantClientIDV1 { - t.Fatalf("v1 clientID = %q", p.clientID) +func TestResolveProvider_V1Explicit(t *testing.T) { + t.Parallel() + p := resolveProvider("v1") + if p.ClientID != wantClientIDV1 { + t.Fatalf("v1 ClientID = %q", p.ClientID) } // v1 is single-host (entire.io); no STS surface, same-host shortcut - // always wins. Empty stsPath is the contract. - if p.stsPath != "" { - t.Fatalf("v1 stsPath = %q, want empty (single-host, no STS)", p.stsPath) + // always wins. Empty STSPath is the contract. + if p.STSPath != "" { + t.Fatalf("v1 STSPath = %q, want empty (single-host, no STS)", p.STSPath) } } -func TestCurrentProvider_V2(t *testing.T) { - t.Setenv(ProviderVersionEnvVar, "v2") - - p := currentProvider() - if p.clientID != wantClientIDV2 { - t.Fatalf("v2 clientID = %q, want %s", p.clientID, wantClientIDV2) +func TestResolveProvider_V2(t *testing.T) { + t.Parallel() + p := resolveProvider("v2") + if p.ClientID != wantClientIDV2 { + t.Fatalf("v2 ClientID = %q, want %s", p.ClientID, wantClientIDV2) } - if p.deviceCodePath != "/api/auth/oauth/device/code" { - t.Fatalf("v2 deviceCodePath = %q", p.deviceCodePath) + if p.DeviceCodePath != "/api/auth/oauth/device/code" { + t.Fatalf("v2 DeviceCodePath = %q", p.DeviceCodePath) } - if p.tokenPath != "/api/auth/token" { - t.Fatalf("v2 tokenPath = %q", p.tokenPath) + if p.TokenPath != "/api/auth/token" { + t.Fatalf("v2 TokenPath = %q", p.TokenPath) } - if p.stsPath != "/api/authz/sts/token" { - t.Fatalf("v2 stsPath = %q", p.stsPath) + if p.STSPath != "/api/authz/sts/token" { + t.Fatalf("v2 STSPath = %q", p.STSPath) + } + if p.AuthTokensPath != "/api/auth/tokens" { + t.Fatalf("v2 AuthTokensPath = %q, want /api/auth/tokens", p.AuthTokensPath) } } -func TestCurrentProvider_UnknownDefaultsToV1(t *testing.T) { - t.Setenv(ProviderVersionEnvVar, "v999") +func TestResolveProvider_UnknownDefaultsToV1(t *testing.T) { + t.Parallel() + p := resolveProvider("v999") + if p.ClientID != wantClientIDV1 { + t.Fatalf("unknown version should default to v1; got ClientID = %q", p.ClientID) + } +} - p := currentProvider() - if p.clientID != wantClientIDV1 { - t.Fatalf("unknown version should default to v1; got clientID = %q", p.clientID) +func TestResolveProvider_TrimsWhitespace(t *testing.T) { + t.Parallel() + p := resolveProvider(" v2 ") + if p.ClientID != wantClientIDV2 { + t.Fatalf("whitespace-padded v2 ClientID = %q, want %s", p.ClientID, wantClientIDV2) } } -func TestCurrentProvider_TrimsWhitespace(t *testing.T) { - t.Setenv(ProviderVersionEnvVar, " v2 ") +// TestSetProviderForTest_OverridesCurrentProvider locks in the test +// seam: any test that pins a provider via SetProviderForTest must see +// it from CurrentProvider regardless of process-wide singleton state. +func TestSetProviderForTest_OverridesCurrentProvider(t *testing.T) { + pinned := Provider{ + ClientID: "test-client", + DeviceCodePath: "/test/device", + TokenPath: "/test/token", + STSPath: "/test/sts", + AuthTokensPath: "/test/tokens", + } + SetProviderForTest(t, pinned) - p := currentProvider() - if p.clientID != wantClientIDV2 { - t.Fatalf("whitespace-padded v2 clientID = %q, want %s", p.clientID, wantClientIDV2) + got := CurrentProvider() + if got != pinned { + t.Fatalf("CurrentProvider() = %+v, want %+v", got, pinned) } } -func TestNewClient_HonoursProviderVersion(t *testing.T) { +func TestNewClient_HonoursPinnedProvider(t *testing.T) { t.Setenv(api.BaseURLEnvVar, "https://example.test") t.Setenv(api.AuthBaseURLEnvVar, "") - t.Setenv(ProviderVersionEnvVar, "v2") + SetProviderForTest(t, resolveProvider("v2")) c := NewClient(&http.Client{}) if c.inner.ClientID != wantClientIDV2 { @@ -96,10 +121,10 @@ func TestNewClient_HonoursProviderVersion(t *testing.T) { } } -func TestNewClient_DefaultsToV1(t *testing.T) { +func TestNewClient_DefaultsToV1WhenPinned(t *testing.T) { t.Setenv(api.BaseURLEnvVar, "https://example.test") t.Setenv(api.AuthBaseURLEnvVar, "") - t.Setenv(ProviderVersionEnvVar, "") + SetProviderForTest(t, resolveProvider("")) c := NewClient(nil) if c.inner.ClientID != wantClientIDV1 { diff --git a/cmd/entire/cli/auth/store.go b/cmd/entire/cli/auth/store.go index 4e77ecd979..db80889726 100644 --- a/cmd/entire/cli/auth/store.go +++ b/cmd/entire/cli/auth/store.go @@ -84,6 +84,9 @@ func (s *Store) GetToken(baseURL string) (string, error) { if kerr != nil { return "", fmt.Errorf("get token from keyring: %w", kerr) } + if !looksLikeBareToken(raw) { + return "", nil + } return raw, nil } @@ -109,6 +112,12 @@ func (s *Store) SaveTokens(profile string, t tokens.TokenSet) error { // keyring errors (transport, permission denied) propagate; only // ErrMalformed triggers the fallback. ErrNotFound surfaces verbatim // so the manager's "not logged in" branch still works. +// +// The fallback also guards against well-formed-but-empty entries (e.g. +// "{}" or an unrelated CLI's blob keyed against the same service/profile): +// those surface as ErrMalformed from the lib, but using their raw bytes +// as a bearer token would be wrong. looksLikeBareToken filters them out, +// converting back to ErrNotFound so the caller sees "not logged in". func (s *Store) LoadTokens(profile string) (tokens.TokenSet, error) { t, err := s.inner.LoadTokens(profile) if err == nil { @@ -125,9 +134,29 @@ func (s *Store) LoadTokens(profile string) (tokens.TokenSet, error) { if kerr != nil { return tokens.TokenSet{}, fmt.Errorf("get token from keyring: %w", kerr) } + if !looksLikeBareToken(raw) { + return tokens.TokenSet{}, tokenstore.ErrNotFound + } return tokens.TokenSet{AccessToken: raw}, nil } +// looksLikeBareToken reports whether raw is plausibly a pre-shim bare +// access token rather than a JSON blob. Real bearer tokens are JWTs or +// opaque ASCII strings; they don't start with the JSON object/array +// delimiters. Trims whitespace first so a stray newline doesn't fool +// the check. +func looksLikeBareToken(raw string) bool { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return false + } + switch trimmed[0] { + case '{', '[': + return false + } + return true +} + // DeleteTokens implements tokenstore.Store. func (s *Store) DeleteTokens(profile string) error { return s.inner.DeleteTokens(profile) //nolint:wrapcheck // shim returns the lib error verbatim diff --git a/cmd/entire/cli/auth/store_test.go b/cmd/entire/cli/auth/store_test.go index 5bbfd93fbe..7a767b5602 100644 --- a/cmd/entire/cli/auth/store_test.go +++ b/cmd/entire/cli/auth/store_test.go @@ -178,6 +178,48 @@ func TestStoreLoadTokens_LegacyBareStringFallback(t *testing.T) { } } +// TestStoreLoadTokens_RejectsJSONShapedFallback guards against a +// well-formed-but-empty JSON entry being mistakenly treated as a +// bare-string token. The lib surfaces these as ErrMalformed; the +// shim's bare-string fallback must filter out anything starting with +// '{' or '[' so the user sees "not logged in" rather than getting +// "Authorization: Bearer {}" on the wire. +func TestStoreLoadTokens_RejectsJSONShapedFallback(t *testing.T) { + for _, body := range []string{`{}`, `{"foo":"bar"}`, `[]`} { + const profile = "https://json-shaped.example.com" + service := "test-json-fallback-" + body[:1] + if err := keyring.Set(service, profile, body); err != nil { + t.Fatalf("seed keyring: %v", err) + } + + got, err := NewStoreWithService(service).LoadTokens(profile) + // We expect ErrNotFound — JSON-shaped malformed entries must + // not be routed through the bare-string fallback. + if err == nil { + t.Fatalf("LoadTokens(%q) returned %+v; want ErrNotFound", body, got) + } + if got.AccessToken != "" { + t.Fatalf("LoadTokens(%q) AccessToken = %q, want empty", body, got.AccessToken) + } + } +} + +func TestStoreGetToken_RejectsJSONShapedFallback(t *testing.T) { + const service = "test-json-getoken" + const profile = "https://json-shaped.example.com" + if err := keyring.Set(service, profile, `{"unrelated":"blob"}`); err != nil { + t.Fatalf("seed keyring: %v", err) + } + + got, err := NewStoreWithService(service).GetToken(profile) + if err != nil { + t.Fatalf("GetToken: %v", err) + } + if got != "" { + t.Fatalf("GetToken = %q, want empty (JSON blob must not be shipped as a bearer)", got) + } +} + func TestLookupCurrentToken(t *testing.T) { t.Setenv(api.BaseURLEnvVar, "http://localhost:8787") t.Setenv(api.AuthBaseURLEnvVar, "") diff --git a/cmd/entire/cli/logout.go b/cmd/entire/cli/logout.go index b34b7496e2..0d44365035 100644 --- a/cmd/entire/cli/logout.go +++ b/cmd/entire/cli/logout.go @@ -41,7 +41,9 @@ func newLogoutCmd() *cobra.Command { } func defaultRevokeCurrentToken(ctx context.Context, token string) error { - return api.NewClientWithBaseURL(token, api.AuthBaseURL()).RevokeCurrentToken(ctx) //nolint:wrapcheck // RevokeCurrentToken already wraps with action context + client := api.NewClientWithBaseURL(token, api.AuthBaseURL()). + WithAuthTokensPath(auth.CurrentProvider().AuthTokensPath) + return client.RevokeCurrentToken(ctx) //nolint:wrapcheck // RevokeCurrentToken already wraps with action context } func runLogout(ctx context.Context, outW, errW io.Writer, store tokenStore, revoke revokeCurrentFunc, baseURL string) error { From 706a4a28e0ee3848a4a683c20bfa8c1cc8c7e63a Mon Sep 17 00:00:00 2001 From: Alex Ong Date: Thu, 14 May 2026 17:03:40 +1000 Subject: [PATCH 20/21] auth: defense-in-depth security hardening MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Four review-surfaced findings, all defense-in-depth on top of server-side validation: 1. Reject JWTs with alg:none header (RFC 7515 / RFC 7518 §3.6 known attack vector). tokens.ParseClaims now decodes the JWT header and returns tokens.ErrUnsignedJWT for any case-insensitive "none" variant. The CLI's use of ParseClaims is documented as unverified and only feeds routing decisions, but a future caller could be tempted to rely on the values — rejecting the unsigned shape at the source keeps that door closed. New makeJWTWithHeader test helper lets us produce well-formed JWTs with specific issues (alg variants, expired exp, etc.) without a real JOSE library. Five regression tests cover lowercase / capitalised / uppercase / whitespace-padded "none" + a sanity check that standard algs (HS256, RS256, ES256, EdDSA, PS512) still parse. 2. Enforce HTTPS on STS exchange and device-flow endpoints. Both sts.Client and deviceflow.Client now reject http:// BaseURLs unless AllowInsecureHTTP is set; new ErrInsecureBaseURL sentinel on each. The cmd-side wires this to auto-permit only loopback http:// (localhost, 127.0.0.1, ::1) via new isLoopbackHTTP helper in cmd/entire/cli/auth/exchange.go and the deviceflow client constructor — production never qualifies, local dev does. tokenmanager.Config.AllowInsecureHTTP plumbs the flag through to the sts.Client at exchange time. 3. Validate verification_uri before showing it to the user. The device-code response field is what we echo and open in the user's browser — a malicious AS pointing it at a phishing page is a direct credential-harvesting vector. New ErrUnsafeVerificationURI sentinel rejects: missing/non-URL, non-https (loopback http only), embedded userinfo (user:pass@host eye-trick), and control characters in the URI string. Tests cover both the safe shapes (https with port / path / query, loopback http) and the unsafe ones (ftp/javascript/data schemes, plain http on non-loopback, newline injection, control chars). 4. Validate received token in login.go before persisting. New validateReceivedToken runs minimum-trust checks on the access token: rejects alg:none (via ParseClaims), iss-mismatch against the issuer we sent the device-code request to, and already- expired exp. Opaque (non-JWT) tokens are allowed — the AS may not issue JWTs at all. Omitted iss is allowed (some servers skip it) but a non-empty mismatch is hard-rejected. Seven unit tests cover the matrix. Also merges in 164 commits from origin/main with conflict resolution in auth/store.go (preserved both the tokenBackend abstraction for the authfilestore test build and the tokenstore.Store interface for the tokenmanager), recap.go (newRecapClient now goes through auth.TokenForResource so split-host setups work), trail_watch_cmd.go (ctx threading), and integration_test/login_test.go (both ENTIRE_AUTH_BASE_URL/ENTIRE_AUTH_PROVIDER_VERSION and ENTIRE_TEST_AUTH_STORE_FILE env vars now set). Soph's earlier defense — rejecting JSON-shaped values in the keyring fallback so corruption never ships as a Bearer header — is also preserved. Co-Authored-By: Claude Opus 4.7 (1M context) --- auth/deviceflow/deviceflow.go | 92 ++++++++++++++++- auth/deviceflow/deviceflow_test.go | 131 +++++++++++++++++++++---- auth/sts/sts.go | 26 ++++- auth/sts/sts_test.go | 7 +- auth/tokenmanager/tokenmanager.go | 15 ++- auth/tokenmanager/tokenmanager_test.go | 4 +- auth/tokens/tokens.go | 30 ++++++ auth/tokens/tokens_test.go | 76 +++++++++++++- cmd/entire/cli/auth/client.go | 21 ++-- cmd/entire/cli/auth/exchange.go | 24 ++++- cmd/entire/cli/login.go | 88 +++++++++++++++++ cmd/entire/cli/login_test.go | 98 ++++++++++++++++++ 12 files changed, 568 insertions(+), 44 deletions(-) diff --git a/auth/deviceflow/deviceflow.go b/auth/deviceflow/deviceflow.go index 2827e053e6..98ff372f62 100644 --- a/auth/deviceflow/deviceflow.go +++ b/auth/deviceflow/deviceflow.go @@ -75,6 +75,14 @@ type Client struct { // back to DefaultRequestTimeout. Negative disables the cap (useful // for tests that want to drive timing via the caller's ctx alone). RequestTimeout time.Duration + + // AllowInsecureHTTP permits http:// BaseURLs. Default (false) is + // reject — the device-flow token endpoint returns the user's + // freshly-minted access token in the response body and must be + // TLS-protected end to end. Production callers MUST leave this + // false; only tests and local development pinned to loopback + // should flip it. + AllowInsecureHTTP bool } // requestTimeout resolves the effective per-request timeout: the @@ -165,9 +173,74 @@ func (c *Client) StartDeviceAuth(ctx context.Context) (*DeviceCode, error) { if err := oauthhttp.ReadAndDecodeJSON(resp.Body, &result, true); err != nil { return nil, fmt.Errorf("start device auth: %w", err) } + if err := validateVerificationURI(result.VerificationURI, c.AllowInsecureHTTP); err != nil { + return nil, fmt.Errorf("start device auth: verification_uri: %w", err) + } + if result.VerificationURIComplete != "" { + if err := validateVerificationURI(result.VerificationURIComplete, c.AllowInsecureHTTP); err != nil { + return nil, fmt.Errorf("start device auth: verification_uri_complete: %w", err) + } + } return &result, nil } +// ErrUnsafeVerificationURI is returned when the authorization server +// returns a verification_uri that fails minimum-trust checks. Defense +// against a compromised or misconfigured AS pointing users at a +// phishing page: the URL we'd otherwise echo to the user and open in +// their browser carries the user code, so a wrong destination is a +// direct credential-harvesting vector. +var ErrUnsafeVerificationURI = errors.New("unsafe verification_uri") + +// validateVerificationURI rejects URIs that obviously look like +// phishing or shell-injection attempts: +// +// - Must parse as an absolute URL. +// - Scheme must be https (or http only when allowInsecureHTTP is +// set AND the host is loopback — production never qualifies). +// - Must not embed userinfo (user:password@host tricks the eye). +// - Must not contain control characters (CR/LF/etc.) that could +// break terminal output or sneak past glance-checks. +// +// This is the bottom-floor check; the embedding CLI is still expected +// to show the URL to the user for visual inspection, and the user is +// expected to read it before opening. +func validateVerificationURI(raw string, allowInsecureHTTP bool) error { + if raw == "" { + return fmt.Errorf("%w: missing", ErrUnsafeVerificationURI) + } + for _, r := range raw { + if r < 0x20 || r == 0x7f { + return fmt.Errorf("%w: contains control character", ErrUnsafeVerificationURI) + } + } + u, err := url.Parse(raw) + if err != nil { + return fmt.Errorf("%w: parse: %w", ErrUnsafeVerificationURI, err) + } + if u.Host == "" { + return fmt.Errorf("%w: missing host", ErrUnsafeVerificationURI) + } + if u.User != nil { + return fmt.Errorf("%w: embedded userinfo not permitted", ErrUnsafeVerificationURI) + } + switch u.Scheme { + case "https": + // fine + case "http": + if !allowInsecureHTTP { + return fmt.Errorf("%w: scheme must be https", ErrUnsafeVerificationURI) + } + host := u.Hostname() + if host != "localhost" && host != "127.0.0.1" && host != "::1" { + return fmt.Errorf("%w: http only permitted on loopback hosts", ErrUnsafeVerificationURI) + } + default: + return fmt.Errorf("%w: scheme %q (must be https)", ErrUnsafeVerificationURI, u.Scheme) + } + return nil +} + // PollDeviceAuth exchanges deviceCode for a TokenSet at the token // endpoint. // @@ -241,7 +314,7 @@ func (c *Client) PollDeviceAuth(ctx context.Context, deviceCode string) (*tokens // timeout must cover the body-read that happens after postForm // returns, so cancel-on-return here would interrupt that read. func (c *Client) postForm(ctx context.Context, path string, body url.Values) (*http.Response, error) { - endpoint, err := resolveURL(c.BaseURL, path) + endpoint, err := resolveURL(c.BaseURL, path, c.AllowInsecureHTTP) if err != nil { return nil, fmt.Errorf("resolve URL %s: %w", path, err) } @@ -269,12 +342,25 @@ func (c *Client) postForm(ctx context.Context, path string, body url.Values) (*h return resp, nil } -func resolveURL(baseURL, path string) (string, error) { +// ErrInsecureBaseURL is returned when device-flow requests are made +// against an http:// BaseURL without AllowInsecureHTTP set. The token +// endpoint returns the user's access token in the response body — over +// plain HTTP that's a credential in the clear. +var ErrInsecureBaseURL = errors.New("refusing to run device-flow over plain HTTP (set Client.AllowInsecureHTTP only for local dev / test)") + +func resolveURL(baseURL, path string, allowInsecureHTTP bool) (string, error) { base, err := url.Parse(baseURL) if err != nil { return "", fmt.Errorf("parse base URL: %w", err) } - if base.Scheme != "http" && base.Scheme != "https" { + switch base.Scheme { + case "https": + // fine + case "http": + if !allowInsecureHTTP { + return "", ErrInsecureBaseURL + } + default: return "", fmt.Errorf("unsupported base URL scheme %q (must be http or https)", base.Scheme) } rel, err := url.Parse(path) diff --git a/auth/deviceflow/deviceflow_test.go b/auth/deviceflow/deviceflow_test.go index a43dfd4709..0ddf1e030f 100644 --- a/auth/deviceflow/deviceflow_test.go +++ b/auth/deviceflow/deviceflow_test.go @@ -2,6 +2,7 @@ package deviceflow import ( "context" + "encoding/json" "errors" "fmt" "io" @@ -42,12 +43,13 @@ func newTestClient(t *testing.T, h http.HandlerFunc) *Client { t.Cleanup(srv.Close) c := &Client{ - HTTP: srv.Client(), - BaseURL: srv.URL, - ClientID: testClientID, - Scope: "cli", - DeviceCodePath: testDeviceCodePath, - TokenPath: testTokenPath, + HTTP: srv.Client(), + BaseURL: srv.URL, + ClientID: testClientID, + Scope: "cli", + DeviceCodePath: testDeviceCodePath, + TokenPath: testTokenPath, + AllowInsecureHTTP: true, // httptest.NewServer is http:// } return c } @@ -102,7 +104,7 @@ func TestStartDeviceAuth_OmitsScopeWhenEmpty(t *testing.T) { t.Errorf("scope should not be sent when Client.Scope is empty") } w.Header().Set("Content-Type", "application/json") - writeBody(t, w, `{"device_code":"d","user_code":"u","verification_uri":"x","expires_in":1,"interval":1}`) + writeBody(t, w, `{"device_code":"d","user_code":"u","verification_uri":"https://example.com/cli","expires_in":1,"interval":1}`) }) c.Scope = "" @@ -116,7 +118,7 @@ func TestStartDeviceAuth_RejectsUnknownFields(t *testing.T) { c := newTestClient(t, func(w http.ResponseWriter, _ *http.Request) { writeBody(t, w, `{ - "device_code":"d","user_code":"u","verification_uri":"x","expires_in":1,"interval":1, + "device_code":"d","user_code":"u","verification_uri":"https://example.com/cli","expires_in":1,"interval":1, "surprise":"field" }`) }) @@ -334,23 +336,25 @@ func TestResolveURL(t *testing.T) { t.Parallel() tests := []struct { - name string - base string - path string - want string - wantErr bool + name string + base string + path string + allowInsecureHTTP bool + want string + wantErr bool }{ - {"https + absolute path", "https://entire.io", "/oauth/device/code", "https://entire.io/oauth/device/code", false}, - {"trailing slash + absolute path", "https://entire.io/", "/oauth/token", "https://entire.io/oauth/token", false}, - {"http allowed", "http://localhost:8180", "/api/auth/token", "http://localhost:8180/api/auth/token", false}, - {"unsupported scheme", "ftp://x", "/y", "", true}, - {"malformed base", "://", "/y", "", true}, + {"https + absolute path", "https://entire.io", "/oauth/device/code", false, "https://entire.io/oauth/device/code", false}, + {"trailing slash + absolute path", "https://entire.io/", "/oauth/token", false, "https://entire.io/oauth/token", false}, + {"http rejected by default", "http://localhost:8180", "/api/auth/token", false, "", true}, + {"http allowed with opt-in", "http://localhost:8180", "/api/auth/token", true, "http://localhost:8180/api/auth/token", false}, + {"unsupported scheme", "ftp://x", "/y", false, "", true}, + {"malformed base", "://", "/y", false, "", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - got, err := resolveURL(tt.base, tt.path) + got, err := resolveURL(tt.base, tt.path, tt.allowInsecureHTTP) if (err != nil) != tt.wantErr { t.Fatalf("resolveURL() err = %v, wantErr %v", err, tt.wantErr) } @@ -440,3 +444,92 @@ func TestRequestTimeout_DefaultAndOverride(t *testing.T) { }) } } + +// TestStartDeviceAuth_RejectsUnsafeVerificationURI pins the +// anti-phishing checks on the verification_uri returned by the AS. +// A compromised or misconfigured server must not be able to redirect +// users to an attacker-controlled login page; the URL we'd otherwise +// echo and open carries the user code. +func TestStartDeviceAuth_RejectsUnsafeVerificationURI(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + uri string + }{ + {"empty", ""}, + {"missing scheme", "example.com/cli"}, + {"non-https scheme", "ftp://example.com/cli"}, + {"plain http on non-loopback", "http://example.com/cli"}, + {"embedded userinfo", "https://entire.io@evil.example.com/cli"}, + {"newline injection", "https://example.com/cli\nGET /steal"}, + {"control character", "https://example.com/\x07cli"}, + {"javascript scheme", "javascript:alert(1)"}, + {"data scheme", "data:text/html,