From 15aa177440e793323379793775292508db0f6082 Mon Sep 17 00:00:00 2001 From: Sebastien Blot Date: Mon, 18 May 2026 15:22:45 +0200 Subject: [PATCH] add custom HTTP server for appsec --- pkg/acquisition/modules/appsec/config.go | 12 +- .../modules/appsec/httpserver/bench_test.go | 107 ++++++ .../modules/appsec/httpserver/body.go | 192 ++++++++++ .../modules/appsec/httpserver/body_test.go | 159 ++++++++ .../modules/appsec/httpserver/parser.go | 215 +++++++++++ .../modules/appsec/httpserver/parser_test.go | 267 ++++++++++++++ .../modules/appsec/httpserver/server.go | 346 ++++++++++++++++++ .../modules/appsec/httpserver/server_test.go | 281 ++++++++++++++ .../modules/appsec/httpserver/writer.go | 118 ++++++ pkg/acquisition/modules/appsec/source.go | 3 +- 10 files changed, 1691 insertions(+), 9 deletions(-) create mode 100644 pkg/acquisition/modules/appsec/httpserver/bench_test.go create mode 100644 pkg/acquisition/modules/appsec/httpserver/body.go create mode 100644 pkg/acquisition/modules/appsec/httpserver/body_test.go create mode 100644 pkg/acquisition/modules/appsec/httpserver/parser.go create mode 100644 pkg/acquisition/modules/appsec/httpserver/parser_test.go create mode 100644 pkg/acquisition/modules/appsec/httpserver/server.go create mode 100644 pkg/acquisition/modules/appsec/httpserver/server_test.go create mode 100644 pkg/acquisition/modules/appsec/httpserver/writer.go diff --git a/pkg/acquisition/modules/appsec/config.go b/pkg/acquisition/modules/appsec/config.go index 9bea4890631..0bd816e8af2 100644 --- a/pkg/acquisition/modules/appsec/config.go +++ b/pkg/acquisition/modules/appsec/config.go @@ -17,6 +17,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" + "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/appsec/httpserver" "github.com/crowdsecurity/crowdsec/pkg/apiclient/useragent" "github.com/crowdsecurity/crowdsec/pkg/appsec" "github.com/crowdsecurity/crowdsec/pkg/appsec/allowlists" @@ -156,16 +157,11 @@ func (w *Source) Configure(ctx context.Context, yamlConfig []byte, logger *log.E w.mux = http.NewServeMux() - w.server = &http.Server{ - Addr: w.config.ListenAddr, - Handler: w.mux, - Protocols: &http.Protocols{}, + w.server = &httpserver.Server{ + Handler: w.mux, + Logger: w.logger, } - w.server.Protocols.SetHTTP1(true) - w.server.Protocols.SetUnencryptedHTTP2(true) - w.server.Protocols.SetHTTP2(true) - w.InChan = make(chan appsec.ParsedRequest) appsecCfg := appsec.AppsecConfig{Logger: w.logger.WithField("component", "appsec_config")} diff --git a/pkg/acquisition/modules/appsec/httpserver/bench_test.go b/pkg/acquisition/modules/appsec/httpserver/bench_test.go new file mode 100644 index 00000000000..8b418c5cf52 --- /dev/null +++ b/pkg/acquisition/modules/appsec/httpserver/bench_test.go @@ -0,0 +1,107 @@ +package httpserver_test + +// End-to-end benchmark comparing the lenient appsec httpserver against +// net/http's server on the same handler, same wire format, same payload. +// +// Run: +// +// go test -bench=. -benchmem -count=3 \ +// ./pkg/acquisition/modules/appsec/httpserver/ +// +// Add -cpuprofile=cpu.out / -memprofile=mem.out to capture profiles, then +// inspect with `go tool pprof`. + +import ( + "bufio" + "io" + "net" + "net/http" + "strconv" + "strings" + "testing" + + "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/appsec/httpserver" +) + +// benchHandler mimics the cheap path of appsecHandler: drain the body and +// write a small JSON-shaped response. It deliberately does nothing else so +// the benchmark measures transport cost, not Coraza. +func benchHandler(w http.ResponseWriter, r *http.Request) { + _, _ = io.Copy(io.Discard, r.Body) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"action":"allow"}`)) +} + +type serverStarter func(net.Listener, http.Handler) (stop func()) + +func startLenient(l net.Listener, h http.Handler) func() { + s := &httpserver.Server{Handler: h} + go func() { _ = s.Serve(l) }() + return func() { _ = s.Close() } +} + +func startStd(l net.Listener, h http.Handler) func() { + s := &http.Server{Handler: h} + go func() { _ = s.Serve(l) }() + return func() { _ = s.Close() } +} + +// buildRequest produces a representative bouncer→crowdsec request: the headers +// the bouncer always sets plus a body of the requested size. +func buildRequest(bodySize int) string { + body := strings.Repeat("a", bodySize) + var b strings.Builder + b.WriteString("POST / HTTP/1.1\r\n") + b.WriteString("Host: x\r\n") + b.WriteString("X-Crowdsec-Appsec-Ip: 1.2.3.4\r\n") + b.WriteString("X-Crowdsec-Appsec-Uri: /foo\r\n") + b.WriteString("X-Crowdsec-Appsec-Verb: POST\r\n") + b.WriteString("X-Crowdsec-Appsec-Api-Key: bench\r\n") + b.WriteString("Content-Length: ") + b.WriteString(strconv.Itoa(len(body))) + b.WriteString("\r\n\r\n") + b.WriteString(body) + return b.String() +} + +func runBench(b *testing.B, start serverStarter, bodySize int) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + b.Fatal(err) + } + stop := start(l, http.HandlerFunc(benchHandler)) + defer stop() + + req := buildRequest(bodySize) + addr := l.Addr().String() + + b.ReportAllocs() + b.SetBytes(int64(len(req))) + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + c, err := net.Dial("tcp", addr) + if err != nil { + b.Fatal(err) + } + defer c.Close() + br := bufio.NewReader(c) + for pb.Next() { + if _, err := io.WriteString(c, req); err != nil { + b.Fatal(err) + } + resp, err := http.ReadResponse(br, nil) + if err != nil { + b.Fatal(err) + } + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + } + }) +} + +func BenchmarkServer_Lenient_SmallBody(b *testing.B) { runBench(b, startLenient, 100) } +func BenchmarkServer_Std_SmallBody(b *testing.B) { runBench(b, startStd, 100) } +func BenchmarkServer_Lenient_LargeBody(b *testing.B) { runBench(b, startLenient, 8*1024) } +func BenchmarkServer_Std_LargeBody(b *testing.B) { runBench(b, startStd, 8*1024) } diff --git a/pkg/acquisition/modules/appsec/httpserver/body.go b/pkg/acquisition/modules/appsec/httpserver/body.go new file mode 100644 index 00000000000..99e252e31f1 --- /dev/null +++ b/pkg/acquisition/modules/appsec/httpserver/body.go @@ -0,0 +1,192 @@ +package httpserver + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "net/http" + "strconv" + "strings" +) + +var ( + errInvalidContentLength = errors.New("invalid Content-Length") + errMalformedChunkSize = errors.New("malformed chunk size") + errMalformedChunk = errors.New("missing CRLF after chunk data") +) + +// bodyInfo describes how the request body is framed on the wire. +type bodyInfo struct { + Body io.ReadCloser + ContentLength int64 + TransferEncoding []string + Chunked bool +} + +// newBodyReader builds an io.ReadCloser bounded by Transfer-Encoding / +// Content-Length so the server never relies on the connection EOF for framing. +// +// If both Transfer-Encoding: chunked and Content-Length are present, chunked +// wins (per RFC 7230 §3.3.3) and Content-Length is dropped. When chunked is +// used, the returned reader does NOT consume the trailer section — the caller +// must drain it before reading the next request. +func newBodyReader(src *bufio.Reader, headers http.Header) (bodyInfo, error) { + te := parseTransferEncoding(headers.Get("Transfer-Encoding")) + chunked := len(te) > 0 && te[len(te)-1] == "chunked" + if chunked { + headers.Del("Content-Length") + return bodyInfo{ + Body: &chunkedBody{r: newChunkedReader(src)}, + ContentLength: -1, + TransferEncoding: te, + Chunked: true, + }, nil + } + + cls := headers.Values("Content-Length") + if len(cls) == 0 { + return bodyInfo{ + Body: http.NoBody, + ContentLength: 0, + TransferEncoding: te, + }, nil + } + cl, err := strconv.ParseInt(strings.TrimSpace(cls[0]), 10, 64) + if err != nil || cl < 0 { + return bodyInfo{}, errInvalidContentLength + } + if cl == 0 { + return bodyInfo{ + Body: http.NoBody, + ContentLength: 0, + TransferEncoding: te, + }, nil + } + return bodyInfo{ + Body: &fixedBody{r: src, remaining: cl}, + ContentLength: cl, + TransferEncoding: te, + }, nil +} + +// fixedBody is a Content-Length bounded ReadCloser. It replaces the +// io.NopCloser(io.LimitReader(...)) chain with a single allocation. +type fixedBody struct { + r *bufio.Reader + remaining int64 +} + +func (b *fixedBody) Read(p []byte) (int, error) { + if b.remaining <= 0 { + return 0, io.EOF + } + if int64(len(p)) > b.remaining { + p = p[:b.remaining] + } + n, err := b.r.Read(p) + b.remaining -= int64(n) + if err == io.EOF && b.remaining > 0 { + err = io.ErrUnexpectedEOF + } + return n, err +} + +func (b *fixedBody) Close() error { return nil } + +// chunkedBody wraps a chunkedReader in an io.ReadCloser without allocating an +// io.NopCloser indirection. +type chunkedBody struct{ r *chunkedReader } + +func (b *chunkedBody) Read(p []byte) (int, error) { return b.r.Read(p) } +func (b *chunkedBody) Close() error { return nil } + +func parseTransferEncoding(raw string) []string { + if raw == "" { + return nil + } + parts := strings.Split(raw, ",") + out := make([]string, 0, len(parts)) + for _, p := range parts { + p = strings.ToLower(strings.TrimSpace(p)) + if p != "" { + out = append(out, p) + } + } + return out +} + +// chunkedReader decodes RFC 7230 chunked transfer-encoding. It returns io.EOF +// when the zero-length terminator chunk is seen. The trailer section (optional +// trailer headers + final CRLF) is NOT consumed. +type chunkedReader struct { + r *bufio.Reader + remaining int64 + eof bool +} + +const maxChunkSizeLine = 256 + +func newChunkedReader(r *bufio.Reader) *chunkedReader { + return &chunkedReader{r: r} +} + +func (cr *chunkedReader) Read(p []byte) (int, error) { + if cr.eof { + return 0, io.EOF + } + if cr.remaining == 0 { + size, err := cr.readChunkSize() + if err != nil { + return 0, err + } + if size == 0 { + cr.eof = true + return 0, io.EOF + } + cr.remaining = size + } + toRead := min(int64(len(p)), cr.remaining) + n, err := cr.r.Read(p[:toRead]) + cr.remaining -= int64(n) + if err != nil { + return n, err + } + if cr.remaining == 0 { + if err := readCRLF(cr.r); err != nil { + return n, err + } + } + return n, nil +} + +func (cr *chunkedReader) readChunkSize() (int64, error) { + line, err := readLine(cr.r, maxChunkSizeLine) + if err != nil { + return 0, err + } + if i := bytes.IndexByte(line, ';'); i >= 0 { + line = line[:i] + } + line = bytes.TrimSpace(line) + if len(line) == 0 { + return 0, errMalformedChunkSize + } + size, err := strconv.ParseInt(string(line), 16, 64) + if err != nil || size < 0 { + return 0, fmt.Errorf("%w: %q", errMalformedChunkSize, line) + } + return size, nil +} + +func readCRLF(r *bufio.Reader) error { + var buf [2]byte + if _, err := io.ReadFull(r, buf[:2]); err != nil { + return err + } + if buf[0] != '\r' || buf[1] != '\n' { + return errMalformedChunk + } + return nil +} diff --git a/pkg/acquisition/modules/appsec/httpserver/body_test.go b/pkg/acquisition/modules/appsec/httpserver/body_test.go new file mode 100644 index 00000000000..609b1cd150f --- /dev/null +++ b/pkg/acquisition/modules/appsec/httpserver/body_test.go @@ -0,0 +1,159 @@ +package httpserver + +import ( + "bufio" + "errors" + "io" + "net/http" + "strings" + "testing" +) + +func TestNewBodyReader_ContentLength(t *testing.T) { + r := bufio.NewReader(strings.NewReader("hello, world")) + h := http.Header{"Content-Length": {"12"}} + info, err := newBodyReader(r, h) + if err != nil { + t.Fatalf("newBodyReader: %v", err) + } + if info.ContentLength != 12 || info.Chunked { + t.Errorf("info=%+v", info) + } + got, err := io.ReadAll(info.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + if string(got) != "hello, world" { + t.Errorf("got %q", got) + } +} + +func TestNewBodyReader_NoBody(t *testing.T) { + r := bufio.NewReader(strings.NewReader("")) + info, err := newBodyReader(r, http.Header{}) + if err != nil { + t.Fatalf("newBodyReader: %v", err) + } + if info.ContentLength != 0 || info.Chunked { + t.Errorf("info=%+v", info) + } + if got, _ := io.ReadAll(info.Body); len(got) != 0 { + t.Errorf("expected empty body, got %q", got) + } +} + +func TestNewBodyReader_Chunked(t *testing.T) { + body := "4\r\nWiki\r\n6\r\npedia \r\nE\r\nin \r\n\r\nchunks.\r\n0\r\n\r\n" + r := bufio.NewReader(strings.NewReader(body)) + h := http.Header{"Transfer-Encoding": {"chunked"}} + info, err := newBodyReader(r, h) + if err != nil { + t.Fatalf("newBodyReader: %v", err) + } + if !info.Chunked { + t.Errorf("expected chunked=true") + } + got, err := io.ReadAll(info.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + if string(got) != "Wikipedia in \r\n\r\nchunks." { + t.Errorf("got %q", got) + } +} + +func TestNewBodyReader_ChunkedDropsContentLength(t *testing.T) { + // Per RFC 7230 §3.3.3: if both are present, chunked wins and CL is removed. + r := bufio.NewReader(strings.NewReader("0\r\n\r\n")) + h := http.Header{ + "Transfer-Encoding": {"chunked"}, + "Content-Length": {"42"}, + } + info, err := newBodyReader(r, h) + if err != nil { + t.Fatalf("newBodyReader: %v", err) + } + if !info.Chunked { + t.Errorf("expected chunked=true") + } + if h.Get("Content-Length") != "" { + t.Errorf("Content-Length should have been removed: %v", h) + } +} + +func TestNewBodyReader_InvalidContentLength(t *testing.T) { + r := bufio.NewReader(strings.NewReader("")) + h := http.Header{"Content-Length": {"abc"}} + if _, err := newBodyReader(r, h); !errors.Is(err, errInvalidContentLength) { + t.Errorf("got err=%v, want errInvalidContentLength", err) + } +} + +func TestNewBodyReader_NegativeContentLength(t *testing.T) { + r := bufio.NewReader(strings.NewReader("")) + h := http.Header{"Content-Length": {"-1"}} + if _, err := newBodyReader(r, h); !errors.Is(err, errInvalidContentLength) { + t.Errorf("got err=%v, want errInvalidContentLength", err) + } +} + +func TestParseTransferEncoding(t *testing.T) { + cases := []struct { + in string + want []string + }{ + {"", nil}, + {"chunked", []string{"chunked"}}, + {"gzip, chunked", []string{"gzip", "chunked"}}, + {" CHUNKED ", []string{"chunked"}}, + } + for _, c := range cases { + got := parseTransferEncoding(c.in) + if len(got) != len(c.want) { + t.Errorf("parseTransferEncoding(%q) len mismatch: got %v, want %v", c.in, got, c.want) + continue + } + for i := range got { + if got[i] != c.want[i] { + t.Errorf("parseTransferEncoding(%q)[%d] = %q, want %q", c.in, i, got[i], c.want[i]) + } + } + } +} + +func TestChunkedReader_MalformedSize(t *testing.T) { + r := bufio.NewReader(strings.NewReader("ZZ\r\nhi\r\n0\r\n\r\n")) + cr := newChunkedReader(r) + buf := make([]byte, 16) + if _, err := cr.Read(buf); !errors.Is(err, errMalformedChunkSize) { + t.Errorf("got err=%v, want errMalformedChunkSize", err) + } +} + +func TestChunkedReader_MissingCRLFAfterChunk(t *testing.T) { + // Chunk says 4 bytes but no CRLF after — should error out at the boundary + // check, returning the consumed bytes along with the error. + r := bufio.NewReader(strings.NewReader("4\r\ndataNOTCRLF")) + cr := newChunkedReader(r) + buf := make([]byte, 4) + n, err := cr.Read(buf) + if n != 4 { + t.Errorf("got n=%d, want 4", n) + } + if !errors.Is(err, errMalformedChunk) { + t.Errorf("got err=%v, want errMalformedChunk", err) + } +} + +func TestChunkedReader_ChunkExt(t *testing.T) { + // chunk-size with chunk-ext: "5;name=value\r\nhello\r\n0\r\n\r\n" + r := bufio.NewReader(strings.NewReader("5;name=value\r\nhello\r\n0\r\n\r\n")) + cr := newChunkedReader(r) + got, err := io.ReadAll(cr) + if err != nil { + t.Fatalf("read: %v", err) + } + if string(got) != "hello" { + t.Errorf("got %q", got) + } +} diff --git a/pkg/acquisition/modules/appsec/httpserver/parser.go b/pkg/acquisition/modules/appsec/httpserver/parser.go new file mode 100644 index 00000000000..2c656f52ad0 --- /dev/null +++ b/pkg/acquisition/modules/appsec/httpserver/parser.go @@ -0,0 +1,215 @@ +package httpserver + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "net/http" + "net/textproto" + "strconv" + "strings" +) + +var ( + ErrLineTooLong = errors.New("line too long") + ErrHeadersTooLarge = errors.New("headers too large") + ErrTooManyHeaders = errors.New("too many headers") + ErrMalformedRequestLine = errors.New("malformed request line") +) + +const ( + DefaultMaxLineSize = 16 * 1024 + DefaultMaxHeaderBytes = 1 * 1024 * 1024 + DefaultMaxHeaderCount = 128 +) + +// Limits bounds parser memory usage. Zero fields fall back to the package defaults. +type Limits struct { + MaxLineSize int + MaxHeaderBytes int + MaxHeaderCount int +} + +func (l Limits) withDefaults() Limits { + if l.MaxLineSize <= 0 { + l.MaxLineSize = DefaultMaxLineSize + } + if l.MaxHeaderBytes <= 0 { + l.MaxHeaderBytes = DefaultMaxHeaderBytes + } + if l.MaxHeaderCount <= 0 { + l.MaxHeaderCount = DefaultMaxHeaderCount + } + return l +} + +// readLine reads bytes through the next \n. The terminating \r\n or bare \n is +// stripped from the returned slice. Returns io.EOF only when no bytes were read. +// +// The returned slice aliases the bufio.Reader's internal buffer and is only +// valid until the next read from r. Callers that need to retain bytes across +// further reads must copy them (e.g. via string(line) when storing in a map). +// +// The bufio.Reader's buffer must be sized to hold the longest acceptable line; +// callers configure this via bufio.NewReaderSize. +func readLine(r *bufio.Reader, max int) ([]byte, error) { + line, err := r.ReadSlice('\n') + if err != nil { + if errors.Is(err, bufio.ErrBufferFull) { + return nil, ErrLineTooLong + } + if errors.Is(err, io.EOF) && len(line) > 0 { + return nil, io.ErrUnexpectedEOF + } + return nil, err + } + if len(line) > max+1 { + return nil, ErrLineTooLong + } + if n := len(line); n > 0 && line[n-1] == '\n' { + line = line[:n-1] + } + if n := len(line); n > 0 && line[n-1] == '\r' { + line = line[:n-1] + } + return line, nil +} + +// requestLine holds the parsed components of an HTTP request line. +type requestLine struct { + Method string + Target string + Proto string + ProtoMajor int + ProtoMinor int +} + +// readRequestLine reads and parses "METHOD SP URI SP HTTP/X.Y" from r. +// Lenient: skips leading blank lines, accepts any non-SP/CR/LF bytes in URI, +// records the proto string even when its format is unrecognized. +func readRequestLine(r *bufio.Reader, maxLine int) (requestLine, error) { + var ( + line []byte + err error + ) + for { + line, err = readLine(r, maxLine) + if err != nil { + return requestLine{}, err + } + if len(line) > 0 { + break + } + } + first := bytes.IndexByte(line, ' ') + if first <= 0 { + return requestLine{}, fmt.Errorf("%w: %q", ErrMalformedRequestLine, line) + } + rest := line[first+1:] + last := bytes.LastIndexByte(rest, ' ') + if last <= 0 { + return requestLine{}, fmt.Errorf("%w: %q", ErrMalformedRequestLine, line) + } + rl := requestLine{ + Method: string(line[:first]), + Target: string(rest[:last]), + Proto: string(rest[last+1:]), + } + if major, minor, ok := parseHTTPVersion(rl.Proto); ok { + rl.ProtoMajor = major + rl.ProtoMinor = minor + } + return rl, nil +} + +func parseHTTPVersion(proto string) (major, minor int, ok bool) { + rest, found := strings.CutPrefix(proto, "HTTP/") + if !found { + return 0, 0, false + } + majorStr, minorStr, found := strings.Cut(rest, ".") + if !found { + return 0, 0, false + } + m, err := strconv.Atoi(majorStr) + if err != nil || m < 0 { + return 0, 0, false + } + n, err := strconv.Atoi(minorStr) + if err != nil || n < 0 { + return 0, 0, false + } + return m, n, true +} + +// readHeaders reads header lines until an empty line. Lenient: any byte except +// CR/LF is allowed in values (including control chars). Invalid names (non-token +// bytes) are skipped silently. Obsolete line folding is dropped. +func readHeaders(r *bufio.Reader, limits Limits) (http.Header, error) { + limits = limits.withDefaults() + // Pre-size to a typical header count to avoid map growth allocations. + h := make(http.Header, 16) + totalBytes := 0 + count := 0 + for { + line, err := readLine(r, limits.MaxLineSize) + if err != nil { + return nil, err + } + if len(line) == 0 { + return h, nil + } + totalBytes += len(line) + 2 + if totalBytes > limits.MaxHeaderBytes { + return nil, ErrHeadersTooLarge + } + if line[0] == ' ' || line[0] == '\t' { + continue + } + colon := bytes.IndexByte(line, ':') + if colon <= 0 { + continue + } + nameBytes := line[:colon] + if !isValidHeaderName(nameBytes) { + continue + } + count++ + if count > limits.MaxHeaderCount { + return nil, ErrTooManyHeaders + } + name := textproto.CanonicalMIMEHeaderKey(string(nameBytes)) + value := bytes.TrimLeft(line[colon+1:], " \t") + value = bytes.TrimRight(value, " \t") + h[name] = append(h[name], string(value)) + } +} + +func isValidHeaderName(name []byte) bool { + if len(name) == 0 { + return false + } + for _, b := range name { + if !isTokenByte(b) { + return false + } + } + return true +} + +// isTokenByte reports whether b is a valid RFC 7230 token byte. +func isTokenByte(b byte) bool { + switch { + case b >= 'a' && b <= 'z', + b >= 'A' && b <= 'Z', + b >= '0' && b <= '9': + return true + } + switch b { + case '!', '#', '$', '%', '&', '\'', '*', '+', '-', '.', '^', '_', '`', '|', '~': + return true + } + return false +} diff --git a/pkg/acquisition/modules/appsec/httpserver/parser_test.go b/pkg/acquisition/modules/appsec/httpserver/parser_test.go new file mode 100644 index 00000000000..a19ef0f25c6 --- /dev/null +++ b/pkg/acquisition/modules/appsec/httpserver/parser_test.go @@ -0,0 +1,267 @@ +package httpserver + +import ( + "bufio" + "errors" + "io" + "strings" + "testing" +) + +func bufReader(s string) *bufio.Reader { + return bufio.NewReader(strings.NewReader(s)) +} + +func TestReadLine_CRLF(t *testing.T) { + r := bufReader("hello\r\n") + got, err := readLine(r, 1024) + if err != nil { + t.Fatalf("readLine: %v", err) + } + if string(got) != "hello" { + t.Errorf("got %q, want %q", got, "hello") + } +} + +func TestReadLine_BareLF(t *testing.T) { + r := bufReader("hello\n") + got, err := readLine(r, 1024) + if err != nil { + t.Fatalf("readLine: %v", err) + } + if string(got) != "hello" { + t.Errorf("got %q, want %q", got, "hello") + } +} + +func TestReadLine_Empty(t *testing.T) { + r := bufReader("\r\n") + got, err := readLine(r, 1024) + if err != nil { + t.Fatalf("readLine: %v", err) + } + if len(got) != 0 { + t.Errorf("got %q, want empty", got) + } +} + +func TestReadLine_TooLong(t *testing.T) { + r := bufReader("xxxxxxxxxxxxxxxx\r\n") + if _, err := readLine(r, 4); !errors.Is(err, ErrLineTooLong) { + t.Errorf("got err=%v, want ErrLineTooLong", err) + } +} + +func TestReadLine_EOFNoData(t *testing.T) { + r := bufReader("") + if _, err := readLine(r, 1024); !errors.Is(err, io.EOF) { + t.Errorf("got err=%v, want io.EOF", err) + } +} + +func TestReadLine_UnexpectedEOF(t *testing.T) { + r := bufReader("partial") + _, err := readLine(r, 1024) + if !errors.Is(err, io.ErrUnexpectedEOF) { + t.Errorf("got err=%v, want io.ErrUnexpectedEOF", err) + } +} + +func TestReadRequestLine_Standard(t *testing.T) { + r := bufReader("GET /foo HTTP/1.1\r\n") + rl, err := readRequestLine(r, 1024) + if err != nil { + t.Fatalf("readRequestLine: %v", err) + } + if rl.Method != "GET" || rl.Target != "/foo" || rl.Proto != "HTTP/1.1" { + t.Errorf("got %+v", rl) + } + if rl.ProtoMajor != 1 || rl.ProtoMinor != 1 { + t.Errorf("proto version got %d.%d, want 1.1", rl.ProtoMajor, rl.ProtoMinor) + } +} + +func TestReadRequestLine_SkipsBlankLines(t *testing.T) { + r := bufReader("\r\n\r\nPOST /a HTTP/1.0\r\n") + rl, err := readRequestLine(r, 1024) + if err != nil { + t.Fatalf("readRequestLine: %v", err) + } + if rl.Method != "POST" || rl.Target != "/a" { + t.Errorf("got %+v", rl) + } +} + +func TestReadRequestLine_URIWithWeirdBytes(t *testing.T) { + // URI contains bytes net/http would reject (e.g. raw 0x01) + r := bufReader("GET /foo\x01bar HTTP/1.1\r\n") + rl, err := readRequestLine(r, 1024) + if err != nil { + t.Fatalf("readRequestLine: %v", err) + } + if rl.Target != "/foo\x01bar" { + t.Errorf("target got %q, want with embedded \\x01", rl.Target) + } +} + +func TestReadRequestLine_UnknownProto(t *testing.T) { + r := bufReader("GET / WAT/9.9\r\n") + rl, err := readRequestLine(r, 1024) + if err != nil { + t.Fatalf("readRequestLine: %v", err) + } + if rl.Proto != "WAT/9.9" { + t.Errorf("proto got %q", rl.Proto) + } + if rl.ProtoMajor != 0 || rl.ProtoMinor != 0 { + t.Errorf("expected zero version for unrecognized proto, got %d.%d", rl.ProtoMajor, rl.ProtoMinor) + } +} + +func TestReadRequestLine_Malformed(t *testing.T) { + for _, in := range []string{ + "GET\r\n", + " HTTP/1.1\r\n", + "GET\r\n", + } { + r := bufReader(in) + if _, err := readRequestLine(r, 1024); !errors.Is(err, ErrMalformedRequestLine) { + t.Errorf("input %q: got err=%v, want ErrMalformedRequestLine", in, err) + } + } +} + +func TestReadHeaders_Basic(t *testing.T) { + r := bufReader("Host: example\r\nX-Foo: bar\r\n\r\n") + h, err := readHeaders(r, Limits{}) + if err != nil { + t.Fatalf("readHeaders: %v", err) + } + if h.Get("Host") != "example" || h.Get("X-Foo") != "bar" { + t.Errorf("got %v", h) + } +} + +func TestReadHeaders_ControlCharsInValue(t *testing.T) { + // This is the key acceptance test: header values with control characters + // must be preserved. net/http would reject this with 400. + r := bufReader("X-Evil: ab\x01cd\x7fef\r\n\r\n") + h, err := readHeaders(r, Limits{}) + if err != nil { + t.Fatalf("readHeaders: %v", err) + } + if got := h.Get("X-Evil"); got != "ab\x01cd\x7fef" { + t.Errorf("value lost control chars: got %q", got) + } +} + +func TestReadHeaders_BareLF(t *testing.T) { + r := bufReader("A: 1\nB: 2\n\n") + h, err := readHeaders(r, Limits{}) + if err != nil { + t.Fatalf("readHeaders: %v", err) + } + if h.Get("A") != "1" || h.Get("B") != "2" { + t.Errorf("got %v", h) + } +} + +func TestReadHeaders_MultipleValues(t *testing.T) { + r := bufReader("Set-Cookie: a=1\r\nSet-Cookie: b=2\r\n\r\n") + h, err := readHeaders(r, Limits{}) + if err != nil { + t.Fatalf("readHeaders: %v", err) + } + if got := h.Values("Set-Cookie"); len(got) != 2 || got[0] != "a=1" || got[1] != "b=2" { + t.Errorf("got %v", got) + } +} + +func TestReadHeaders_SkipInvalidName(t *testing.T) { + // "X Foo" has a space — invalid token. Should be silently dropped. + r := bufReader("X Foo: bad\r\nX-Good: ok\r\n\r\n") + h, err := readHeaders(r, Limits{}) + if err != nil { + t.Fatalf("readHeaders: %v", err) + } + if h.Get("X-Good") != "ok" { + t.Errorf("X-Good missing: %v", h) + } + if h.Get("X Foo") != "" { + t.Errorf("invalid header should have been dropped") + } +} + +func TestReadHeaders_DropObsFold(t *testing.T) { + r := bufReader("X-Foo: a\r\n b\r\n\r\n") + h, err := readHeaders(r, Limits{}) + if err != nil { + t.Fatalf("readHeaders: %v", err) + } + if h.Get("X-Foo") != "a" { + t.Errorf("got %q, continuation line should have been dropped", h.Get("X-Foo")) + } +} + +func TestReadHeaders_TooMany(t *testing.T) { + var b strings.Builder + for i := range 130 { + b.WriteString("X-H") + b.WriteByte(byte('a' + i%26)) + b.WriteString(": v\r\n") + } + b.WriteString("\r\n") + r := bufReader(b.String()) + if _, err := readHeaders(r, Limits{MaxHeaderCount: 100}); !errors.Is(err, ErrTooManyHeaders) { + t.Errorf("got err=%v, want ErrTooManyHeaders", err) + } +} + +func TestReadHeaders_TooLarge(t *testing.T) { + // 200 bytes of header data with a 64-byte total budget. + var b strings.Builder + for range 10 { + b.WriteString("X-Foo: ") + b.WriteString(strings.Repeat("a", 30)) + b.WriteString("\r\n") + } + b.WriteString("\r\n") + r := bufReader(b.String()) + if _, err := readHeaders(r, Limits{MaxHeaderBytes: 64}); !errors.Is(err, ErrHeadersTooLarge) { + t.Errorf("got err=%v, want ErrHeadersTooLarge", err) + } +} + +func TestReadHeaders_TrimsOWS(t *testing.T) { + r := bufReader("X-Foo: bar \r\n\r\n") + h, err := readHeaders(r, Limits{}) + if err != nil { + t.Fatalf("readHeaders: %v", err) + } + if h.Get("X-Foo") != "bar" { + t.Errorf("got %q, want %q", h.Get("X-Foo"), "bar") + } +} + +func TestParseHTTPVersion(t *testing.T) { + cases := []struct { + in string + major, minor int + ok bool + }{ + {"HTTP/1.1", 1, 1, true}, + {"HTTP/1.0", 1, 0, true}, + {"HTTP/2.0", 2, 0, true}, + {"HTTP/9.9", 9, 9, true}, + {"WAT/1.1", 0, 0, false}, + {"HTTP/", 0, 0, false}, + {"HTTP/abc", 0, 0, false}, + {"HTTP/1", 0, 0, false}, + } + for _, c := range cases { + maj, min, ok := parseHTTPVersion(c.in) + if ok != c.ok || maj != c.major || min != c.minor { + t.Errorf("parseHTTPVersion(%q) = (%d, %d, %v), want (%d, %d, %v)", c.in, maj, min, ok, c.major, c.minor, c.ok) + } + } +} diff --git a/pkg/acquisition/modules/appsec/httpserver/server.go b/pkg/acquisition/modules/appsec/httpserver/server.go new file mode 100644 index 00000000000..ddb9857f0b3 --- /dev/null +++ b/pkg/acquisition/modules/appsec/httpserver/server.go @@ -0,0 +1,346 @@ +// This server is NOT a general-purpose HTTP server. It speaks only HTTP/1.x, +// trusts its peers (deployments are expected to put it behind a bouncer), and +// makes minimal effort to defend against pathological clients beyond enforcing +// configurable size limits. +// Its goal is to be able to accept almost any request that looks like HTTP, even if it's technically malformed +package httpserver + +import ( + "bufio" + "context" + "crypto/tls" + "errors" + "io" + "net" + "net/http" + "net/url" + "strings" + "sync" + "sync/atomic" + "time" + + log "github.com/sirupsen/logrus" +) + +// Server is a minimal lenient HTTP/1.x server. The zero value is usable; set +// Handler before calling Serve / ServeTLS. +type Server struct { + Handler http.Handler + + // TLSConfig is used by ServeTLS. If nil, a fresh tls.Config is created. + TLSConfig *tls.Config + + // ReadHeaderTimeout bounds the time spent reading the request line and headers. + // Zero disables the timeout (matching net/http's default). + ReadHeaderTimeout time.Duration + + // IdleTimeout bounds the wait for the next request on a keep-alive connection. + // Zero disables the timeout. + IdleTimeout time.Duration + + // WriteTimeout bounds the time spent writing the response. Zero disables it. + WriteTimeout time.Duration + + // Limits bounds request line / header sizes. Zero values fall back to defaults. + Limits Limits + + // Logger receives diagnostic messages. May be nil. + Logger *log.Entry + + mu sync.Mutex + listeners map[net.Listener]struct{} + activeConns map[net.Conn]struct{} + inShutdown atomic.Bool +} + +// Serve accepts connections on l and serves each one in its own goroutine. +// Returns http.ErrServerClosed after a successful Shutdown. +func (s *Server) Serve(l net.Listener) error { + return s.serve(l, false) +} + +// ServeTLS wraps l in a TLS listener using TLSConfig and the provided certificate +// pair, then serves on it. +func (s *Server) ServeTLS(l net.Listener, certFile, keyFile string) error { + cfg := s.TLSConfig + if cfg == nil { + cfg = &tls.Config{} + } else { + cfg = cfg.Clone() + } + if certFile != "" || keyFile != "" { + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return err + } + cfg.Certificates = append(cfg.Certificates, cert) + } + return s.serve(tls.NewListener(l, cfg), true) +} + +func (s *Server) serve(l net.Listener, isTLS bool) error { + if s.inShutdown.Load() { + return http.ErrServerClosed + } + s.trackListener(l, true) + defer s.trackListener(l, false) + + var tempDelay time.Duration + for { + conn, err := l.Accept() + if err != nil { + if s.inShutdown.Load() { + return http.ErrServerClosed + } + var ne net.Error + if errors.As(err, &ne) && ne.Timeout() { + if tempDelay == 0 { + tempDelay = 5 * time.Millisecond + } else { + tempDelay *= 2 + } + if tempDelay > time.Second { + tempDelay = time.Second + } + s.logf("accept temporary error: %v; retrying in %v", err, tempDelay) + time.Sleep(tempDelay) + continue + } + return err + } + tempDelay = 0 + go s.serveConn(conn, isTLS) + } +} + +func (s *Server) serveConn(conn net.Conn, isTLS bool) { + s.trackConn(conn, true) + defer s.trackConn(conn, false) + defer conn.Close() + + limits := s.Limits.withDefaults() + // Size the read buffer to fit the longest acceptable line so readLine can + // use bufio.ReadSlice (zero-alloc) without hitting ErrBufferFull. + br := bufio.NewReaderSize(conn, limits.MaxLineSize+512) + bw := bufio.NewWriter(conn) + first := true + + for { + if s.inShutdown.Load() { + return + } + + // Bound the wait for the next request byte: ReadHeaderTimeout on the + // first request, IdleTimeout (falling back to ReadHeaderTimeout) for + // subsequent ones on the same keep-alive connection. + wait := s.ReadHeaderTimeout + if !first && s.IdleTimeout > 0 { + wait = s.IdleTimeout + } + if wait > 0 { + _ = conn.SetReadDeadline(time.Now().Add(wait)) + } + if _, err := br.Peek(1); err != nil { + return + } + first = false + + // Now bound the time to actually parse the headers. + if s.ReadHeaderTimeout > 0 { + _ = conn.SetReadDeadline(time.Now().Add(s.ReadHeaderTimeout)) + } + + req, info, closeAfter, err := s.readRequest(br, conn, isTLS, limits) + if err != nil { + if !errors.Is(err, io.EOF) { + s.logf("read request: %v", err) + writeBadRequest(bw) + } + return + } + + // Clear the deadline; the handler may set its own via http.NewResponseController. + _ = conn.SetReadDeadline(time.Time{}) + if s.WriteTimeout > 0 { + _ = conn.SetWriteDeadline(time.Now().Add(s.WriteTimeout)) + } + + rw := newResponseWriter(conn, bw, req.Proto) + rw.closeConn = closeAfter + + s.Handler.ServeHTTP(rw, req) + + // Drain any unread body so the next request lands on a clean boundary. + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + _ = req.Body.Close() + } + + // For chunked requests we still need to consume the trailer section + // (optional trailer headers + final CRLF) that http.NewChunkedReader + // leaves on the wire. + if info.Chunked { + if err := drainChunkedTrailer(br, limits); err != nil { + s.logf("drain chunked trailer: %v", err) + _ = rw.flush() + return + } + } + + if err := rw.flush(); err != nil { + s.logf("write response: %v", err) + return + } + + if closeAfter { + return + } + } +} + +func (s *Server) readRequest(br *bufio.Reader, conn net.Conn, isTLS bool, limits Limits) (*http.Request, bodyInfo, bool, error) { + rl, err := readRequestLine(br, limits.MaxLineSize) + if err != nil { + return nil, bodyInfo{}, true, err + } + headers, err := readHeaders(br, limits) + if err != nil { + return nil, bodyInfo{}, true, err + } + info, err := newBodyReader(br, headers) + if err != nil { + return nil, bodyInfo{}, true, err + } + + req := &http.Request{ + Method: rl.Method, + Proto: rl.Proto, + ProtoMajor: rl.ProtoMajor, + ProtoMinor: rl.ProtoMinor, + Header: headers, + Body: info.Body, + ContentLength: info.ContentLength, + TransferEncoding: info.TransferEncoding, + Host: headers.Get("Host"), + RequestURI: rl.Target, + RemoteAddr: conn.RemoteAddr().String(), + } + if u, err := url.ParseRequestURI(rl.Target); err == nil { + req.URL = u + } else if u, err := url.Parse(rl.Target); err == nil { + req.URL = u + } else { + req.URL = &url.URL{Path: rl.Target} + } + + if isTLS { + if tlsConn, ok := conn.(*tls.Conn); ok { + state := tlsConn.ConnectionState() + req.TLS = &state + } + } + + // (*http.Request).Context() returns context.Background() when the unexported + // ctx field is nil, so we skip WithContext entirely — it would clone the + // whole request struct for no benefit. + + return req, info, shouldClose(rl.ProtoMajor, rl.ProtoMinor, headers), nil +} + +func shouldClose(major, minor int, h http.Header) bool { + connHdr := strings.ToLower(h.Get("Connection")) + if major < 1 || (major == 1 && minor < 1) { + return !strings.Contains(connHdr, "keep-alive") + } + return strings.Contains(connHdr, "close") +} + +func writeBadRequest(bw *bufio.Writer) { + _, _ = bw.WriteString("HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\nConnection: close\r\n\r\n") + _ = bw.Flush() +} + +func drainChunkedTrailer(br *bufio.Reader, limits Limits) error { + for { + line, err := readLine(br, limits.MaxLineSize) + if err != nil { + return err + } + if len(line) == 0 { + return nil + } + } +} + +// Shutdown stops accepting new connections and waits for active connections to +// finish, up to ctx's deadline. +func (s *Server) Shutdown(ctx context.Context) error { + s.inShutdown.Store(true) + s.mu.Lock() + for l := range s.listeners { + _ = l.Close() + } + s.mu.Unlock() + + ticker := time.NewTicker(50 * time.Millisecond) + defer ticker.Stop() + for { + s.mu.Lock() + n := len(s.activeConns) + s.mu.Unlock() + if n == 0 { + return nil + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + } + } +} + +// Close immediately closes all listeners and active connections without waiting. +func (s *Server) Close() error { + s.inShutdown.Store(true) + s.mu.Lock() + defer s.mu.Unlock() + for l := range s.listeners { + _ = l.Close() + } + for c := range s.activeConns { + _ = c.Close() + } + return nil +} + +func (s *Server) trackListener(l net.Listener, add bool) { + s.mu.Lock() + defer s.mu.Unlock() + if s.listeners == nil { + s.listeners = make(map[net.Listener]struct{}) + } + if add { + s.listeners[l] = struct{}{} + } else { + delete(s.listeners, l) + } +} + +func (s *Server) trackConn(c net.Conn, add bool) { + s.mu.Lock() + defer s.mu.Unlock() + if s.activeConns == nil { + s.activeConns = make(map[net.Conn]struct{}) + } + if add { + s.activeConns[c] = struct{}{} + } else { + delete(s.activeConns, c) + } +} + +func (s *Server) logf(format string, args ...any) { + if s.Logger != nil { + s.Logger.Debugf(format, args...) + } +} diff --git a/pkg/acquisition/modules/appsec/httpserver/server_test.go b/pkg/acquisition/modules/appsec/httpserver/server_test.go new file mode 100644 index 00000000000..067da90cded --- /dev/null +++ b/pkg/acquisition/modules/appsec/httpserver/server_test.go @@ -0,0 +1,281 @@ +package httpserver + +import ( + "bufio" + "context" + "io" + "net" + "net/http" + "sync/atomic" + "testing" + "time" +) + +// startServer launches a Server on a loopback listener and returns the address +// plus a teardown function. +func startServer(t *testing.T, h http.Handler) (string, func()) { + t.Helper() + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + srv := &Server{ + Handler: h, + ReadHeaderTimeout: 2 * time.Second, + IdleTimeout: 2 * time.Second, + } + serveErr := make(chan error, 1) + go func() { + serveErr <- srv.Serve(l) + }() + addr := l.Addr().String() + return addr, func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _ = srv.Shutdown(ctx) + select { + case <-serveErr: + case <-time.After(2 * time.Second): + t.Errorf("server did not return from Serve in time") + } + } +} + +// dial returns a connected client with a bufio.Reader for reading the response. +func dial(t *testing.T, addr string) (net.Conn, *bufio.Reader) { + t.Helper() + c, err := net.Dial("tcp", addr) + if err != nil { + t.Fatalf("dial: %v", err) + } + _ = c.SetDeadline(time.Now().Add(3 * time.Second)) + return c, bufio.NewReader(c) +} + +func TestServer_BasicGET(t *testing.T) { + addr, stop := startServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" || r.URL.Path != "/foo" { + t.Errorf("unexpected request: %s %s", r.Method, r.URL.Path) + } + _, _ = w.Write([]byte("ok")) + })) + defer stop() + + c, br := dial(t, addr) + defer c.Close() + _, _ = io.WriteString(c, "GET /foo HTTP/1.1\r\nHost: x\r\nConnection: close\r\n\r\n") + + resp, err := http.ReadResponse(br, nil) + if err != nil { + t.Fatalf("read response: %v", err) + } + if resp.StatusCode != 200 { + t.Errorf("status %d", resp.StatusCode) + } + body, _ := io.ReadAll(resp.Body) + if string(body) != "ok" { + t.Errorf("body %q", body) + } +} + +// TestServer_AcceptsControlCharsInHeader is the key acceptance test: a header +// value containing a control character (0x01) — which net/http rejects with a +// 400 — must be delivered to the handler unchanged. +func TestServer_AcceptsControlCharsInHeader(t *testing.T) { + var seen atomic.Value + addr, stop := startServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + seen.Store(r.Header.Get("X-Evil")) + _, _ = w.Write([]byte("ok")) + })) + defer stop() + + c, br := dial(t, addr) + defer c.Close() + _, _ = io.WriteString(c, "GET / HTTP/1.1\r\nHost: x\r\nX-Evil: ab\x01cd\r\nConnection: close\r\n\r\n") + + resp, err := http.ReadResponse(br, nil) + if err != nil { + t.Fatalf("read response: %v", err) + } + if resp.StatusCode != 200 { + t.Errorf("status %d", resp.StatusCode) + } + got, _ := seen.Load().(string) + if got != "ab\x01cd" { + t.Errorf("handler received %q, want %q", got, "ab\x01cd") + } +} + +func TestServer_KeepAlive(t *testing.T) { + var count atomic.Int32 + addr, stop := startServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count.Add(1) + _, _ = w.Write([]byte("ok")) + })) + defer stop() + + c, br := dial(t, addr) + defer c.Close() + + // Send two requests on the same connection. + for i := 0; i < 2; i++ { + _, _ = io.WriteString(c, "GET / HTTP/1.1\r\nHost: x\r\n\r\n") + resp, err := http.ReadResponse(br, nil) + if err != nil { + t.Fatalf("request %d: %v", i, err) + } + _, _ = io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } + if got := count.Load(); got != 2 { + t.Errorf("handler called %d times, want 2", got) + } +} + +func TestServer_ConnectionClose(t *testing.T) { + addr, stop := startServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("bye")) + })) + defer stop() + + c, br := dial(t, addr) + defer c.Close() + _, _ = io.WriteString(c, "GET / HTTP/1.1\r\nHost: x\r\nConnection: close\r\n\r\n") + resp, err := http.ReadResponse(br, nil) + if err != nil { + t.Fatalf("read response: %v", err) + } + // http.ReadResponse strips Connection into resp.Close — check there. + if !resp.Close { + t.Errorf("resp.Close = false, want true") + } + _, _ = io.Copy(io.Discard, resp.Body) + one := make([]byte, 1) + _ = c.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) + if _, err := c.Read(one); err == nil { + t.Errorf("expected error reading from closed connection") + } +} + +func TestServer_HTTP10DefaultsClose(t *testing.T) { + addr, stop := startServer(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("ok")) + })) + defer stop() + + c, br := dial(t, addr) + defer c.Close() + _, _ = io.WriteString(c, "GET / HTTP/1.0\r\nHost: x\r\n\r\n") + resp, err := http.ReadResponse(br, nil) + if err != nil { + t.Fatalf("read response: %v", err) + } + if resp.StatusCode != 200 { + t.Errorf("status %d", resp.StatusCode) + } + _, _ = io.Copy(io.Discard, resp.Body) + // HTTP/1.0 without keep-alive: server should close. + one := make([]byte, 1) + _ = c.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) + if _, err := c.Read(one); err == nil { + t.Errorf("expected close after HTTP/1.0 response") + } +} + +func TestServer_PostWithBody(t *testing.T) { + var got []byte + addr, stop := startServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("read body: %v", err) + } + got = b + _, _ = w.Write([]byte("ok")) + })) + defer stop() + + c, br := dial(t, addr) + defer c.Close() + body := "hello body" + _, _ = io.WriteString(c, "POST / HTTP/1.1\r\nHost: x\r\nContent-Length: 10\r\nConnection: close\r\n\r\n"+body) + resp, err := http.ReadResponse(br, nil) + if err != nil { + t.Fatalf("read response: %v", err) + } + _, _ = io.Copy(io.Discard, resp.Body) + if string(got) != body { + t.Errorf("handler got body %q, want %q", got, body) + } +} + +func TestServer_PostChunkedBody(t *testing.T) { + var got []byte + addr, stop := startServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("read body: %v", err) + } + got = b + _, _ = w.Write([]byte("ok")) + })) + defer stop() + + c, br := dial(t, addr) + defer c.Close() + chunked := "5\r\nhello\r\n6\r\n world\r\n0\r\n\r\n" + _, _ = io.WriteString(c, "POST / HTTP/1.1\r\nHost: x\r\nTransfer-Encoding: chunked\r\nConnection: close\r\n\r\n"+chunked) + resp, err := http.ReadResponse(br, nil) + if err != nil { + t.Fatalf("read response: %v", err) + } + _, _ = io.Copy(io.Discard, resp.Body) + if string(got) != "hello world" { + t.Errorf("handler got %q", got) + } +} + +func TestServer_BadRequestLine(t *testing.T) { + addr, stop := startServer(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + t.Errorf("handler should not be called for malformed request") + })) + defer stop() + + c, br := dial(t, addr) + defer c.Close() + _, _ = io.WriteString(c, "GARBAGE\r\n\r\n") + resp, err := http.ReadResponse(br, nil) + if err != nil { + t.Fatalf("read response: %v", err) + } + if resp.StatusCode != 400 { + t.Errorf("status %d, want 400", resp.StatusCode) + } +} + +func TestServer_Shutdown(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + srv := &Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("ok")) + }), + } + done := make(chan error, 1) + go func() { done <- srv.Serve(l) }() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if err := srv.Shutdown(ctx); err != nil { + t.Fatalf("Shutdown: %v", err) + } + select { + case err := <-done: + if err != http.ErrServerClosed { + t.Errorf("Serve returned %v, want http.ErrServerClosed", err) + } + case <-time.After(2 * time.Second): + t.Errorf("Serve did not return after Shutdown") + } +} diff --git a/pkg/acquisition/modules/appsec/httpserver/writer.go b/pkg/acquisition/modules/appsec/httpserver/writer.go new file mode 100644 index 00000000000..4a6c026d76f --- /dev/null +++ b/pkg/acquisition/modules/appsec/httpserver/writer.go @@ -0,0 +1,118 @@ +package httpserver + +import ( + "bufio" + "bytes" + "net" + "net/http" + "strconv" + "time" +) + +// responseWriter is a minimal http.ResponseWriter that buffers the response +// body in memory and emits a single HTTP/1.x message on flush. Body buffering +// is fine for the appsec handler — responses are small JSON blobs. +// +// It implements SetReadDeadline / SetWriteDeadline so the existing handler can +// keep using http.NewResponseController(rw).SetReadDeadline(...) to bound the +// body read. +type responseWriter struct { + conn net.Conn + bw *bufio.Writer + header http.Header // lazy: nil until Header() is called + body bytes.Buffer + status int + headerSent bool + proto string + closeConn bool +} + +func newResponseWriter(conn net.Conn, bw *bufio.Writer, proto string) *responseWriter { + if proto != "HTTP/1.0" { + proto = "HTTP/1.1" + } + return &responseWriter{ + conn: conn, + bw: bw, + proto: proto, + } +} + +func (w *responseWriter) Header() http.Header { + if w.header == nil { + w.header = make(http.Header, 4) + } + return w.header +} + +func (w *responseWriter) WriteHeader(code int) { + if w.headerSent { + return + } + w.status = code + w.headerSent = true +} + +func (w *responseWriter) Write(p []byte) (int, error) { + if !w.headerSent { + w.WriteHeader(http.StatusOK) + } + return w.body.Write(p) +} + +// SetReadDeadline lets http.NewResponseController set a body-read deadline. +func (w *responseWriter) SetReadDeadline(t time.Time) error { + return w.conn.SetReadDeadline(t) +} + +// SetWriteDeadline mirrors SetReadDeadline for completeness. +func (w *responseWriter) SetWriteDeadline(t time.Time) error { + return w.conn.SetWriteDeadline(t) +} + +// flush writes the status line, headers, and buffered body to the connection. +// Server-controlled headers (Content-Length, Connection) are written directly +// to avoid map operations; handler-set duplicates are skipped. +func (w *responseWriter) flush() error { + if !w.headerSent { + w.WriteHeader(http.StatusOK) + } + status := w.status + if status == 0 { + status = http.StatusOK + } + bw := w.bw + var nbuf [20]byte + + bw.WriteString(w.proto) + bw.WriteByte(' ') + bw.Write(strconv.AppendInt(nbuf[:0], int64(status), 10)) + bw.WriteByte(' ') + reason := http.StatusText(status) + if reason == "" { + reason = "status" + } + bw.WriteString(reason) + bw.WriteString("\r\n") + + bw.WriteString("Content-Length: ") + bw.Write(strconv.AppendInt(nbuf[:0], int64(w.body.Len()), 10)) + bw.WriteString("\r\n") + if w.closeConn { + bw.WriteString("Connection: close\r\n") + } + for name, vals := range w.header { + if name == "Content-Length" || name == "Connection" || name == "Date" { + continue + } + for _, v := range vals { + bw.WriteString(name) + bw.WriteString(": ") + bw.WriteString(v) + bw.WriteString("\r\n") + } + } + bw.WriteString("\r\n") + bw.Write(w.body.Bytes()) + return bw.Flush() +} diff --git a/pkg/acquisition/modules/appsec/source.go b/pkg/acquisition/modules/appsec/source.go index 0c7b4fbb2da..83fd64d73ad 100644 --- a/pkg/acquisition/modules/appsec/source.go +++ b/pkg/acquisition/modules/appsec/source.go @@ -8,6 +8,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/appsec/httpserver" "github.com/crowdsecurity/crowdsec/pkg/appsec" "github.com/crowdsecurity/crowdsec/pkg/appsec/allowlists" "github.com/crowdsecurity/crowdsec/pkg/csconfig" @@ -20,7 +21,7 @@ type Source struct { lapiClientConfig *csconfig.LocalApiClientCfg logger *log.Entry mux *http.ServeMux - server *http.Server + server *httpserver.Server InChan chan appsec.ParsedRequest AppsecRuntime *appsec.AppsecRuntimeConfig AppsecConfigs map[string]appsec.AppsecConfig