Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pgconn/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ type Config struct {
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
BuildFrontend BuildFrontendFunc

// MaxBodyLen is the maximum length of a PostgreSQL wire protocol message body in octets. If a message body exceeds
// this length, reading the message will fail with pgproto3.ExceededMaxBodyLenErr. The default value is 0, which
// means no maximum is enforced.
MaxBodyLen int

// BuildContextWatcherHandler is called to create a ContextWatcherHandler for a connection. The handler is called
// when a context passed to a PgConn method is canceled.
BuildContextWatcherHandler func(*PgConn) ctxwatch.Handler
Expand Down
2 changes: 2 additions & 0 deletions pgconn/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,7 @@ func TestConfigCopyReturnsEqualConfig(t *testing.T) {
connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5"
original, err := pgconn.ParseConfig(connString)
require.NoError(t, err)
original.MaxBodyLen = 12345

copied := original.Copy()
assertConfigsEqual(t, original, copied, "Test Config.Copy() returns equal config")
Expand Down Expand Up @@ -938,6 +939,7 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName
assert.Equalf(t, expected.User, actual.User, "%s - User", testName)
assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName)
assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName)
assert.Equalf(t, expected.MaxBodyLen, actual.MaxBodyLen, "%s - MaxBodyLen", testName)
assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName)

// Can't test function equality, so just test that they are set or not.
Expand Down
6 changes: 6 additions & 0 deletions pgconn/pgconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,9 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo
pgConn.slowWriteTimer.Stop()
pgConn.bgReaderStarted = make(chan struct{})
pgConn.frontend = config.BuildFrontend(pgConn.bgReader, pgConn.conn)
if config.MaxBodyLen > 0 {
pgConn.frontend.SetMaxBodyLen(config.MaxBodyLen)
}

startupMsg := pgproto3.StartupMessage{
ProtocolVersion: maxProtocolVersion,
Expand Down Expand Up @@ -2177,6 +2180,9 @@ func Construct(hc *HijackedConn) (*PgConn, error) {
pgConn.slowWriteTimer.Stop()
pgConn.bgReaderStarted = make(chan struct{})
pgConn.frontend = hc.Config.BuildFrontend(pgConn.bgReader, pgConn.conn)
if hc.Config.MaxBodyLen > 0 {
pgConn.frontend.SetMaxBodyLen(hc.Config.MaxBodyLen)
}

return pgConn, nil
}
Expand Down
61 changes: 61 additions & 0 deletions pgconn/pgconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,67 @@ func TestConnectConfigRequiresConfigFromParseConfig(t *testing.T) {
require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgconn.ConnectConfig(ctx, config) })
}

func TestConnectConfigMaxBodyLen(t *testing.T) {
t.Parallel()

ln, err := net.Listen("tcp", "127.0.0.1:")
require.NoError(t, err)
defer ln.Close()

serverErrChan := make(chan error, 1)
go func() {
defer close(serverErrChan)

conn, err := ln.Accept()
if err != nil {
serverErrChan <- err
return
}
defer conn.Close()

err = conn.SetDeadline(time.Now().Add(5 * time.Second))
if err != nil {
serverErrChan <- err
return
}

backend := pgproto3.NewBackend(conn, conn)
_, err = backend.ReceiveStartupMessage()
if err != nil {
serverErrChan <- err
return
}

// ParameterStatus declares a 6-octet body, which verifies that Config.MaxBodyLen
// is installed before startup responses are received.
_, err = conn.Write([]byte{'S', 0, 0, 0, 10})
if err != nil {
serverErrChan <- err
return
}
}()

host, port, _ := strings.Cut(ln.Addr().String(), ":")
connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

config, err := pgconn.ParseConfig(connStr)
require.NoError(t, err)
config.MaxBodyLen = 5

conn, err := pgconn.ConnectConfig(ctx, config)
require.Nil(t, conn)

var bodyLenErr *pgproto3.ExceededMaxBodyLenErr
require.ErrorAs(t, err, &bodyLenErr)
assert.Equal(t, 5, bodyLenErr.MaxExpectedBodyLen)
assert.Equal(t, 6, bodyLenErr.ActualBodyLen)

require.NoError(t, <-serverErrChan)
}

func TestConnPrepareSyntaxError(t *testing.T) {
t.Parallel()

Expand Down