Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion auth/deviceflow/deviceflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ type DeviceCode struct {
Interval int `json:"interval"`
}

// DefaultRequestTimeout caps a single device-flow HTTP round-trip
// (StartDeviceAuth or one PollDeviceAuth call). Set conservatively:
// healthy device-flow endpoints respond in sub-seconds, so the cap
// mainly defends against slow-loris responses dripping bytes within
// MaxResponseBytes — see Client.RequestTimeout for the per-Client
// override. The polling-loop interval is the caller's concern; this
// timeout governs only the individual HTTP request.
const DefaultRequestTimeout = 30 * time.Second

// Client polls an RFC 8628 device authorization grant.
//
// All configuration is explicit; the package has no global state and
Expand All @@ -60,6 +69,26 @@ type Client struct {
UserAgent string
DeviceCodePath string
TokenPath string

// RequestTimeout is the per-request deadline applied via
// context.WithTimeout on top of the caller's context. Zero falls
// back to DefaultRequestTimeout. Negative disables the cap (useful
// for tests that want to drive timing via the caller's ctx alone).
RequestTimeout time.Duration
}

// requestTimeout resolves the effective per-request timeout: the
// configured RequestTimeout if positive, the package default if zero,
// or zero (no cap) if negative.
func (c *Client) requestTimeout() time.Duration {
switch {
case c.RequestTimeout < 0:
return 0
case c.RequestTimeout == 0:
return DefaultRequestTimeout
default:
return c.RequestTimeout
}
}

// Sentinel errors returned by PollDeviceAuth when the token endpoint
Expand Down Expand Up @@ -110,6 +139,12 @@ func errCodeToSentinel(code string) error {
// server. The returned DeviceCode is opaque to the client; pass it
// back unmodified on every PollDeviceAuth.
func (c *Client) StartDeviceAuth(ctx context.Context) (*DeviceCode, error) {
if timeout := c.requestTimeout(); timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}

body := url.Values{}
body.Set("client_id", c.ClientID)
if c.Scope != "" {
Expand Down Expand Up @@ -141,6 +176,12 @@ func (c *Client) StartDeviceAuth(ctx context.Context) (*DeviceCode, error) {
// the matching sentinel error from this package. Other failures
// (network, malformed responses) are wrapped with context.
func (c *Client) PollDeviceAuth(ctx context.Context, deviceCode string) (*tokens.TokenSet, error) {
if timeout := c.requestTimeout(); timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}

body := url.Values{}
body.Set("grant_type", deviceCodeGrantType)
body.Set("client_id", c.ClientID)
Expand Down Expand Up @@ -195,7 +236,10 @@ func (c *Client) PollDeviceAuth(ctx context.Context, deviceCode string) (*tokens
}

// postForm POSTs body as application/x-www-form-urlencoded to a path
// resolved against the client's BaseURL.
// resolved against the client's BaseURL. The caller is responsible
// for applying any per-request timeout via context.WithTimeout — the
// timeout must cover the body-read that happens after postForm
// returns, so cancel-on-return here would interrupt that read.
func (c *Client) postForm(ctx context.Context, path string, body url.Values) (*http.Response, error) {
endpoint, err := resolveURL(c.BaseURL, path)
if err != nil {
Expand Down
80 changes: 80 additions & 0 deletions auth/deviceflow/deviceflow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,3 +360,83 @@ func TestResolveURL(t *testing.T) {
})
}
}

// TestPollDeviceAuth_RequestTimeoutFires pins the slow-loris defence:
// a handler that never finishes writing a response must surface as a
// context deadline error rather than blocking the polling loop forever.
func TestPollDeviceAuth_RequestTimeoutFires(t *testing.T) {
t.Parallel()
// Cleanup is LIFO and httptest.Server.Close waits for active
// handler goroutines, so close(hung) is registered AFTER
// newTestClient to fire first and let the handler exit before
// srv.Close runs.
hung := make(chan struct{})
c := newTestClient(t, func(_ http.ResponseWriter, r *http.Request) {
select {
case <-hung:
case <-r.Context().Done():
}
})
t.Cleanup(func() { close(hung) })
c.RequestTimeout = 50 * time.Millisecond

_, err := c.PollDeviceAuth(context.Background(), "dev-1")
if err == nil {
t.Fatal("expected timeout error, got nil")
}
if !strings.Contains(err.Error(), "context deadline exceeded") {
t.Fatalf("err = %v, want context deadline exceeded", err)
}
}

// TestStartDeviceAuth_RequestTimeoutFires mirrors the poll-side test
// for the device-code endpoint.
func TestStartDeviceAuth_RequestTimeoutFires(t *testing.T) {
t.Parallel()
// Cleanup is LIFO and httptest.Server.Close waits for active
// handler goroutines, so close(hung) is registered AFTER
// newTestClient to fire first and let the handler exit before
// srv.Close runs.
hung := make(chan struct{})
c := newTestClient(t, func(_ http.ResponseWriter, r *http.Request) {
select {
case <-hung:
case <-r.Context().Done():
}
})
t.Cleanup(func() { close(hung) })
c.RequestTimeout = 50 * time.Millisecond

_, err := c.StartDeviceAuth(context.Background())
if err == nil {
t.Fatal("expected timeout error, got nil")
}
if !strings.Contains(err.Error(), "context deadline exceeded") {
t.Fatalf("err = %v, want context deadline exceeded", err)
}
}

// TestRequestTimeout_DefaultAndOverride exercises the timeout policy
// without doing IO — pure resolution of the (zero / negative /
// positive) input contract.
func TestRequestTimeout_DefaultAndOverride(t *testing.T) {
t.Parallel()
cases := []struct {
name string
in time.Duration
want time.Duration
}{
{"zero -> default", 0, DefaultRequestTimeout},
{"negative -> disabled", -1, 0},
{"positive -> verbatim", 5 * time.Second, 5 * time.Second},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
c := &Client{RequestTimeout: tc.in}
if got := c.requestTimeout(); got != tc.want {
t.Fatalf("requestTimeout() = %v, want %v", got, tc.want)
}
})
}
}
33 changes: 33 additions & 0 deletions auth/sts/sts.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ func (r ExchangeRequest) validate() error {
return nil
}

// DefaultRequestTimeout caps a single token-exchange round-trip. Set
// conservatively: even with a slow auth host plus TLS handshake, a
// healthy exchange completes in sub-seconds. The cap mainly defends
// against slow-loris responses dripping bytes within MaxResponseBytes
// — see Client.RequestTimeout for the per-Client override.
const DefaultRequestTimeout = 30 * time.Second

// Client exchanges subject tokens for tokens of a different type at an
// RFC 8693 token endpoint.
//
Expand All @@ -78,6 +85,12 @@ type Client struct {
BaseURL string
Path string
UserAgent string

// RequestTimeout is the per-Exchange deadline applied via
// context.WithTimeout on top of the caller's context. Zero falls
// back to DefaultRequestTimeout. Negative disables the cap (useful
// for tests that want to drive timing via the caller's ctx alone).
RequestTimeout time.Duration
}

// Exchange performs one RFC 8693 token exchange.
Expand All @@ -98,6 +111,12 @@ func (c *Client) Exchange(ctx context.Context, req ExchangeRequest) (*tokens.Tok
return nil, fmt.Errorf("token exchange: resolve URL: %w", err)
}

if timeout := c.requestTimeout(); timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}

httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode()))
if err != nil {
return nil, fmt.Errorf("token exchange: create request: %w", err)
Expand Down Expand Up @@ -176,6 +195,20 @@ func buildForm(req ExchangeRequest) url.Values {
return form
}

// requestTimeout resolves the effective per-request timeout: the
// configured RequestTimeout if positive, the package default if zero,
// or zero (no cap) if negative.
func (c *Client) requestTimeout() time.Duration {
switch {
case c.RequestTimeout < 0:
return 0
case c.RequestTimeout == 0:
return DefaultRequestTimeout
default:
return c.RequestTimeout
}
}

func resolveURL(baseURL, path string) (string, error) {
base, err := url.Parse(baseURL)
if err != nil {
Expand Down
59 changes: 59 additions & 0 deletions auth/sts/sts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,62 @@ func TestExchange_NoExpiry(t *testing.T) {
t.Fatalf("ExpiresAt = %v, want zero", got.ExpiresAt)
}
}

// TestExchange_RequestTimeoutFires pins the slow-loris defence: a
// handler that never writes a response body must surface as a context
// deadline error rather than blocking the caller indefinitely.
//
// Cleanup order matters: t.Cleanup is LIFO, and httptest.Server.Close
// waits for in-flight handler goroutines to return. We register
// `close(hung)` AFTER newTestClient so it fires first and lets the
// handler exit before srv.Close runs.
func TestExchange_RequestTimeoutFires(t *testing.T) {
t.Parallel()
hung := make(chan struct{})

c := newTestClient(t, func(_ http.ResponseWriter, r *http.Request) {
select {
case <-hung:
case <-r.Context().Done():
}
})
t.Cleanup(func() { close(hung) })
c.RequestTimeout = 50 * time.Millisecond

_, err := c.Exchange(context.Background(), ExchangeRequest{
SubjectToken: "sub",
SubjectTokenType: SubjectTokenTypeJWT,
RequestedTokenType: "urn:example:t",
})
if err == nil {
t.Fatal("expected timeout error, got nil")
}
if !strings.Contains(err.Error(), "context deadline exceeded") {
t.Fatalf("err = %v, want context deadline exceeded", err)
}
}

// TestRequestTimeout_DefaultAndOverride exercises the timeout policy
// without doing IO — pure resolution of the (zero / negative /
// positive) input contract.
func TestRequestTimeout_DefaultAndOverride(t *testing.T) {
t.Parallel()
cases := []struct {
name string
in time.Duration
want time.Duration
}{
{"zero -> default", 0, DefaultRequestTimeout},
{"negative -> disabled", -1, 0},
{"positive -> verbatim", 5 * time.Second, 5 * time.Second},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
c := &Client{RequestTimeout: tc.in}
if got := c.requestTimeout(); got != tc.want {
t.Fatalf("requestTimeout() = %v, want %v", got, tc.want)
}
})
}
}
Loading
Loading