From dad0fe2b0090ae3199a00ac60a184b93c6709b2a Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Mon, 27 Apr 2026 08:44:37 +0300 Subject: [PATCH] fix(pgdriver): send required params during startup --- driver/pgdriver/driver.go | 9 +- driver/pgdriver/proto.go | 105 +++++++++++++++++-- driver/pgdriver/proto_test.go | 190 ++++++++++++++++++++++++++++++++++ 3 files changed, 288 insertions(+), 16 deletions(-) create mode 100644 driver/pgdriver/proto_test.go diff --git a/driver/pgdriver/driver.go b/driver/pgdriver/driver.go index b28f63c82..ef04b00b5 100644 --- a/driver/pgdriver/driver.go +++ b/driver/pgdriver/driver.go @@ -141,11 +141,10 @@ func newConn(ctx context.Context, conf *Config) (*Conn, error) { continue } - switch { - case !conf.UnsafeStrings && k == "standard_conforming_strings" && v != "on": - return nil, errRequiresParameter - case !conf.UnsafeStrings && k == "client_encoding" && v != "UTF8": - return nil, errRequiresParameter + if !conf.UnsafeStrings { + if err := checkRequiredConnParam(k, v); err != nil { + return nil, err + } } if _, err = cn.ExecContext(ctx, fmt.Sprintf("SET %s TO $1", k), []driver.NamedValue{ diff --git a/driver/pgdriver/proto.go b/driver/pgdriver/proto.go index 84a1567db..cf573b2fe 100644 --- a/driver/pgdriver/proto.go +++ b/driver/pgdriver/proto.go @@ -15,6 +15,7 @@ import ( "io" "math" "strconv" + "strings" "sync" "time" "unicode/utf8" @@ -223,12 +224,30 @@ func writeStartup(ctx context.Context, cn *Conn) error { wb.WriteString("application_name") wb.WriteString(cn.conf.AppName) } + writeStartupConnParams(wb, cn.conf) wb.WriteString("") wb.FinishMessage() return cn.write(ctx, wb) } +func writeStartupConnParams(wb *writeBuffer, conf *Config) { + for k, v := range conf.ConnParams { + if v == nil || !isRequiredParameter(k) { + continue + } + + // These parameters must be effective before startup ParameterStatus messages + // are validated; the later SET path runs too late for that check. + value := startupParamValue(v) + if !conf.UnsafeStrings { + value = canonicalRequiredParameterValue(k, value) + } + wb.WriteString(k) + wb.WriteString(value) + } +} + //------------------------------------------------------------------------------ func auth(ctx context.Context, cn *Conn, rd *reader) error { @@ -1143,15 +1162,7 @@ func checkParameterStatus(b []byte) error { return nil } - var mustBe []byte - switch { - case bytes.Equal(b[:i], []byte("standard_conforming_strings")): - mustBe = []byte("on") - case bytes.Equal(b[:i], []byte("client_encoding")): - mustBe = []byte("UTF8") - default: - return nil - } + name := string(b[:i]) // Move past the parameter key and the first null byte to get to the value. b = b[i+1:] @@ -1159,8 +1170,80 @@ func checkParameterStatus(b []byte) error { if i = bytes.IndexByte(b, 0x00); i == -1 { return nil } - if !bytes.Equal(b[:i], mustBe) { - return errRequiresParameter + return checkRequiredParameter(name, string(b[:i])) +} + +func checkRequiredConnParam(k string, v any) error { + if v == nil { + return nil + } + return checkRequiredParameter(k, startupParamValue(v)) +} + +func checkRequiredParameter(k, v string) error { + switch k { + case "standard_conforming_strings": + if !isStandardConformingStringsOn(v) { + return errRequiresParameter + } + case "client_encoding": + if !isClientEncodingUTF8(v) { + return errRequiresParameter + } } return nil } + +func isRequiredParameter(k string) bool { + switch k { + case "standard_conforming_strings", "client_encoding": + return true + default: + return false + } +} + +func canonicalRequiredParameterValue(k, v string) string { + // PostgreSQL accepts equivalent spellings, but it reports canonical values. + // Sending canonical values keeps startup validation independent of aliases. + switch k { + case "standard_conforming_strings": + if isStandardConformingStringsOn(v) { + return "on" + } + case "client_encoding": + if isClientEncodingUTF8(v) { + return "UTF8" + } + } + return v +} + +func isStandardConformingStringsOn(s string) bool { + switch strings.ToLower(strings.TrimSpace(s)) { + case "on", "true", "yes", "1": + return true + default: + return false + } +} + +func isClientEncodingUTF8(s string) bool { + switch strings.ToUpper(strings.ReplaceAll(strings.TrimSpace(s), "-", "")) { + case "UTF8", "UNICODE": + return true + default: + return false + } +} + +func startupParamValue(v any) string { + switch v := v.(type) { + case string: + return v + case []byte: + return string(v) + default: + return fmt.Sprint(v) + } +} diff --git a/driver/pgdriver/proto_test.go b/driver/pgdriver/proto_test.go new file mode 100644 index 000000000..92d44e1f8 --- /dev/null +++ b/driver/pgdriver/proto_test.go @@ -0,0 +1,190 @@ +package pgdriver + +import ( + "bytes" + "context" + "encoding/binary" + "io" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestWriteStartupIncludesRequiredConnParams(t *testing.T) { + netConn := new(bufferConn) + cn := &Conn{ + conf: &Config{ + User: "user", + Database: "db", + AppName: "app", + ConnParams: map[string]any{ + "client_encoding": "utf-8", + "standard_conforming_strings": true, + "search_path": "public", + "statement_timeout": 1000, + "ignored_param_with_nil_value": nil, + }, + }, + netConn: netConn, + } + + err := writeStartup(context.Background(), cn) + require.NoError(t, err) + + b := netConn.Bytes() + require.GreaterOrEqual(t, len(b), 8) + require.Equal(t, len(b), int(binary.BigEndian.Uint32(b[:4]))) + require.Equal(t, int32(196608), int32(binary.BigEndian.Uint32(b[4:8]))) + + params := readStartupParams(t, b[8:]) + require.Equal(t, "user", params["user"]) + require.Equal(t, "db", params["database"]) + require.Equal(t, "app", params["application_name"]) + require.Equal(t, "UTF8", params["client_encoding"]) + require.Equal(t, "on", params["standard_conforming_strings"]) + require.NotContains(t, params, "search_path") + require.NotContains(t, params, "statement_timeout") + require.NotContains(t, params, "ignored_param_with_nil_value") +} + +func TestCheckRequiredParameter(t *testing.T) { + tests := []struct { + name string + param string + value string + wantErr bool + }{ + { + name: "standard_conforming_strings on", + param: "standard_conforming_strings", + value: "on", + }, + { + name: "standard_conforming_strings true", + param: "standard_conforming_strings", + value: "true", + }, + { + name: "standard_conforming_strings off", + param: "standard_conforming_strings", + value: "off", + wantErr: true, + }, + { + name: "client_encoding UTF8", + param: "client_encoding", + value: "UTF8", + }, + { + name: "client_encoding utf-8", + param: "client_encoding", + value: "utf-8", + }, + { + name: "client_encoding unicode alias", + param: "client_encoding", + value: "Unicode", + }, + { + name: "client_encoding latin1", + param: "client_encoding", + value: "LATIN1", + wantErr: true, + }, + { + name: "unrelated parameter", + param: "server_version", + value: "16.0", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := checkRequiredParameter(test.param, test.value) + if test.wantErr { + require.ErrorIs(t, err, errRequiresParameter) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestCheckParameterStatus(t *testing.T) { + err := checkParameterStatus([]byte("client_encoding\x00LATIN1\x00")) + require.ErrorIs(t, err, errRequiresParameter) + + err = checkParameterStatus([]byte("client_encoding\x00UTF8\x00")) + require.NoError(t, err) +} + +func readStartupParams(t *testing.T, b []byte) map[string]string { + t.Helper() + + params := make(map[string]string) + for { + key := readCString(t, &b) + if key == "" { + require.Empty(t, b) + return params + } + + value := readCString(t, &b) + params[key] = value + } +} + +func readCString(t *testing.T, b *[]byte) string { + t.Helper() + + i := bytes.IndexByte(*b, 0) + require.NotEqual(t, -1, i) + + s := string((*b)[:i]) + *b = (*b)[i+1:] + return s +} + +type bufferConn struct { + bytes.Buffer +} + +func (c *bufferConn) Read([]byte) (int, error) { + return 0, io.EOF +} + +func (c *bufferConn) Close() error { + return nil +} + +func (c *bufferConn) LocalAddr() net.Addr { + return testAddr("local") +} + +func (c *bufferConn) RemoteAddr() net.Addr { + return testAddr("remote") +} + +func (c *bufferConn) SetDeadline(time.Time) error { + return nil +} + +func (c *bufferConn) SetReadDeadline(time.Time) error { + return nil +} + +func (c *bufferConn) SetWriteDeadline(time.Time) error { + return nil +} + +type testAddr string + +func (a testAddr) Network() string { + return string(a) +} + +func (a testAddr) String() string { + return string(a) +}