From f8fca46a8a9e95ae9a4630e22f3aca2f67f9796e Mon Sep 17 00:00:00 2001 From: carter-ya <18641671+carter-ya@users.noreply.github.com> Date: Fri, 15 May 2026 18:49:27 +0800 Subject: [PATCH] pgconn: expose frontend max body length --- pgconn/config.go | 5 ++++ pgconn/config_test.go | 2 ++ pgconn/pgconn.go | 6 +++++ pgconn/pgconn_test.go | 61 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 74 insertions(+) diff --git a/pgconn/config.go b/pgconn/config.go index 01710bb23..d6a2674f0 100644 --- a/pgconn/config.go +++ b/pgconn/config.go @@ -44,6 +44,11 @@ type Config struct { LookupFunc LookupFunc // e.g. net.Resolver.LookupHost BuildFrontend BuildFrontendFunc + // MaxBodyLen is the maximum length of a PostgreSQL wire protocol message body in octets. If a message body exceeds + // this length, reading the message will fail with pgproto3.ExceededMaxBodyLenErr. The default value is 0, which + // means no maximum is enforced. + MaxBodyLen int + // BuildContextWatcherHandler is called to create a ContextWatcherHandler for a connection. The handler is called // when a context passed to a PgConn method is canceled. BuildContextWatcherHandler func(*PgConn) ctxwatch.Handler diff --git a/pgconn/config_test.go b/pgconn/config_test.go index 1f7ea08a5..a3e0266d0 100644 --- a/pgconn/config_test.go +++ b/pgconn/config_test.go @@ -841,6 +841,7 @@ func TestConfigCopyReturnsEqualConfig(t *testing.T) { connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5" original, err := pgconn.ParseConfig(connString) require.NoError(t, err) + original.MaxBodyLen = 12345 copied := original.Copy() assertConfigsEqual(t, original, copied, "Test Config.Copy() returns equal config") @@ -938,6 +939,7 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName assert.Equalf(t, expected.User, actual.User, "%s - User", testName) assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName) assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName) + assert.Equalf(t, expected.MaxBodyLen, actual.MaxBodyLen, "%s - MaxBodyLen", testName) assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) // Can't test function equality, so just test that they are set or not. diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index abf6c9d8d..34d9f7234 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -384,6 +384,9 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo pgConn.slowWriteTimer.Stop() pgConn.bgReaderStarted = make(chan struct{}) pgConn.frontend = config.BuildFrontend(pgConn.bgReader, pgConn.conn) + if config.MaxBodyLen > 0 { + pgConn.frontend.SetMaxBodyLen(config.MaxBodyLen) + } startupMsg := pgproto3.StartupMessage{ ProtocolVersion: maxProtocolVersion, @@ -2177,6 +2180,9 @@ func Construct(hc *HijackedConn) (*PgConn, error) { pgConn.slowWriteTimer.Stop() pgConn.bgReaderStarted = make(chan struct{}) pgConn.frontend = hc.Config.BuildFrontend(pgConn.bgReader, pgConn.conn) + if hc.Config.MaxBodyLen > 0 { + pgConn.frontend.SetMaxBodyLen(hc.Config.MaxBodyLen) + } return pgConn, nil } diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index dd6fce6b7..9fd540994 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -875,6 +875,67 @@ func TestConnectConfigRequiresConfigFromParseConfig(t *testing.T) { require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgconn.ConnectConfig(ctx, config) }) } +func TestConnectConfigMaxBodyLen(t *testing.T) { + t.Parallel() + + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + serverErrChan := make(chan error, 1) + go func() { + defer close(serverErrChan) + + conn, err := ln.Accept() + if err != nil { + serverErrChan <- err + return + } + defer conn.Close() + + err = conn.SetDeadline(time.Now().Add(5 * time.Second)) + if err != nil { + serverErrChan <- err + return + } + + backend := pgproto3.NewBackend(conn, conn) + _, err = backend.ReceiveStartupMessage() + if err != nil { + serverErrChan <- err + return + } + + // ParameterStatus declares a 6-octet body, which verifies that Config.MaxBodyLen + // is installed before startup responses are received. + _, err = conn.Write([]byte{'S', 0, 0, 0, 10}) + if err != nil { + serverErrChan <- err + return + } + }() + + host, port, _ := strings.Cut(ln.Addr().String(), ":") + connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + config, err := pgconn.ParseConfig(connStr) + require.NoError(t, err) + config.MaxBodyLen = 5 + + conn, err := pgconn.ConnectConfig(ctx, config) + require.Nil(t, conn) + + var bodyLenErr *pgproto3.ExceededMaxBodyLenErr + require.ErrorAs(t, err, &bodyLenErr) + assert.Equal(t, 5, bodyLenErr.MaxExpectedBodyLen) + assert.Equal(t, 6, bodyLenErr.ActualBodyLen) + + require.NoError(t, <-serverErrChan) +} + func TestConnPrepareSyntaxError(t *testing.T) { t.Parallel()