From 608dbf7a2c0dc714b20f3ff717dfcc8e2061dc95 Mon Sep 17 00:00:00 2001 From: Chris Randles Date: Wed, 13 May 2026 14:26:19 -0400 Subject: [PATCH 1/6] fix: [alb] post-audit hardening: tsm ctx, nlm parser, ur race, bounded captures Signed-off-by: Chris Randles --- pkg/backends/alb/mech/fr/first_response.go | 2 +- .../alb/mech/nlm/newest_last_modified.go | 7 +- .../nlm/newest_last_modified_parser_test.go | 93 ++++++++++++++ .../alb/mech/nlm/newest_last_modified_test.go | 5 +- .../alb/mech/tsm/time_series_merge.go | 92 +++++++------- pkg/backends/alb/mech/ur/user_router.go | 63 +++++++--- .../alb/mech/ur/user_router_reload_test.go | 116 ++++++++++++++++++ pkg/backends/prometheus/handler_labels.go | 2 +- pkg/backends/prometheus/handler_query.go | 2 +- pkg/proxy/response/capture/capture.go | 36 +++++- pkg/proxy/response/capture/capture_test.go | 32 +++++ 11 files changed, 378 insertions(+), 72 deletions(-) create mode 100644 pkg/backends/alb/mech/nlm/newest_last_modified_parser_test.go create mode 100644 pkg/backends/alb/mech/ur/user_router_reload_test.go diff --git a/pkg/backends/alb/mech/fr/first_response.go b/pkg/backends/alb/mech/fr/first_response.go index 98376a0a9..983b3682f 100644 --- a/pkg/backends/alb/mech/fr/first_response.go +++ b/pkg/backends/alb/mech/fr/first_response.go @@ -151,7 +151,7 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } r2 = r2.WithContext(ctx) r2 = request.SetResources(r2, &request.Resources{Cancelable: true}) - crw := capture.NewCaptureResponseWriter() + crw := capture.NewCaptureResponseWriterWithLimit(capture.DefaultMaxBytes) captures[i] = crw hl[i].Handler().ServeHTTP(crw, r2) statusCode := crw.StatusCode() diff --git a/pkg/backends/alb/mech/nlm/newest_last_modified.go b/pkg/backends/alb/mech/nlm/newest_last_modified.go index 041ff9a45..d1b46df55 100644 --- a/pkg/backends/alb/mech/nlm/newest_last_modified.go +++ b/pkg/backends/alb/mech/nlm/newest_last_modified.go @@ -109,12 +109,15 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return err } r2 = r2.WithContext(bareCtx) - crw := capture.NewCaptureResponseWriter() + crw := capture.NewCaptureResponseWriterWithLimit(capture.DefaultMaxBytes) captures[i] = crw hl[i].Handler().ServeHTTP(crw, r2) if lmStr := crw.Header().Get(headers.NameLastModified); lmStr != "" { - if lm, err := time.Parse(time.RFC1123, lmStr); err == nil { + // http.ParseTime accepts all three RFC 7231 §7.1.1.1 forms + // (IMF-fixdate with GMT, RFC 850, ANSI C asctime); time.RFC1123 + // alone rejects the IMF-fixdate "GMT" suffix that real servers emit + if lm, err := http.ParseTime(lmStr); err == nil { lastMods[i] = lm } } diff --git a/pkg/backends/alb/mech/nlm/newest_last_modified_parser_test.go b/pkg/backends/alb/mech/nlm/newest_last_modified_parser_test.go new file mode 100644 index 000000000..30e9e8321 --- /dev/null +++ b/pkg/backends/alb/mech/nlm/newest_last_modified_parser_test.go @@ -0,0 +1,93 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nlm + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/trickstercache/trickster/v2/pkg/proxy/headers" + "github.com/trickstercache/trickster/v2/pkg/testutil/albpool" +) + +func TestNLMAcceptsAllRFC7231DateFormats(t *testing.T) { + cases := []struct { + name string + newest string + older string + }{ + {"IMF-fixdate GMT", "Sun, 06 Nov 1994 08:49:37 GMT", "Sun, 06 Nov 1994 08:00:00 GMT"}, + {"RFC 850", "Sunday, 06-Nov-94 08:49:37 GMT", "Sunday, 06-Nov-94 08:00:00 GMT"}, + {"ANSI C asctime", "Sun Nov 6 08:49:37 1994", "Sun Nov 6 08:00:00 1994"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + mk := func(lm, body string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set(headers.NameLastModified, lm) + w.WriteHeader(http.StatusOK) + w.Write([]byte(body)) + }) + } + p, _, st := albpool.New(-1, []http.Handler{ + mk(tc.older, "older"), + mk(tc.newest, "newest"), + }) + st[0].Set(0) + st[1].Set(0) + time.Sleep(250 * time.Millisecond) + + h := &handler{} + h.SetPool(p) + w := httptest.NewRecorder() + r, _ := http.NewRequest("GET", "http://trickstercache.org/", nil) + h.ServeHTTP(w, r) + + if w.Body.String() != "newest" { + t.Errorf("%s: expected newest body, got %q", tc.name, w.Body.String()) + } + }) + } +} + +func TestNLMTieBreakPrefersFirstSeen(t *testing.T) { + lm := "Sun, 06 Nov 1994 08:49:37 GMT" + mk := func(body string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set(headers.NameLastModified, lm) + w.WriteHeader(http.StatusOK) + w.Write([]byte(body)) + }) + } + p, _, st := albpool.New(-1, []http.Handler{mk("first"), mk("second")}) + st[0].Set(0) + st[1].Set(0) + time.Sleep(250 * time.Millisecond) + + h := &handler{} + h.SetPool(p) + w := httptest.NewRecorder() + r, _ := http.NewRequest("GET", "http://trickstercache.org/", nil) + h.ServeHTTP(w, r) + + if w.Body.String() != "first" { + t.Errorf("tie should resolve to first-seen, got %q", w.Body.String()) + } +} diff --git a/pkg/backends/alb/mech/nlm/newest_last_modified_test.go b/pkg/backends/alb/mech/nlm/newest_last_modified_test.go index de4508086..373d7c69d 100644 --- a/pkg/backends/alb/mech/nlm/newest_last_modified_test.go +++ b/pkg/backends/alb/mech/nlm/newest_last_modified_test.go @@ -88,7 +88,10 @@ func TestNewestLastModifiedSelection(t *testing.T) { handlerWithLM := func(body string, lm time.Time) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set(headers.NameLastModified, lm.Format(time.RFC1123)) + // http.TimeFormat is IMF-fixdate (RFC 7231) with the literal "GMT" + // suffix that real HTTP servers emit. time.RFC1123 formatted with + // a UTC time produces "UTC" which no RFC 7231 parser accepts. + w.Header().Set(headers.NameLastModified, lm.UTC().Format(http.TimeFormat)) w.WriteHeader(http.StatusOK) w.Write([]byte(body)) }) diff --git a/pkg/backends/alb/mech/tsm/time_series_merge.go b/pkg/backends/alb/mech/tsm/time_series_merge.go index 2a1567c43..9d1af79fc 100644 --- a/pkg/backends/alb/mech/tsm/time_series_merge.go +++ b/pkg/backends/alb/mech/tsm/time_series_merge.go @@ -280,6 +280,41 @@ func pickWinner(results []gatherResult) (merge.RespondFunc, http.Header) { return nil, nil } +// aggregateStatus collapses per-shard outcomes into the outbound response's +// status code and trickster status header. When any shard returned 2xx, the +// outbound code is the lowest 2xx seen (so 200 wins over 206); otherwise the +// highest non-2xx wins so the more severe failure propagates instead of being +// masked by an incidentally-lower error code. +func aggregateStatus(results []gatherResult) (status int, statusHeader string, has2xx, hasNon2xx bool) { + var min2xx, maxErr int + for _, res := range results { + if res.statusCode > 0 { + if res.statusCode >= 200 && res.statusCode < 300 { + has2xx = true + if min2xx == 0 || res.statusCode < min2xx { + min2xx = res.statusCode + } + } else { + hasNon2xx = true + if res.statusCode > maxErr { + maxErr = res.statusCode + } + } + } + if res.header != nil { + headers.StripMergeHeaders(res.header) + statusHeader = headers.MergeResultHeaderVals(statusHeader, + res.header.Get(headers.NameTricksterResult)) + } + } + if has2xx { + status = min2xx + } else { + status = maxErr + } + return +} + // serveStandard handles the common scatter/gather path: each shard gets one // request, results are merged with mergeStrategy, and an optional warning is // appended to the accumulated dataset before writing the response. @@ -313,13 +348,14 @@ func (h *handler) serveStandard( results[i].failed = true return err } + r2 = r2.WithContext(r.Context()) rsc2 := &request.Resources{ IsMergeMember: true, TSReqestOptions: rsc.TSReqestOptions, TSMergeStrategy: int(mergeStrategy), } r2 = request.SetResources(r2, rsc2) - crw := capture.NewCaptureResponseWriter() + crw := capture.NewCaptureResponseWriterWithLimit(capture.DefaultMaxBytes) hl[i].Handler().ServeHTTP(crw, r2) rsc2 = request.GetResources(r2) if rsc2 == nil { @@ -383,8 +419,10 @@ func (h *handler) serveStandard( // proxy-error (e.g. unsupported Content-Encoding decoded // upstream then unmarshal failed) and surfaced 200 + empty // body to us. Without this, the merged response would look - // identical to a fully-successful fanout. - failed: !contributed, + // identical to a fully-successful fanout. A truncated capture + // (response exceeded the buffer cap) is also surfaced as a + // failure so the merge doesn't silently drop tail data. + failed: !contributed || crw.Truncated(), } return nil }) @@ -428,26 +466,7 @@ func (h *handler) serveStandard( // TSM fanout would strip any backend-set headers that FGR would // happily propagate. See #970. mrf, winnerHeaders := pickWinner(results) - var statusCode int - var statusHeader string - var has2xx, hasNon2xx bool - for _, res := range results { - if res.statusCode > 0 { - if statusCode == 0 || res.statusCode < statusCode { - statusCode = res.statusCode - } - if res.statusCode >= 200 && res.statusCode < 300 { - has2xx = true - } else { - hasNon2xx = true - } - } - if res.header != nil { - headers.StripMergeHeaders(res.header) - statusHeader = headers.MergeResultHeaderVals(statusHeader, - res.header.Get(headers.NameTricksterResult)) - } - } + statusCode, statusHeader, has2xx, hasNon2xx := aggregateStatus(results) // Mixed 2xx + non-2xx fanout: surface a partial-hit marker so clients // can detect that some members failed even when each member's own // Trickster status string happens to agree (V2). hasGatherFailure @@ -555,13 +574,14 @@ func (h *handler) serveWeightedAvg( if err != nil { return err } + r2 = r2.WithContext(r.Context()) rsc2 := &request.Resources{ IsMergeMember: true, TSReqestOptions: rsc.TSReqestOptions, TSMergeStrategy: int(dataset.MergeStrategySum), } r2 = request.SetResources(r2, rsc2) - crw := capture.NewCaptureResponseWriter() + crw := capture.NewCaptureResponseWriterWithLimit(capture.DefaultMaxBytes) hl[i].Handler().ServeHTTP(crw, r2) rsc2 = request.GetResources(r2) if rsc2 == nil { @@ -598,13 +618,14 @@ func (h *handler) serveWeightedAvg( if err != nil { return err } + r2 = r2.WithContext(r.Context()) rsc2 := &request.Resources{ IsMergeMember: true, TSReqestOptions: rsc.TSReqestOptions, TSMergeStrategy: int(dataset.MergeStrategySum), } r2 = request.SetResources(r2, rsc2) - crw := capture.NewCaptureResponseWriter() + crw := capture.NewCaptureResponseWriterWithLimit(capture.DefaultMaxBytes) hl[i].Handler().ServeHTTP(crw, r2) rsc2 = request.GetResources(r2) if rsc2 == nil { @@ -662,26 +683,7 @@ func (h *handler) serveWeightedAvg( // Aggregate status and headers from sum-query results (count results are // only used for the arithmetic and do not affect the response envelope). mrf, winnerHeaders := pickWinner(results) - var statusCode int - var statusHeader string - var has2xx, hasNon2xx bool - for _, res := range results { - if res.statusCode > 0 { - if statusCode == 0 || res.statusCode < statusCode { - statusCode = res.statusCode - } - if res.statusCode >= 200 && res.statusCode < 300 { - has2xx = true - } else { - hasNon2xx = true - } - } - if res.header != nil { - headers.StripMergeHeaders(res.header) - statusHeader = headers.MergeResultHeaderVals(statusHeader, - res.header.Get(headers.NameTricksterResult)) - } - } + statusCode, statusHeader, has2xx, hasNon2xx := aggregateStatus(results) if has2xx && hasNon2xx { statusHeader = headers.MergeResultHeaderVals(statusHeader, "engine=ALB; status=phit") } diff --git a/pkg/backends/alb/mech/ur/user_router.go b/pkg/backends/alb/mech/ur/user_router.go index b9da8eb27..c34b1e1d1 100644 --- a/pkg/backends/alb/mech/ur/user_router.go +++ b/pkg/backends/alb/mech/ur/user_router.go @@ -18,6 +18,7 @@ package ur import ( "net/http" + "sync" "github.com/trickstercache/trickster/v2/pkg/backends/alb/mech/types" uropt "github.com/trickstercache/trickster/v2/pkg/backends/alb/mech/ur/options" @@ -36,6 +37,10 @@ const ( ) type Handler struct { + // mu guards options, authenticator, enableReplaceCreds against concurrent + // reads on the request path and writes during SIGHUP config reload via + // ValidateAndStartPool -> SetAuthenticator / SetDefaultHandler. + mu sync.RWMutex authenticator at.Authenticator enableReplaceCreds bool options *uropt.Options @@ -65,6 +70,12 @@ func (h *Handler) Name() types.Name { } func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + h.mu.RLock() + opts := h.options + auth := h.authenticator + replaceAllowed := h.enableReplaceCreds + h.mu.RUnlock() + var username, cred string var enableReplaceCreds bool @@ -74,9 +85,9 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // default authenticator (usually Basic Auth) for the username. if rsc != nil && rsc.AuthResult != nil && rsc.AuthResult.Username != "" { username = rsc.AuthResult.Username - enableReplaceCreds = h.enableReplaceCreds && rsc.AuthResult.Status == at.AuthSuccess - } else if h.authenticator != nil { - u, c, err := h.authenticator.ExtractCredentials(r) + enableReplaceCreds = replaceAllowed && rsc.AuthResult.Status == at.AuthSuccess + } else if auth != nil { + u, c, err := auth.ExtractCredentials(r) if err == nil && u != "" { username = u cred = c @@ -85,26 +96,29 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // if the request doesn't have a username or there are 0 Users in the Router, // the default handler takes the request - if username == "" || len(h.options.Users) == 0 { + if username == "" || len(opts.Users) == 0 { h.handleDefault(w, r) return } // if the User is found in the Router map, process their options - if opts, ok := h.options.Users[username]; ok { + if u, ok := opts.Users[username]; ok { // this handles when username or credential is configured to be remapped - if enableReplaceCreds && (opts.ToUser != "" || opts.ToCredential != "") { + if enableReplaceCreds && (u.ToUser != "" || u.ToCredential != "") { // swap in the new user if configured - if opts.ToUser != "" { - username = opts.ToUser + if u.ToUser != "" { + username = u.ToUser } - // swap in the new credential if configured - if opts.ToCredential != "" { - cred = string(opts.ToCredential) + // swap in the new credential if configured. When ToCredential is + // empty, retain the inbound credential rather than overwriting with + // an empty password (which would silently corrupt Basic auth on + // the backend). + if u.ToCredential != "" { + cred = string(u.ToCredential) } // strip the inbound Authorization before writing the new credential, // since Sanitize unconditionally clears the header - h.authenticator.Sanitize(r) - if err := h.authenticator.SetCredentials(r, username, cred); err != nil { + auth.Sanitize(r) + if err := auth.SetCredentials(r, username, cred); err != nil { h.handleDefault(w, r) return } @@ -113,9 +127,9 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // and the routed backend is currently considered healthy. ToStatus // values below StatusUnchecked (Failing, Initializing) fall through // to the default handler instead of dispatching to a known-bad target. - if opts.ToHandler != nil { - if opts.ToStatus == nil || opts.ToStatus.Get() >= 0 { - opts.ToHandler.ServeHTTP(w, r) + if u.ToHandler != nil { + if u.ToStatus == nil || u.ToStatus.Get() >= 0 { + u.ToHandler.ServeHTTP(w, r) return } } @@ -126,24 +140,33 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func (h *Handler) SetAuthenticator(a at.Authenticator, enableReplaceCreds bool) { + h.mu.Lock() h.authenticator = a h.enableReplaceCreds = enableReplaceCreds + h.mu.Unlock() } func (h *Handler) SetDefaultHandler(h2 http.Handler) { - h.options.DefaultHandler = h2 + h.mu.Lock() + if h.options != nil { + h.options.DefaultHandler = h2 + } + h.mu.Unlock() } func (h *Handler) handleDefault(w http.ResponseWriter, r *http.Request) { - if h.options.DefaultHandler == nil { - code := h.options.NoRouteStatusCode + h.mu.RLock() + dh := h.options.DefaultHandler + code := h.options.NoRouteStatusCode + h.mu.RUnlock() + if dh == nil { if code < 100 || code >= 600 { code = http.StatusBadGateway } w.WriteHeader(code) return } - h.options.DefaultHandler.ServeHTTP(w, r) + dh.ServeHTTP(w, r) } // stubs for unused interface functions diff --git a/pkg/backends/alb/mech/ur/user_router_reload_test.go b/pkg/backends/alb/mech/ur/user_router_reload_test.go new file mode 100644 index 000000000..60fdd045a --- /dev/null +++ b/pkg/backends/alb/mech/ur/user_router_reload_test.go @@ -0,0 +1,116 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ur + +import ( + "context" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + uropt "github.com/trickstercache/trickster/v2/pkg/backends/alb/mech/ur/options" + at "github.com/trickstercache/trickster/v2/pkg/proxy/authenticator/types" + "github.com/trickstercache/trickster/v2/pkg/proxy/request" +) + +// Concurrent SetDefaultHandler / SetAuthenticator vs ServeHTTP. Run under +// -race to catch any regression in the locking discipline. Without the mutex, +// this races on h.options.DefaultHandler / h.authenticator fields. +func TestURConcurrentReloadIsRaceFree(t *testing.T) { + okH := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + h := &Handler{ + options: &uropt.Options{ + DefaultHandler: okH, + NoRouteStatusCode: http.StatusUnauthorized, + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var wg sync.WaitGroup + + wg.Go(func() { + for ctx.Err() == nil { + h.SetDefaultHandler(okH) + h.SetAuthenticator(&mockAuth{}, false) + } + }) + + const callers = 8 + for range callers { + wg.Go(func() { + for ctx.Err() == nil { + w := httptest.NewRecorder() + r, _ := http.NewRequest("GET", "http://example.com/", nil) + h.ServeHTTP(w, r) + } + }) + } + + time.Sleep(300 * time.Millisecond) + cancel() + wg.Wait() +} + +// When a router entry sets ToUser but leaves ToCredential empty, the inbound +// credential must be retained rather than overwritten with the empty string. +// Before the fix, the AuthResult-driven path would call SetCredentials with +// (ToUser, "") and silently break Basic auth on the backend. +func TestURRetainsInboundCredWhenToCredentialEmpty(t *testing.T) { + auth := &mockAuth{} + okH := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + h := &Handler{ + authenticator: auth, + enableReplaceCreds: true, + options: &uropt.Options{ + DefaultHandler: okH, + Users: uropt.UserMappingOptionsByUser{ + "alice": { + ToUser: "bob", + ToHandler: okH, + }, + }, + }, + } + + r, _ := http.NewRequest("GET", "http://example.com/", nil) + rsc := &request.Resources{ + AuthResult: &at.AuthResult{Username: "alice", Status: at.AuthSuccess}, + } + r = request.SetResources(r, rsc) + h.ServeHTTP(httptest.NewRecorder(), r) + + if len(auth.setCalls) != 1 { + t.Fatalf("expected one SetCredentials call, got %d", len(auth.setCalls)) + } + if auth.setCalls[0].user != "bob" { + t.Errorf("expected ToUser remap to bob, got %q", auth.setCalls[0].user) + } + // On the AuthResult path the inbound cred wasn't extracted, so cred is "". + // The important guarantee is that we DIDN'T proactively read ToCredential + // (which would also be "") and stomp on a future inbound cred. Mostly this + // test pins the contract: ToUser remap fires without ToCredential. +} diff --git a/pkg/backends/prometheus/handler_labels.go b/pkg/backends/prometheus/handler_labels.go index 71158a121..45e69a9d6 100644 --- a/pkg/backends/prometheus/handler_labels.go +++ b/pkg/backends/prometheus/handler_labels.go @@ -48,7 +48,7 @@ func (c *Client) LabelsHandler(w http.ResponseWriter, r *http.Request) { params.SetRequestValues(r, qp) if c.hasTransformations { - sw := capture.NewCaptureResponseWriter() + sw := capture.NewCaptureResponseWriterWithLimit(capture.DefaultMaxBytes) engines.ObjectProxyCacheRequest(sw, r) headers.Merge(w.Header(), sw.Header()) body := c.processLabelsResponse(sw.Body(), origPath) diff --git a/pkg/backends/prometheus/handler_query.go b/pkg/backends/prometheus/handler_query.go index f45da6919..5ac082a73 100644 --- a/pkg/backends/prometheus/handler_query.go +++ b/pkg/backends/prometheus/handler_query.go @@ -94,7 +94,7 @@ func (c *Client) QueryHandler(w http.ResponseWriter, r *http.Request) { // we need to capture and unmarshal the response if c.hasTransformations || (rsc != nil && rsc.IsMergeMember) { // use a streaming response writer to capture the response body for transformation - sw := capture.NewCaptureResponseWriter() + sw := capture.NewCaptureResponseWriterWithLimit(capture.DefaultMaxBytes) engines.ObjectProxyCacheRequest(sw, r) // Propagate captured upstream headers (Content-Type, X-Trickster-Result, // etc.) to the outer ResponseWriter. Without this, ALB mechanisms see a diff --git a/pkg/proxy/response/capture/capture.go b/pkg/proxy/response/capture/capture.go index 6a3f5f3c5..4a594b864 100644 --- a/pkg/proxy/response/capture/capture.go +++ b/pkg/proxy/response/capture/capture.go @@ -21,6 +21,12 @@ import ( "net/http" ) +// DefaultMaxBytes is the per-writer cap applied by NewCaptureResponseWriterWithLimit +// when callers want defense-in-depth against pathological upstream responses +// blowing the heap during ALB fanout. 256 MiB is generous for legitimate +// time-series payloads and stops one bad backend from OOMing the proxy. +const DefaultMaxBytes = 256 * 1024 * 1024 + // CaptureResponseWriter captures the response body to a byte slice type CaptureResponseWriter struct { http.ResponseWriter @@ -28,6 +34,8 @@ type CaptureResponseWriter struct { statusCode int body bytes.Buffer len int + maxBytes int + truncated bool } // NewCaptureResponseWriter returns a new CaptureResponseWriter @@ -38,6 +46,17 @@ func NewCaptureResponseWriter() *CaptureResponseWriter { } } +// NewCaptureResponseWriterWithLimit returns a CaptureResponseWriter that drops +// bytes past maxBytes and flips Truncated() to true. A non-positive maxBytes +// means unlimited. +func NewCaptureResponseWriterWithLimit(maxBytes int) *CaptureResponseWriter { + return &CaptureResponseWriter{ + header: make(http.Header), + statusCode: http.StatusOK, + maxBytes: maxBytes, + } +} + // Header returns the response header map func (sw *CaptureResponseWriter) Header() http.Header { return sw.header @@ -51,8 +70,18 @@ func (sw *CaptureResponseWriter) WriteHeader(code int) { sw.statusCode = code } -// Write appends data to the response body +// Write appends data to the response body. Returns len(b) even after the cap +// is reached so the upstream producer doesn't error or block; Truncated() +// surfaces the drop to the merge layer. func (sw *CaptureResponseWriter) Write(b []byte) (int, error) { + if sw.maxBytes > 0 && sw.len+len(b) > sw.maxBytes { + if remaining := sw.maxBytes - sw.len; remaining > 0 { + sw.body.Write(b[:remaining]) + sw.len += remaining + } + sw.truncated = true + return len(b), nil + } sw.body.Write(b) sw.len += len(b) return len(b), nil @@ -70,3 +99,8 @@ func (sw *CaptureResponseWriter) StatusCode() int { } return sw.statusCode } + +// Truncated reports whether Write dropped bytes due to hitting maxBytes. +func (sw *CaptureResponseWriter) Truncated() bool { + return sw.truncated +} diff --git a/pkg/proxy/response/capture/capture_test.go b/pkg/proxy/response/capture/capture_test.go index 76ca7501e..7c9c29d1d 100644 --- a/pkg/proxy/response/capture/capture_test.go +++ b/pkg/proxy/response/capture/capture_test.go @@ -63,6 +63,38 @@ func TestWrite_IoCopy_SmallBody(t *testing.T) { require.Equal(t, src, sw.Body()) } +func TestWriteWithLimit_TruncatesAndFlags(t *testing.T) { + sw := NewCaptureResponseWriterWithLimit(10) + n, err := sw.Write([]byte("0123456789ABCDEF")) + require.NoError(t, err) + require.Equal(t, 16, n, "Write must report full input length even when truncating") + require.Equal(t, []byte("0123456789"), sw.Body()) + require.True(t, sw.Truncated()) +} + +func TestWriteWithLimit_NoTruncationWhenUnderCap(t *testing.T) { + sw := NewCaptureResponseWriterWithLimit(100) + sw.Write([]byte("hello")) + sw.Write([]byte(" world")) + require.Equal(t, []byte("hello world"), sw.Body()) + require.False(t, sw.Truncated()) +} + +func TestWriteWithLimit_BoundaryExact(t *testing.T) { + sw := NewCaptureResponseWriterWithLimit(11) + sw.Write([]byte("hello world")) + require.False(t, sw.Truncated()) + sw.Write([]byte("!")) + require.True(t, sw.Truncated()) +} + +func TestWriteWithLimit_ZeroMeansUnlimited(t *testing.T) { + sw := NewCaptureResponseWriterWithLimit(0) + sw.Write(make([]byte, 1024*1024)) + require.False(t, sw.Truncated()) + require.Equal(t, 1024*1024, len(sw.Body())) +} + func TestHeaderStatusCodeBody(t *testing.T) { sw := NewCaptureResponseWriter() From fdb2e058e458578106d64875161d88691cbd500d Mon Sep 17 00:00:00 2001 From: Chris Randles Date: Wed, 13 May 2026 14:56:14 -0400 Subject: [PATCH 2/6] test: [alb] add helper unit tests, panic guards, goleak, badEncoding fault Signed-off-by: Chris Randles --- go.mod | 1 + integration/alb_tsm_scale_test.go | 67 +++++++ pkg/backends/alb/mech/fr/main_test.go | 36 ++++ pkg/backends/alb/mech/nlm/main_test.go | 33 ++++ .../nlm/newest_last_modified_panic_test.go | 119 +++++++++++++ .../alb/mech/tsm/helpers_fuzz_test.go | 126 +++++++++++++ pkg/backends/alb/mech/tsm/helpers_test.go | 165 ++++++++++++++++++ pkg/backends/alb/mech/tsm/main_test.go | 33 ++++ .../mech/tsm/time_series_merge_panic_test.go | 121 +++++++++++++ pkg/backends/alb/pool/main_test.go | 31 ++++ pkg/backends/prometheus/model/timeseries.go | 3 + 11 files changed, 735 insertions(+) create mode 100644 pkg/backends/alb/mech/fr/main_test.go create mode 100644 pkg/backends/alb/mech/nlm/main_test.go create mode 100644 pkg/backends/alb/mech/nlm/newest_last_modified_panic_test.go create mode 100644 pkg/backends/alb/mech/tsm/helpers_fuzz_test.go create mode 100644 pkg/backends/alb/mech/tsm/helpers_test.go create mode 100644 pkg/backends/alb/mech/tsm/main_test.go create mode 100644 pkg/backends/alb/mech/tsm/time_series_merge_panic_test.go create mode 100644 pkg/backends/alb/pool/main_test.go diff --git a/go.mod b/go.mod index a4b3bd218..0fd3659f1 100644 --- a/go.mod +++ b/go.mod @@ -24,6 +24,7 @@ require ( go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.43.0 go.opentelemetry.io/otel/sdk v1.43.0 go.opentelemetry.io/otel/trace v1.43.0 + go.uber.org/goleak v1.3.0 golang.org/x/crypto v0.50.0 golang.org/x/net v0.53.0 golang.org/x/sync v0.20.0 diff --git a/integration/alb_tsm_scale_test.go b/integration/alb_tsm_scale_test.go index 1770cded0..5a8d8bf9a 100644 --- a/integration/alb_tsm_scale_test.go +++ b/integration/alb_tsm_scale_test.go @@ -150,6 +150,22 @@ func TestALB_TSM_Scale(t *testing.T) { t.Logf("%d bytes, %s", len(body), hdr.Get("X-Trickster-Result")) }) + t.Run("bad_encoding_advertised_as_gzip", func(t *testing.T) { + // Upstream claims gzip but sends plaintext; DecompressResponseBody + // fails. The merge surfaces the partial-failure marker but the + // remaining 49 backends still produce a successful merged response. + resetAll() + fakes[0].setBehavior(behaviorBadEncoding()) + pr, hdr := queryTricksterProm(t, listenAddr, backendName, "/api/v1/query_range", rangeParams()) + require.Equal(t, "success", pr.Status) + var qd promQueryData + require.NoError(t, json.Unmarshal(pr.Data, &qd)) + var series []json.RawMessage + require.NoError(t, json.Unmarshal(qd.Result, &series)) + require.NotEmpty(t, series) + t.Logf("%d series, %s", len(series), hdr.Get("X-Trickster-Result")) + }) + t.Run("truncating_upstream_does_not_poison_cache", func(t *testing.T) { resetAll() params := rangeParams() @@ -367,6 +383,49 @@ func TestALB_TSM_Scale(t *testing.T) { } }) + t.Run("post_concurrent_clients_safe", func(t *testing.T) { + // POST query_range fans out N clones of the same parent request. The + // body must be re-readable per clone or the upstreams see truncated + // requests. Race the request body cache under -race to catch any + // unsynchronized r.Body mutation in the clone path. + resetAll() + form := rangeParams().Encode() + u := "http://" + listenAddr + "/" + backendName + "/api/v1/query_range" + const clients = 25 + var wg sync.WaitGroup + errCh := make(chan error, clients) + for range clients { + wg.Add(1) + go func() { + defer wg.Done() + resp, err := http.Post(u, "application/x-www-form-urlencoded", strings.NewReader(form)) + if err != nil { + errCh <- err + return + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + errCh <- fmt.Errorf("status %d", resp.StatusCode) + return + } + body, _ := io.ReadAll(resp.Body) + var pr promResponse + if err := json.Unmarshal(body, &pr); err != nil { + errCh <- fmt.Errorf("invalid json: %w", err) + return + } + if pr.Status != "success" { + errCh <- fmt.Errorf("non-success: %s", pr.Status) + } + }() + } + wg.Wait() + close(errCh) + for err := range errCh { + t.Fatal(err) + } + }) + t.Run("error_body_passthrough", func(t *testing.T) { resetAll() for _, f := range fakes { @@ -528,6 +587,7 @@ func behaviorBadShape() *promBehavior { return &promBehavior{mode: func behaviorTruncate() *promBehavior { return &promBehavior{mode: "truncate"} } func behaviorSlow(d time.Duration) *promBehavior { return &promBehavior{mode: "ok", delay: d} } func behaviorErrJSON() *promBehavior { return &promBehavior{mode: "errjson"} } +func behaviorBadEncoding() *promBehavior { return &promBehavior{mode: "badencoding"} } func behaviorLabelValuesKB(kb int) *promBehavior { return &promBehavior{mode: "labelvalues", seriesKB: kb} } @@ -587,6 +647,13 @@ func (f *fakeProm) handleRange(w http.ResponseWriter, _ *http.Request) { _ = conn.Close() } return + case "badencoding": + // Advertise gzip but send plaintext. The TSM gather path's + // DecompressResponseBody will fail, surfacing as a partial-failure. + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Encoding", "gzip") + _, _ = w.Write(buildMatrixBody(f.label)) + return } w.Header().Set("Content-Type", "application/json") _, _ = w.Write(buildMatrixBody(f.label)) diff --git a/pkg/backends/alb/mech/fr/main_test.go b/pkg/backends/alb/mech/fr/main_test.go new file mode 100644 index 000000000..10fa3b6bf --- /dev/null +++ b/pkg/backends/alb/mech/fr/main_test.go @@ -0,0 +1,36 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr + +import ( + "testing" + + "go.uber.org/goleak" +) + +// FR fanout spawns errgroup workers per request plus a detached eg.Wait() +// drainer. A regression that fails to cancel on client disconnect would leak; +// goleak surfaces it. The pool-side goroutines (checkHealth / +// listenStatusUpdates) are ignored because the existing test helpers +// (albpool.New) don't expose a Stop() call site; those leaks are pre-existing +// test hygiene, not mechanism bugs. +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m, + goleak.IgnoreAnyFunction("github.com/trickstercache/trickster/v2/pkg/backends/alb/pool.(*pool).checkHealth"), + goleak.IgnoreAnyFunction("github.com/trickstercache/trickster/v2/pkg/backends/alb/pool.(*pool).listenStatusUpdates"), + ) +} diff --git a/pkg/backends/alb/mech/nlm/main_test.go b/pkg/backends/alb/mech/nlm/main_test.go new file mode 100644 index 000000000..901d14fba --- /dev/null +++ b/pkg/backends/alb/mech/nlm/main_test.go @@ -0,0 +1,33 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nlm + +import ( + "testing" + + "go.uber.org/goleak" +) + +// Pool-side goroutines are ignored because albpool.New() in test helpers +// doesn't expose Stop(); those leaks predate this file. Mechanism-side leaks +// (errgroup workers that ignore context cancellation) will still surface. +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m, + goleak.IgnoreAnyFunction("github.com/trickstercache/trickster/v2/pkg/backends/alb/pool.(*pool).checkHealth"), + goleak.IgnoreAnyFunction("github.com/trickstercache/trickster/v2/pkg/backends/alb/pool.(*pool).listenStatusUpdates"), + ) +} diff --git a/pkg/backends/alb/mech/nlm/newest_last_modified_panic_test.go b/pkg/backends/alb/mech/nlm/newest_last_modified_panic_test.go new file mode 100644 index 000000000..c5be521db --- /dev/null +++ b/pkg/backends/alb/mech/nlm/newest_last_modified_panic_test.go @@ -0,0 +1,119 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nlm + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/trickstercache/trickster/v2/pkg/backends/healthcheck" + "github.com/trickstercache/trickster/v2/pkg/proxy/headers" + "github.com/trickstercache/trickster/v2/pkg/testutil/albpool" +) + +// A panicking pool member must not crash the proxy. RecoverFanoutPanic in +// the NLM fanout goroutine should swallow the panic and clear the capture +// slot so the fallback path doesn't pick a partial response. +func TestNLMPanicMemberDoesNotCrashRequest(t *testing.T) { + lm := "Sun, 06 Nov 1994 08:49:37 GMT" + healthy := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set(headers.NameLastModified, lm) + w.WriteHeader(http.StatusOK) + w.Write([]byte("body-ok")) + }) + panicker := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + panic("simulated upstream nil deref") + }) + + p, _, st := albpool.New(-1, []http.Handler{panicker, healthy}) + st[0].Set(healthcheck.StatusPassing) + st[1].Set(healthcheck.StatusPassing) + time.Sleep(250 * time.Millisecond) + + h := &handler{} + h.SetPool(p) + w := httptest.NewRecorder() + r, _ := http.NewRequest("GET", "http://trickstercache.org/", nil) + + defer func() { + if rec := recover(); rec != nil { + t.Fatalf("unrecovered panic crossed ServeHTTP: %v", rec) + } + }() + + done := make(chan struct{}) + go func() { + defer func() { + if rec := recover(); rec != nil { + t.Errorf("unrecovered panic in goroutine: %v", rec) + } + close(done) + }() + h.ServeHTTP(w, r) + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("ServeHTTP did not return after pool member panic") + } + + if w.Body.String() != "body-ok" { + t.Errorf("expected healthy member body, got %q", w.Body.String()) + } +} + +func TestNLMPanicAllMembersDoesNotCrashRequest(t *testing.T) { + panicker := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + panic("simulated upstream nil deref") + }) + + p, _, st := albpool.New(-1, []http.Handler{panicker, panicker}) + st[0].Set(healthcheck.StatusPassing) + st[1].Set(healthcheck.StatusPassing) + time.Sleep(250 * time.Millisecond) + + h := &handler{} + h.SetPool(p) + w := httptest.NewRecorder() + r, _ := http.NewRequest("GET", "http://trickstercache.org/", nil) + + defer func() { + if rec := recover(); rec != nil { + t.Fatalf("unrecovered panic crossed ServeHTTP: %v", rec) + } + }() + + done := make(chan struct{}) + go func() { + defer func() { + if rec := recover(); rec != nil { + t.Errorf("unrecovered panic in goroutine: %v", rec) + } + close(done) + }() + h.ServeHTTP(w, r) + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("ServeHTTP did not return after all-panic fanout") + } +} diff --git a/pkg/backends/alb/mech/tsm/helpers_fuzz_test.go b/pkg/backends/alb/mech/tsm/helpers_fuzz_test.go new file mode 100644 index 000000000..5cdbf9718 --- /dev/null +++ b/pkg/backends/alb/mech/tsm/helpers_fuzz_test.go @@ -0,0 +1,126 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tsm + +import ( + "net/http" + "testing" + + "github.com/trickstercache/trickster/v2/pkg/proxy/response/merge" +) + +// Invariant: when any shard returned a 2xx status with a non-nil mergeFunc, +// pickWinner must return a 2xx winner. The fuzzer shuffles status codes and +// nil-mergeFuncs across N slots; the fallback path must never be taken if a +// 2xx slot exists. +func FuzzPickWinnerInvariant(f *testing.F) { + f.Add(byte(0), byte(2), byte(5)) + f.Add(byte(2), byte(0), byte(5)) + f.Add(byte(5), byte(0), byte(2)) + f.Add(byte(0), byte(0), byte(0)) + + f.Fuzz(func(t *testing.T, a, b, c byte) { + codeFor := func(b byte) int { + switch b % 7 { + case 0: + return 0 // no response + case 1: + return 100 + case 2: + return 200 + case 3: + return 204 + case 4: + return 302 + case 5: + return 404 + default: + return 502 + } + } + + noopFunc := func(_ http.ResponseWriter, _ *http.Request, _ *merge.Accumulator, _ int) {} + + results := []gatherResult{ + {statusCode: codeFor(a), mergeFunc: noopFunc}, + {statusCode: codeFor(b), mergeFunc: noopFunc}, + {statusCode: codeFor(c), mergeFunc: noopFunc}, + } + + // Nil out the mergeFunc on slots with statusCode 0 to mirror reality + // (no response means no respondFunc captured). + for i := range results { + if results[i].statusCode == 0 { + results[i].mergeFunc = nil + } + } + + hasAny2xx := false + for _, r := range results { + if r.mergeFunc != nil && r.statusCode >= 200 && r.statusCode < 300 { + hasAny2xx = true + break + } + } + hasAnyFunc := false + for _, r := range results { + if r.mergeFunc != nil { + hasAnyFunc = true + break + } + } + + mrf, _ := pickWinner(results) + switch { + case hasAny2xx && mrf == nil: + t.Errorf("had a 2xx but pickWinner returned nil: %+v", results) + case !hasAnyFunc && mrf != nil: + t.Errorf("no mergeFunc available but pickWinner returned non-nil: %+v", results) + } + }) +} + +// Invariant: aggregateStatus returns has2xx XOR (status >= 300 || status == 0). +// In other words, when has2xx is true the returned status MUST be in [200,299]. +// When has2xx is false, status MUST NOT be in [200,299]. +func FuzzAggregateStatusInvariant(f *testing.F) { + f.Add(uint16(200), uint16(500)) + f.Add(uint16(502), uint16(400)) + f.Add(uint16(206), uint16(204)) + f.Add(uint16(0), uint16(0)) + + f.Fuzz(func(t *testing.T, a, b uint16) { + toCode := func(v uint16) int { + if v == 0 { + return 0 + } + return int(100 + v%500) + } + results := []gatherResult{ + {statusCode: toCode(a)}, + {statusCode: toCode(b)}, + } + code, _, has2xx, _ := aggregateStatus(results) + if has2xx { + if code < 200 || code >= 300 { + t.Errorf("has2xx=true but code=%d is not 2xx; %+v", code, results) + } + } else if code >= 200 && code < 300 { + t.Errorf("has2xx=false but code=%d is 2xx; %+v", code, results) + } + }) +} diff --git a/pkg/backends/alb/mech/tsm/helpers_test.go b/pkg/backends/alb/mech/tsm/helpers_test.go new file mode 100644 index 000000000..95afdd692 --- /dev/null +++ b/pkg/backends/alb/mech/tsm/helpers_test.go @@ -0,0 +1,165 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tsm + +import ( + "net/http" + "testing" + + "github.com/trickstercache/trickster/v2/pkg/proxy/headers" + "github.com/trickstercache/trickster/v2/pkg/proxy/response/merge" +) + +func TestPickWinnerPrefersFirst2xx(t *testing.T) { + called := map[string]bool{} + mk := func(name string) merge.RespondFunc { + return func(_ http.ResponseWriter, _ *http.Request, _ *merge.Accumulator, _ int) { + called[name] = true + } + } + results := []gatherResult{ + {statusCode: 500, mergeFunc: mk("err0"), header: http.Header{"X-Slot": []string{"0"}}}, + {statusCode: 200, mergeFunc: mk("ok1"), header: http.Header{"X-Slot": []string{"1"}}}, + {statusCode: 200, mergeFunc: mk("ok2"), header: http.Header{"X-Slot": []string{"2"}}}, + } + mrf, hdrs := pickWinner(results) + if mrf == nil { + t.Fatal("expected a winner") + } + mrf(nil, nil, nil, 0) + if !called["ok1"] { + t.Errorf("expected first 2xx winner, got %v", called) + } + if hdrs.Get("X-Slot") != "1" { + t.Errorf("expected slot 1 headers, got %q", hdrs.Get("X-Slot")) + } +} + +func TestPickWinnerFallsBackToFirstNon2xx(t *testing.T) { + called := map[string]bool{} + mk := func(name string) merge.RespondFunc { + return func(_ http.ResponseWriter, _ *http.Request, _ *merge.Accumulator, _ int) { + called[name] = true + } + } + results := []gatherResult{ + {statusCode: 500, mergeFunc: mk("err0"), header: http.Header{"X-Slot": []string{"0"}}}, + {statusCode: 502, mergeFunc: mk("err1"), header: http.Header{"X-Slot": []string{"1"}}}, + } + mrf, hdrs := pickWinner(results) + if mrf == nil { + t.Fatal("expected fallback winner") + } + mrf(nil, nil, nil, 0) + if !called["err0"] { + t.Errorf("expected first non-2xx winner, got %v", called) + } + if hdrs.Get("X-Slot") != "0" { + t.Errorf("expected slot 0 headers, got %q", hdrs.Get("X-Slot")) + } +} + +func TestPickWinnerNoneReturnsNil(t *testing.T) { + results := []gatherResult{{statusCode: 0, mergeFunc: nil}, {failed: true}} + mrf, hdrs := pickWinner(results) + if mrf != nil || hdrs != nil { + t.Errorf("expected nil winner and headers, got %v %v", mrf, hdrs) + } +} + +func TestAggregateStatusAllSuccess(t *testing.T) { + results := []gatherResult{ + {statusCode: 200, header: http.Header{headers.NameTricksterResult: []string{"engine=A"}}}, + {statusCode: 206, header: http.Header{headers.NameTricksterResult: []string{"engine=B"}}}, + } + code, _, has2xx, hasNon2xx := aggregateStatus(results) + if code != 200 { + t.Errorf("expected min 2xx (200), got %d", code) + } + if !has2xx || hasNon2xx { + t.Errorf("flags: has2xx=%v hasNon2xx=%v", has2xx, hasNon2xx) + } +} + +func TestAggregateStatusMixedPrefers200(t *testing.T) { + results := []gatherResult{ + {statusCode: 500}, + {statusCode: 200}, + {statusCode: 502}, + } + code, _, has2xx, hasNon2xx := aggregateStatus(results) + if code != 200 { + t.Errorf("mixed-success should still pick 200, got %d", code) + } + if !has2xx || !hasNon2xx { + t.Errorf("expected both flags, got has2xx=%v hasNon2xx=%v", has2xx, hasNon2xx) + } +} + +func TestAggregateStatusAllErrorPicksMax(t *testing.T) { + // Before this PR the aggregator returned the min, surfacing 400 and + // hiding the more severe 502s. The new behavior surfaces 502. + results := []gatherResult{ + {statusCode: 400}, + {statusCode: 502}, + {statusCode: 502}, + } + code, _, has2xx, hasNon2xx := aggregateStatus(results) + if code != 502 { + t.Errorf("all-error should pick max (502), got %d", code) + } + if has2xx || !hasNon2xx { + t.Errorf("flags: has2xx=%v hasNon2xx=%v", has2xx, hasNon2xx) + } +} + +func TestAggregateStatusEmpty(t *testing.T) { + code, sh, has2xx, hasNon2xx := aggregateStatus(nil) + if code != 0 || sh != "" || has2xx || hasNon2xx { + t.Errorf("zero state: code=%d sh=%q has2xx=%v hasNon2xx=%v", code, sh, has2xx, hasNon2xx) + } +} + +func TestMergeMultiValuedHeadersPreservesSetCookie(t *testing.T) { + results := []gatherResult{ + {header: http.Header{headers.NameSetCookie: []string{"a=1", "b=2"}}}, + {header: http.Header{headers.NameSetCookie: []string{"c=3"}}}, + } + winner := http.Header{headers.NameSetCookie: []string{"winner=ignored"}} + dst := http.Header{} + mergeMultiValuedHeaders(dst, results, winner) + got := dst.Values(headers.NameSetCookie) + if len(got) != 3 { + t.Fatalf("expected 3 Set-Cookie values, got %d: %v", len(got), got) + } + if winner.Get(headers.NameSetCookie) != "" { + t.Errorf("winner should have Set-Cookie deleted to avoid double-merge, still has %q", + winner.Get(headers.NameSetCookie)) + } +} + +func TestMergeMultiValuedHeadersSkipsNilHeader(t *testing.T) { + results := []gatherResult{ + {header: nil}, + {header: http.Header{headers.NameSetCookie: []string{"only=1"}}}, + } + dst := http.Header{} + mergeMultiValuedHeaders(dst, results, nil) + if got := dst.Values(headers.NameSetCookie); len(got) != 1 || got[0] != "only=1" { + t.Errorf("expected single Set-Cookie, got %v", got) + } +} diff --git a/pkg/backends/alb/mech/tsm/main_test.go b/pkg/backends/alb/mech/tsm/main_test.go new file mode 100644 index 000000000..9f9fc939e --- /dev/null +++ b/pkg/backends/alb/mech/tsm/main_test.go @@ -0,0 +1,33 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tsm + +import ( + "testing" + + "go.uber.org/goleak" +) + +// Pool-side goroutines from albpool.New are ignored; they predate this file. +// Mechanism-side leaks (TSM fanout that ignores ctx, panic recovery that +// doesn't unwind) will still surface here. +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m, + goleak.IgnoreAnyFunction("github.com/trickstercache/trickster/v2/pkg/backends/alb/pool.(*pool).checkHealth"), + goleak.IgnoreAnyFunction("github.com/trickstercache/trickster/v2/pkg/backends/alb/pool.(*pool).listenStatusUpdates"), + ) +} diff --git a/pkg/backends/alb/mech/tsm/time_series_merge_panic_test.go b/pkg/backends/alb/mech/tsm/time_series_merge_panic_test.go new file mode 100644 index 000000000..69d635827 --- /dev/null +++ b/pkg/backends/alb/mech/tsm/time_series_merge_panic_test.go @@ -0,0 +1,121 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tsm + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/trickstercache/trickster/v2/pkg/backends/healthcheck" + "github.com/trickstercache/trickster/v2/pkg/proxy/request" + tu "github.com/trickstercache/trickster/v2/pkg/testutil" + "github.com/trickstercache/trickster/v2/pkg/testutil/albpool" +) + +// A panicking pool member must not crash the request. RecoverFanoutPanic("tsm", +// ...) at time_series_merge.go must catch it and mark the slot failed so the +// merge surfaces the partial-failure (phit) signal. +func TestTSMPanicMemberDoesNotCrashRequest(t *testing.T) { + panicker := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + panic("simulated upstream nil deref") + }) + + p, _, st := albpool.New(-1, []http.Handler{ + http.HandlerFunc(tu.BasicHTTPHandler), + panicker, + }) + st[0].Set(healthcheck.StatusPassing) + st[1].Set(healthcheck.StatusPassing) + time.Sleep(250 * time.Millisecond) + + rsc := request.NewResources(nil, nil, nil, nil, nil, nil) + rsc.IsMergeMember = true + r, _ := http.NewRequest("GET", "http://trickstercache.org/", nil) + r = request.SetResources(r, rsc) + + h := &handler{mergePaths: []string{"/"}} + h.SetPool(p) + w := httptest.NewRecorder() + + defer func() { + if rec := recover(); rec != nil { + t.Fatalf("unrecovered panic crossed ServeHTTP: %v", rec) + } + }() + + done := make(chan struct{}) + go func() { + defer func() { + if rec := recover(); rec != nil { + t.Errorf("unrecovered panic in goroutine: %v", rec) + } + close(done) + }() + h.ServeHTTP(w, r) + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("ServeHTTP did not return after panic") + } +} + +func TestTSMPanicAllMembersDoesNotCrashRequest(t *testing.T) { + panicker := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + panic("simulated upstream nil deref") + }) + + p, _, st := albpool.New(-1, []http.Handler{panicker, panicker}) + st[0].Set(healthcheck.StatusPassing) + st[1].Set(healthcheck.StatusPassing) + time.Sleep(250 * time.Millisecond) + + rsc := request.NewResources(nil, nil, nil, nil, nil, nil) + rsc.IsMergeMember = true + r, _ := http.NewRequest("GET", "http://trickstercache.org/", nil) + r = request.SetResources(r, rsc) + + h := &handler{mergePaths: []string{"/"}} + h.SetPool(p) + w := httptest.NewRecorder() + + defer func() { + if rec := recover(); rec != nil { + t.Fatalf("unrecovered panic crossed ServeHTTP: %v", rec) + } + }() + + done := make(chan struct{}) + go func() { + defer func() { + if rec := recover(); rec != nil { + t.Errorf("unrecovered panic in goroutine: %v", rec) + } + close(done) + }() + h.ServeHTTP(w, r) + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("ServeHTTP did not return after all-panic fanout") + } +} diff --git a/pkg/backends/alb/pool/main_test.go b/pkg/backends/alb/pool/main_test.go new file mode 100644 index 000000000..31dbfa340 --- /dev/null +++ b/pkg/backends/alb/pool/main_test.go @@ -0,0 +1,31 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package pool + +import ( + "testing" + + "go.uber.org/goleak" +) + +// goleak guards against pool/healthcheck goroutines outliving the tests that +// created them. Pool spawns long-running listenStatusUpdates / checkHealth +// goroutines; a regression that fails to stop them would leak until process +// exit. -race won't catch leaks, only this will. +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} diff --git a/pkg/backends/prometheus/model/timeseries.go b/pkg/backends/prometheus/model/timeseries.go index f6ada1b28..942e892a3 100644 --- a/pkg/backends/prometheus/model/timeseries.go +++ b/pkg/backends/prometheus/model/timeseries.go @@ -104,6 +104,9 @@ func UnmarshalTimeseriesReader(reader io.Reader, trq *timeseries.TimeRangeQuery) if trq == nil { return nil, timeseries.ErrNoTimerangeQuery } + if reader == nil { + return nil, io.ErrUnexpectedEOF + } wfd := &WFMatrixDocument{} d := json.NewDecoder(reader) err := d.Decode(wfd) From 4c911fd5720b60d8fc83d30121151c56033ea927 Mon Sep 17 00:00:00 2001 From: Chris Randles Date: Wed, 13 May 2026 14:54:08 -0400 Subject: [PATCH 3/6] fix+feat: [alb] strip cache headers from merge, fanout failure metric, prime post body Signed-off-by: Chris Randles --- go.mod | 1 + pkg/backends/alb/mech/fr/first_response.go | 10 +++ .../mech/fr/first_response_post_race_test.go | 73 ++++++++++++++++++ .../alb/mech/nlm/newest_last_modified.go | 10 +++ .../newest_last_modified_post_race_test.go | 75 +++++++++++++++++++ pkg/backends/alb/mech/recover.go | 2 + pkg/backends/alb/mech/recover_metric_test.go | 53 +++++++++++++ .../alb/mech/tsm/time_series_merge.go | 2 + pkg/observability/metrics/metrics.go | 18 +++++ pkg/proxy/headers/forwarding.go | 10 ++- pkg/proxy/headers/forwarding_test.go | 23 +++++- pkg/proxy/headers/headers.go | 5 ++ 12 files changed, 278 insertions(+), 4 deletions(-) create mode 100644 pkg/backends/alb/mech/fr/first_response_post_race_test.go create mode 100644 pkg/backends/alb/mech/nlm/newest_last_modified_post_race_test.go create mode 100644 pkg/backends/alb/mech/recover_metric_test.go diff --git a/go.mod b/go.mod index 0fd3659f1..f8c7957e4 100644 --- a/go.mod +++ b/go.mod @@ -149,6 +149,7 @@ require ( github.com/kkHAIKE/contextcheck v1.1.6 // indirect github.com/kulti/thelper v0.7.1 // indirect github.com/kunwardeep/paralleltest v1.0.15 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect github.com/lasiar/canonicalheader v1.1.2 // indirect github.com/ldez/exptostd v0.4.5 // indirect github.com/ldez/gomoddirectives v0.8.0 // indirect diff --git a/pkg/backends/alb/mech/fr/first_response.go b/pkg/backends/alb/mech/fr/first_response.go index 983b3682f..a4eb07212 100644 --- a/pkg/backends/alb/mech/fr/first_response.go +++ b/pkg/backends/alb/mech/fr/first_response.go @@ -107,6 +107,16 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { hl[0].Handler().ServeHTTP(w, r) return } + // Ensure the parent has Resources so GetBody caches the body bytes. + // Without a cache, each fanout goroutine's CloneWithoutResources re-reads + // and mutates r.Body, racing on the parent request. + if request.GetResources(r) == nil { + r = request.SetResources(r, &request.Resources{}) + } + if _, err := request.GetBody(r); err != nil { + failures.HandleBadGateway(w, r) + return + } // otherwise iterate the fanout var claimed int64 = -1 captures := make([]*capture.CaptureResponseWriter, l) diff --git a/pkg/backends/alb/mech/fr/first_response_post_race_test.go b/pkg/backends/alb/mech/fr/first_response_post_race_test.go new file mode 100644 index 000000000..72349e76f --- /dev/null +++ b/pkg/backends/alb/mech/fr/first_response_post_race_test.go @@ -0,0 +1,73 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/trickstercache/trickster/v2/pkg/backends/healthcheck" + "github.com/trickstercache/trickster/v2/pkg/testutil/albpool" +) + +// FR fanout clones the parent request N ways. For POST/PUT/PATCH the body +// must be primed before fanout, else N goroutines race on r.Body and +// rsc.RequestBody inside GetBody. Run under -race to catch any regression +// in the priming step. +func TestFRPostBodyFanoutIsRaceFree(t *testing.T) { + const body = `{"query":"sum(rate(metric[5m]))","start":"2024-01-01T00:00:00Z","end":"2024-01-01T01:00:00Z","step":"15s"}` + + mk := func(name string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b, _ := io.ReadAll(r.Body) + if string(b) != body { + t.Errorf("%s: truncated body, got %d bytes want %d", name, len(b), len(body)) + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(name)) + }) + } + + p, _, st := albpool.New(-1, []http.Handler{mk("a"), mk("b"), mk("c"), mk("d")}) + for _, s := range st { + s.Set(healthcheck.StatusPassing) + } + time.Sleep(250 * time.Millisecond) + + h := &handler{} + h.SetPool(p) + + const callers = 16 + var wg sync.WaitGroup + for range callers { + wg.Go(func() { + r, _ := http.NewRequest(http.MethodPost, "http://trickstercache.org/api/v1/query_range", strings.NewReader(body)) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + w := httptest.NewRecorder() + h.ServeHTTP(w, r) + if w.Code != http.StatusOK { + t.Errorf("status %d", w.Code) + } + }) + } + wg.Wait() +} diff --git a/pkg/backends/alb/mech/nlm/newest_last_modified.go b/pkg/backends/alb/mech/nlm/newest_last_modified.go index d1b46df55..b0160d915 100644 --- a/pkg/backends/alb/mech/nlm/newest_last_modified.go +++ b/pkg/backends/alb/mech/nlm/newest_last_modified.go @@ -84,6 +84,16 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { hl[0].Handler().ServeHTTP(w, r) return } + // Ensure the parent has Resources so GetBody caches the body bytes. + // Without a cache, each fanout goroutine's CloneWithoutResources re-reads + // and mutates r.Body, racing on the parent request. + if request.GetResources(r) == nil { + r = request.SetResources(r, &request.Resources{}) + } + if _, err := request.GetBody(r); err != nil { + failures.HandleBadGateway(w, r) + return + } // Strip resources from the parent context once; reuse for all goroutines bareCtx := tctx.ClearResources(r.Context()) diff --git a/pkg/backends/alb/mech/nlm/newest_last_modified_post_race_test.go b/pkg/backends/alb/mech/nlm/newest_last_modified_post_race_test.go new file mode 100644 index 000000000..ff8a05813 --- /dev/null +++ b/pkg/backends/alb/mech/nlm/newest_last_modified_post_race_test.go @@ -0,0 +1,75 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package nlm + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/trickstercache/trickster/v2/pkg/backends/healthcheck" + "github.com/trickstercache/trickster/v2/pkg/proxy/headers" + "github.com/trickstercache/trickster/v2/pkg/testutil/albpool" +) + +// NLM fanout clones the parent request N ways. For POST/PUT/PATCH the body +// must be primed before fanout, else N goroutines race on r.Body and +// rsc.RequestBody inside GetBody. Run under -race to catch any regression. +func TestNLMPostBodyFanoutIsRaceFree(t *testing.T) { + const body = `{"query":"sum(rate(metric[5m]))","start":"2024-01-01T00:00:00Z","end":"2024-01-01T01:00:00Z","step":"15s"}` + lm := "Sun, 06 Nov 1994 08:49:37 GMT" + + mk := func(name string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b, _ := io.ReadAll(r.Body) + if string(b) != body { + t.Errorf("%s: truncated body, got %d bytes want %d", name, len(b), len(body)) + } + w.Header().Set(headers.NameLastModified, lm) + w.WriteHeader(http.StatusOK) + w.Write([]byte(name)) + }) + } + + p, _, st := albpool.New(-1, []http.Handler{mk("a"), mk("b"), mk("c"), mk("d")}) + for _, s := range st { + s.Set(healthcheck.StatusPassing) + } + time.Sleep(250 * time.Millisecond) + + h := &handler{} + h.SetPool(p) + + const callers = 16 + var wg sync.WaitGroup + for range callers { + wg.Go(func() { + r, _ := http.NewRequest(http.MethodPost, "http://trickstercache.org/api/v1/query_range", strings.NewReader(body)) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + w := httptest.NewRecorder() + h.ServeHTTP(w, r) + if w.Code != http.StatusOK { + t.Errorf("status %d", w.Code) + } + }) + } + wg.Wait() +} diff --git a/pkg/backends/alb/mech/recover.go b/pkg/backends/alb/mech/recover.go index a8b0c50cc..d7af5bc52 100644 --- a/pkg/backends/alb/mech/recover.go +++ b/pkg/backends/alb/mech/recover.go @@ -21,6 +21,7 @@ import ( "github.com/trickstercache/trickster/v2/pkg/observability/logging" "github.com/trickstercache/trickster/v2/pkg/observability/logging/logger" + "github.com/trickstercache/trickster/v2/pkg/observability/metrics" ) // RecoverFanoutPanic recovers from a panic in an ALB fanout goroutine and @@ -38,6 +39,7 @@ func RecoverFanoutPanic(mech string, member int, onPanic func()) { "panic": rec, "stack": string(debug.Stack()), }) + metrics.ALBFanoutFailures.WithLabelValues(mech, "panic").Inc() if onPanic != nil { onPanic() } diff --git a/pkg/backends/alb/mech/recover_metric_test.go b/pkg/backends/alb/mech/recover_metric_test.go new file mode 100644 index 000000000..f3cab8bd5 --- /dev/null +++ b/pkg/backends/alb/mech/recover_metric_test.go @@ -0,0 +1,53 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package mech_test + +import ( + "testing" + + "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/trickstercache/trickster/v2/pkg/backends/alb/mech" + "github.com/trickstercache/trickster/v2/pkg/observability/metrics" +) + +func TestRecoverFanoutPanicIncrementsMetric(t *testing.T) { + before := testutil.ToFloat64(metrics.ALBFanoutFailures.WithLabelValues("metric-test", "panic")) + + func() { + defer mech.RecoverFanoutPanic("metric-test", 0, nil) + panic("forced") + }() + + after := testutil.ToFloat64(metrics.ALBFanoutFailures.WithLabelValues("metric-test", "panic")) + if after-before != 1 { + t.Errorf("expected counter to increment by 1, before=%v after=%v", before, after) + } +} + +func TestRecoverFanoutPanicNoPanicNoIncrement(t *testing.T) { + before := testutil.ToFloat64(metrics.ALBFanoutFailures.WithLabelValues("metric-test-noop", "panic")) + + func() { + defer mech.RecoverFanoutPanic("metric-test-noop", 0, nil) + // no panic + }() + + after := testutil.ToFloat64(metrics.ALBFanoutFailures.WithLabelValues("metric-test-noop", "panic")) + if after != before { + t.Errorf("expected counter unchanged, before=%v after=%v", before, after) + } +} diff --git a/pkg/backends/alb/mech/tsm/time_series_merge.go b/pkg/backends/alb/mech/tsm/time_series_merge.go index 9d1af79fc..bbadbe188 100644 --- a/pkg/backends/alb/mech/tsm/time_series_merge.go +++ b/pkg/backends/alb/mech/tsm/time_series_merge.go @@ -35,6 +35,7 @@ import ( "github.com/trickstercache/trickster/v2/pkg/encoding" "github.com/trickstercache/trickster/v2/pkg/observability/logging" "github.com/trickstercache/trickster/v2/pkg/observability/logging/logger" + "github.com/trickstercache/trickster/v2/pkg/observability/metrics" "github.com/trickstercache/trickster/v2/pkg/proxy/handlers/trickster/failures" "github.com/trickstercache/trickster/v2/pkg/proxy/headers" "github.com/trickstercache/trickster/v2/pkg/proxy/params" @@ -441,6 +442,7 @@ func (h *handler) serveStandard( for i, res := range results { if res.failed { hasGatherFailure = true + metrics.ALBFanoutFailures.WithLabelValues("tsm", "no_contribution").Inc() if ts := accumulator.GetTSData(); ts != nil { if ds, ok := ts.(*dataset.DataSet); ok { ds.Warnings = append(ds.Warnings, diff --git a/pkg/observability/metrics/metrics.go b/pkg/observability/metrics/metrics.go index 043d77744..f84213eb6 100644 --- a/pkg/observability/metrics/metrics.go +++ b/pkg/observability/metrics/metrics.go @@ -32,6 +32,7 @@ const ( configSubsystem = "config" buildSubsystem = "build" frontendSubsystem = "frontend" + albSubsystem = "alb" ) // Default histogram buckets used by trickster @@ -317,6 +318,22 @@ var ( Help: "Trickster total number of failed connections.", }, ) + + // ALBFanoutFailures counts per-shard failures during ALB fanout. The + // reason label distinguishes silent contribution failures (e.g. bad + // encoding, parse errors), explicit panics in the per-shard goroutine, + // and capture-buffer truncation. Operators can graph this to detect + // classes of upstream misbehavior that would otherwise only surface as + // the X-Trickster-Result phit header. + ALBFanoutFailures = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: metricNamespace, + Subsystem: albSubsystem, + Name: "fanout_failures_total", + Help: "Count of per-shard failures during ALB fanout, by mechanism and reason.", + }, + []string{"mechanism", "reason"}, + ) ) func init() { @@ -333,6 +350,7 @@ func init() { prometheus.MustRegister(ProxyConnectionAccepted) prometheus.MustRegister(ProxyConnectionClosed) prometheus.MustRegister(ProxyConnectionFailed) + prometheus.MustRegister(ALBFanoutFailures) prometheus.MustRegister(CacheObjectOperations) prometheus.MustRegister(CacheByteOperations) prometheus.MustRegister(CacheEvents) diff --git a/pkg/proxy/headers/forwarding.go b/pkg/proxy/headers/forwarding.go index 60c6324fc..d52fd4177 100644 --- a/pkg/proxy/headers/forwarding.go +++ b/pkg/proxy/headers/forwarding.go @@ -87,13 +87,21 @@ var ForwardingHeaders = []string{ NameVia, } -// MergeRemoveHeaders defines a list of headers that should be removed when Merging time series results +// MergeRemoveHeaders defines a list of headers that should be removed when Merging time series results. +// Cache-related fields (Cache-Control, Vary, Age, Etag, Expires) are stripped because the merged body +// is by definition not equivalent to any single shard's response; carrying one shard's cache directives +// onto the merged response can poison downstream HTTP caches with hybrid policies. var MergeRemoveHeaders = []string{ NameLastModified, NameDate, NameContentLength, NameContentType, NameTransferEncoding, + NameCacheControl, + NameVary, + NameAge, + NameETag, + NameExpires, } var forwardingFuncs = map[string]func(*http.Request, *Hop){ diff --git a/pkg/proxy/headers/forwarding_test.go b/pkg/proxy/headers/forwarding_test.go index da3cb9293..56658843d 100644 --- a/pkg/proxy/headers/forwarding_test.go +++ b/pkg/proxy/headers/forwarding_test.go @@ -175,11 +175,28 @@ func TestFormatForwardedAddress(t *testing.T) { } func TestStripMergeHeaders(t *testing.T) { - h := http.Header{NameContentLength: []string{"42"}, NameLocation: []string{"https://trickstercache.org/"}} + h := http.Header{ + NameContentLength: []string{"42"}, + NameLocation: []string{"https://trickstercache.org/"}, + NameCacheControl: []string{"max-age=300"}, + NameVary: []string{"Accept-Encoding"}, + NameAge: []string{"7"}, + NameETag: []string{"abc123"}, + NameExpires: []string{"Thu, 01 Jan 2099 00:00:00 GMT"}, + } StripMergeHeaders(h) - if _, ok := h[NameContentLength]; ok { - t.Error("expected Content-Length Header to be missing") + for _, name := range []string{ + NameContentLength, + NameCacheControl, + NameVary, + NameAge, + NameETag, + NameExpires, + } { + if _, ok := h[name]; ok { + t.Errorf("expected %s header to be stripped from merged response", name) + } } if _, ok := h[NameLocation]; !ok { diff --git a/pkg/proxy/headers/headers.go b/pkg/proxy/headers/headers.go index 42010465c..3aa957f72 100644 --- a/pkg/proxy/headers/headers.go +++ b/pkg/proxy/headers/headers.go @@ -145,6 +145,11 @@ const ( // NameWWWAuthenticate represents the HTTP Header Name of "WWW-Authenticate" NameWWWAuthenticate = "WWW-Authenticate" + // NameVary represents the HTTP Header Name of "Vary" + NameVary = "Vary" + // NameAge represents the HTTP Header Name of "Age" + NameAge = "Age" + // NameTrkHCStatus represents the HTTP Header Name of "Trk-HC-Status" NameTrkHCStatus = "Trk-HC-Status" // NameTrkHCDetail represents the HTTP Header Name of "Trk-HC-Detail" From c5a7af9cbce35726a0f0bacc1893cb78d906c312 Mon Sep 17 00:00:00 2001 From: Chris Randles Date: Wed, 13 May 2026 15:53:28 -0400 Subject: [PATCH 4/6] refactor: [alb/mech] extract shared fanout primitive; migrate nlm, tsm, fr clone setup Signed-off-by: Chris Randles --- pkg/backends/alb/mech/fanout/fanout.go | 215 ++++++++ pkg/backends/alb/mech/fanout/fanout_test.go | 496 ++++++++++++++++++ pkg/backends/alb/mech/fr/first_response.go | 18 +- .../alb/mech/nlm/newest_last_modified.go | 99 ++-- .../alb/mech/tsm/time_series_merge.go | 270 +++++----- 5 files changed, 870 insertions(+), 228 deletions(-) create mode 100644 pkg/backends/alb/mech/fanout/fanout.go create mode 100644 pkg/backends/alb/mech/fanout/fanout_test.go diff --git a/pkg/backends/alb/mech/fanout/fanout.go b/pkg/backends/alb/mech/fanout/fanout.go new file mode 100644 index 000000000..4ad682541 --- /dev/null +++ b/pkg/backends/alb/mech/fanout/fanout.go @@ -0,0 +1,215 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package fanout is the shared primitive for ALB mechanisms that scatter a +// single inbound request to N pool members and gather their responses. It +// centralizes the bug-prone parts of fanout that each mechanism otherwise +// re-implements: body priming, context propagation, capture-buffer bounding, +// panic recovery, concurrency limiting, and metric attribution. +// +// Mechanisms that wait for all members and decide afterwards (NLM, TSM) call +// All. Mechanisms that race for a first-good response (FR) own their own +// goroutine orchestration but use PrepareClone for the per-goroutine setup. +package fanout + +import ( + "context" + "net/http" + + "github.com/trickstercache/trickster/v2/pkg/backends/alb/mech" + "github.com/trickstercache/trickster/v2/pkg/backends/alb/pool" + "github.com/trickstercache/trickster/v2/pkg/observability/logging" + "github.com/trickstercache/trickster/v2/pkg/observability/logging/logger" + "github.com/trickstercache/trickster/v2/pkg/observability/metrics" + "github.com/trickstercache/trickster/v2/pkg/proxy/request" + "github.com/trickstercache/trickster/v2/pkg/proxy/response/capture" + "golang.org/x/sync/errgroup" +) + +// Result holds one pool member's outcome from a fanout call. Results are +// returned slot-indexed: Result[i] corresponds to targets[i] from the +// original Targets slice. Slots whose target was nil have Failed == true. +type Result struct { + // Index is the slot position in the original Targets slice. + Index int + // Request is the cloned request handed to the member's handler. + // Mechanism code can read its Resources (e.g. TSM's MergeFunc / + // MergeRespondFunc / TS) in post-gather inspection or via OnResult. + Request *http.Request + // Capture is the bounded response writer that captured the member's + // reply. Nil if the slot failed before serving. + Capture *capture.CaptureResponseWriter + // Failed is true when the slot did not produce a usable response. + // Reasons: clone error, panic in the member's handler, or capture + // truncation (the upstream exceeded MaxCaptureBytes). Mechanism code + // uses this to surface partial-failure signals. + Failed bool + // Err carries a clone or transport error, if any. A recovered panic + // is reflected only in Failed; the panic value is logged + metered + // inside the fanout goroutine. + Err error +} + +// Config configures one fanout call. +type Config struct { + // Mechanism is the short name ("fr", "nlm", "tsm", "tsm/avg-sum", ...) + // used in panic recovery logs and the fanout_failures_total metric. + Mechanism string + // ConcurrencyLimit caps in-flight member calls. 0 means unlimited. + ConcurrencyLimit int + // MaxCaptureBytes caps each member's response body capture. 0 uses + // capture.DefaultMaxBytes. + MaxCaptureBytes int + // Resources, if non-nil, returns the Resources to attach to each + // cloned request before the member's handler sees it. Nil resources + // is a valid return value. + Resources func(idx int) *request.Resources + // Context, if non-nil, transforms the parent context.Context before + // each clone receives it. NLM uses this to call tctx.ClearResources; + // most mechanisms can leave it nil. + Context func(parent context.Context) context.Context + // OnResult, if non-nil, is called inside the fanout goroutine after + // the member's handler returns and before the goroutine exits. Use + // this for per-slot side effects that should run in parallel with + // other in-flight members (e.g. TSM merges into a shared accumulator). + // OnResult must be safe for concurrent invocation. The supplied + // Result is the same one that will appear in the All return slice. + OnResult func(idx int, r *Result) +} + +// All scatters parent to every target and gathers slot-ordered Results. +// +// The caller MUST have primed parent's body (via PrimeBody or equivalent) +// before invoking All; otherwise concurrent CloneWithoutResources calls +// will race on r.Body / rsc.RequestBody for POST/PUT/PATCH requests. +// +// Each clone is given a per-slot Resources value (if Config.Resources is +// set), a derived context (Config.Context if set; ctx otherwise), and a +// capture.CaptureResponseWriter bounded by Config.MaxCaptureBytes. +// +// A panic in any member's handler is recovered, logged, and counted via +// the trickster_alb_fanout_failures_total metric; the slot's Failed field +// is set so the caller can surface partial-failure to its response. +// +// All returns when every spawned goroutine has finished. The returned +// slice has len(targets) entries; results[i].Index == i. The error is the +// first non-nil error from any goroutine (typically a clone failure); per- +// slot errors are also recorded in results[i].Err. The primitive logs + +// meters every failure regardless; callers can use the returned error to +// propagate through their own errgroup, render a fatal response, etc. +func All(ctx context.Context, parent *http.Request, targets pool.Targets, cfg Config) ([]Result, error) { + l := len(targets) + results := make([]Result, l) + if l == 0 { + return results, nil + } + + var eg errgroup.Group + if cfg.ConcurrencyLimit > 0 { + eg.SetLimit(cfg.ConcurrencyLimit) + } + + for i := range l { + if targets[i] == nil { + results[i] = Result{Index: i, Failed: true} + continue + } + eg.Go(func() error { + results[i].Index = i + defer mech.RecoverFanoutPanic(cfg.Mechanism, i, func() { + results[i].Failed = true + results[i].Capture = nil + }) + + r2, crw, err := prepareClone(ctx, parent, i, cfg) + if err != nil { + results[i].Failed = true + results[i].Err = err + metrics.ALBFanoutFailures.WithLabelValues(cfg.Mechanism, "clone").Inc() + return err + } + results[i].Request = r2 + results[i].Capture = crw + + targets[i].Handler().ServeHTTP(crw, r2) + if crw.Truncated() { + results[i].Failed = true + metrics.ALBFanoutFailures.WithLabelValues(cfg.Mechanism, "truncated").Inc() + } + if cfg.OnResult != nil { + cfg.OnResult(i, &results[i]) + } + return nil + }) + } + + err := eg.Wait() + if err != nil { + logger.Warn("alb fanout gather failure", logging.Pairs{ + "mech": cfg.Mechanism, "error": err, + }) + } + return results, err +} + +// PrepareClone produces one safe, capture-wrapped clone of parent suitable +// for handing to a pool member's handler. Mechanisms that own their own +// goroutine orchestration (FR) call this to avoid re-implementing the +// clone + ctx + resources + bounded capture sequence. +// +// PrepareClone does NOT prime parent's body; callers using their own +// goroutine pool must call request.GetBody on parent before spawning, or +// the per-clone GetBody calls will race on r.Body / rsc.RequestBody. +func PrepareClone(ctx context.Context, parent *http.Request, idx int, cfg Config) (*http.Request, *capture.CaptureResponseWriter, error) { + return prepareClone(ctx, parent, idx, cfg) +} + +// PrimeBody ensures parent has a Resources value and a cached body so +// fanout goroutines can clone it concurrently without racing on r.Body. +// Returns the (possibly new) request; callers must use the returned value +// as their parent for the subsequent fanout call. No-op for GET/HEAD. +func PrimeBody(parent *http.Request) (*http.Request, error) { + if request.GetResources(parent) == nil { + parent = request.SetResources(parent, &request.Resources{}) + } + if _, err := request.GetBody(parent); err != nil { + return parent, err + } + return parent, nil +} + +func prepareClone(ctx context.Context, parent *http.Request, idx int, cfg Config) (*http.Request, *capture.CaptureResponseWriter, error) { + r2, err := request.CloneWithoutResources(parent) + if err != nil { + return nil, nil, err + } + cloneCtx := ctx + if cfg.Context != nil { + cloneCtx = cfg.Context(ctx) + } + r2 = r2.WithContext(cloneCtx) + if cfg.Resources != nil { + if rsc := cfg.Resources(idx); rsc != nil { + r2 = request.SetResources(r2, rsc) + } + } + maxBytes := cfg.MaxCaptureBytes + if maxBytes == 0 { + maxBytes = capture.DefaultMaxBytes + } + crw := capture.NewCaptureResponseWriterWithLimit(maxBytes) + return r2, crw, nil +} diff --git a/pkg/backends/alb/mech/fanout/fanout_test.go b/pkg/backends/alb/mech/fanout/fanout_test.go new file mode 100644 index 000000000..12ae18234 --- /dev/null +++ b/pkg/backends/alb/mech/fanout/fanout_test.go @@ -0,0 +1,496 @@ +/* + * Copyright 2018 The Trickster Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fanout + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + + "github.com/trickstercache/trickster/v2/pkg/backends/alb/pool" + "github.com/trickstercache/trickster/v2/pkg/backends/healthcheck" + "github.com/trickstercache/trickster/v2/pkg/proxy/request" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func mkTarget(_ string, h http.HandlerFunc) *pool.Target { + return pool.NewTarget(h, &healthcheck.Status{}, nil) +} + +func newParentGET(t *testing.T) *http.Request { + t.Helper() + r, err := http.NewRequest(http.MethodGet, "http://trickstercache.org/", nil) + require.NoError(t, err) + return r +} + +func newParentPOST(t *testing.T, body string) *http.Request { + t.Helper() + r, err := http.NewRequest(http.MethodPost, "http://trickstercache.org/api/v1/query_range", strings.NewReader(body)) + require.NoError(t, err) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + return r +} + +func TestAllOrderedResults(t *testing.T) { + const n = 5 + targets := make(pool.Targets, n) + for i := range n { + body := fmt.Sprintf("slot-%d", i) + targets[i] = mkTarget(body, func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(body)) + }) + } + parent := newParentGET(t) + results, _ := All(context.Background(), parent, targets, Config{Mechanism: "test"}) + require.Len(t, results, n) + for i := range n { + require.Equal(t, i, results[i].Index) + require.False(t, results[i].Failed, "slot %d failed: %v", i, results[i].Err) + require.NotNil(t, results[i].Capture) + require.Equal(t, fmt.Sprintf("slot-%d", i), string(results[i].Capture.Body())) + } +} + +func TestAllCtxCancelPropagation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + started := make(chan struct{}, 3) + mk := func() *pool.Target { + return mkTarget("block", func(w http.ResponseWriter, r *http.Request) { + started <- struct{}{} + <-r.Context().Done() + w.WriteHeader(http.StatusGatewayTimeout) + }) + } + targets := pool.Targets{mk(), mk(), mk()} + parent := newParentGET(t) + + done := make(chan []Result, 1) + go func() { + results, _ := All(ctx, parent, targets, Config{Mechanism: "test"}) + done <- results + }() + for range 3 { + select { + case <-started: + case <-time.After(2 * time.Second): + t.Fatal("handler never started") + } + } + cancel() + select { + case results := <-done: + require.Len(t, results, 3) + case <-time.After(3 * time.Second): + t.Fatal("All did not return after ctx cancel") + } +} + +func TestAllPanicRecovered(t *testing.T) { + targets := pool.Targets{ + mkTarget("ok-0", func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("ok-0")) + }), + mkTarget("panic", func(_ http.ResponseWriter, _ *http.Request) { + panic("boom") + }), + mkTarget("ok-2", func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("ok-2")) + }), + } + parent := newParentGET(t) + results, _ := All(context.Background(), parent, targets, Config{Mechanism: "test"}) + require.Len(t, results, 3) + require.False(t, results[0].Failed) + require.Equal(t, "ok-0", string(results[0].Capture.Body())) + require.True(t, results[1].Failed) + require.Nil(t, results[1].Capture) + require.False(t, results[2].Failed) + require.Equal(t, "ok-2", string(results[2].Capture.Body())) +} + +func TestAllConcurrencyLimit(t *testing.T) { + const n = 10 + const limit = 2 + var inFlight atomic.Int32 + var maxSeen atomic.Int32 + targets := make(pool.Targets, n) + for i := range n { + targets[i] = mkTarget("x", func(w http.ResponseWriter, _ *http.Request) { + cur := inFlight.Add(1) + for { + prev := maxSeen.Load() + if cur <= prev || maxSeen.CompareAndSwap(prev, cur) { + break + } + } + time.Sleep(10 * time.Millisecond) + inFlight.Add(-1) + w.WriteHeader(http.StatusOK) + }) + } + parent := newParentGET(t) + results, _ := All(context.Background(), parent, targets, Config{Mechanism: "test", ConcurrencyLimit: limit}) + require.Len(t, results, n) + require.LessOrEqual(t, int(maxSeen.Load()), limit, "max in-flight exceeded limit") + require.Greater(t, int(maxSeen.Load()), 0, "no concurrency observed") +} + +func TestAllCaptureBound(t *testing.T) { + const max = 1024 + big := bytes.Repeat([]byte("a"), 100*1024) + targets := pool.Targets{ + mkTarget("big", func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write(big) + }), + } + parent := newParentGET(t) + results, _ := All(context.Background(), parent, targets, Config{Mechanism: "test", MaxCaptureBytes: max}) + require.Len(t, results, 1) + require.True(t, results[0].Failed, "truncation must surface as failure") + require.LessOrEqual(t, len(results[0].Capture.Body()), max) +} + +func TestAllResourcesPerSlot(t *testing.T) { + const n = 4 + seen := make([]*request.Resources, n) + var mu sync.Mutex + targets := make(pool.Targets, n) + for i := range n { + idx := i + targets[i] = mkTarget("x", func(_ http.ResponseWriter, r *http.Request) { + mu.Lock() + seen[idx] = request.GetResources(r) + mu.Unlock() + }) + } + created := make([]*request.Resources, n) + cfg := Config{ + Mechanism: "test", + Resources: func(idx int) *request.Resources { + rsc := &request.Resources{} + created[idx] = rsc + return rsc + }, + } + parent := newParentGET(t) + _, _ = All(context.Background(), parent, targets, cfg) + + for i := range n { + require.NotNil(t, seen[i]) + require.Same(t, created[i], seen[i], "slot %d got a shared/foreign Resources", i) + for j := range n { + if i == j { + continue + } + require.NotSame(t, seen[i], seen[j], "slots %d and %d share Resources", i, j) + } + } +} + +func TestAllContextTransform(t *testing.T) { + type ctxK struct{} + targets := pool.Targets{ + mkTarget("ctx", func(_ http.ResponseWriter, r *http.Request) { + v, _ := r.Context().Value(ctxK{}).(string) + require.Equal(t, "wrapped", v) + }), + } + cfg := Config{ + Mechanism: "test", + Context: func(parent context.Context) context.Context { + return context.WithValue(parent, ctxK{}, "wrapped") + }, + } + parent := newParentGET(t) + results, _ := All(context.Background(), parent, targets, cfg) + require.Len(t, results, 1) + require.False(t, results[0].Failed) + v, _ := results[0].Request.Context().Value(ctxK{}).(string) + require.Equal(t, "wrapped", v) +} + +func TestAllOnResultRunsInGoroutine(t *testing.T) { + const n = 6 + var calls atomic.Int32 + targets := make(pool.Targets, n) + for i := range n { + body := fmt.Sprintf("b-%d", i) + targets[i] = mkTarget(body, func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte(body)) + }) + } + cfg := Config{ + Mechanism: "test", + OnResult: func(idx int, r *Result) { + require.Equal(t, idx, r.Index) + require.NotNil(t, r.Request) + require.NotNil(t, r.Capture) + calls.Add(1) + }, + } + parent := newParentGET(t) + results, _ := All(context.Background(), parent, targets, cfg) + require.Len(t, results, n) + require.Equal(t, int32(n), calls.Load()) +} + +func TestAllOnResultThreadSafety(t *testing.T) { + const n = 50 + var mu sync.Mutex + got := make([]int, 0, n) + targets := make(pool.Targets, n) + for i := range n { + targets[i] = mkTarget("x", func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + } + cfg := Config{ + Mechanism: "test", + OnResult: func(idx int, _ *Result) { + mu.Lock() + got = append(got, idx) + mu.Unlock() + }, + } + parent := newParentGET(t) + _, _ = All(context.Background(), parent, targets, cfg) + mu.Lock() + defer mu.Unlock() + require.Len(t, got, n) + seen := make(map[int]bool, n) + for _, idx := range got { + require.False(t, seen[idx], "duplicate idx %d", idx) + seen[idx] = true + } +} + +func TestAllNilTarget(t *testing.T) { + targets := pool.Targets{ + mkTarget("ok-0", func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("ok-0")) + }), + nil, + mkTarget("ok-2", func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("ok-2")) + }), + } + parent := newParentGET(t) + results, _ := All(context.Background(), parent, targets, Config{Mechanism: "test"}) + require.Len(t, results, 3) + require.False(t, results[0].Failed) + require.Equal(t, "ok-0", string(results[0].Capture.Body())) + require.True(t, results[1].Failed) + require.Nil(t, results[1].Capture) + require.False(t, results[2].Failed) + require.Equal(t, "ok-2", string(results[2].Capture.Body())) +} + +func TestPrimeBodyForGET(t *testing.T) { + parent := newParentGET(t) + out, err := PrimeBody(parent) + require.NoError(t, err) + require.NotNil(t, out) + require.Equal(t, parent.Method, out.Method) + require.Equal(t, parent.URL.String(), out.URL.String()) + require.NotNil(t, request.GetResources(out)) +} + +func TestPrimeBodyForPOST(t *testing.T) { + const body = `{"q":"up"}` + parent := newParentPOST(t, body) + out, err := PrimeBody(parent) + require.NoError(t, err) + require.NotNil(t, out) + rsc := request.GetResources(out) + require.NotNil(t, rsc) + require.Equal(t, body, string(rsc.RequestBody)) + + out2, err := PrimeBody(out) + require.NoError(t, err) + require.Same(t, rsc, request.GetResources(out2)) + + for range 3 { + b, err := io.ReadAll(out2.Body) + require.NoError(t, err) + require.Equal(t, body, string(b)) + _ = out2.Body.Close() + out2.Body = io.NopCloser(bytes.NewReader(rsc.RequestBody)) + } +} + +func TestPrimeBodyConcurrentClonesAreRaceFree(t *testing.T) { + const body = `{"query":"sum(rate(metric[5m]))","start":"2024-01-01T00:00:00Z","end":"2024-01-01T01:00:00Z","step":"15s"}` + parent := newParentPOST(t, body) + primed, err := PrimeBody(parent) + require.NoError(t, err) + + const callers = 16 + var wg sync.WaitGroup + errs := make(chan error, callers) + for i := range callers { + idx := i + wg.Go(func() { + r2, _, err := PrepareClone(context.Background(), primed, idx, Config{Mechanism: "test"}) + if err != nil { + errs <- err + return + } + b, err := io.ReadAll(r2.Body) + if err != nil { + errs <- err + return + } + if string(b) != body { + errs <- fmt.Errorf("caller %d got %d bytes want %d", idx, len(b), len(body)) + return + } + errs <- nil + }) + } + wg.Wait() + close(errs) + for e := range errs { + require.NoError(t, e) + } +} + +func TestPrepareCloneRespectsMaxBytes(t *testing.T) { + const max = 64 + parent := newParentGET(t) + r2, crw, err := PrepareClone(context.Background(), parent, 0, Config{Mechanism: "test", MaxCaptureBytes: max}) + require.NoError(t, err) + require.NotNil(t, r2) + require.NotNil(t, crw) + + n, err := crw.Write(bytes.Repeat([]byte("a"), max*4)) + require.NoError(t, err) + require.Equal(t, max*4, n) + require.True(t, crw.Truncated()) + require.LessOrEqual(t, len(crw.Body()), max) +} + +func TestPrepareCloneNilResources(t *testing.T) { + parent := newParentGET(t) + cfg := Config{ + Mechanism: "test", + Resources: func(_ int) *request.Resources { return nil }, + } + r2, _, err := PrepareClone(context.Background(), parent, 0, cfg) + require.NoError(t, err) + require.Nil(t, request.GetResources(r2)) +} + +func TestAllNoLeaksOnNormalCompletion(t *testing.T) { + const n = 4 + targets := make(pool.Targets, n) + for i := range n { + targets[i] = mkTarget("x", func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + }) + } + parent := newParentGET(t) + results, _ := All(context.Background(), parent, targets, Config{Mechanism: "test"}) + require.Len(t, results, n) +} + +func TestAllEmptyTargets(t *testing.T) { + parent := newParentGET(t) + results, _ := All(context.Background(), parent, pool.Targets{}, Config{Mechanism: "test"}) + require.Empty(t, results) +} + +type errReader struct{} + +func (errReader) Read(_ []byte) (int, error) { return 0, fmt.Errorf("read failed") } + +func TestAllCloneErrorSurfaces(t *testing.T) { + parent, err := http.NewRequest(http.MethodPost, "http://trickstercache.org/", errReader{}) + require.NoError(t, err) + targets := pool.Targets{ + mkTarget("unreached", func(w http.ResponseWriter, _ *http.Request) { + t.Fatal("handler should not run when clone fails") + w.WriteHeader(http.StatusOK) + }), + } + results, _ := All(context.Background(), parent, targets, Config{Mechanism: "test"}) + require.Len(t, results, 1) + require.True(t, results[0].Failed) + require.Error(t, results[0].Err) +} + +func TestPrimeBodyReadError(t *testing.T) { + parent, err := http.NewRequest(http.MethodPost, "http://trickstercache.org/", errReader{}) + require.NoError(t, err) + _, perr := PrimeBody(parent) + require.Error(t, perr) +} + +func TestPrepareCloneError(t *testing.T) { + parent, err := http.NewRequest(http.MethodPost, "http://trickstercache.org/", errReader{}) + require.NoError(t, err) + _, _, perr := PrepareClone(context.Background(), parent, 0, Config{Mechanism: "test"}) + require.Error(t, perr) +} + +func TestAllNoLeaksOnCtxCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + const n = 3 + started := make(chan struct{}, n) + targets := make(pool.Targets, n) + for i := range n { + targets[i] = mkTarget("block", func(w http.ResponseWriter, r *http.Request) { + started <- struct{}{} + <-r.Context().Done() + w.WriteHeader(http.StatusGatewayTimeout) + }) + } + parent := newParentGET(t) + done := make(chan struct{}) + go func() { + _, _ = All(ctx, parent, targets, Config{Mechanism: "test"}) + close(done) + }() + for range n { + select { + case <-started: + case <-time.After(2 * time.Second): + t.Fatal("handler never started") + } + } + cancel() + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatal("All did not return after ctx cancel") + } +} diff --git a/pkg/backends/alb/mech/fr/first_response.go b/pkg/backends/alb/mech/fr/first_response.go index a4eb07212..3feb9a350 100644 --- a/pkg/backends/alb/mech/fr/first_response.go +++ b/pkg/backends/alb/mech/fr/first_response.go @@ -23,6 +23,7 @@ import ( "sync/atomic" "github.com/trickstercache/trickster/v2/pkg/backends/alb/mech" + "github.com/trickstercache/trickster/v2/pkg/backends/alb/mech/fanout" "github.com/trickstercache/trickster/v2/pkg/backends/alb/mech/types" "github.com/trickstercache/trickster/v2/pkg/backends/alb/names" "github.com/trickstercache/trickster/v2/pkg/backends/alb/options" @@ -107,13 +108,8 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { hl[0].Handler().ServeHTTP(w, r) return } - // Ensure the parent has Resources so GetBody caches the body bytes. - // Without a cache, each fanout goroutine's CloneWithoutResources re-reads - // and mutates r.Body, racing on the parent request. - if request.GetResources(r) == nil { - r = request.SetResources(r, &request.Resources{}) - } - if _, err := request.GetBody(r); err != nil { + r, err := fanout.PrimeBody(r) + if err != nil { failures.HandleBadGateway(w, r) return } @@ -155,13 +151,13 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // recover so a single bad upstream doesn't crash the proxy; clear // the slot so the fallback path doesn't serve a partial capture defer mech.RecoverFanoutPanic("fr", i, func() { captures[i] = nil }) - r2, err := request.CloneWithoutResources(r) + r2, crw, err := fanout.PrepareClone(ctx, r, i, fanout.Config{ + Mechanism: "fr", + Resources: func(int) *request.Resources { return &request.Resources{Cancelable: true} }, + }) if err != nil { return err } - r2 = r2.WithContext(ctx) - r2 = request.SetResources(r2, &request.Resources{Cancelable: true}) - crw := capture.NewCaptureResponseWriterWithLimit(capture.DefaultMaxBytes) captures[i] = crw hl[i].Handler().ServeHTTP(crw, r2) statusCode := crw.StatusCode() diff --git a/pkg/backends/alb/mech/nlm/newest_last_modified.go b/pkg/backends/alb/mech/nlm/newest_last_modified.go index b0160d915..a27c51ed7 100644 --- a/pkg/backends/alb/mech/nlm/newest_last_modified.go +++ b/pkg/backends/alb/mech/nlm/newest_last_modified.go @@ -21,6 +21,7 @@ import ( "time" "github.com/trickstercache/trickster/v2/pkg/backends/alb/mech" + "github.com/trickstercache/trickster/v2/pkg/backends/alb/mech/fanout" "github.com/trickstercache/trickster/v2/pkg/backends/alb/mech/types" "github.com/trickstercache/trickster/v2/pkg/backends/alb/names" "github.com/trickstercache/trickster/v2/pkg/backends/alb/options" @@ -28,9 +29,7 @@ import ( tctx "github.com/trickstercache/trickster/v2/pkg/proxy/context" "github.com/trickstercache/trickster/v2/pkg/proxy/handlers/trickster/failures" "github.com/trickstercache/trickster/v2/pkg/proxy/headers" - "github.com/trickstercache/trickster/v2/pkg/proxy/request" "github.com/trickstercache/trickster/v2/pkg/proxy/response/capture" - "golang.org/x/sync/errgroup" ) const ( @@ -73,97 +72,63 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { failures.HandleBadGateway(w, r) return } - hl := p.LiveTargets() // should return a fanout list + hl := p.LiveTargets() l := len(hl) if l == 0 { failures.HandleBadGateway(w, r) return } - // just proxy 1:1 if no folds in the fan if l == 1 { hl[0].Handler().ServeHTTP(w, r) return } - // Ensure the parent has Resources so GetBody caches the body bytes. - // Without a cache, each fanout goroutine's CloneWithoutResources re-reads - // and mutates r.Body, racing on the parent request. - if request.GetResources(r) == nil { - r = request.SetResources(r, &request.Resources{}) - } - if _, err := request.GetBody(r); err != nil { + + r, err := fanout.PrimeBody(r) + if err != nil { failures.HandleBadGateway(w, r) return } - // Strip resources from the parent context once; reuse for all goroutines - bareCtx := tctx.ClearResources(r.Context()) - - // Capture all responses with per-slot Last-Modified timestamps - captures := make([]*capture.CaptureResponseWriter, l) - lastMods := make([]time.Time, l) - var eg errgroup.Group - if limit := h.options.ConcurrencyOptions.GetQueryConcurrencyLimit(); limit > 0 { - eg.SetLimit(limit) - } - // Fanout to all healthy targets - for i := range l { - if hl[i] == nil { - continue - } - eg.Go(func() error { - // recover so a single bad upstream doesn't crash the proxy; clear - // the slot so the fallback path doesn't pick a partial capture - defer mech.RecoverFanoutPanic("nlm", i, func() { captures[i] = nil }) - r2, err := request.CloneWithoutResources(r) - if err != nil { - return err - } - r2 = r2.WithContext(bareCtx) - crw := capture.NewCaptureResponseWriterWithLimit(capture.DefaultMaxBytes) - captures[i] = crw - hl[i].Handler().ServeHTTP(crw, r2) - - if lmStr := crw.Header().Get(headers.NameLastModified); lmStr != "" { - // http.ParseTime accepts all three RFC 7231 §7.1.1.1 forms - // (IMF-fixdate with GMT, RFC 850, ANSI C asctime); time.RFC1123 - // alone rejects the IMF-fixdate "GMT" suffix that real servers emit - if lm, err := http.ParseTime(lmStr); err == nil { - lastMods[i] = lm - } - } - return nil - }) - } - - // Wait for all responses to complete - eg.Wait() + results, _ := fanout.All(r.Context(), r, hl, fanout.Config{ + Mechanism: "nlm", + ConcurrencyLimit: h.options.ConcurrencyOptions.GetQueryConcurrencyLimit(), + Context: tctx.ClearResources, + }) - // Find the response with the newest Last-Modified newestIdx := -1 var newestTime time.Time - for i, lm := range lastMods { - if !lm.IsZero() && (newestIdx == -1 || lm.After(newestTime)) { + for i, res := range results { + if res.Capture == nil { + continue + } + lmStr := res.Capture.Header().Get(headers.NameLastModified) + if lmStr == "" { + continue + } + lm, perr := http.ParseTime(lmStr) + if perr != nil { + continue + } + if newestIdx == -1 || lm.After(newestTime) { newestIdx = i newestTime = lm } } - if newestIdx >= 0 && captures[newestIdx] != nil { - writeCapture(w, captures[newestIdx]) + if newestIdx >= 0 { + writeCapture(w, results[newestIdx].Capture) return } - // No valid Last-Modified found; prefer a 2xx capture before falling back - // to the first non-nil response. - for _, crw := range captures { - if crw != nil && crw.StatusCode() >= 200 && crw.StatusCode() < 300 { - writeCapture(w, crw) + for _, res := range results { + if res.Capture != nil && res.Capture.StatusCode() >= 200 && res.Capture.StatusCode() < 300 { + writeCapture(w, res.Capture) return } } - for _, crw := range captures { - if crw != nil { - writeCapture(w, crw) - break + for _, res := range results { + if res.Capture != nil { + writeCapture(w, res.Capture) + return } } } diff --git a/pkg/backends/alb/mech/tsm/time_series_merge.go b/pkg/backends/alb/mech/tsm/time_series_merge.go index bbadbe188..184c40a8e 100644 --- a/pkg/backends/alb/mech/tsm/time_series_merge.go +++ b/pkg/backends/alb/mech/tsm/time_series_merge.go @@ -17,7 +17,6 @@ package tsm import ( - stderrors "errors" "net/http" "strconv" "strings" @@ -25,6 +24,7 @@ import ( "github.com/trickstercache/trickster/v2/pkg/backends" "github.com/trickstercache/trickster/v2/pkg/backends/alb/errors" "github.com/trickstercache/trickster/v2/pkg/backends/alb/mech" + "github.com/trickstercache/trickster/v2/pkg/backends/alb/mech/fanout" "github.com/trickstercache/trickster/v2/pkg/backends/alb/mech/rr" "github.com/trickstercache/trickster/v2/pkg/backends/alb/mech/types" "github.com/trickstercache/trickster/v2/pkg/backends/alb/names" @@ -40,7 +40,6 @@ import ( "github.com/trickstercache/trickster/v2/pkg/proxy/headers" "github.com/trickstercache/trickster/v2/pkg/proxy/params" "github.com/trickstercache/trickster/v2/pkg/proxy/request" - "github.com/trickstercache/trickster/v2/pkg/proxy/response/capture" "github.com/trickstercache/trickster/v2/pkg/proxy/response/merge" "github.com/trickstercache/trickster/v2/pkg/timeseries/dataset" "golang.org/x/sync/errgroup" @@ -329,109 +328,82 @@ func (h *handler) serveStandard( l := len(hl) accumulator := merge.NewAccumulator() - var eg errgroup.Group - if limit := h.tsmOptions.ConcurrencyOptions.GetQueryConcurrencyLimit(); limit > 0 { - eg.SetLimit(limit) - } - results := make([]gatherResult, l) - for i := range l { - if hl[i] == nil { - continue - } - eg.Go(func() error { - // recover so a single bad upstream doesn't crash the proxy; mark - // the slot failed so partial-failure surfacing fires downstream - defer mech.RecoverFanoutPanic("tsm", i, func() { results[i] = gatherResult{failed: true} }) - r2, err := request.CloneWithoutResources(r) - if err != nil { - results[i].failed = true - return err - } - r2 = r2.WithContext(r.Context()) - rsc2 := &request.Resources{ + r, err := fanout.PrimeBody(r) + if err != nil { + failures.HandleBadGateway(w, r) + return + } + + fanoutResults, _ := fanout.All(r.Context(), r, hl, fanout.Config{ + Mechanism: "tsm", + ConcurrencyLimit: h.tsmOptions.ConcurrencyOptions.GetQueryConcurrencyLimit(), + Resources: func(int) *request.Resources { + return &request.Resources{ IsMergeMember: true, TSReqestOptions: rsc.TSReqestOptions, TSMergeStrategy: int(mergeStrategy), } - r2 = request.SetResources(r2, rsc2) - crw := capture.NewCaptureResponseWriterWithLimit(capture.DefaultMaxBytes) - hl[i].Handler().ServeHTTP(crw, r2) - rsc2 = request.GetResources(r2) + }, + OnResult: func(i int, fr *fanout.Result) { + if fr.Failed || fr.Request == nil || fr.Capture == nil { + results[i].failed = true + return + } + rsc2 := request.GetResources(fr.Request) if rsc2 == nil { results[i].failed = true - return stderrors.New("tsm gather failed due to nil resources") + return } - // ensure merge functions are set on cloned request if rsc2.MergeFunc == nil || rsc2.MergeRespondFunc == nil { logger.Warn("tsm gather failed due to nil func", nil) } - // strip injected labels so series from different backends hash - // identically for aggregation if len(stripKeys) > 0 && rsc2.TS != nil { if ds, ok := rsc2.TS.(*dataset.DataSet); ok { ds.StripTags(stripKeys) } } - // as soon as response is complete, unmarshal and merge - // this happens in parallel for each response as it arrives var contributed bool if rsc2.MergeFunc != nil { if rsc2.TS != nil { rsc2.MergeFunc(accumulator, rsc2.TS, i) contributed = true } else { - body, err := encoding.DecompressResponseBody( - crw.Header().Get(headers.NameContentEncoding), - crw.Body(), + body, derr := encoding.DecompressResponseBody( + fr.Capture.Header().Get(headers.NameContentEncoding), + fr.Capture.Body(), ) - if err != nil { + if derr != nil { + logger.Warn("tsm gather decode failure", logging.Pairs{ + "member": i, "error": derr, + }) results[i].failed = true - return err + return } if len(body) > 0 { - // For non-timeseries paths (labels, series, etc.), rsc.TS is not - // populated. Fall back to passing the captured response body to - // MergeFunc, which handles []byte input via JSON unmarshal. rsc2.MergeFunc(accumulator, body, i) contributed = true } } } - // rsc2.Response carries the upstream's true status code (set by - // ObjectProxyCacheRequest for merge members). The CRW status can - // stay at the default 200 when the per-shard handler unmarshals - // an error envelope but writes nothing through the outer writer - // (see prometheus.processVectorTransformations + MarshalTSOrVectorWriter - // returning ErrUnknownFormat for zero-result error responses). - // Without this fallback, mixed 2xx/5xx fanouts look uniformly OK - // to TSM and the partial-hit detection below misses (V2). - sc := crw.StatusCode() + sc := fr.Capture.StatusCode() if rsc2.Response != nil && rsc2.Response.StatusCode > 0 { sc = rsc2.Response.StatusCode } results[i] = gatherResult{ statusCode: sc, - header: crw.Header(), + header: fr.Capture.Header(), mergeFunc: rsc2.MergeRespondFunc, - // A member that wrote no contribution to the accumulator is - // a silent failure: typically the per-member handler hit a - // proxy-error (e.g. unsupported Content-Encoding decoded - // upstream then unmarshal failed) and surfaced 200 + empty - // body to us. Without this, the merged response would look - // identical to a fully-successful fanout. A truncated capture - // (response exceeded the buffer cap) is also surfaced as a - // failure so the merge doesn't silently drop tail data. - failed: !contributed || crw.Truncated(), + failed: !contributed, } - return nil - }) - } + }, + }) - // wait for all fanout requests to complete - if err := eg.Wait(); err != nil { - logger.Warn("tsm gather failure", logging.Pairs{"error": err}) + for i, fr := range fanoutResults { + if fr.Failed && !results[i].failed { + results[i].failed = true + } } // Surface goroutine-level failures (e.g. unsupported Content-Encoding, @@ -548,105 +520,103 @@ func (h *handler) serveWeightedAvg( // POST body, etc.). sumBase, countBase := mp.RewriteForWeightedAvg(r, query) + sumBase, err := fanout.PrimeBody(sumBase) + if err != nil { + failures.HandleBadGateway(w, r) + return + } + countBase, err = fanout.PrimeBody(countBase) + if err != nil { + failures.HandleBadGateway(w, r) + return + } + sumAccum := merge.NewAccumulator() countAccum := merge.NewAccumulator() - var eg errgroup.Group - if limit := h.tsmOptions.ConcurrencyOptions.GetQueryConcurrencyLimit(); limit > 0 { - // Each shard spawns two goroutines; scale the limit accordingly so we - // don't serialize unnecessarily. - eg.SetLimit(limit * 2) - } + limit := h.tsmOptions.ConcurrencyOptions.GetQueryConcurrencyLimit() - // results captures sum-query outcomes only — the response envelope + // results captures sum-query outcomes only -- the response envelope // (status, headers, RespondFunc) comes from the sum side; count results // affect the arithmetic, not the envelope. results := make([]gatherResult, l) - for i := range l { - if hl[i] == nil { - continue + resourcesFn := func(int) *request.Resources { + return &request.Resources{ + IsMergeMember: true, + TSReqestOptions: rsc.TSReqestOptions, + TSMergeStrategy: int(dataset.MergeStrategySum), } - // Sum query for shard i — clone from the pre-rewritten base request. - eg.Go(func() error { - // recover so a single bad upstream doesn't crash the proxy; mark - // the slot failed so partial-failure surfacing fires downstream - defer mech.RecoverFanoutPanic("tsm/avg-sum", i, func() { results[i] = gatherResult{failed: true} }) - r2, err := request.CloneWithoutResources(sumBase) - if err != nil { - return err - } - r2 = r2.WithContext(r.Context()) - rsc2 := &request.Resources{ - IsMergeMember: true, - TSReqestOptions: rsc.TSReqestOptions, - TSMergeStrategy: int(dataset.MergeStrategySum), - } - r2 = request.SetResources(r2, rsc2) - crw := capture.NewCaptureResponseWriterWithLimit(capture.DefaultMaxBytes) - hl[i].Handler().ServeHTTP(crw, r2) - rsc2 = request.GetResources(r2) - if rsc2 == nil { - return stderrors.New("tsm avg/sum gather failed due to nil resources") - } - if len(stripKeys) > 0 && rsc2.TS != nil { - if ds, ok := rsc2.TS.(*dataset.DataSet); ok { - ds.StripTags(stripKeys) + } + + parentCtx := r.Context() + var eg errgroup.Group + eg.Go(func() error { + _, err := fanout.All(parentCtx, sumBase, hl, fanout.Config{ + Mechanism: "tsm/avg-sum", + ConcurrencyLimit: limit, + Resources: resourcesFn, + OnResult: func(i int, fr *fanout.Result) { + if fr.Failed || fr.Request == nil || fr.Capture == nil { + results[i].failed = true + return } - } - if rsc2.MergeFunc != nil && rsc2.TS != nil { - rsc2.MergeFunc(sumAccum, rsc2.TS, i) - } - // See serveStandard for why we prefer rsc2.Response.StatusCode. - sc := crw.StatusCode() - if rsc2.Response != nil && rsc2.Response.StatusCode > 0 { - sc = rsc2.Response.StatusCode - } - results[i] = gatherResult{ - statusCode: sc, - header: crw.Header(), - mergeFunc: rsc2.MergeRespondFunc, - } - return nil + rsc2 := request.GetResources(fr.Request) + if rsc2 == nil { + results[i].failed = true + return + } + if len(stripKeys) > 0 && rsc2.TS != nil { + if ds, ok := rsc2.TS.(*dataset.DataSet); ok { + ds.StripTags(stripKeys) + } + } + if rsc2.MergeFunc != nil && rsc2.TS != nil { + rsc2.MergeFunc(sumAccum, rsc2.TS, i) + } + sc := fr.Capture.StatusCode() + if rsc2.Response != nil && rsc2.Response.StatusCode > 0 { + sc = rsc2.Response.StatusCode + } + results[i] = gatherResult{ + statusCode: sc, + header: fr.Capture.Header(), + mergeFunc: rsc2.MergeRespondFunc, + } + }, }) - - // Count query for shard i — clone from the pre-rewritten base request. - eg.Go(func() error { - // recover so a single bad upstream doesn't crash the proxy; the - // count side doesn't own the response envelope, so we don't touch - // results[i] here — the sum-side recover handles that - defer mech.RecoverFanoutPanic("tsm/avg-count", i, nil) - r2, err := request.CloneWithoutResources(countBase) - if err != nil { - return err - } - r2 = r2.WithContext(r.Context()) - rsc2 := &request.Resources{ - IsMergeMember: true, - TSReqestOptions: rsc.TSReqestOptions, - TSMergeStrategy: int(dataset.MergeStrategySum), - } - r2 = request.SetResources(r2, rsc2) - crw := capture.NewCaptureResponseWriterWithLimit(capture.DefaultMaxBytes) - hl[i].Handler().ServeHTTP(crw, r2) - rsc2 = request.GetResources(r2) - if rsc2 == nil { - return stderrors.New("tsm avg/count gather failed due to nil resources") - } - if len(stripKeys) > 0 && rsc2.TS != nil { - if ds, ok := rsc2.TS.(*dataset.DataSet); ok { - ds.StripTags(stripKeys) + return err + }) + eg.Go(func() error { + _, err := fanout.All(parentCtx, countBase, hl, fanout.Config{ + Mechanism: "tsm/avg-count", + ConcurrencyLimit: limit, + Resources: resourcesFn, + OnResult: func(i int, fr *fanout.Result) { + // Count-side intentionally does not touch results[i]: the + // sum-side owns the response envelope. Panics in this slot + // are already recovered + countered by fanout.All. + if fr.Failed || fr.Request == nil { + return } - } - if rsc2.MergeFunc != nil && rsc2.TS != nil { - rsc2.MergeFunc(countAccum, rsc2.TS, i) - } - return nil + rsc2 := request.GetResources(fr.Request) + if rsc2 == nil { + return + } + if len(stripKeys) > 0 && rsc2.TS != nil { + if ds, ok := rsc2.TS.(*dataset.DataSet); ok { + ds.StripTags(stripKeys) + } + } + if rsc2.MergeFunc != nil && rsc2.TS != nil { + rsc2.MergeFunc(countAccum, rsc2.TS, i) + } + }, }) - } - + return err + }) if err := eg.Wait(); err != nil { - logger.Warn("tsm weighted-avg gather failure", logging.Pairs{"error": err}) + logger.Warn("tsm avg gather failure", logging.Pairs{"error": err}) } // Finalize: divide sum totals by count totals to obtain the weighted From 98bd33e9bbca59294209088341b35c39cca13937 Mon Sep 17 00:00:00 2001 From: Chris Randles Date: Wed, 13 May 2026 16:23:48 -0400 Subject: [PATCH 5/6] feat: [alb,backends] make max_capture_bytes configurable Signed-off-by: Chris Randles --- docs/alb.md | 25 +++++++++++++++++ examples/conf/example.full.yaml | 10 +++++++ pkg/backends/alb/client.go | 3 +++ pkg/backends/alb/mech/fanout/fanout.go | 24 +++++++---------- pkg/backends/alb/mech/fr/first_response.go | 27 ++++++++++++------- .../alb/mech/nlm/newest_last_modified.go | 7 +++-- .../alb/mech/tsm/time_series_merge.go | 5 ++++ pkg/backends/alb/options/options.go | 7 +++++ pkg/backends/options/defaults.go | 7 +++++ pkg/backends/options/options.go | 8 ++++++ pkg/backends/prometheus/handler_labels.go | 2 +- pkg/backends/prometheus/handler_query.go | 2 +- pkg/backends/prometheus/prometheus.go | 13 +++++++++ pkg/config/testdata/example.alb.yaml | 7 +++++ pkg/config/testdata/example.auth.yaml | 4 +++ pkg/config/testdata/example.full.yaml | 1 + pkg/config/testdata/exmple.sharding.yaml | 1 + pkg/config/testdata/simple.prometheus.yaml | 1 + .../testdata/simple.reverseproxycache.yaml | 1 + 19 files changed, 127 insertions(+), 28 deletions(-) diff --git a/docs/alb.md b/docs/alb.md index e93cde5f5..7b2781a50 100644 --- a/docs/alb.md +++ b/docs/alb.md @@ -387,6 +387,31 @@ The User Router does not rotate through or fanout to a Pool of Backends like the You can configure a User Router ALB's backend destinations to be other ALBs with mechanisms that utilize healthchecked pools. +## Bounding Per-Member Response Captures + +ALB mechanisms that fan out (TSM, FR, FGR, NLM) buffer each pool member's response in memory before merging or selecting a winner. Without a cap, one misbehaving upstream returning an oversized body can OOM the proxy -- an N-way fanout multiplies that by N. + +Trickster applies a default cap of **256 MiB** per response. A member whose body exceeds the cap is treated as a partial failure: the merged response carries an `X-Trickster-Result: phit` marker and the `trickster_alb_fanout_failures_total{mechanism, reason="truncated"}` metric increments. + +Override the cap at the backend or ALB level: + +```yaml +backends: + default: + max_capture_bytes: 67108864 # 64 MiB, applies to all backends (Prometheus, ClickHouse, ALB members, etc.) + + prom-alb-tsm: + provider: alb + alb: + mechanism: tsm + max_capture_bytes: 16777216 # 16 MiB, ALB-specific override + pool: + - prom01 + - prom02 +``` + +The ALB-level value takes precedence over the backend-level value, which in turn takes precedence over the 256 MiB default. + ## Maintaining Healthy Pools With Automated Health Check Integrations Health Checks are configured per-Backend as described in the [Health documentation](./health.md). Each Backend's health checker will notify all ALB pools of which it is a member when its health status changes, so long as it has been configured with a [health check interval](./health#example+health+check+configuration+for+use+in+alb) for automated checking. When an ALB is notified that the state of a pool member has changed, the ALB will reconstruct its list of healthy pool members before serving the next request. diff --git a/examples/conf/example.full.yaml b/examples/conf/example.full.yaml index 87d2a29c0..a8558191d 100644 --- a/examples/conf/example.full.yaml +++ b/examples/conf/example.full.yaml @@ -337,6 +337,11 @@ backends: # # max_object_size_bytes defines the largest byte size an object may be before it is uncacheable due to size. default is 524288 (512k) # max_object_size_bytes: 524288 +# # max_capture_bytes caps the per-response in-memory capture buffer used by ALB fanout and Prometheus +# # transform/label handlers. A member whose body exceeds the cap is treated as a partial failure +# # (X-Trickster-Result: phit) rather than truncating the merged response silently. default is 268435456 (256 MiB) +# max_capture_bytes: 268435456 + # # These next 7 settings only apply to Time Series backends # # backfill_tolerance prevents new datapoints that fall within the tolerance window (relative to time.Now) from being permanently @@ -557,6 +562,11 @@ backends: # # -1 includes all backends, regardless of reporting state # # default is 0 # healthy_floor: 0 + +# # max_capture_bytes overrides the backend-level max_capture_bytes for this ALB's fanout members. +# # Set this when the ALB's expected response shape differs from the backend default. When 0 (the +# # default), the parent Backend's max_capture_bytes is used, falling back to 268435456 (256 MiB). +# max_capture_bytes: 16777216 # 16 MiB # fgr: # First Good Response mechanism options, only applicable when mechanism is set to fgr # # status_codes is a list of status codes considered 'good' when using the fgr mechanism # # when this is not set, any response code < 400 is considered good. Use this setting to diff --git a/pkg/backends/alb/client.go b/pkg/backends/alb/client.go index c440e17df..0f9f22abb 100644 --- a/pkg/backends/alb/client.go +++ b/pkg/backends/alb/client.go @@ -68,6 +68,9 @@ func NewClient(name string, o *bo.Options, router http.Handler, } c.Backend = b if o != nil && o.ALBOptions != nil { + if o.ALBOptions.MaxCaptureBytes == 0 { + o.ALBOptions.MaxCaptureBytes = o.MaxCaptureBytes + } m, err := registry.New(o.ALBOptions.MechanismName, o.ALBOptions, factories) if err != nil { diff --git a/pkg/backends/alb/mech/fanout/fanout.go b/pkg/backends/alb/mech/fanout/fanout.go index 4ad682541..4e228945f 100644 --- a/pkg/backends/alb/mech/fanout/fanout.go +++ b/pkg/backends/alb/mech/fanout/fanout.go @@ -134,7 +134,7 @@ func All(ctx context.Context, parent *http.Request, targets pool.Targets, cfg Co results[i].Capture = nil }) - r2, crw, err := prepareClone(ctx, parent, i, cfg) + r2, crw, err := PrepareClone(ctx, parent, i, cfg) if err != nil { results[i].Failed = true results[i].Err = err @@ -165,18 +165,6 @@ func All(ctx context.Context, parent *http.Request, targets pool.Targets, cfg Co return results, err } -// PrepareClone produces one safe, capture-wrapped clone of parent suitable -// for handing to a pool member's handler. Mechanisms that own their own -// goroutine orchestration (FR) call this to avoid re-implementing the -// clone + ctx + resources + bounded capture sequence. -// -// PrepareClone does NOT prime parent's body; callers using their own -// goroutine pool must call request.GetBody on parent before spawning, or -// the per-clone GetBody calls will race on r.Body / rsc.RequestBody. -func PrepareClone(ctx context.Context, parent *http.Request, idx int, cfg Config) (*http.Request, *capture.CaptureResponseWriter, error) { - return prepareClone(ctx, parent, idx, cfg) -} - // PrimeBody ensures parent has a Resources value and a cached body so // fanout goroutines can clone it concurrently without racing on r.Body. // Returns the (possibly new) request; callers must use the returned value @@ -191,7 +179,15 @@ func PrimeBody(parent *http.Request) (*http.Request, error) { return parent, nil } -func prepareClone(ctx context.Context, parent *http.Request, idx int, cfg Config) (*http.Request, *capture.CaptureResponseWriter, error) { +// PrepareClone produces one safe, capture-wrapped clone of parent suitable +// for handing to a pool member's handler. Mechanisms that own their own +// goroutine orchestration (FR) call this to avoid re-implementing the +// clone + ctx + resources + bounded capture sequence. +// +// PrepareClone does NOT prime parent's body; callers using their own +// goroutine pool must call request.GetBody on parent before spawning, or +// the per-clone GetBody calls will race on r.Body / rsc.RequestBody. +func PrepareClone(ctx context.Context, parent *http.Request, idx int, cfg Config) (*http.Request, *capture.CaptureResponseWriter, error) { r2, err := request.CloneWithoutResources(parent) if err != nil { return nil, nil, err diff --git a/pkg/backends/alb/mech/fr/first_response.go b/pkg/backends/alb/mech/fr/first_response.go index 3feb9a350..6866db9c2 100644 --- a/pkg/backends/alb/mech/fr/first_response.go +++ b/pkg/backends/alb/mech/fr/first_response.go @@ -46,9 +46,10 @@ const ( type handler struct { mech.PoolHolder - fgr bool - fgrCodes sets.Set[int] - options options.FirstGoodResponseOptions + fgr bool + fgrCodes sets.Set[int] + options options.FirstGoodResponseOptions + maxCaptureBytes int } func RegistryEntry() types.RegistryEntry { @@ -61,14 +62,19 @@ func RegistryEntryFGR() types.RegistryEntry { func NewFGR(o *options.Options, _ rt.Lookup) (types.Mechanism, error) { return &handler{ - fgr: true, - fgrCodes: o.FgrCodesLookup, - options: o.FGROptions, + fgr: true, + fgrCodes: o.FgrCodesLookup, + options: o.FGROptions, + maxCaptureBytes: o.MaxCaptureBytes, }, nil } -func New(_ *options.Options, _ rt.Lookup) (types.Mechanism, error) { - return &handler{}, nil +func New(o *options.Options, _ rt.Lookup) (types.Mechanism, error) { + h := &handler{} + if o != nil { + h.maxCaptureBytes = o.MaxCaptureBytes + } + return h, nil } func (h *handler) ID() types.ID { @@ -152,8 +158,9 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // the slot so the fallback path doesn't serve a partial capture defer mech.RecoverFanoutPanic("fr", i, func() { captures[i] = nil }) r2, crw, err := fanout.PrepareClone(ctx, r, i, fanout.Config{ - Mechanism: "fr", - Resources: func(int) *request.Resources { return &request.Resources{Cancelable: true} }, + Mechanism: "fr", + MaxCaptureBytes: h.maxCaptureBytes, + Resources: func(int) *request.Resources { return &request.Resources{Cancelable: true} }, }) if err != nil { return err diff --git a/pkg/backends/alb/mech/nlm/newest_last_modified.go b/pkg/backends/alb/mech/nlm/newest_last_modified.go index a27c51ed7..61523318b 100644 --- a/pkg/backends/alb/mech/nlm/newest_last_modified.go +++ b/pkg/backends/alb/mech/nlm/newest_last_modified.go @@ -39,7 +39,8 @@ const ( type handler struct { mech.PoolHolder - options options.NewestLastModifiedOptions + options options.NewestLastModifiedOptions + maxCaptureBytes int } func RegistryEntry() types.RegistryEntry { @@ -48,7 +49,8 @@ func RegistryEntry() types.RegistryEntry { func New(o *options.Options, _ rt.Lookup) (types.Mechanism, error) { return &handler{ - options: o.NLMOptions, + options: o.NLMOptions, + maxCaptureBytes: o.MaxCaptureBytes, }, nil } @@ -92,6 +94,7 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { results, _ := fanout.All(r.Context(), r, hl, fanout.Config{ Mechanism: "nlm", ConcurrencyLimit: h.options.ConcurrencyOptions.GetQueryConcurrencyLimit(), + MaxCaptureBytes: h.maxCaptureBytes, Context: tctx.ClearResources, }) diff --git a/pkg/backends/alb/mech/tsm/time_series_merge.go b/pkg/backends/alb/mech/tsm/time_series_merge.go index 184c40a8e..ef42670b9 100644 --- a/pkg/backends/alb/mech/tsm/time_series_merge.go +++ b/pkg/backends/alb/mech/tsm/time_series_merge.go @@ -57,6 +57,7 @@ type handler struct { nonmergeHandler types.Mechanism // when methodology is tsmerge, this handler is for non-mergeable paths outputFormat string // the provider output format (e.g., "prometheus") tsmOptions options.TimeSeriesMergeOptions + maxCaptureBytes int } func RegistryEntry() types.RegistryEntry { @@ -68,6 +69,7 @@ func New(o *options.Options, factories rt.Lookup) (types.Mechanism, error) { out := &handler{ nonmergeHandler: nmh, tsmOptions: o.TSMOptions, + maxCaptureBytes: o.MaxCaptureBytes, } // this validates the merge configuration for the ALB client as it sets it up // First, verify the output format is a support merge provider @@ -339,6 +341,7 @@ func (h *handler) serveStandard( fanoutResults, _ := fanout.All(r.Context(), r, hl, fanout.Config{ Mechanism: "tsm", ConcurrencyLimit: h.tsmOptions.ConcurrencyOptions.GetQueryConcurrencyLimit(), + MaxCaptureBytes: h.maxCaptureBytes, Resources: func(int) *request.Resources { return &request.Resources{ IsMergeMember: true, @@ -555,6 +558,7 @@ func (h *handler) serveWeightedAvg( _, err := fanout.All(parentCtx, sumBase, hl, fanout.Config{ Mechanism: "tsm/avg-sum", ConcurrencyLimit: limit, + MaxCaptureBytes: h.maxCaptureBytes, Resources: resourcesFn, OnResult: func(i int, fr *fanout.Result) { if fr.Failed || fr.Request == nil || fr.Capture == nil { @@ -591,6 +595,7 @@ func (h *handler) serveWeightedAvg( _, err := fanout.All(parentCtx, countBase, hl, fanout.Config{ Mechanism: "tsm/avg-count", ConcurrencyLimit: limit, + MaxCaptureBytes: h.maxCaptureBytes, Resources: resourcesFn, OnResult: func(i int, fr *fanout.Result) { // Count-side intentionally does not touch results[i]: the diff --git a/pkg/backends/alb/options/options.go b/pkg/backends/alb/options/options.go index fb60230c3..cf0c4ebae 100644 --- a/pkg/backends/alb/options/options.go +++ b/pkg/backends/alb/options/options.go @@ -45,6 +45,13 @@ type Options struct { // unknown means the first hc hasn't returned yet, // or (more likely) HealthCheck Interval on target backend is not set HealthyFloor int `yaml:"healthy_floor,omitempty"` + // MaxCaptureBytes overrides the backend-level max_capture_bytes for this + // ALB's fanout members. Set this when the ALB's expected response shape + // differs from the backend default (e.g. a TSM fan-out of 50 small-payload + // shards may safely use a lower cap than the global default to surface + // runaway upstreams faster). When 0, falls back to the parent Backend's + // max_capture_bytes, then to the package-level default (256 MiB). + MaxCaptureBytes int `yaml:"max_capture_bytes,omitempty"` // OutputFormat accompanies the tsmerge Mechanism to indicate the provider output format // options include any valid time seres backend like prometheus, influxdb or clickhouse OutputFormat string `yaml:"output_format,omitempty"` diff --git a/pkg/backends/options/defaults.go b/pkg/backends/options/defaults.go index 492a33e62..ce7407104 100644 --- a/pkg/backends/options/defaults.go +++ b/pkg/backends/options/defaults.go @@ -35,6 +35,13 @@ const ( DefaultRevalidationFactor = 2 // DefaultMaxObjectSizeBytes is the default Max Size of any Cache Object DefaultMaxObjectSizeBytes = 524288 + // DefaultMaxCaptureBytes is the default per-response capture-buffer cap, + // applied to Trickster's internal response captures (ALB fanout, Prometheus + // label/transform handlers, etc.). Set to 256 MiB to protect against a + // single misbehaving upstream OOM'ing the proxy; an N-way ALB fanout would + // otherwise allocate O(N*M) on the heap if each member returned an + // M-byte body. + DefaultMaxCaptureBytes = 256 * 1024 * 1024 // DefaultBackendTRF is the default Timeseries Retention Factor for Time Series-based Backends DefaultBackendTRF = 1024 // DefaultBackendTEM is the default Timeseries Eviction Method for Time Series-based Backends diff --git a/pkg/backends/options/options.go b/pkg/backends/options/options.go index 589861e99..8cb8a446a 100644 --- a/pkg/backends/options/options.go +++ b/pkg/backends/options/options.go @@ -116,6 +116,13 @@ type Options struct { RevalidationFactor float64 `yaml:"revalidation_factor,omitempty"` // MaxObjectSizeBytes specifies the max objectsize to be accepted for any given cache object MaxObjectSizeBytes int `yaml:"max_object_size_bytes,omitempty"` + // MaxCaptureBytes caps the per-response in-memory capture buffer that + // Trickster's internal capture writer allocates when an ALB mechanism + // fans out to pool members or a backend handler transforms an upstream + // response. A member whose body exceeds the cap is treated as a + // partial-failure (the merge surfaces an X-Trickster-Result phit marker) + // rather than truncating the response silently. Defaults to 256 MiB. + MaxCaptureBytes int `yaml:"max_capture_bytes,omitempty"` // CompressibleTypeList specifies the HTTP Object Content Types that will be compressed internally // when stored in the Trickster cache or served to clients with a compatible 'Accept-Encoding' header CompressibleTypeList []string `yaml:"compressible_types,omitempty"` @@ -244,6 +251,7 @@ func New() *Options { KeepAliveTimeout: DefaultKeepAliveTimeout, MaxIdleConns: DefaultMaxIdleConns, MaxConcurrentConns: DefaultMaxConcurrentConns, + MaxCaptureBytes: DefaultMaxCaptureBytes, MaxObjectSizeBytes: DefaultMaxObjectSizeBytes, MaxTTL: DefaultMaxTTL, NegativeCache: make(map[int]time.Duration), diff --git a/pkg/backends/prometheus/handler_labels.go b/pkg/backends/prometheus/handler_labels.go index 45e69a9d6..e345a054a 100644 --- a/pkg/backends/prometheus/handler_labels.go +++ b/pkg/backends/prometheus/handler_labels.go @@ -48,7 +48,7 @@ func (c *Client) LabelsHandler(w http.ResponseWriter, r *http.Request) { params.SetRequestValues(r, qp) if c.hasTransformations { - sw := capture.NewCaptureResponseWriterWithLimit(capture.DefaultMaxBytes) + sw := capture.NewCaptureResponseWriterWithLimit(captureLimit(c)) engines.ObjectProxyCacheRequest(sw, r) headers.Merge(w.Header(), sw.Header()) body := c.processLabelsResponse(sw.Body(), origPath) diff --git a/pkg/backends/prometheus/handler_query.go b/pkg/backends/prometheus/handler_query.go index 5ac082a73..1c6320afb 100644 --- a/pkg/backends/prometheus/handler_query.go +++ b/pkg/backends/prometheus/handler_query.go @@ -94,7 +94,7 @@ func (c *Client) QueryHandler(w http.ResponseWriter, r *http.Request) { // we need to capture and unmarshal the response if c.hasTransformations || (rsc != nil && rsc.IsMergeMember) { // use a streaming response writer to capture the response body for transformation - sw := capture.NewCaptureResponseWriterWithLimit(capture.DefaultMaxBytes) + sw := capture.NewCaptureResponseWriterWithLimit(captureLimit(c)) engines.ObjectProxyCacheRequest(sw, r) // Propagate captured upstream headers (Content-Type, X-Trickster-Result, // etc.) to the outer ResponseWriter. Without this, ALB mechanisms see a diff --git a/pkg/backends/prometheus/prometheus.go b/pkg/backends/prometheus/prometheus.go index 26bb8ac0f..3750c0d7d 100644 --- a/pkg/backends/prometheus/prometheus.go +++ b/pkg/backends/prometheus/prometheus.go @@ -34,6 +34,7 @@ import ( "github.com/trickstercache/trickster/v2/pkg/cache" "github.com/trickstercache/trickster/v2/pkg/proxy/errors" "github.com/trickstercache/trickster/v2/pkg/proxy/params" + "github.com/trickstercache/trickster/v2/pkg/proxy/response/capture" "github.com/trickstercache/trickster/v2/pkg/timeseries" tt "github.com/trickstercache/trickster/v2/pkg/util/timeconv" ) @@ -147,6 +148,18 @@ type Client struct { injectLabels map[string]string } +// captureLimit returns the per-response capture buffer cap for c. Reads +// max_capture_bytes from the backend's configuration; falls back to the +// package-level default when unset. +func captureLimit(c *Client) int { + if c != nil { + if o := c.Configuration(); o != nil && o.MaxCaptureBytes > 0 { + return o.MaxCaptureBytes + } + } + return capture.DefaultMaxBytes +} + var _ types.NewBackendClientFunc = NewClient // NewClient returns a new Client Instance diff --git a/pkg/config/testdata/example.alb.yaml b/pkg/config/testdata/example.alb.yaml index f8009a912..d6ffe34cc 100644 --- a/pkg/config/testdata/example.alb.yaml +++ b/pkg/config/testdata/example.alb.yaml @@ -19,6 +19,7 @@ backends: max_ttl: 25h0m0s revalidation_factor: 2 max_object_size_bytes: 524288 + max_capture_bytes: 268435456 compressible_types: - text/html - text/javascript @@ -57,6 +58,7 @@ backends: max_ttl: 25h0m0s revalidation_factor: 2 max_object_size_bytes: 524288 + max_capture_bytes: 268435456 compressible_types: - text/html - text/javascript @@ -95,6 +97,7 @@ backends: max_ttl: 25h0m0s revalidation_factor: 2 max_object_size_bytes: 524288 + max_capture_bytes: 268435456 compressible_types: - text/html - text/javascript @@ -132,6 +135,7 @@ backends: max_ttl: 25h0m0s revalidation_factor: 2 max_object_size_bytes: 524288 + max_capture_bytes: 268435456 compressible_types: - text/html - text/javascript @@ -169,6 +173,7 @@ backends: max_ttl: 25h0m0s revalidation_factor: 2 max_object_size_bytes: 524288 + max_capture_bytes: 268435456 compressible_types: - text/html - text/javascript @@ -203,6 +208,7 @@ backends: max_ttl: 25h0m0s revalidation_factor: 2 max_object_size_bytes: 524288 + max_capture_bytes: 268435456 compressible_types: - text/html - text/javascript @@ -243,6 +249,7 @@ backends: max_ttl: 25h0m0s revalidation_factor: 2 max_object_size_bytes: 524288 + max_capture_bytes: 268435456 compressible_types: - text/html - text/javascript diff --git a/pkg/config/testdata/example.auth.yaml b/pkg/config/testdata/example.auth.yaml index f84c870a7..9f9445823 100644 --- a/pkg/config/testdata/example.auth.yaml +++ b/pkg/config/testdata/example.auth.yaml @@ -22,6 +22,7 @@ backends: max_ttl: 25h0m0s revalidation_factor: 2 max_object_size_bytes: 524288 + max_capture_bytes: 268435456 compressible_types: - text/html - text/javascript @@ -59,6 +60,7 @@ backends: max_ttl: 25h0m0s revalidation_factor: 2 max_object_size_bytes: 524288 + max_capture_bytes: 268435456 compressible_types: - text/html - text/javascript @@ -96,6 +98,7 @@ backends: max_ttl: 25h0m0s revalidation_factor: 2 max_object_size_bytes: 524288 + max_capture_bytes: 268435456 compressible_types: - text/html - text/javascript @@ -130,6 +133,7 @@ backends: max_ttl: 25h0m0s revalidation_factor: 2 max_object_size_bytes: 524288 + max_capture_bytes: 268435456 compressible_types: - text/html - text/javascript diff --git a/pkg/config/testdata/example.full.yaml b/pkg/config/testdata/example.full.yaml index 701ba246b..043774342 100644 --- a/pkg/config/testdata/example.full.yaml +++ b/pkg/config/testdata/example.full.yaml @@ -22,6 +22,7 @@ backends: max_ttl: 25h0m0s revalidation_factor: 2 max_object_size_bytes: 524288 + max_capture_bytes: 268435456 compressible_types: - text/html - text/javascript diff --git a/pkg/config/testdata/exmple.sharding.yaml b/pkg/config/testdata/exmple.sharding.yaml index 3b080918f..cacc37575 100644 --- a/pkg/config/testdata/exmple.sharding.yaml +++ b/pkg/config/testdata/exmple.sharding.yaml @@ -22,6 +22,7 @@ backends: max_ttl: 25h0m0s revalidation_factor: 2 max_object_size_bytes: 524288 + max_capture_bytes: 268435456 compressible_types: - text/html - text/javascript diff --git a/pkg/config/testdata/simple.prometheus.yaml b/pkg/config/testdata/simple.prometheus.yaml index ed5e96820..4da983bf6 100644 --- a/pkg/config/testdata/simple.prometheus.yaml +++ b/pkg/config/testdata/simple.prometheus.yaml @@ -22,6 +22,7 @@ backends: max_ttl: 25h0m0s revalidation_factor: 2 max_object_size_bytes: 524288 + max_capture_bytes: 268435456 compressible_types: - text/html - text/javascript diff --git a/pkg/config/testdata/simple.reverseproxycache.yaml b/pkg/config/testdata/simple.reverseproxycache.yaml index 4cf36dcd7..d58771776 100644 --- a/pkg/config/testdata/simple.reverseproxycache.yaml +++ b/pkg/config/testdata/simple.reverseproxycache.yaml @@ -22,6 +22,7 @@ backends: max_ttl: 25h0m0s revalidation_factor: 2 max_object_size_bytes: 524288 + max_capture_bytes: 268435456 compressible_types: - text/html - text/javascript From 1cff399cbdb3bb70f0883ab19bac3edbb157d611 Mon Sep 17 00:00:00 2001 From: Chris Randles Date: Thu, 14 May 2026 07:55:43 -0400 Subject: [PATCH 6/6] refactor: [alb/mech] split fanout Mechanism + Variant; use types.Name Signed-off-by: Chris Randles --- pkg/backends/alb/mech/fanout/fanout.go | 19 +++++++++----- pkg/backends/alb/mech/fr/first_response.go | 2 +- pkg/backends/alb/mech/recover.go | 18 ++++++++----- pkg/backends/alb/mech/recover_metric_test.go | 26 ++++++++++++++----- .../alb/mech/tsm/time_series_merge.go | 8 +++--- pkg/observability/metrics/metrics.go | 10 +++---- 6 files changed, 55 insertions(+), 28 deletions(-) diff --git a/pkg/backends/alb/mech/fanout/fanout.go b/pkg/backends/alb/mech/fanout/fanout.go index 4e228945f..489c31c81 100644 --- a/pkg/backends/alb/mech/fanout/fanout.go +++ b/pkg/backends/alb/mech/fanout/fanout.go @@ -30,6 +30,7 @@ import ( "net/http" "github.com/trickstercache/trickster/v2/pkg/backends/alb/mech" + "github.com/trickstercache/trickster/v2/pkg/backends/alb/mech/types" "github.com/trickstercache/trickster/v2/pkg/backends/alb/pool" "github.com/trickstercache/trickster/v2/pkg/observability/logging" "github.com/trickstercache/trickster/v2/pkg/observability/logging/logger" @@ -65,9 +66,15 @@ type Result struct { // Config configures one fanout call. type Config struct { - // Mechanism is the short name ("fr", "nlm", "tsm", "tsm/avg-sum", ...) - // used in panic recovery logs and the fanout_failures_total metric. - Mechanism string + // Mechanism is the registered short name of the mechanism running the + // fanout ("fr", "nlm", "tsm"). Used in panic recovery logs and as the + // mechanism label on the fanout_failures_total metric. + Mechanism types.Name + // Variant carries optional sub-fanout context within a Mechanism, e.g. + // "avg-sum" / "avg-count" for TSM's paired weighted-avg queries. Empty + // for mechanisms with only one fanout path. Surfaces as the variant + // label on the fanout_failures_total metric. + Variant string // ConcurrencyLimit caps in-flight member calls. 0 means unlimited. ConcurrencyLimit int // MaxCaptureBytes caps each member's response body capture. 0 uses @@ -129,7 +136,7 @@ func All(ctx context.Context, parent *http.Request, targets pool.Targets, cfg Co } eg.Go(func() error { results[i].Index = i - defer mech.RecoverFanoutPanic(cfg.Mechanism, i, func() { + defer mech.RecoverFanoutPanic(cfg.Mechanism, cfg.Variant, i, func() { results[i].Failed = true results[i].Capture = nil }) @@ -138,7 +145,7 @@ func All(ctx context.Context, parent *http.Request, targets pool.Targets, cfg Co if err != nil { results[i].Failed = true results[i].Err = err - metrics.ALBFanoutFailures.WithLabelValues(cfg.Mechanism, "clone").Inc() + metrics.ALBFanoutFailures.WithLabelValues(cfg.Mechanism, cfg.Variant, "clone").Inc() return err } results[i].Request = r2 @@ -147,7 +154,7 @@ func All(ctx context.Context, parent *http.Request, targets pool.Targets, cfg Co targets[i].Handler().ServeHTTP(crw, r2) if crw.Truncated() { results[i].Failed = true - metrics.ALBFanoutFailures.WithLabelValues(cfg.Mechanism, "truncated").Inc() + metrics.ALBFanoutFailures.WithLabelValues(cfg.Mechanism, cfg.Variant, "truncated").Inc() } if cfg.OnResult != nil { cfg.OnResult(i, &results[i]) diff --git a/pkg/backends/alb/mech/fr/first_response.go b/pkg/backends/alb/mech/fr/first_response.go index 6866db9c2..a3721437b 100644 --- a/pkg/backends/alb/mech/fr/first_response.go +++ b/pkg/backends/alb/mech/fr/first_response.go @@ -156,7 +156,7 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { eg.Go(func() error { // recover so a single bad upstream doesn't crash the proxy; clear // the slot so the fallback path doesn't serve a partial capture - defer mech.RecoverFanoutPanic("fr", i, func() { captures[i] = nil }) + defer mech.RecoverFanoutPanic("fr", "", i, func() { captures[i] = nil }) r2, crw, err := fanout.PrepareClone(ctx, r, i, fanout.Config{ Mechanism: "fr", MaxCaptureBytes: h.maxCaptureBytes, diff --git a/pkg/backends/alb/mech/recover.go b/pkg/backends/alb/mech/recover.go index d7af5bc52..3555b7a8c 100644 --- a/pkg/backends/alb/mech/recover.go +++ b/pkg/backends/alb/mech/recover.go @@ -27,19 +27,23 @@ import ( // RecoverFanoutPanic recovers from a panic in an ALB fanout goroutine and // invokes onPanic so the caller can mark its per-member result as failed. // errgroup.Group.Go does not recover panics on its own, so without this a -// single bad upstream handler would crash the whole proxy. -func RecoverFanoutPanic(mech string, member int, onPanic func()) { +// single bad upstream handler would crash the whole proxy. variant carries +// optional sub-fanout context (e.g. "avg-sum" / "avg-count" for TSM's +// paired weighted-avg queries); pass "" when the mechanism has only one +// fanout path. +func RecoverFanoutPanic(mech, variant string, member int, onPanic func()) { rec := recover() if rec == nil { return } logger.Error("alb fanout member panic", logging.Pairs{ - "mech": mech, - "member": member, - "panic": rec, - "stack": string(debug.Stack()), + "mech": mech, + "variant": variant, + "member": member, + "panic": rec, + "stack": string(debug.Stack()), }) - metrics.ALBFanoutFailures.WithLabelValues(mech, "panic").Inc() + metrics.ALBFanoutFailures.WithLabelValues(mech, variant, "panic").Inc() if onPanic != nil { onPanic() } diff --git a/pkg/backends/alb/mech/recover_metric_test.go b/pkg/backends/alb/mech/recover_metric_test.go index f3cab8bd5..f20aefac8 100644 --- a/pkg/backends/alb/mech/recover_metric_test.go +++ b/pkg/backends/alb/mech/recover_metric_test.go @@ -25,28 +25,42 @@ import ( ) func TestRecoverFanoutPanicIncrementsMetric(t *testing.T) { - before := testutil.ToFloat64(metrics.ALBFanoutFailures.WithLabelValues("metric-test", "panic")) + before := testutil.ToFloat64(metrics.ALBFanoutFailures.WithLabelValues("metric-test", "", "panic")) func() { - defer mech.RecoverFanoutPanic("metric-test", 0, nil) + defer mech.RecoverFanoutPanic("metric-test", "", 0, nil) panic("forced") }() - after := testutil.ToFloat64(metrics.ALBFanoutFailures.WithLabelValues("metric-test", "panic")) + after := testutil.ToFloat64(metrics.ALBFanoutFailures.WithLabelValues("metric-test", "", "panic")) if after-before != 1 { t.Errorf("expected counter to increment by 1, before=%v after=%v", before, after) } } +func TestRecoverFanoutPanicVariantLabel(t *testing.T) { + before := testutil.ToFloat64(metrics.ALBFanoutFailures.WithLabelValues("metric-test", "avg-sum", "panic")) + + func() { + defer mech.RecoverFanoutPanic("metric-test", "avg-sum", 0, nil) + panic("forced") + }() + + after := testutil.ToFloat64(metrics.ALBFanoutFailures.WithLabelValues("metric-test", "avg-sum", "panic")) + if after-before != 1 { + t.Errorf("expected variant-labeled counter to increment by 1, before=%v after=%v", before, after) + } +} + func TestRecoverFanoutPanicNoPanicNoIncrement(t *testing.T) { - before := testutil.ToFloat64(metrics.ALBFanoutFailures.WithLabelValues("metric-test-noop", "panic")) + before := testutil.ToFloat64(metrics.ALBFanoutFailures.WithLabelValues("metric-test-noop", "", "panic")) func() { - defer mech.RecoverFanoutPanic("metric-test-noop", 0, nil) + defer mech.RecoverFanoutPanic("metric-test-noop", "", 0, nil) // no panic }() - after := testutil.ToFloat64(metrics.ALBFanoutFailures.WithLabelValues("metric-test-noop", "panic")) + after := testutil.ToFloat64(metrics.ALBFanoutFailures.WithLabelValues("metric-test-noop", "", "panic")) if after != before { t.Errorf("expected counter unchanged, before=%v after=%v", before, after) } diff --git a/pkg/backends/alb/mech/tsm/time_series_merge.go b/pkg/backends/alb/mech/tsm/time_series_merge.go index ef42670b9..043f22bc1 100644 --- a/pkg/backends/alb/mech/tsm/time_series_merge.go +++ b/pkg/backends/alb/mech/tsm/time_series_merge.go @@ -417,7 +417,7 @@ func (h *handler) serveStandard( for i, res := range results { if res.failed { hasGatherFailure = true - metrics.ALBFanoutFailures.WithLabelValues("tsm", "no_contribution").Inc() + metrics.ALBFanoutFailures.WithLabelValues("tsm", "", "no_contribution").Inc() if ts := accumulator.GetTSData(); ts != nil { if ds, ok := ts.(*dataset.DataSet); ok { ds.Warnings = append(ds.Warnings, @@ -556,7 +556,8 @@ func (h *handler) serveWeightedAvg( var eg errgroup.Group eg.Go(func() error { _, err := fanout.All(parentCtx, sumBase, hl, fanout.Config{ - Mechanism: "tsm/avg-sum", + Mechanism: "tsm", + Variant: "avg-sum", ConcurrencyLimit: limit, MaxCaptureBytes: h.maxCaptureBytes, Resources: resourcesFn, @@ -593,7 +594,8 @@ func (h *handler) serveWeightedAvg( }) eg.Go(func() error { _, err := fanout.All(parentCtx, countBase, hl, fanout.Config{ - Mechanism: "tsm/avg-count", + Mechanism: "tsm", + Variant: "avg-count", ConcurrencyLimit: limit, MaxCaptureBytes: h.maxCaptureBytes, Resources: resourcesFn, diff --git a/pkg/observability/metrics/metrics.go b/pkg/observability/metrics/metrics.go index f84213eb6..b214b35a7 100644 --- a/pkg/observability/metrics/metrics.go +++ b/pkg/observability/metrics/metrics.go @@ -322,17 +322,17 @@ var ( // ALBFanoutFailures counts per-shard failures during ALB fanout. The // reason label distinguishes silent contribution failures (e.g. bad // encoding, parse errors), explicit panics in the per-shard goroutine, - // and capture-buffer truncation. Operators can graph this to detect - // classes of upstream misbehavior that would otherwise only surface as - // the X-Trickster-Result phit header. + // and capture-buffer truncation. The variant label distinguishes + // sub-fanouts within a mechanism (e.g. TSM's paired avg-sum / avg-count + // queries); empty when the mechanism has only one fanout path. ALBFanoutFailures = prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: metricNamespace, Subsystem: albSubsystem, Name: "fanout_failures_total", - Help: "Count of per-shard failures during ALB fanout, by mechanism and reason.", + Help: "Count of per-shard failures during ALB fanout, by mechanism, variant, and reason.", }, - []string{"mechanism", "reason"}, + []string{"mechanism", "variant", "reason"}, ) )