diff --git a/CLAUDE.md b/CLAUDE.md index 01d7bb37a4..7925de6122 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -340,6 +340,32 @@ if settings.IsSummarizeEnabled() { - `settings/settings.go` - `EntireSettings` struct, `Load()`, and helper methods - `config.go` - Higher-level config functions that use settings (for `cli` package consumers) +### Auth and token resolution + +The CLI uses a shareable auth library at `auth/` (subpackages: `deviceflow`, `sts`, `tokens`, `tokenstore`, `tokenmanager`). The `cmd/entire/cli/auth/` package wraps it with entire-specific config (provider table, keyring service name) and exposes the call surface that command code should use. + +**For every data-API call, get the bearer through one of these two entry points — never read the keyring directly:** + +```go +// Preferred — for callers that need an *api.Client. +client, err := cli.NewAuthenticatedAPIClient(ctx, insecureHTTP) + +// Direct — for callers that hand the bearer to a non-api package +// (e.g. the search service client, dispatch CloudClient). +bearer, err := auth.TokenForResource(ctx, serviceURL) +``` + +Both route through `tokenmanager.Manager.Token`, which: + +1. Returns `auth.ErrNotLoggedIn` when the keyring is empty. +2. Hits the same-host shortcut when `api.AuthBaseURL() == resourceURL`. +3. Hits the JWT-`aud`-includes-resource shortcut when the core token is already valid for `resourceURL` (caller didn't request an explicit `Audience`). +4. Otherwise runs an RFC 8693 token exchange against the auth host's STS endpoint and caches the result per `(core token, resource, audience, requested-token-type, scope)`. + +**Don't use `auth.LookupCurrentToken` for data-API calls.** It returns the raw core token (audience = auth host). On split-host deployments (`ENTIRE_AUTH_BASE_URL` set) the data API will reject it with 401. `LookupCurrentToken` is correct only for auth-host-targeted commands (`auth list/revoke/status`, `logout`) — they intentionally hold the auth-audience bearer. + +**Test injection:** at the cmd layer use `auth.SetManagerForTest(t, mgr)` with a `tokenmanager.Manager` constructed via `tokenmanager.New(Config{Exchange: ...})`. The manager's `Config.Exchange` and `Config.Now` fields are test seams — production callers leave them nil. + ### Logging vs User Output - **Internal/debug logging**: Use `logging.Debug/Info/Warn/Error(ctx, msg, attrs...)` from `cmd/entire/cli/logging/`. Writes to `.entire/logs/`. diff --git a/cmd/entire/cli/activity_cmd.go b/cmd/entire/cli/activity_cmd.go index 35b49018fa..c43af88df0 100644 --- a/cmd/entire/cli/activity_cmd.go +++ b/cmd/entire/cli/activity_cmd.go @@ -56,7 +56,7 @@ func newActivityCmd() *cobra.Command { } func runActivity(ctx context.Context, w, errW io.Writer) error { - client, err := NewAuthenticatedAPIClient(false) + client, err := NewAuthenticatedAPIClient(ctx, false) if err != nil { fmt.Fprintln(errW, "Not logged in. Run 'entire login' to authenticate.") return NewSilentError(err) diff --git a/cmd/entire/cli/api/auth_tokens.go b/cmd/entire/cli/api/auth_tokens.go index ad9d53e5ba..557c9e6ab3 100644 --- a/cmd/entire/cli/api/auth_tokens.go +++ b/cmd/entire/cli/api/auth_tokens.go @@ -2,11 +2,12 @@ package api import ( "context" + "errors" "fmt" "net/url" ) -// Token is a single API token row returned by GET /api/v1/auth/tokens. +// Token is a single API token row returned by the auth-tokens endpoint. // Plaintext token values are never returned by the server — only metadata. type Token struct { ID string `json:"id"` @@ -18,15 +19,32 @@ type Token struct { CreatedAt string `json:"created_at"` } -// TokensResponse is the envelope returned by GET /api/v1/auth/tokens. +// TokensResponse is the envelope returned by the list endpoint. type TokensResponse struct { Tokens []Token `json:"tokens"` } +// errAuthTokensPathUnset surfaces when an auth-tokens method is called +// on a Client that wasn't given a base path. Construct via +// NewClientWithBaseURL(...).WithAuthTokensPath(...) — the active path +// lives in cmd/entire/cli/auth.CurrentProvider().AuthTokensPath, the +// single source of truth for provider-version routing. +var errAuthTokensPathUnset = errors.New("api: auth-tokens path is unset (call (*Client).WithAuthTokensPath before list/revoke)") + +func (c *Client) authTokensBasePath() (string, error) { + if c.authTokensPath == "" { + return "", errAuthTokensPathUnset + } + return c.authTokensPath, nil +} + // ListTokens returns the authenticated user's non-expired API tokens. -// Backed by GET /api/v1/auth/tokens. func (c *Client) ListTokens(ctx context.Context) ([]Token, error) { - resp, err := c.Get(ctx, "/api/v1/auth/tokens") + base, err := c.authTokensBasePath() + if err != nil { + return nil, fmt.Errorf("list tokens: %w", err) + } + resp, err := c.Get(ctx, base) if err != nil { return nil, fmt.Errorf("list tokens: %w", err) } @@ -44,9 +62,12 @@ func (c *Client) ListTokens(ctx context.Context) ([]Token, error) { } // RevokeCurrentToken revokes the bearer token used to authenticate this client. -// Backed by DELETE /api/v1/auth/tokens/current. func (c *Client) RevokeCurrentToken(ctx context.Context) error { - resp, err := c.Delete(ctx, "/api/v1/auth/tokens/current") + base, err := c.authTokensBasePath() + if err != nil { + return fmt.Errorf("revoke current token: %w", err) + } + resp, err := c.Delete(ctx, base+"/current") if err != nil { return fmt.Errorf("revoke current token: %w", err) } @@ -59,9 +80,12 @@ func (c *Client) RevokeCurrentToken(ctx context.Context) error { } // RevokeToken revokes the API token with the given id. -// Backed by DELETE /api/v1/auth/tokens/{id}. func (c *Client) RevokeToken(ctx context.Context, id string) error { - resp, err := c.Delete(ctx, "/api/v1/auth/tokens/"+url.PathEscape(id)) + base, err := c.authTokensBasePath() + if err != nil { + return fmt.Errorf("revoke token %s: %w", id, err) + } + resp, err := c.Delete(ctx, base+"/"+url.PathEscape(id)) if err != nil { return fmt.Errorf("revoke token %s: %w", id, err) } diff --git a/cmd/entire/cli/api/auth_tokens_test.go b/cmd/entire/cli/api/auth_tokens_test.go index 8fc354c4ab..07ccd694d2 100644 --- a/cmd/entire/cli/api/auth_tokens_test.go +++ b/cmd/entire/cli/api/auth_tokens_test.go @@ -9,6 +9,21 @@ import ( "testing" ) +const ( + testV1AuthTokensPath = "/api/v1/auth/tokens" + testV2AuthTokensPath = "/api/auth/tokens" +) + +// newAuthTokensTestClient builds a Client pointed at server.URL with +// the given auth-tokens base path. Used by all auth-tokens tests so +// the wiring matches production: callers chain WithAuthTokensPath at +// construction time. +func newAuthTokensTestClient(serverURL, authTokensPath string) *Client { + c := NewClient("tok").WithAuthTokensPath(authTokensPath) + c.baseURL = serverURL + return c +} + func TestClient_RevokeCurrentToken_SendsDeleteWithBearer(t *testing.T) { t.Parallel() @@ -23,8 +38,7 @@ func TestClient_RevokeCurrentToken_SendsDeleteWithBearer(t *testing.T) { })) defer server.Close() - c := NewClient("tok") - c.baseURL = server.URL + c := newAuthTokensTestClient(server.URL, testV1AuthTokensPath) if err := c.RevokeCurrentToken(context.Background()); err != nil { t.Fatalf("RevokeCurrentToken() error = %v", err) @@ -51,8 +65,7 @@ func TestClient_RevokeCurrentToken_ReturnsHTTPErrorOn401(t *testing.T) { })) defer server.Close() - c := NewClient("tok") - c.baseURL = server.URL + c := newAuthTokensTestClient(server.URL, testV1AuthTokensPath) err := c.RevokeCurrentToken(context.Background()) if err == nil { @@ -87,8 +100,7 @@ func TestClient_ListTokens_DecodesResponse(t *testing.T) { })) defer server.Close() - c := NewClient("tok") - c.baseURL = server.URL + c := newAuthTokensTestClient(server.URL, testV1AuthTokensPath) tokens, err := c.ListTokens(context.Background()) if err != nil { @@ -129,8 +141,7 @@ func TestClient_ListTokens_ReturnsHTTPErrorOn401(t *testing.T) { })) defer server.Close() - c := NewClient("tok") - c.baseURL = server.URL + c := newAuthTokensTestClient(server.URL, testV1AuthTokensPath) _, err := c.ListTokens(context.Background()) if err == nil { @@ -155,8 +166,7 @@ func TestClient_RevokeToken_SendsDeleteWithEscapedID(t *testing.T) { })) defer server.Close() - c := NewClient("tok") - c.baseURL = server.URL + c := newAuthTokensTestClient(server.URL, testV1AuthTokensPath) // Use an id that needs URL escaping to verify we don't blindly concat. if err := c.RevokeToken(context.Background(), "abc/def 1"); err != nil { @@ -184,8 +194,7 @@ func TestClient_RevokeToken_ReturnsErrorBody(t *testing.T) { })) defer server.Close() - c := NewClient("tok") - c.baseURL = server.URL + c := newAuthTokensTestClient(server.URL, testV1AuthTokensPath) err := c.RevokeToken(context.Background(), "missing") if err == nil { @@ -198,3 +207,47 @@ func TestClient_RevokeToken_ReturnsErrorBody(t *testing.T) { t.Errorf("IsHTTPErrorStatus(err, 404) = false; err = %v", err) } } + +// TestClient_AuthTokens_RoutesV2Path verifies that whatever path the +// caller supplies via WithAuthTokensPath is what hits the wire. The +// provider table itself (which path corresponds to which version) is +// exercised by cmd/entire/cli/auth's resolveProvider tests. +func TestClient_AuthTokens_RoutesV2Path(t *testing.T) { + t.Parallel() + + var gotPath string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"tokens":[]}`)) //nolint:errcheck // test handler + })) + defer server.Close() + + c := newAuthTokensTestClient(server.URL, testV2AuthTokensPath) + + if _, err := c.ListTokens(context.Background()); err != nil { + t.Fatalf("ListTokens: %v", err) + } + if gotPath != "/api/auth/tokens" { + t.Fatalf("path = %q, want /api/auth/tokens (v2)", gotPath) + } +} + +// TestClient_AuthTokens_UnsetPathErrors guards against silently +// shipping a request to "" — we want a clear error pointing at the +// missing WithAuthTokensPath wiring. +func TestClient_AuthTokens_UnsetPathErrors(t *testing.T) { + t.Parallel() + + c := NewClient("tok") // no WithAuthTokensPath + + if _, err := c.ListTokens(context.Background()); !errors.Is(err, errAuthTokensPathUnset) { + t.Errorf("ListTokens err = %v, want errAuthTokensPathUnset", err) + } + if err := c.RevokeCurrentToken(context.Background()); !errors.Is(err, errAuthTokensPathUnset) { + t.Errorf("RevokeCurrentToken err = %v, want errAuthTokensPathUnset", err) + } + if err := c.RevokeToken(context.Background(), "any"); !errors.Is(err, errAuthTokensPathUnset) { + t.Errorf("RevokeToken err = %v, want errAuthTokensPathUnset", err) + } +} diff --git a/cmd/entire/cli/api/base_url.go b/cmd/entire/cli/api/base_url.go index 9e684c04a1..0f2109093e 100644 --- a/cmd/entire/cli/api/base_url.go +++ b/cmd/entire/cli/api/base_url.go @@ -17,6 +17,13 @@ const ( // BaseURLEnvVar overrides the Entire API origin for local development. BaseURLEnvVar = "ENTIRE_API_BASE_URL" + + // AuthBaseURLEnvVar overrides only the auth/login origin (device flow, + // auth-tokens management, keyring key). Falls back to BaseURLEnvVar when + // unset, which is the right behavior for single-host deployments. Split + // hosts (e.g. auth on us.console.partial.to, data on partial.to) set + // both. + AuthBaseURLEnvVar = "ENTIRE_AUTH_BASE_URL" ) // BaseURL returns the effective Entire API base URL. @@ -29,6 +36,18 @@ func BaseURL() string { return DefaultBaseURL } +// AuthBaseURL returns the origin used for the device-flow login, auth-token +// management endpoints, and the keyring key under which the bearer token is +// stored. ENTIRE_AUTH_BASE_URL takes precedence; otherwise it falls back to +// BaseURL() so single-host deployments keep working unchanged. +func AuthBaseURL() string { + if raw := strings.TrimSpace(os.Getenv(AuthBaseURLEnvVar)); raw != "" { + return normalizeBaseURL(raw) + } + + return BaseURL() +} + // ResolveURL joins an API-relative path against the effective base URL. func ResolveURL(path string) (string, error) { return ResolveURLFromBase(BaseURL(), path) diff --git a/cmd/entire/cli/api/base_url_test.go b/cmd/entire/cli/api/base_url_test.go index 474de2026a..b094a97a11 100644 --- a/cmd/entire/cli/api/base_url_test.go +++ b/cmd/entire/cli/api/base_url_test.go @@ -53,6 +53,28 @@ func TestRequireSecureURL_RejectsHTTP(t *testing.T) { } } +func TestAuthBaseURL_FallsBackToBaseURL(t *testing.T) { + t.Setenv(BaseURLEnvVar, "https://partial.to") + t.Setenv(AuthBaseURLEnvVar, "") + + if got := AuthBaseURL(); got != "https://partial.to" { + t.Fatalf("AuthBaseURL() = %q, want fallback to BaseURL %q", got, "https://partial.to") + } +} + +func TestAuthBaseURL_OverridesBaseURL(t *testing.T) { + t.Setenv(BaseURLEnvVar, "https://partial.to") + t.Setenv(AuthBaseURLEnvVar, " https://us.console.partial.to/ ") + + if got := AuthBaseURL(); got != "https://us.console.partial.to" { + t.Fatalf("AuthBaseURL() = %q, want %q", got, "https://us.console.partial.to") + } + + if got := BaseURL(); got != "https://partial.to" { + t.Fatalf("BaseURL() = %q, want unchanged %q", got, "https://partial.to") + } +} + func TestResolveURL(t *testing.T) { t.Setenv(BaseURLEnvVar, "http://localhost:8787/") diff --git a/cmd/entire/cli/api/client.go b/cmd/entire/cli/api/client.go index 9be47f45ed..44227139b6 100644 --- a/cmd/entire/cli/api/client.go +++ b/cmd/entire/cli/api/client.go @@ -21,10 +21,39 @@ const ( type Client struct { httpClient *http.Client baseURL string -} -// NewClient creates a new authenticated API client with an explicit bearer token. + // authTokensPath is the base path for the auth-tokens management + // endpoints (list / revoke). Set via WithAuthTokensPath when the + // client targets the auth host. Empty for data-API-only clients; + // auth-tokens methods error out if called against an empty path. + authTokensPath string +} + +// WithAuthTokensPath sets the base path used by ListTokens, +// RevokeCurrentToken, and RevokeToken. The path is supplied by the +// auth shim from auth.CurrentProvider().AuthTokensPath, which is the +// single source of truth for provider-version routing — the api +// package no longer reads ENTIRE_AUTH_PROVIDER_VERSION itself. +// +// Returns the receiver for chaining at construction: +// +// c := api.NewClientWithBaseURL(token, base).WithAuthTokensPath(p) +func (c *Client) WithAuthTokensPath(path string) *Client { + c.authTokensPath = path + return c +} + +// NewClient creates a new authenticated API client with an explicit bearer +// token, targeting the data API base URL (BaseURL()). func NewClient(token string) *Client { + return NewClientWithBaseURL(token, BaseURL()) +} + +// NewClientWithBaseURL creates a new authenticated API client targeting an +// explicit base URL. Use this for endpoints that live on the auth host (e.g. +// auth-token management) when ENTIRE_AUTH_BASE_URL splits the auth origin +// from the data API origin. +func NewClientWithBaseURL(token, baseURL string) *Client { return &Client{ httpClient: &http.Client{ Transport: &bearerTransport{ @@ -32,7 +61,7 @@ func NewClient(token string) *Client { base: http.DefaultTransport, }, }, - baseURL: BaseURL(), + baseURL: baseURL, } } @@ -42,7 +71,15 @@ type bearerTransport struct { base http.RoundTripper } +// errEmptyBearerToken surfaces at first request rather than at construction +// because NewClient* don't return errors. An empty bearer otherwise becomes +// "Authorization: Bearer " on the wire and produces a confusing 401. +var errEmptyBearerToken = errors.New("api: refusing to send request with empty bearer token (construct via NewAuthenticatedAPIClient)") + func (t *bearerTransport) RoundTrip(req *http.Request) (*http.Response, error) { + if t.token == "" { + return nil, errEmptyBearerToken + } // Clone the request to avoid mutating the caller's request. r := req.Clone(req.Context()) r.Header.Set("Authorization", "Bearer "+t.token) diff --git a/cmd/entire/cli/api/client_test.go b/cmd/entire/cli/api/client_test.go index f5977b7a72..de105409f1 100644 --- a/cmd/entire/cli/api/client_test.go +++ b/cmd/entire/cli/api/client_test.go @@ -3,6 +3,7 @@ package api import ( "context" "encoding/json" + "errors" "io" "net/http" "net/http/httptest" @@ -300,3 +301,23 @@ func TestDecodeJSONResponse(t *testing.T) { t.Errorf("Status = %q, want %q", result.Status, "ok") } } + +// TestBearerTransport_RejectsEmptyToken locks in the early-failure +// behaviour for an accidentally-empty bearer. Without this guard the +// CLI would put "Authorization: Bearer " on the wire and the server +// would respond with a confusing 401. +func TestBearerTransport_RejectsEmptyToken(t *testing.T) { + t.Parallel() + + c := NewClientWithBaseURL("", "https://example.test") + resp, err := c.Get(context.Background(), "/probe") + if resp != nil { + _ = resp.Body.Close() + } + if err == nil { + t.Fatal("Get with empty token must error") + } + if !errors.Is(err, errEmptyBearerToken) { + t.Fatalf("err = %v, want errEmptyBearerToken sentinel", err) + } +} diff --git a/cmd/entire/cli/api_client.go b/cmd/entire/cli/api_client.go index 65f36968f5..453c050e3b 100644 --- a/cmd/entire/cli/api_client.go +++ b/cmd/entire/cli/api_client.go @@ -1,6 +1,7 @@ package cli import ( + "context" "errors" "fmt" @@ -8,22 +9,37 @@ import ( "github.com/entireio/cli/cmd/entire/cli/auth" ) -// NewAuthenticatedAPIClient creates an API client using the bearer token -// from the CLI login flow. Returns an error if the user is not logged in. -// Pass insecureHTTP=true to allow plain HTTP base URLs (for local development). -func NewAuthenticatedAPIClient(insecureHTTP bool) (*api.Client, error) { - token, err := auth.LookupCurrentToken() - if err != nil { - return nil, fmt.Errorf("lookup auth token: %w", err) - } - if token == "" { - return nil, errors.New("not logged in (run 'entire login' first)") - } - +// NewAuthenticatedAPIClient creates an API client targeting api.BaseURL() +// (the data API origin) carrying a token valid for that audience. +// +// Resolution: looks up the core token from the keyring, then either uses +// it directly (single-host setup, or when the core token's `aud` already +// covers api.BaseURL()) or performs an RFC 8693 token exchange against +// the auth host to obtain a token scoped to the data API. Exchanged +// tokens are cached in-memory keyed off the wire-affecting fields of +// the request — see tokenmanager.cacheKey for the precise key shape. +// +// Pass insecureHTTP=true to allow plain HTTP base URLs for local +// development. Both api.BaseURL() and api.AuthBaseURL() are validated: +// the bearer travels to the data host on resource requests, and the +// core token travels to the auth host during the exchange step. +func NewAuthenticatedAPIClient(ctx context.Context, insecureHTTP bool) (*api.Client, error) { if !insecureHTTP { if err := api.RequireSecureURL(api.BaseURL()); err != nil { return nil, fmt.Errorf("base URL check: %w", err) } + if err := api.RequireSecureURL(api.AuthBaseURL()); err != nil { + return nil, fmt.Errorf("auth base URL check: %w", err) + } } + + token, err := auth.TokenForResource(ctx, api.BaseURL()) + if err != nil { + if errors.Is(err, auth.ErrNotLoggedIn) { + return nil, errors.New("not logged in (run 'entire login' first)") + } + return nil, fmt.Errorf("resolve API token: %w", err) + } + return api.NewClient(token), nil } diff --git a/cmd/entire/cli/auth.go b/cmd/entire/cli/auth.go index 0a441aef07..d5a08469eb 100644 --- a/cmd/entire/cli/auth.go +++ b/cmd/entire/cli/auth.go @@ -35,6 +35,11 @@ const ( // command that sends a bearer token over the network (login, logout, // auth status/list/revoke) must call this so credentials don't leak over // plaintext HTTP without explicit opt-in. +// +// Both the auth and data API origins are checked: the bearer travels to the +// auth host for login + auth-token management, and to the data host for +// search/activity/dispatch/etc. When ENTIRE_AUTH_BASE_URL is unset they +// resolve to the same URL and the second check is a no-op. func requireSecureBaseURL(insecureHTTPAuth bool) error { if insecureHTTPAuth { return nil @@ -42,6 +47,9 @@ func requireSecureBaseURL(insecureHTTPAuth bool) error { if err := api.RequireSecureURL(api.BaseURL()); err != nil { return fmt.Errorf("base URL check: %w", err) } + if err := api.RequireSecureURL(api.AuthBaseURL()); err != nil { + return fmt.Errorf("auth base URL check: %w", err) + } return nil } @@ -84,7 +92,7 @@ func newAuthStatusCmd() *cobra.Command { return err } return runAuthStatus(cmd.Context(), cmd.OutOrStdout(), - auth.NewStore(), defaultListTokens, api.BaseURL()) + auth.NewStore(), defaultListTokens, api.AuthBaseURL()) }, } addInsecureHTTPAuthFlag(cmd, &insecureHTTPAuth) @@ -92,7 +100,9 @@ func newAuthStatusCmd() *cobra.Command { } func defaultListTokens(ctx context.Context, token string) ([]api.Token, error) { - return api.NewClient(token).ListTokens(ctx) //nolint:wrapcheck // ListTokens already wraps with action context + client := api.NewClientWithBaseURL(token, api.AuthBaseURL()). + WithAuthTokensPath(auth.CurrentProvider().AuthTokensPath) + return client.ListTokens(ctx) //nolint:wrapcheck // ListTokens already wraps with action context } func runAuthStatus(ctx context.Context, w io.Writer, store tokenStore, list authTokenLister, baseURL string) error { @@ -135,7 +145,7 @@ func newAuthListCmd() *cobra.Command { return err } return runAuthList(cmd.Context(), cmd.OutOrStdout(), - auth.NewStore(), defaultListTokens, api.BaseURL(), jsonOut) + auth.NewStore(), defaultListTokens, api.AuthBaseURL(), jsonOut) }, } cmd.Flags().BoolVar(&jsonOut, "json", false, "Print tokens as JSON") @@ -409,7 +419,7 @@ func newAuthRevokeCmd() *cobra.Command { } return runAuthRevoke(cmd.Context(), cmd.OutOrStdout(), cmd.ErrOrStderr(), auth.NewStore(), defaultListTokens, defaultRevokeTokenByID, defaultRevokeCurrentToken, - api.BaseURL(), id, revokeCurrent) + api.AuthBaseURL(), id, revokeCurrent) }, } cmd.Flags().BoolVar(&revokeCurrent, "current", false, "Revoke the token used by this CLI and remove the local copy") @@ -418,7 +428,9 @@ func newAuthRevokeCmd() *cobra.Command { } func defaultRevokeTokenByID(ctx context.Context, callerToken, id string) error { - return api.NewClient(callerToken).RevokeToken(ctx, id) //nolint:wrapcheck // RevokeToken already wraps with action context + client := api.NewClientWithBaseURL(callerToken, api.AuthBaseURL()). + WithAuthTokensPath(auth.CurrentProvider().AuthTokensPath) + return client.RevokeToken(ctx, id) //nolint:wrapcheck // RevokeToken already wraps with action context } func runAuthRevoke( diff --git a/cmd/entire/cli/auth/client.go b/cmd/entire/cli/auth/client.go index 5ffd215003..8377a4b8ef 100644 --- a/cmd/entire/cli/auth/client.go +++ b/cmd/entire/cli/auth/client.go @@ -1,187 +1,169 @@ package auth import ( - "bytes" "context" - "encoding/json" - "fmt" - "io" + "errors" "net/http" - "net/url" "strings" + "time" + "github.com/entireio/auth-go/deviceflow" + "github.com/entireio/auth-go/tokens" "github.com/entireio/cli/cmd/entire/cli/api" ) -const ( - maxResponseBytes = 1 << 20 - clientID = "entire-cli" -) +// nowFunc is the package's clock. Override in tests. +var nowFunc = time.Now -type Client struct { - httpClient *http.Client - baseURL string -} - -type DeviceAuthStart struct { - DeviceCode string `json:"device_code"` - UserCode string `json:"user_code"` - VerificationURI string `json:"verification_uri"` - VerificationURIComplete string `json:"verification_uri_complete"` - ExpiresIn int `json:"expires_in"` - Interval int `json:"interval"` -} +// DeviceAuthStart preserves the historical type name; the shape now +// matches deviceflow.DeviceCode field-for-field. +type DeviceAuthStart = deviceflow.DeviceCode +// DeviceAuthPoll is the historical token-poll response shape. The shim +// flattens deviceflow's typed errors back into the Error field so +// existing login.go logic that switches on result.Error keeps working. +// +// ErrorDescription carries the optional `error_description` from the +// server's RFC 8628 §3.5 error response, when present. Used to give +// callers a more actionable message than the bare error code. type DeviceAuthPoll struct { - AccessToken string `json:"access_token,omitempty"` - TokenType string `json:"token_type,omitempty"` - ExpiresIn int `json:"expires_in,omitempty"` - Scope string `json:"scope,omitempty"` - Error string `json:"error,omitempty"` + AccessToken string + TokenType string + ExpiresIn int + Scope string + Error string + ErrorDescription string } -type errorResponse struct { - Error string `json:"error"` +// Client wraps a deviceflow.Client preconfigured for whichever provider +// version is selected via ENTIRE_AUTH_PROVIDER_VERSION (defaulting to +// v1). +type Client struct { + inner *deviceflow.Client } +// NewClient constructs a Client targeting the active provider version. +// httpClient is used directly when non-nil; otherwise http.DefaultClient. +// +// HTTPS is required unless the configured auth host is loopback http:// +// (localhost, 127.0.0.1, ::1) — see isLoopbackHTTP. Production +// deployments never qualify; local dev does. There is no other opt-in +// for plain HTTP on the device-flow surface. func NewClient(httpClient *http.Client) *Client { - if httpClient == nil { - httpClient = &http.Client{} - } - - return &Client{ - httpClient: httpClient, - baseURL: api.BaseURL(), - } + p := CurrentProvider() + issuer := api.AuthBaseURL() + return &Client{inner: &deviceflow.Client{ + HTTP: httpClient, + BaseURL: issuer, + ClientID: p.ClientID, + Scope: "cli", + UserAgent: p.ClientID, + DeviceCodePath: p.DeviceCodePath, + TokenPath: p.TokenPath, + AllowInsecureHTTP: isLoopbackHTTP(issuer), + }} } -func (c *Client) BaseURL() string { - return c.baseURL -} +// BaseURL returns the issuer base URL this client talks to. +func (c *Client) BaseURL() string { return c.inner.BaseURL } +// StartDeviceAuth requests a fresh device code. func (c *Client) StartDeviceAuth(ctx context.Context) (*DeviceAuthStart, error) { - body := url.Values{} - body.Set("client_id", clientID) - body.Set("scope", "cli") - - resp, err := c.postForm(ctx, "/oauth/device/code", body) - if err != nil { - return nil, fmt.Errorf("start device auth: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, readAPIError(resp, "start device auth") - } - - var result DeviceAuthStart - if err := decodeJSONStrict(resp.Body, &result); err != nil { - return nil, fmt.Errorf("decode device auth start response: %w", err) - } - - return &result, nil + return c.inner.StartDeviceAuth(ctx) //nolint:wrapcheck // shim preserves the lib's wrapped errors verbatim } +// PollDeviceAuth polls the token endpoint. On any OAuth-protocol error +// (recognised RFC 8628 §3.5 sentinel or unknown but spec-shaped code +// like invalid_request / invalid_client / server_error), the wire-side +// code is returned in DeviceAuthPoll.Error so the existing polling +// loop in login.go can branch on it — known codes hit the dedicated +// switch arms, unknown codes fall through to the default arm and fail +// fast. Non-protocol errors (network, decode) are returned as a real +// error and treated as transient by the polling loop. func (c *Client) PollDeviceAuth(ctx context.Context, deviceCode string) (*DeviceAuthPoll, error) { - body := url.Values{} - body.Set("grant_type", "urn:ietf:params:oauth:grant-type:device_code") - body.Set("client_id", clientID) - body.Set("device_code", deviceCode) - - resp, err := c.postForm(ctx, "/oauth/token", body) + t, err := c.inner.PollDeviceAuth(ctx, deviceCode) if err != nil { - return nil, fmt.Errorf("poll device auth: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - apiErr, err := readAPIErrorResponse(resp) - if err != nil { - return nil, fmt.Errorf("poll device auth: %w", err) + if code, description, ok := oauthErrorParts(err); ok { + return &DeviceAuthPoll{ + Error: code, + ErrorDescription: description, + }, nil } - return &DeviceAuthPoll{Error: apiErr.Error}, nil + return nil, err //nolint:wrapcheck // shim returns deviceflow errors verbatim so callers can errors.Is on sentinels } - var result DeviceAuthPoll - if err := decodeJSON(resp.Body, &result); err != nil { - return nil, fmt.Errorf("decode device auth poll response: %w", err) - } - - return &result, nil + return &DeviceAuthPoll{ + AccessToken: t.AccessToken, + TokenType: t.TokenType, + ExpiresIn: secondsUntil(t), + Scope: t.Scope, + }, nil } -// postForm sends a POST request with form-encoded body to an API-relative path. -func (c *Client) postForm(ctx context.Context, path string, body url.Values) (*http.Response, error) { - endpoint, err := api.ResolveURLFromBase(c.baseURL, path) - if err != nil { - return nil, fmt.Errorf("resolve URL %s: %w", path, err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(body.Encode())) - if err != nil { - return nil, fmt.Errorf("create request: %w", err) - } - - req.Header.Set("Accept", "application/json") - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("User-Agent", clientID) - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("request %s: %w", path, err) - } - - return resp, nil -} - -func readAPIErrorResponse(resp *http.Response) (*errorResponse, error) { - body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes)) - if err != nil { - return nil, fmt.Errorf("status %d", resp.StatusCode) - } - - var apiErr errorResponse - if err := json.Unmarshal(body, &apiErr); err == nil && strings.TrimSpace(apiErr.Error) != "" { - return &apiErr, nil - } - - text := strings.TrimSpace(string(body)) - if text != "" { - return nil, fmt.Errorf("status %d: %s", resp.StatusCode, text) - } - - return nil, fmt.Errorf("status %d", resp.StatusCode) -} - -func readAPIError(resp *http.Response, action string) error { - apiErr, err := readAPIErrorResponse(resp) - if err == nil { - return fmt.Errorf("%s: %s", action, apiErr.Error) +// oauthErrorParts inspects err for either a recognised RFC 8628 §3.5 +// sentinel or the generic "oauth error: " wrapper deviceflow uses +// for unrecognised but spec-shaped codes (RFC 6749 §5.2: invalid_request, +// invalid_client, server_error, …). +// +// On a match, returns the wire-side code, any error_description the +// server included, and ok=true. Otherwise returns "", "", false — the +// caller should treat the error as a transport/decode failure. +// +// Surfacing unknown codes as ok=true is what keeps login.go's polling +// loop fast-failing on terminal OAuth rejections instead of treating +// them as transient and retrying ~5 times. +func oauthErrorParts(err error) (code, description string, ok bool) { + switch { + case errors.Is(err, deviceflow.ErrAuthorizationPending): + code = "authorization_pending" + case errors.Is(err, deviceflow.ErrSlowDown): + code = "slow_down" + case errors.Is(err, deviceflow.ErrAccessDenied): + code = "access_denied" + case errors.Is(err, deviceflow.ErrExpiredToken): + code = "expired_token" + case errors.Is(err, deviceflow.ErrInvalidGrant): + code = "invalid_grant" + default: + // Unknown but legitimate OAuth codes come back from + // deviceflow.errCodeToSentinel as fmt.Errorf("oauth error: %s", + // code), optionally wrapped a second time with ": " + // when the server supplied error_description. + const oauthPrefix = "oauth error: " + rest, hadPrefix := strings.CutPrefix(err.Error(), oauthPrefix) + if !hadPrefix { + return "", "", false + } + if c, d, hasDesc := strings.Cut(rest, ": "); hasDesc { + return c, d, true + } + return rest, "", true } - return fmt.Errorf("%s: %w", action, err) + description = descriptionFromSentinelError(err, code) + return code, description, true } -func decodeJSON(r io.Reader, dest any) error { - return decodeJSONWithOptions(r, dest, false) +// descriptionFromSentinelError pulls the description suffix out of a +// wrapped sentinel error. The deviceflow lib uses +// fmt.Errorf("%w: %s", sentinel, description) when the server included +// an error_description, so the formatted error reads +// ": ". Stripping the ": " prefix yields the +// description; absent prefix means the server didn't supply one. +func descriptionFromSentinelError(err error, code string) string { + msg := err.Error() + prefix := code + ": " + if rest, ok := strings.CutPrefix(msg, prefix); ok { + return rest + } + return "" } -func decodeJSONStrict(r io.Reader, dest any) error { - return decodeJSONWithOptions(r, dest, true) -} - -func decodeJSONWithOptions(r io.Reader, dest any, strict bool) error { - body, err := io.ReadAll(io.LimitReader(r, maxResponseBytes)) - if err != nil { - return fmt.Errorf("read JSON response: %w", err) - } - - dec := json.NewDecoder(bytes.NewReader(body)) - if strict { - dec.DisallowUnknownFields() - } - if err := dec.Decode(dest); err != nil { - return fmt.Errorf("decode JSON response: %w", err) +// secondsUntil computes seconds-until-expiry for a TokenSet with an +// absolute ExpiresAt. Returns 0 when no expiry is set, mirroring the +// historical shape of DeviceAuthPoll.ExpiresIn. +func secondsUntil(t *tokens.TokenSet) int { + if t.ExpiresAt.IsZero() { + return 0 } - - return nil + return int(t.ExpiresAt.Unix() - nowFunc().Unix()) } diff --git a/cmd/entire/cli/auth/client_test.go b/cmd/entire/cli/auth/client_test.go index 6a6a0bf229..a7ce52ba76 100644 --- a/cmd/entire/cli/auth/client_test.go +++ b/cmd/entire/cli/auth/client_test.go @@ -1,42 +1,122 @@ package auth import ( - "strings" + "errors" + "fmt" "testing" + + "github.com/entireio/auth-go/deviceflow" ) -func TestDecodeJSON_AllowsUnknownFields(t *testing.T) { +// TestOAuthErrorParts_KnownSentinels covers the five RFC 8628 §3.5 +// codes the polling loop in login.go switches on by name. Without a +// match here, the loop's switch never fires for these terminal cases. +func TestOAuthErrorParts_KnownSentinels(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + err error + want string + }{ + {"authorization_pending", deviceflow.ErrAuthorizationPending, "authorization_pending"}, + {"slow_down", deviceflow.ErrSlowDown, "slow_down"}, + {"access_denied", deviceflow.ErrAccessDenied, "access_denied"}, + {"expired_token", deviceflow.ErrExpiredToken, "expired_token"}, + {"invalid_grant", deviceflow.ErrInvalidGrant, "invalid_grant"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + code, desc, ok := oauthErrorParts(tc.err) + if !ok { + t.Fatalf("oauthErrorParts returned ok=false for %v", tc.err) + } + if code != tc.want { + t.Errorf("code = %q, want %q", code, tc.want) + } + if desc != "" { + t.Errorf("desc = %q, want empty (no description supplied)", desc) + } + }) + } +} + +// TestOAuthErrorParts_KnownSentinelWithDescription pins the description +// extraction for the wrapped form: fmt.Errorf("%w: %s", sentinel, desc). +func TestOAuthErrorParts_KnownSentinelWithDescription(t *testing.T) { t.Parallel() - var result DeviceAuthPoll - err := decodeJSON(strings.NewReader(`{ - "access_token": "token", - "token_type": "Bearer", - "refresh_token": "ignored" - }`), &result) - if err != nil { - t.Fatalf("decodeJSON() error = %v", err) + wrapped := fmt.Errorf("%w: device approval window closed", deviceflow.ErrExpiredToken) + code, desc, ok := oauthErrorParts(wrapped) + if !ok { + t.Fatalf("oauthErrorParts ok=false for wrapped sentinel") + } + if code != "expired_token" { + t.Errorf("code = %q, want expired_token", code) } + if desc != "device approval window closed" { + t.Errorf("desc = %q, want %q", desc, "device approval window closed") + } +} - if result.AccessToken != "token" { - t.Fatalf("AccessToken = %q, want %q", result.AccessToken, "token") +// TestOAuthErrorParts_UnknownCodePassesThrough is the regression for +// the bug surfaced in PR review: unknown OAuth codes (e.g. +// invalid_request, invalid_client, server_error) coming back from +// deviceflow as fmt.Errorf("oauth error: %s", code) used to fall +// through to the transient-retry path in login.go's waitForApproval, +// burning ~25-150s on permanent server errors. They now land in +// DeviceAuthPoll.Error so the polling loop's default switch arm fails +// fast with "device authorization failed: ". +func TestOAuthErrorParts_UnknownCodePassesThrough(t *testing.T) { + t.Parallel() + + cases := []string{"invalid_request", "invalid_client", "server_error", "unsupported_grant_type"} + for _, want := range cases { + t.Run(want, func(t *testing.T) { + t.Parallel() + err := fmt.Errorf("oauth error: %s", want) + code, desc, ok := oauthErrorParts(err) + if !ok { + t.Fatalf("oauthErrorParts ok=false for unknown OAuth code %q", want) + } + if code != want { + t.Errorf("code = %q, want %q", code, want) + } + if desc != "" { + t.Errorf("desc = %q, want empty (no description supplied)", desc) + } + }) + } +} + +// TestOAuthErrorParts_UnknownCodeWithDescription matches the wire +// shape deviceflow produces when the server returns both an unknown +// error code and a non-empty error_description. +func TestOAuthErrorParts_UnknownCodeWithDescription(t *testing.T) { + t.Parallel() + + err := errors.New("oauth error: invalid_client: client authentication failed") + code, desc, ok := oauthErrorParts(err) + if !ok { + t.Fatalf("oauthErrorParts ok=false") + } + if code != "invalid_client" { + t.Errorf("code = %q, want invalid_client", code) + } + if desc != "client authentication failed" { + t.Errorf("desc = %q, want %q", desc, "client authentication failed") } } -func TestDecodeJSONStrict_RejectsUnknownFields(t *testing.T) { +// TestOAuthErrorParts_NonOAuthError confirms transport/decode errors +// (no "oauth error:" prefix and not an RFC 8628 sentinel) are reported +// as ok=false so the polling loop treats them as transient. +func TestOAuthErrorParts_NonOAuthError(t *testing.T) { t.Parallel() - var result DeviceAuthStart - err := decodeJSONStrict(strings.NewReader(`{ - "device_code": "device", - "user_code": "ABCD-EFGH", - "verification_uri": "https://example.com/verify", - "verification_uri_complete": "https://example.com/verify?code=ABCD-EFGH", - "expires_in": 600, - "interval": 5, - "extra": true - }`), &result) - if err == nil { - t.Fatal("decodeJSONStrict() error = nil, want unknown-field error") + err := errors.New("connection reset by peer") + if _, _, ok := oauthErrorParts(err); ok { + t.Fatal("oauthErrorParts ok=true for non-OAuth error; would mask transient transport failure") } } diff --git a/cmd/entire/cli/auth/exchange.go b/cmd/entire/cli/auth/exchange.go new file mode 100644 index 0000000000..e5b149abb9 --- /dev/null +++ b/cmd/entire/cli/auth/exchange.go @@ -0,0 +1,113 @@ +package auth + +import ( + "context" + "fmt" + "net/url" + "sync" + + "github.com/entireio/auth-go/tokenmanager" + "github.com/entireio/cli/cmd/entire/cli/api" +) + +// TokenRequest is the entire-CLI alias of tokenmanager.TokenRequest so +// callers don't have to import the underlying package for the common +// case. The two types are interchangeable. +type TokenRequest = tokenmanager.TokenRequest + +// ErrNotLoggedIn re-exports tokenmanager.ErrNotLoggedIn so callers in +// the cli package can errors.Is against it without an extra import. +var ErrNotLoggedIn = tokenmanager.ErrNotLoggedIn + +var ( + managerOnce sync.Once + manager *tokenmanager.Manager + errManager error + + // managerForTest, when non-nil, is returned by defaultManager() + // instead of constructing the production manager. Tests use + // SetManagerForTest to inject a manager that hits a test STS + // server / in-memory store. Production code never reads this var. + managerForTest *tokenmanager.Manager +) + +// SetManagerForTest installs mgr as the manager returned by +// defaultManager() and returns a cleanup function. Test-only. +func SetManagerForTest(t interface{ Helper() }, mgr *tokenmanager.Manager) func() { + t.Helper() + prev := managerForTest + managerForTest = mgr + return func() { managerForTest = prev } +} + +// defaultManager returns the package-level Manager built from this +// CLI's identity (current provider, AuthBaseURL, NewStore service +// name). Constructed lazily on first use so any env-var setup +// (ENTIRE_AUTH_BASE_URL, ENTIRE_AUTH_PROVIDER_VERSION) lands before +// construction. sync.Once means later env-var changes within the same +// process are ignored; tests bypass the singleton via SetManagerForTest. +func defaultManager() (*tokenmanager.Manager, error) { + if managerForTest != nil { + return managerForTest, nil + } + managerOnce.Do(func() { + provider := CurrentProvider() + issuer := api.AuthBaseURL() + m, err := tokenmanager.New(tokenmanager.Config{ + Issuer: issuer, + ClientID: provider.ClientID, + STSPath: provider.STSPath, + Store: NewStore(), + UserAgent: provider.ClientID, + Scope: "cli", + // Auto-permit only loopback http:// for local development. + // Anything else must be https:// — STS ships the user's + // core token in the request body and would leak in clear + // otherwise. Matches the server-side JWKS acceptance + // pattern (isAcceptableJwksOrigin). + AllowInsecureHTTP: isLoopbackHTTP(issuer), + }) + manager = m + if err != nil { + errManager = fmt.Errorf("build token manager: %w", err) + } + }) + return manager, errManager +} + +// TokenForResource returns a bearer token suitable for use against +// resourceBaseURL, performing an RFC 8693 token exchange when the +// stored core token's audience doesn't already cover that resource. +// See tokenmanager.Manager.Token for the full resolution rules. +func TokenForResource(ctx context.Context, resourceBaseURL string) (string, error) { + m, err := defaultManager() + if err != nil { + return "", err + } + return m.TokenForResource(ctx, resourceBaseURL) //nolint:wrapcheck // shim returns the lib error verbatim +} + +// Token is the full-control entry point. Use TokenForResource for the +// common case; this exists so callers can override the wire-level +// Audience, RequestedTokenType, or Scope per call. +func Token(ctx context.Context, req TokenRequest) (string, error) { + m, err := defaultManager() + if err != nil { + return "", err + } + return m.Token(ctx, req) //nolint:wrapcheck // shim returns the lib error verbatim +} + +// isLoopbackHTTP reports whether u is an http:// URL pointing at a +// loopback hostname (localhost, 127.0.0.1, ::1). Used to scope the +// "auto-permit insecure HTTP" path on the tokenmanager so production +// misconfigurations (e.g. http://api.example.com) fail loudly while +// loopback-only local-dev flows keep working. +func isLoopbackHTTP(rawURL string) bool { + u, err := url.Parse(rawURL) + if err != nil || u.Scheme != "http" { + return false + } + host := u.Hostname() + return host == "localhost" || host == "127.0.0.1" || host == "::1" +} diff --git a/cmd/entire/cli/auth/exchange_test.go b/cmd/entire/cli/auth/exchange_test.go new file mode 100644 index 0000000000..334f27c330 --- /dev/null +++ b/cmd/entire/cli/auth/exchange_test.go @@ -0,0 +1,80 @@ +package auth + +import ( + "context" + "errors" + "testing" + + "github.com/entireio/auth-go/sts" + "github.com/entireio/auth-go/tokenmanager" + "github.com/entireio/auth-go/tokens" + "github.com/entireio/auth-go/tokenstore" +) + +// memStoreForExchange mirrors the tokenmanager test helper but local +// to this package — the cmd-side test only exercises wiring, so we +// don't import the manager's test fixtures. +type memStoreForExchange struct { + data map[string]tokens.TokenSet +} + +func (s *memStoreForExchange) SaveTokens(profile string, t tokens.TokenSet) error { + s.data[profile] = t + return nil +} + +func (s *memStoreForExchange) LoadTokens(profile string) (tokens.TokenSet, error) { + t, ok := s.data[profile] + if !ok { + return tokens.TokenSet{}, tokenstore.ErrNotFound + } + return t, nil +} + +func (s *memStoreForExchange) DeleteTokens(profile string) error { + delete(s.data, profile) + return nil +} + +// TestTokenForResource_DelegatesToManager verifies the cmd-side shim +// forwards calls to whatever Manager SetManagerForTest installs. The +// underlying behaviour (cache, exchange, audience checks) is covered +// by the tokenmanager package tests. +func TestTokenForResource_DelegatesToManager(t *testing.T) { + // Not parallel: SetManagerForTest mutates package-level state. + store := &memStoreForExchange{data: map[string]tokens.TokenSet{ + "https://auth.example.com": {AccessToken: "core"}, + }} + mgr, err := tokenmanager.New(tokenmanager.Config{ + Issuer: "https://auth.example.com", + ClientID: "test-cli", + STSPath: "/sts/token", + Store: store, + Exchange: func(_ context.Context, _ sts.ExchangeRequest) (*tokens.TokenSet, error) { + return &tokens.TokenSet{AccessToken: "exchanged"}, nil + }, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + cleanup := SetManagerForTest(t, mgr) + t.Cleanup(cleanup) + + got, err := TokenForResource(context.Background(), "https://api.example.com") + if err != nil { + t.Fatalf("TokenForResource: %v", err) + } + if got != "exchanged" { + t.Fatalf("got %q, want exchanged", got) + } +} + +// TestTokenForResource_ReExportsErrNotLoggedIn ensures the cmd-side +// alias matches the underlying sentinel so callers can errors.Is +// against either. +func TestTokenForResource_ReExportsErrNotLoggedIn(t *testing.T) { + t.Parallel() + if !errors.Is(ErrNotLoggedIn, tokenmanager.ErrNotLoggedIn) { + t.Fatal("auth.ErrNotLoggedIn must alias tokenmanager.ErrNotLoggedIn") + } +} diff --git a/cmd/entire/cli/auth/provider.go b/cmd/entire/cli/auth/provider.go new file mode 100644 index 0000000000..b30ec8b6b8 --- /dev/null +++ b/cmd/entire/cli/auth/provider.go @@ -0,0 +1,127 @@ +package auth + +import ( + "os" + "strings" + "sync" +) + +// ProviderVersionEnvVar selects which OAuth surface this CLI talks to. +// +// Recognised values: +// +// - "v1" (or unset / unrecognised) — current device-flow surface +// - "v2" — next-generation device-flow surface +// +// This is a transition-period switch: once v2 is the universal default +// the env var goes away. Surfaces are otherwise reachable as RFC 8628 +// device-flow endpoints; the only differences are paths and client_id. +// +// Read once at process startup via CurrentProvider; later flips within +// the same process are intentionally ignored. Tests inject via +// SetProviderForTest rather than mutating the env mid-run. +const ProviderVersionEnvVar = "ENTIRE_AUTH_PROVIDER_VERSION" + +// Provider captures the per-surface bits of OAuth wiring. +// +// STSPath is the RFC 8693 token-exchange endpoint. v1 is the legacy +// single-host surface where the auth and data API live at the same +// origin (entire.io); the same-host shortcut in tokenmanager.Token +// always wins and STS is never invoked, so v1.STSPath is left empty. +// v2 exposes a dedicated STS path because it's used in split-host +// deployments (e.g. us.auth.partial.to mints, partial.to consumes). +// +// AuthTokensPath is the base path for the auth-tokens management +// endpoint family (list / revoke). Routed at the api.Client layer via +// (*api.Client).WithAuthTokensPath so the provider table is the single +// source of truth — no env-var duplication between auth/ and api/. +type Provider struct { + ClientID string + DeviceCodePath string + TokenPath string + STSPath string + AuthTokensPath string +} + +var providers = map[string]Provider{ + "v1": { //nolint:gosec // OAuth client_id and endpoint paths, not credentials + ClientID: "entire-cli", + DeviceCodePath: "/oauth/device/code", + TokenPath: "/oauth/token", + AuthTokensPath: "/api/v1/auth/tokens", + }, + "v2": { //nolint:gosec // OAuth client_id and endpoint paths, not credentials + ClientID: "entire-cli", + DeviceCodePath: "/api/auth/oauth/device/code", + TokenPath: "/api/auth/token", + STSPath: "/api/authz/sts/token", + AuthTokensPath: "/api/auth/tokens", + }, +} + +// resolveProvider returns the Provider matching version. Defaulting +// (rather than erroring) on unrecognised values keeps old binaries safe +// if a future v3 ever lands. Pure function — no env reads — so unit +// tests can exercise the routing table without env-var gymnastics. +func resolveProvider(version string) Provider { + switch strings.TrimSpace(version) { + case "v2": + return providers["v2"] + default: + return providers["v1"] + } +} + +var ( + providerOnce sync.Once + resolvedProvider Provider + + // providerForTest, when non-nil, short-circuits CurrentProvider so + // tests can install a specific Provider without racing the + // process-wide sync.Once (which freezes the first observation + // forever). Mutated only via SetProviderForTest. Production code + // never reads this var. + providerForTest *Provider + providerTestMu sync.Mutex +) + +// CurrentProvider returns the active Provider for this process. +// +// Resolution: read ENTIRE_AUTH_PROVIDER_VERSION exactly once on first +// call, freeze the result, and return the same Provider on every +// subsequent call. Tests that need a different provider must use +// SetProviderForTest before any auth call constructs the singleton. +func CurrentProvider() Provider { + providerTestMu.Lock() + override := providerForTest + providerTestMu.Unlock() + if override != nil { + return *override + } + providerOnce.Do(func() { + resolvedProvider = resolveProvider(os.Getenv(ProviderVersionEnvVar)) + }) + return resolvedProvider +} + +// SetProviderForTest installs p as the Provider returned by +// CurrentProvider for the duration of the test, and registers a +// t.Cleanup to remove the override. Test-only. +// +// Takes a tiny interface rather than *testing.T so production builds +// don't import testing. +func SetProviderForTest(t interface { + Helper() + Cleanup(f func()) +}, p Provider) { + t.Helper() + providerTestMu.Lock() + prev := providerForTest + providerForTest = &p + providerTestMu.Unlock() + t.Cleanup(func() { + providerTestMu.Lock() + providerForTest = prev + providerTestMu.Unlock() + }) +} diff --git a/cmd/entire/cli/auth/provider_test.go b/cmd/entire/cli/auth/provider_test.go new file mode 100644 index 0000000000..8a47f6894e --- /dev/null +++ b/cmd/entire/cli/auth/provider_test.go @@ -0,0 +1,139 @@ +package auth + +import ( + "net/http" + "testing" + + "github.com/entireio/cli/cmd/entire/cli/api" +) + +// Test-local mirrors of the v1 / v2 client_id values, so assertions +// don't repeat the same string literal across multiple tests (goconst). +// Both providers now share the same client_id; the constants are kept +// distinct so a future divergence (or a regression that re-splits them) +// shows up at a single edit site. +const ( + wantClientIDV1 = "entire-cli" + wantClientIDV2 = "entire-cli" +) + +// resolveProvider is a pure function — no env reads — so the routing +// table can be exercised without t.Setenv (and without the +// process-wide sync.Once in CurrentProvider freezing the first +// observation forever). + +func TestResolveProvider_DefaultsToV1(t *testing.T) { + t.Parallel() + p := resolveProvider("") + if p.ClientID != wantClientIDV1 || p.DeviceCodePath != "/oauth/device/code" || p.TokenPath != "/oauth/token" { + t.Fatalf("default provider = %+v, want v1 config", p) + } + if p.AuthTokensPath != "/api/v1/auth/tokens" { + t.Fatalf("default AuthTokensPath = %q, want /api/v1/auth/tokens", p.AuthTokensPath) + } +} + +func TestResolveProvider_V1Explicit(t *testing.T) { + t.Parallel() + p := resolveProvider("v1") + if p.ClientID != wantClientIDV1 { + t.Fatalf("v1 ClientID = %q", p.ClientID) + } + // v1 is single-host (entire.io); no STS surface, same-host shortcut + // always wins. Empty STSPath is the contract. + if p.STSPath != "" { + t.Fatalf("v1 STSPath = %q, want empty (single-host, no STS)", p.STSPath) + } +} + +func TestResolveProvider_V2(t *testing.T) { + t.Parallel() + p := resolveProvider("v2") + if p.ClientID != wantClientIDV2 { + t.Fatalf("v2 ClientID = %q, want %s", p.ClientID, wantClientIDV2) + } + if p.DeviceCodePath != "/api/auth/oauth/device/code" { + t.Fatalf("v2 DeviceCodePath = %q", p.DeviceCodePath) + } + if p.TokenPath != "/api/auth/token" { + t.Fatalf("v2 TokenPath = %q", p.TokenPath) + } + if p.STSPath != "/api/authz/sts/token" { + t.Fatalf("v2 STSPath = %q", p.STSPath) + } + if p.AuthTokensPath != "/api/auth/tokens" { + t.Fatalf("v2 AuthTokensPath = %q, want /api/auth/tokens", p.AuthTokensPath) + } +} + +func TestResolveProvider_UnknownDefaultsToV1(t *testing.T) { + t.Parallel() + p := resolveProvider("v999") + if p.ClientID != wantClientIDV1 { + t.Fatalf("unknown version should default to v1; got ClientID = %q", p.ClientID) + } +} + +func TestResolveProvider_TrimsWhitespace(t *testing.T) { + t.Parallel() + p := resolveProvider(" v2 ") + if p.ClientID != wantClientIDV2 { + t.Fatalf("whitespace-padded v2 ClientID = %q, want %s", p.ClientID, wantClientIDV2) + } +} + +// TestSetProviderForTest_OverridesCurrentProvider locks in the test +// seam: any test that pins a provider via SetProviderForTest must see +// it from CurrentProvider regardless of process-wide singleton state. +func TestSetProviderForTest_OverridesCurrentProvider(t *testing.T) { + pinned := Provider{ + ClientID: "test-client", + DeviceCodePath: "/test/device", + TokenPath: "/test/token", + STSPath: "/test/sts", + AuthTokensPath: "/test/tokens", + } + SetProviderForTest(t, pinned) + + got := CurrentProvider() + if got != pinned { + t.Fatalf("CurrentProvider() = %+v, want %+v", got, pinned) + } +} + +func TestNewClient_HonoursPinnedProvider(t *testing.T) { + t.Setenv(api.BaseURLEnvVar, "https://example.test") + t.Setenv(api.AuthBaseURLEnvVar, "") + SetProviderForTest(t, resolveProvider("v2")) + + c := NewClient(&http.Client{}) + if c.inner.ClientID != wantClientIDV2 { + t.Errorf("ClientID = %q, want %s", c.inner.ClientID, wantClientIDV2) + } + if c.inner.DeviceCodePath != "/api/auth/oauth/device/code" { + t.Errorf("DeviceCodePath = %q", c.inner.DeviceCodePath) + } + if c.inner.TokenPath != "/api/auth/token" { + t.Errorf("TokenPath = %q", c.inner.TokenPath) + } + if c.inner.BaseURL != "https://example.test" { + t.Errorf("BaseURL = %q", c.inner.BaseURL) + } +} + +func TestNewClient_DefaultsToV1WhenPinned(t *testing.T) { + t.Setenv(api.BaseURLEnvVar, "https://example.test") + t.Setenv(api.AuthBaseURLEnvVar, "") + SetProviderForTest(t, resolveProvider("")) + + c := NewClient(nil) + if c.inner.ClientID != wantClientIDV1 { + t.Errorf("ClientID = %q, want %s", c.inner.ClientID, wantClientIDV1) + } + if c.inner.DeviceCodePath != "/oauth/device/code" { + t.Errorf("DeviceCodePath = %q", c.inner.DeviceCodePath) + } + if c.inner.TokenPath != "/oauth/token" { + t.Errorf("TokenPath = %q", c.inner.TokenPath) + } +} diff --git a/cmd/entire/cli/auth/store.go b/cmd/entire/cli/auth/store.go index 729bf69eb4..20902304d1 100644 --- a/cmd/entire/cli/auth/store.go +++ b/cmd/entire/cli/auth/store.go @@ -5,16 +5,28 @@ import ( "fmt" "strings" + "github.com/entireio/auth-go/tokens" + "github.com/entireio/auth-go/tokenstore" "github.com/entireio/cli/cmd/entire/cli/api" "github.com/zalando/go-keyring" ) +// keyringService is the OS-keyring service name for this CLI. Renaming +// would orphan every existing user's stored credentials — they'd appear +// logged out until they ran `entire login` again. Don't change this +// without a migration path. const keyringService = "entire-cli" // Store manages CLI authentication tokens via a pluggable backend. The // production binary always resolves to the OS keyring. A file-backed // backend is available only in builds tagged `authfilestore` (used by // integration tests to avoid the OS keychain). +// +// Implements tokenstore.Store so it can be passed to tokenmanager.New +// as the persistence layer. The interface methods (SaveTokens / +// LoadTokens / DeleteTokens) delegate to the same backend as the +// legacy SaveToken / GetToken / DeleteToken pair, so production and +// test paths share a single source of truth. type Store struct { service string backend tokenBackend @@ -50,6 +62,11 @@ func NewStoreWithService(service string) *Store { } // SaveToken persists an access token for the given base URL. +// +// Deprecated: prefer SaveTokens (the tokenstore.Store interface method) +// for new callers. SaveToken is kept for the legacy direct-bearer call +// sites (login, logout, auth status/list/revoke) that don't go through +// the tokenmanager. func (s *Store) SaveToken(baseURL, token string) error { token = strings.TrimSpace(token) if token == "" { @@ -58,22 +75,89 @@ func (s *Store) SaveToken(baseURL, token string) error { return s.backend.save(s.service, baseURL, token) } -// GetToken retrieves a stored token for the given base URL. -// Returns an empty string (and no error) if no token is stored. +// GetToken retrieves a stored token for the given base URL. Returns +// an empty string (and no error) if no token is stored, or if the +// stored value is JSON-shaped (defensive: pre-shim entries are +// opaque token strings, never JSON; a JSON blob in the keyring is +// corruption and must not be put on the wire as a bearer). +// +// Deprecated: prefer LoadTokens (the tokenstore.Store interface method) +// for new callers — it returns the full TokenSet so refresh tokens and +// expiry survive the round trip. GetToken is retained for the direct- +// bearer call sites that only need the access token string. func (s *Store) GetToken(baseURL string) (string, error) { - return s.backend.get(s.service, baseURL) + raw, err := s.backend.get(s.service, baseURL) + if err != nil { + return "", err + } + if looksJSONShaped(raw) { + return "", nil + } + return raw, nil +} + +// looksJSONShaped reports whether the keyring value's first +// non-whitespace byte is '{' or '['. Used to reject corrupt / +// previous-encoding entries before they end up in an +// Authorization: Bearer header. +func looksJSONShaped(s string) bool { + trimmed := strings.TrimLeft(s, " \t\r\n") + if trimmed == "" { + return false + } + return trimmed[0] == '{' || trimmed[0] == '[' } // DeleteToken removes a stored token for the given base URL. // Returns no error if the token does not exist. +// +// Deprecated: prefer DeleteTokens (the tokenstore.Store interface +// method). DeleteToken is retained for direct-bearer call sites. func (s *Store) DeleteToken(baseURL string) error { return s.backend.delete(s.service, baseURL) } -// LookupCurrentToken retrieves the token for the current base URL. +// SaveTokens implements tokenstore.Store. Refresh token, scope, expiry, +// and token type are intentionally dropped — the entire device-flow +// surface doesn't issue refresh tokens, and the legacy keyring/file +// layout stores bare access-token strings. If refresh-token support +// lands, this method (and the tokenBackend interface) become the +// migration point. +func (s *Store) SaveTokens(profile string, t tokens.TokenSet) error { + access := strings.TrimSpace(t.AccessToken) + if access == "" { + return errors.New("refusing to save empty access token") + } + return s.backend.save(s.service, profile, access) +} + +// LoadTokens implements tokenstore.Store. Reads the bare-string entry +// and wraps it back into a TokenSet. Returns tokenstore.ErrNotFound +// when nothing is stored under the profile (or the stored value is +// JSON-shaped — see GetToken's note about defensive rejection of +// non-token blobs) so callers can errors.Is against the lib sentinel. +func (s *Store) LoadTokens(profile string) (tokens.TokenSet, error) { + access, err := s.backend.get(s.service, profile) + if err != nil { + return tokens.TokenSet{}, err + } + if access == "" || looksJSONShaped(access) { + return tokens.TokenSet{}, tokenstore.ErrNotFound + } + return tokens.TokenSet{AccessToken: access}, nil +} + +// DeleteTokens implements tokenstore.Store. +func (s *Store) DeleteTokens(profile string) error { + return s.backend.delete(s.service, profile) +} + +// LookupCurrentToken retrieves the token for the current auth base URL. +// Tokens are keyed by the auth issuer (api.AuthBaseURL()) since that's the +// host that minted them; in single-host deployments AuthBaseURL falls back +// to BaseURL so behaviour is unchanged. func LookupCurrentToken() (string, error) { - store := NewStore() - return store.GetToken(api.BaseURL()) + return NewStore().GetToken(api.AuthBaseURL()) } type keyringBackend struct{} diff --git a/cmd/entire/cli/auth/store_test.go b/cmd/entire/cli/auth/store_test.go index 8ccf122372..7a767b5602 100644 --- a/cmd/entire/cli/auth/store_test.go +++ b/cmd/entire/cli/auth/store_test.go @@ -132,8 +132,97 @@ func TestStoreDeleteToken_NotFoundIsNoop(t *testing.T) { } } +// TestStoreGetToken_LegacyBareStringFallback verifies that a pre-shim +// keyring entry (raw access-token string, not a JSON-encoded TokenSet) +// is still readable via GetToken after the shim landed. Without the +// fallback, pre-shim users would appear logged out after upgrading. +func TestStoreGetToken_LegacyBareStringFallback(t *testing.T) { + // Not parallel: go-keyring's mock provider uses an unprotected map. + const service = "test-legacy-getoken" + const profile = "https://legacy.example.com" + const bareToken = "ent_pre_shim_raw_token" + + if err := keyring.Set(service, profile, bareToken); err != nil { + t.Fatalf("seed keyring: %v", err) + } + + got, err := NewStoreWithService(service).GetToken(profile) + if err != nil { + t.Fatalf("GetToken: %v", err) + } + if got != bareToken { + t.Fatalf("GetToken() = %q, want bare token %q", got, bareToken) + } +} + +// TestStoreLoadTokens_LegacyBareStringFallback is the tokenstore.Store +// counterpart of the above. The tokenmanager calls LoadTokens, so this +// path is what determines whether the manager-backed code path +// recognises pre-shim entries. +func TestStoreLoadTokens_LegacyBareStringFallback(t *testing.T) { + // Not parallel: go-keyring's mock provider uses an unprotected map. + const service = "test-legacy-loadtokens" + const profile = "https://legacy.example.com" + const bareToken = "ent_pre_shim_raw_token" + + if err := keyring.Set(service, profile, bareToken); err != nil { + t.Fatalf("seed keyring: %v", err) + } + + got, err := NewStoreWithService(service).LoadTokens(profile) + if err != nil { + t.Fatalf("LoadTokens: %v", err) + } + if got.AccessToken != bareToken { + t.Fatalf("LoadTokens AccessToken = %q, want %q", got.AccessToken, bareToken) + } +} + +// TestStoreLoadTokens_RejectsJSONShapedFallback guards against a +// well-formed-but-empty JSON entry being mistakenly treated as a +// bare-string token. The lib surfaces these as ErrMalformed; the +// shim's bare-string fallback must filter out anything starting with +// '{' or '[' so the user sees "not logged in" rather than getting +// "Authorization: Bearer {}" on the wire. +func TestStoreLoadTokens_RejectsJSONShapedFallback(t *testing.T) { + for _, body := range []string{`{}`, `{"foo":"bar"}`, `[]`} { + const profile = "https://json-shaped.example.com" + service := "test-json-fallback-" + body[:1] + if err := keyring.Set(service, profile, body); err != nil { + t.Fatalf("seed keyring: %v", err) + } + + got, err := NewStoreWithService(service).LoadTokens(profile) + // We expect ErrNotFound — JSON-shaped malformed entries must + // not be routed through the bare-string fallback. + if err == nil { + t.Fatalf("LoadTokens(%q) returned %+v; want ErrNotFound", body, got) + } + if got.AccessToken != "" { + t.Fatalf("LoadTokens(%q) AccessToken = %q, want empty", body, got.AccessToken) + } + } +} + +func TestStoreGetToken_RejectsJSONShapedFallback(t *testing.T) { + const service = "test-json-getoken" + const profile = "https://json-shaped.example.com" + if err := keyring.Set(service, profile, `{"unrelated":"blob"}`); err != nil { + t.Fatalf("seed keyring: %v", err) + } + + got, err := NewStoreWithService(service).GetToken(profile) + if err != nil { + t.Fatalf("GetToken: %v", err) + } + if got != "" { + t.Fatalf("GetToken = %q, want empty (JSON blob must not be shipped as a bearer)", got) + } +} + func TestLookupCurrentToken(t *testing.T) { t.Setenv(api.BaseURLEnvVar, "http://localhost:8787") + t.Setenv(api.AuthBaseURLEnvVar, "") store := NewStore() if err := store.SaveToken("http://localhost:8787", "local-token"); err != nil { diff --git a/cmd/entire/cli/dispatch/dispatch_test.go b/cmd/entire/cli/dispatch/dispatch_test.go index 431e66d2dd..5427141342 100644 --- a/cmd/entire/cli/dispatch/dispatch_test.go +++ b/cmd/entire/cli/dispatch/dispatch_test.go @@ -4,13 +4,17 @@ import ( "context" "strings" "testing" + + "github.com/entireio/cli/cmd/entire/cli/auth" ) func TestRun_ServerAllowsRepos(t *testing.T) { - oldLookup := lookupCurrentToken - lookupCurrentToken = func() (string, error) { return "", nil } + oldResource := lookupResourceToken + lookupResourceToken = func(_ context.Context, _ string) (string, error) { + return "", auth.ErrNotLoggedIn + } t.Cleanup(func() { - lookupCurrentToken = oldLookup + lookupResourceToken = oldResource }) _, err := Run(context.Background(), Options{ diff --git a/cmd/entire/cli/dispatch/mode_cloud.go b/cmd/entire/cli/dispatch/mode_cloud.go index 3c5fa93b3e..21aae53a9a 100644 --- a/cmd/entire/cli/dispatch/mode_cloud.go +++ b/cmd/entire/cli/dispatch/mode_cloud.go @@ -8,6 +8,7 @@ import ( "time" "github.com/entireio/cli/cmd/entire/cli/api" + "github.com/entireio/cli/cmd/entire/cli/auth" "github.com/entireio/cli/cmd/entire/cli/paths" "github.com/go-git/go-git/v6" ) @@ -19,14 +20,6 @@ import ( var requireSecureDispatchURL = api.RequireSecureURL func runServer(ctx context.Context, opts Options) (*Dispatch, error) { - token, err := lookupCurrentToken() - if err != nil { - return nil, fmt.Errorf("reading credentials: %w", err) - } - if token == "" { - return nil, errors.New("dispatch requires login — run `entire login`") - } - baseURL := api.BaseURL() if !opts.InsecureHTTPAuth { if err := requireSecureDispatchURL(baseURL); err != nil { @@ -34,6 +27,19 @@ func runServer(ctx context.Context, opts Options) (*Dispatch, error) { } } + // Resolve a bearer scoped to the dispatch service host. In split-host + // deployments the tokenmanager runs an RFC 8693 exchange so the + // bearer carries the data-API audience rather than the auth-host + // one; single-host setups hit the same-host shortcut and return the + // core token unchanged. + token, err := lookupResourceToken(ctx, baseURL) + if errors.Is(err, auth.ErrNotLoggedIn) { + return nil, errors.New("dispatch requires login — run `entire login`") + } + if err != nil { + return nil, fmt.Errorf("reading credentials: %w", err) + } + now := nowUTC() sinceInput := strings.TrimSpace(opts.Since) if sinceInput == "" { diff --git a/cmd/entire/cli/dispatch/mode_cloud_test.go b/cmd/entire/cli/dispatch/mode_cloud_test.go index c8af1b42c4..521f519e1a 100644 --- a/cmd/entire/cli/dispatch/mode_cloud_test.go +++ b/cmd/entire/cli/dispatch/mode_cloud_test.go @@ -20,11 +20,16 @@ import ( func stubCloudDispatchAuth(t *testing.T) { t.Helper() oldLookup := lookupCurrentToken + oldResource := lookupResourceToken oldRequire := requireSecureDispatchURL lookupCurrentToken = func() (string, error) { return testCloudDispatchToken, nil } + lookupResourceToken = func(_ context.Context, _ string) (string, error) { + return testCloudDispatchToken, nil + } requireSecureDispatchURL = func(string) error { return nil } t.Cleanup(func() { lookupCurrentToken = oldLookup + lookupResourceToken = oldResource requireSecureDispatchURL = oldRequire }) } @@ -368,11 +373,16 @@ func TestServerMode_InsecureHTTPAuthBypassesSecureURLCheck(t *testing.T) { defer mock.Close() oldLookup := lookupCurrentToken + oldResource := lookupResourceToken oldNow := nowUTC lookupCurrentToken = func() (string, error) { return testCloudDispatchToken, nil } + lookupResourceToken = func(_ context.Context, _ string) (string, error) { + return testCloudDispatchToken, nil + } nowUTC = func() time.Time { return time.Date(2026, 4, 16, 0, 0, 0, 0, time.UTC) } t.Cleanup(func() { lookupCurrentToken = oldLookup + lookupResourceToken = oldResource nowUTC = oldNow }) @@ -399,8 +409,15 @@ func TestServerMode_InsecureHTTPAuthBypassesSecureURLCheck(t *testing.T) { // leak reaches users. func TestServerMode_RejectsPlainHTTPBaseURL(t *testing.T) { oldLookup := lookupCurrentToken + oldResource := lookupResourceToken lookupCurrentToken = func() (string, error) { return testCloudDispatchToken, nil } - t.Cleanup(func() { lookupCurrentToken = oldLookup }) + lookupResourceToken = func(_ context.Context, _ string) (string, error) { + return testCloudDispatchToken, nil + } + t.Cleanup(func() { + lookupCurrentToken = oldLookup + lookupResourceToken = oldResource + }) t.Setenv("ENTIRE_API_BASE_URL", "http://dispatch.example.invalid") diff --git a/cmd/entire/cli/dispatch/mode_local.go b/cmd/entire/cli/dispatch/mode_local.go index 58186ac6e1..34ea268f32 100644 --- a/cmd/entire/cli/dispatch/mode_local.go +++ b/cmd/entire/cli/dispatch/mode_local.go @@ -22,8 +22,20 @@ import ( ) var ( + // lookupCurrentToken is retained for test-injection back-compat. The + // cloud-mode runner now resolves its bearer via lookupResourceToken + // so it picks up the RFC 8693 exchange in split-host deployments; + // existing tests that swap lookupCurrentToken keep working because + // the default lookupResourceToken delegates to it. lookupCurrentToken = auth.LookupCurrentToken - nowUTC = func() time.Time { return time.Now().UTC() } + + // lookupResourceToken returns a bearer scoped to the given resource + // origin. Production wiring goes through auth.TokenForResource so + // the tokenmanager's same-host shortcut, JWT-aud shortcut, and + // exchange dispatch all apply. Tests swap to a fixed-token closure. + lookupResourceToken = auth.TokenForResource + + nowUTC = func() time.Time { return time.Now().UTC() } ) func runLocal(ctx context.Context, opts Options) (*Dispatch, error) { diff --git a/cmd/entire/cli/dispatch_wizard.go b/cmd/entire/cli/dispatch_wizard.go index dadd60ca40..88e5551cab 100644 --- a/cmd/entire/cli/dispatch_wizard.go +++ b/cmd/entire/cli/dispatch_wizard.go @@ -29,7 +29,7 @@ var getDispatchWizardCurrentBranch = GetCurrentBranch var runDispatchWizardForm = func(form *huh.Form) error { return form.Run() } func defaultListDispatchWizardRepoResources(ctx context.Context) ([]api.Repository, error) { - client, err := NewAuthenticatedAPIClient(false) + client, err := NewAuthenticatedAPIClient(ctx, false) if err != nil { return nil, err } diff --git a/cmd/entire/cli/integration_test/login_test.go b/cmd/entire/cli/integration_test/login_test.go index 2954365147..7438d4f361 100644 --- a/cmd/entire/cli/integration_test/login_test.go +++ b/cmd/entire/cli/integration_test/login_test.go @@ -203,6 +203,8 @@ func runLoginProcess(t *testing.T, apiBaseURL string) *loginProcess { "ENTIRE_TEST_GEMINI_PROJECT_DIR="+env.GeminiProjectDir, "ENTIRE_TEST_OPENCODE_PROJECT_DIR="+env.OpenCodeProjectDir, "ENTIRE_API_BASE_URL="+apiBaseURL, + "ENTIRE_AUTH_BASE_URL="+apiBaseURL, + "ENTIRE_AUTH_PROVIDER_VERSION=v1", "ENTIRE_TEST_AUTH_STORE_FILE="+filepath.Join(env.RepoDir, ".entire-test-auth-store.json"), ) diff --git a/cmd/entire/cli/login.go b/cmd/entire/cli/login.go index 88618c522e..78806cf82f 100644 --- a/cmd/entire/cli/login.go +++ b/cmd/entire/cli/login.go @@ -12,6 +12,7 @@ import ( "runtime" "time" + "github.com/entireio/auth-go/tokens" "github.com/entireio/cli/cmd/entire/cli/auth" "github.com/entireio/cli/cmd/entire/cli/interactive" "github.com/spf13/cobra" @@ -84,9 +85,16 @@ func runLogin(ctx context.Context, outW, errW io.Writer, client deviceAuthClient return fmt.Errorf("complete login: %w", err) } + if err := validateReceivedToken(token, client.BaseURL(), time.Now()); err != nil { + return fmt.Errorf("reject login token: %w", err) + } + store := auth.NewStore() - if err := store.SaveToken(client.BaseURL(), token); err != nil { + // Login deliberately uses the legacy SaveToken (string, string) + // surface — we only have an access-token string at this point; + // the deviceflow client doesn't return a TokenSet here. + if err := store.SaveToken(client.BaseURL(), token); err != nil { //nolint:staticcheck // SA1019: legacy direct-bearer call site, see Store.SaveToken doc return fmt.Errorf("save auth token: %w", err) } @@ -146,6 +154,9 @@ func waitForApproval(ctx context.Context, poller deviceAuthClient, deviceCode st case "expired_token": return "", errors.New("device authorization expired") default: + if result.ErrorDescription != "" { + return "", fmt.Errorf("device authorization failed: %s: %s", result.Error, result.ErrorDescription) + } return "", fmt.Errorf("device authorization failed: %s", result.Error) } @@ -218,3 +229,86 @@ func openBrowser(ctx context.Context, browserURL string) error { return nil } + +// validateReceivedToken runs minimum-trust checks on the access token +// the AS handed us before we persist it. The server is the authority +// on signature/exp; this is defense in depth aimed at catching gross +// misbehaviour by a compromised or misconfigured AS (e.g. handing back +// a token from a different issuer than the one we asked, or one whose +// claims are already-expired). +// +// Opaque (non-JWT) tokens are permitted — the AS may not issue JWTs at +// all. Only when we can parse the token as a JWT do we cross-check the +// claims. Unsigned (alg:none) JWTs are always rejected: see +// tokens.ErrUnsignedJWT. +func validateReceivedToken(rawToken, issuerURL string, now time.Time) error { + claims, err := tokens.ParseClaims(rawToken) + switch { + case errors.Is(err, tokens.ErrMalformedJWT): + // Opaque token — no claim-based checks available. Trust the + // server-side validation. (Most OAuth flows allow this; it's + // only a problem if our resource servers later expect JWTs.) + return nil + case errors.Is(err, tokens.ErrUnsignedJWT): + // alg:none is always a refusal — see tokens.ErrUnsignedJWT + // rationale. + return err //nolint:wrapcheck // sentinel surfaces verbatim for caller's errors.Is + case err != nil: + return fmt.Errorf("parse claims: %w", err) + } + + // iss check: the token must claim to come from the issuer we sent + // the device-code request to. A mismatch means either the AS is + // misconfigured or someone's playing games. + if issErr := issMatches(claims.Issuer, issuerURL); issErr != nil { + return issErr + } + + // exp sanity: a token that's already expired before we even store + // it is a smell. Don't reject if exp is unset (some servers omit). + if !claims.ExpiresAt.IsZero() && !now.Before(claims.ExpiresAt) { + return fmt.Errorf("token already expired (exp=%s, now=%s)", + claims.ExpiresAt.Format(time.RFC3339), now.Format(time.RFC3339)) + } + + return nil +} + +// issMatches reports whether claimed equals expected after a light +// normalisation: trim trailing slashes so "https://issuer/" and +// "https://issuer" match. Returns nil on match. +func issMatches(claimed, expected string) error { + if claimed == "" { + // Some servers omit iss (especially in opaque-but-jwt-shaped + // tokens). Allow rather than reject — the server still does + // the real check on every request. + return nil + } + normClaimed, err := normalizeIssuer(claimed) + if err != nil { + return fmt.Errorf("parse iss claim %q: %w", claimed, err) + } + normExpected, err := normalizeIssuer(expected) + if err != nil { + return fmt.Errorf("parse expected issuer %q: %w", expected, err) + } + if normClaimed != normExpected { + return fmt.Errorf("iss mismatch: token claims %q, expected %q", normClaimed, normExpected) + } + return nil +} + +func normalizeIssuer(raw string) (string, error) { + u, err := url.Parse(raw) + if err != nil { + return "", err //nolint:wrapcheck // caller wraps with context + } + if u.Scheme == "" || u.Host == "" { + // Non-URL issuer (logical name) — return verbatim. + return raw, nil + } + u.Path = "" + u.RawQuery = "" + u.Fragment = "" + return u.String(), nil +} diff --git a/cmd/entire/cli/login_test.go b/cmd/entire/cli/login_test.go index 63834fec3a..2ab5a30e06 100644 --- a/cmd/entire/cli/login_test.go +++ b/cmd/entire/cli/login_test.go @@ -2,11 +2,14 @@ package cli import ( "context" + "encoding/base64" + "encoding/json" "errors" "strings" "testing" "time" + "github.com/entireio/auth-go/tokens" "github.com/entireio/cli/cmd/entire/cli/auth" ) @@ -259,3 +262,98 @@ func TestWaitForApproval_ContextCancelled(t *testing.T) { t.Fatalf("err = %v, want context canceled", err) } } + +// makeTestJWT builds a well-formed JWT (alg != none) with the given +// claims for use in login validation tests. Signature is junk — +// validateReceivedToken doesn't verify it. +func makeTestJWT(t *testing.T, claims map[string]any) string { + t.Helper() + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"EdDSA","typ":"JWT"}`)) + body, err := json.Marshal(claims) + if err != nil { + t.Fatalf("marshal: %v", err) + } + return header + "." + base64.RawURLEncoding.EncodeToString(body) + ".sig" +} + +func TestValidateReceivedToken_OpaqueAllowed(t *testing.T) { + t.Parallel() + // Non-JWT tokens are permitted — the AS may not issue JWTs at all. + if err := validateReceivedToken("opaque_token_value", "https://issuer.example", time.Now()); err != nil { + t.Fatalf("validateReceivedToken(opaque) = %v, want nil", err) + } +} + +func TestValidateReceivedToken_RejectsUnsignedJWT(t *testing.T) { + t.Parallel() + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`)) + body := base64.RawURLEncoding.EncodeToString([]byte(`{"iss":"https://issuer.example"}`)) + jwt := header + "." + body + ".sig" + + err := validateReceivedToken(jwt, "https://issuer.example", time.Now()) + if !errors.Is(err, tokens.ErrUnsignedJWT) { + t.Fatalf("validateReceivedToken(alg:none) = %v, want ErrUnsignedJWT", err) + } +} + +func TestValidateReceivedToken_RejectsWrongIssuer(t *testing.T) { + t.Parallel() + jwt := makeTestJWT(t, map[string]any{ + "iss": "https://attacker.example", + "sub": "account:x", + "exp": time.Now().Add(time.Hour).Unix(), + }) + err := validateReceivedToken(jwt, "https://issuer.example", time.Now()) + if err == nil || !strings.Contains(err.Error(), "iss mismatch") { + t.Fatalf("validateReceivedToken(wrong iss) = %v, want iss mismatch", err) + } +} + +func TestValidateReceivedToken_AcceptsMatchingIssuer(t *testing.T) { + t.Parallel() + now := time.Now() + jwt := makeTestJWT(t, map[string]any{ + "iss": "https://issuer.example", + "sub": "account:x", + "exp": now.Add(time.Hour).Unix(), + }) + if err := validateReceivedToken(jwt, "https://issuer.example", now); err != nil { + t.Fatalf("validateReceivedToken(matching iss) = %v, want nil", err) + } +} + +func TestValidateReceivedToken_NormalisesTrailingSlash(t *testing.T) { + t.Parallel() + jwt := makeTestJWT(t, map[string]any{ + "iss": "https://issuer.example/", + "exp": time.Now().Add(time.Hour).Unix(), + }) + if err := validateReceivedToken(jwt, "https://issuer.example", time.Now()); err != nil { + t.Fatalf("validateReceivedToken: trailing-slash iss must normalise, got %v", err) + } +} + +func TestValidateReceivedToken_RejectsExpiredToken(t *testing.T) { + t.Parallel() + now := time.Date(2026, 5, 8, 12, 0, 0, 0, time.UTC) + jwt := makeTestJWT(t, map[string]any{ + "iss": "https://issuer.example", + "exp": now.Add(-time.Hour).Unix(), + }) + err := validateReceivedToken(jwt, "https://issuer.example", now) + if err == nil || !strings.Contains(err.Error(), "already expired") { + t.Fatalf("validateReceivedToken(expired) = %v, want already-expired error", err) + } +} + +func TestValidateReceivedToken_OmittedIssIsAllowed(t *testing.T) { + t.Parallel() + // Some servers omit iss; allow rather than reject — the server + // still does the real check on every request. + jwt := makeTestJWT(t, map[string]any{ + "exp": time.Now().Add(time.Hour).Unix(), + }) + if err := validateReceivedToken(jwt, "https://issuer.example", time.Now()); err != nil { + t.Fatalf("validateReceivedToken(no iss) = %v, want nil", err) + } +} diff --git a/cmd/entire/cli/logout.go b/cmd/entire/cli/logout.go index 77c8533976..0d44365035 100644 --- a/cmd/entire/cli/logout.go +++ b/cmd/entire/cli/logout.go @@ -33,7 +33,7 @@ func newLogoutCmd() *cobra.Command { return err } return runLogout(cmd.Context(), cmd.OutOrStdout(), cmd.ErrOrStderr(), - auth.NewStore(), defaultRevokeCurrentToken, api.BaseURL()) + auth.NewStore(), defaultRevokeCurrentToken, api.AuthBaseURL()) }, } addInsecureHTTPAuthFlag(cmd, &insecureHTTPAuth) @@ -41,7 +41,9 @@ func newLogoutCmd() *cobra.Command { } func defaultRevokeCurrentToken(ctx context.Context, token string) error { - return api.NewClient(token).RevokeCurrentToken(ctx) //nolint:wrapcheck // RevokeCurrentToken already wraps with action context + client := api.NewClientWithBaseURL(token, api.AuthBaseURL()). + WithAuthTokensPath(auth.CurrentProvider().AuthTokensPath) + return client.RevokeCurrentToken(ctx) //nolint:wrapcheck // RevokeCurrentToken already wraps with action context } func runLogout(ctx context.Context, outW, errW io.Writer, store tokenStore, revoke revokeCurrentFunc, baseURL string) error { diff --git a/cmd/entire/cli/recap.go b/cmd/entire/cli/recap.go index 5271ef543d..50be797d78 100644 --- a/cmd/entire/cli/recap.go +++ b/cmd/entire/cli/recap.go @@ -123,7 +123,7 @@ func runRecap(ctx context.Context, w, errW io.Writer, f *recapFlags) error { if err != nil { return err } - client, err := newRecapClient(f.insecureHTTP) + client, err := newRecapClient(ctx, f.insecureHTTP) if err != nil { var keyringErr *keyringReadError switch { @@ -178,8 +178,17 @@ func (e *keyringReadError) Unwrap() error { return e.Cause } // real auth error are not collapsed into one "sign in" hint. A keyring read // failure is surfaced as *keyringReadError so the caller can show a targeted // message instead of misattributing it to a missing login. -func newRecapClient(insecureHTTP bool) (*api.Client, error) { - token, err := auth.LookupCurrentToken() +// +// Goes through auth.TokenForResource so split-host deployments get a +// resource-scoped bearer via RFC 8693 exchange. ErrNotLoggedIn is +// collapsed back into an empty token so the caller's "render with no +// bearer, let the server respond 401" path still fires. +func newRecapClient(ctx context.Context, insecureHTTP bool) (*api.Client, error) { + token, err := auth.TokenForResource(ctx, api.BaseURL()) + if errors.Is(err, auth.ErrNotLoggedIn) { + token = "" + err = nil + } if err != nil { return nil, &keyringReadError{Cause: err} } diff --git a/cmd/entire/cli/search/search.go b/cmd/entire/cli/search/search.go index 2e58557da8..6f93bbb7fd 100644 --- a/cmd/entire/cli/search/search.go +++ b/cmd/entire/cli/search/search.go @@ -66,7 +66,11 @@ type Response struct { // Config holds the configuration for a search request. type Config struct { - ServiceURL string // Base URL of the search service + ServiceURL string // Base URL of the search service + // GitHubToken is a misnomer kept for backwards compatibility: + // callers populate it with the resource-scoped OAuth bearer from + // auth.TokenForResource(ctx, ServiceURL). The wire format is + // unchanged (Authorization: Bearer ). GitHubToken string Owner string Repo string diff --git a/cmd/entire/cli/search_cmd.go b/cmd/entire/cli/search_cmd.go index 299f46f186..55b2cbf00d 100644 --- a/cmd/entire/cli/search_cmd.go +++ b/cmd/entire/cli/search_cmd.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "io" + "net/url" "os" "strings" @@ -76,14 +77,6 @@ branch:, repo:, and repo:* to search all accessible repos.`, return errors.New("query required when using --json, accessible mode, or piped output. Usage: entire search ") } - ghToken, err := auth.LookupCurrentToken() - if err != nil { - return fmt.Errorf("reading credentials: %w", err) - } - if ghToken == "" { - return errors.New("not authenticated. Run 'entire login' to authenticate") - } - // Get the repo's GitHub remote URL repo, err := strategy.OpenRepository(ctx) if err != nil { @@ -108,7 +101,24 @@ branch:, repo:, and repo:* to search all accessible repos.`, serviceURL := os.Getenv("ENTIRE_SEARCH_URL") if serviceURL == "" { - serviceURL = search.DefaultServiceURL + // Search lives on the data API host. Fall back to + // api.BaseURL() so ENTIRE_API_BASE_URL applies; the search + // package's DefaultServiceURL is only consulted by callers + // that bypass this entry point. + serviceURL = api.BaseURL() + } + + // Resolve a bearer scoped to the search service host. In split-host + // deployments this triggers an RFC 8693 exchange so the bearer + // carries the data-API audience rather than the auth-host one; + // single-host setups hit the same-host shortcut and return the + // core token unchanged. + ghToken, err := auth.TokenForResource(ctx, searchTokenResourceURL(serviceURL)) + if errors.Is(err, auth.ErrNotLoggedIn) { + return errors.New("not authenticated. Run 'entire login' to authenticate") + } + if err != nil { + return fmt.Errorf("reading credentials: %w", err) } searchCfg := search.Config{ @@ -200,6 +210,15 @@ branch:, repo:, and repo:* to search all accessible repos.`, return cmd } +func searchTokenResourceURL(serviceURL string) string { + raw := strings.TrimSpace(serviceURL) + u, err := url.Parse(raw) + if err != nil || u.Scheme == "" || u.Host == "" { + return raw + } + return (&url.URL{Scheme: u.Scheme, Host: u.Host}).String() +} + // completeRepoFlag returns shell-completion suggestions for the search // command's --repo flag. "*" is always offered so the wildcard works // regardless of auth state. Errors are swallowed (rather than surfaced via @@ -207,7 +226,7 @@ branch:, repo:, and repo:* to search all accessible repos.`, // must never pollute the user's prompt with error output. func completeRepoFlag(cmd *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) { suggestions := []string{"*"} - client, err := NewAuthenticatedAPIClient(false) + client, err := NewAuthenticatedAPIClient(cmd.Context(), false) if err != nil { return suggestions, cobra.ShellCompDirectiveNoFileComp } diff --git a/cmd/entire/cli/search_cmd_test.go b/cmd/entire/cli/search_cmd_test.go index 484627655c..e0f3a7e6ba 100644 --- a/cmd/entire/cli/search_cmd_test.go +++ b/cmd/entire/cli/search_cmd_test.go @@ -76,3 +76,28 @@ func TestWriteSearchJSON_ZeroLimitFallsBackToDefaultPageSize(t *testing.T) { t.Fatalf("output missing total_pages:\n%s", output) } } + +func TestSearchTokenResourceURL_NormalizesToOrigin(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + serviceURL string + want string + }{ + {"plain origin", "https://entire.io", "https://entire.io"}, + {"trailing slash", "https://entire.io/", "https://entire.io"}, + {"pathful search URL", "https://entire.io/custom/search", "https://entire.io"}, + {"localhost port", "http://localhost:8787/search", "http://localhost:8787"}, + {"parse fallback", "://not-a-url", "://not-a-url"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := searchTokenResourceURL(tt.serviceURL); got != tt.want { + t.Fatalf("searchTokenResourceURL(%q) = %q, want %q", tt.serviceURL, got, tt.want) + } + }) + } +} diff --git a/cmd/entire/cli/trail_cmd.go b/cmd/entire/cli/trail_cmd.go index 4b71f33746..5833a3d1af 100644 --- a/cmd/entire/cli/trail_cmd.go +++ b/cmd/entire/cli/trail_cmd.go @@ -68,7 +68,7 @@ func runTrailShow(ctx context.Context, w io.Writer, insecureHTTP bool) error { return runTrailListAll(ctx, w, "", false, false, insecureHTTP) } - client, err := NewAuthenticatedAPIClient(insecureHTTP) + client, err := NewAuthenticatedAPIClient(ctx, insecureHTTP) if err != nil { return fmt.Errorf("authentication required: %w", err) } @@ -136,7 +136,7 @@ func newTrailListCmd() *cobra.Command { } func runTrailListAll(ctx context.Context, w io.Writer, statusFilter string, jsonOutput, showAll, insecureHTTP bool) error { - client, err := NewAuthenticatedAPIClient(insecureHTTP) + client, err := NewAuthenticatedAPIClient(ctx, insecureHTTP) if err != nil { return fmt.Errorf("authentication required: %w", err) } @@ -323,7 +323,7 @@ func runTrailCreate(cmd *cobra.Command, title, body, base, branch, statusStr str // --- Phase 2: API operations --- - client, err := NewAuthenticatedAPIClient(trailInsecureHTTP(cmd)) + client, err := NewAuthenticatedAPIClient(cmd.Context(), trailInsecureHTTP(cmd)) if err != nil { return fmt.Errorf("authentication required: %w", err) } @@ -410,7 +410,7 @@ func newTrailUpdateCmd() *cobra.Command { func runTrailUpdate(ctx context.Context, w, errW io.Writer, insecureHTTP bool, statusStr, title, body, branch string, labelAdd, labelRemove []string) error { _ = errW // reserved for future warnings - client, err := NewAuthenticatedAPIClient(insecureHTTP) + client, err := NewAuthenticatedAPIClient(ctx, insecureHTTP) if err != nil { return fmt.Errorf("authentication required: %w", err) } diff --git a/cmd/entire/cli/trail_watch_cmd.go b/cmd/entire/cli/trail_watch_cmd.go index 4732487d8a..413fea5300 100644 --- a/cmd/entire/cli/trail_watch_cmd.go +++ b/cmd/entire/cli/trail_watch_cmd.go @@ -88,7 +88,7 @@ func runTrailWatch(cmd *cobra.Command, number int, jsonOutput, showPings, once b w := cmd.OutOrStdout() errW := cmd.ErrOrStderr() - client, err := NewAuthenticatedAPIClient(trailInsecureHTTP(cmd)) + client, err := NewAuthenticatedAPIClient(ctx, trailInsecureHTTP(cmd)) if err != nil { return fmt.Errorf("authentication required: %w", err) } diff --git a/go.mod b/go.mod index 594a3230af..7b47dcb9ab 100644 --- a/go.mod +++ b/go.mod @@ -68,6 +68,7 @@ require ( github.com/dsnet/compress v0.0.2-0.20230904184137-39efe44ab707 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/emirpasic/gods v1.18.1 // indirect + github.com/entireio/auth-go v0.1.0 github.com/fatih/semgroup v1.2.0 // indirect github.com/fsnotify/fsnotify v1.8.0 // indirect github.com/gitleaks/go-gitdiff v0.9.1 // indirect diff --git a/go.sum b/go.sum index 0d9045ee92..f4a16d3656 100644 --- a/go.sum +++ b/go.sum @@ -109,6 +109,8 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= +github.com/entireio/auth-go v0.1.0 h1:+1d+jdkxWHTEdTBs1xwZwhIp8g+CmnyVdDaNS6hirE4= +github.com/entireio/auth-go v0.1.0/go.mod h1:tt7T8auf+cZritzm2qeqGQgioLUU3XEwNn6qHiESH08= github.com/fatih/semgroup v1.2.0 h1:h/OLXwEM+3NNyAdZEpMiH1OzfplU09i2qXPVThGZvyg= github.com/fatih/semgroup v1.2.0/go.mod h1:1KAD4iIYfXjE4U13B48VM4z9QUwV5Tt8O4rS879kgm8= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=