Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions driver/pgdriver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,10 @@ func newConn(ctx context.Context, conf *Config) (*Conn, error) {
continue
}

switch {
case !conf.UnsafeStrings && k == "standard_conforming_strings" && v != "on":
return nil, errRequiresParameter
case !conf.UnsafeStrings && k == "client_encoding" && v != "UTF8":
return nil, errRequiresParameter
if !conf.UnsafeStrings {
if err := checkRequiredConnParam(k, v); err != nil {
return nil, err
}
}

if _, err = cn.ExecContext(ctx, fmt.Sprintf("SET %s TO $1", k), []driver.NamedValue{
Expand Down
105 changes: 94 additions & 11 deletions driver/pgdriver/proto.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"io"
"math"
"strconv"
"strings"
"sync"
"time"
"unicode/utf8"
Expand Down Expand Up @@ -223,12 +224,30 @@ func writeStartup(ctx context.Context, cn *Conn) error {
wb.WriteString("application_name")
wb.WriteString(cn.conf.AppName)
}
writeStartupConnParams(wb, cn.conf)
wb.WriteString("")
wb.FinishMessage()

return cn.write(ctx, wb)
}

func writeStartupConnParams(wb *writeBuffer, conf *Config) {
for k, v := range conf.ConnParams {
if v == nil || !isRequiredParameter(k) {
continue
}

// These parameters must be effective before startup ParameterStatus messages
// are validated; the later SET path runs too late for that check.
value := startupParamValue(v)
if !conf.UnsafeStrings {
value = canonicalRequiredParameterValue(k, value)
}
wb.WriteString(k)
wb.WriteString(value)
}
}

//------------------------------------------------------------------------------

func auth(ctx context.Context, cn *Conn, rd *reader) error {
Expand Down Expand Up @@ -1143,24 +1162,88 @@ func checkParameterStatus(b []byte) error {
return nil
}

var mustBe []byte
switch {
case bytes.Equal(b[:i], []byte("standard_conforming_strings")):
mustBe = []byte("on")
case bytes.Equal(b[:i], []byte("client_encoding")):
mustBe = []byte("UTF8")
default:
return nil
}
name := string(b[:i])

// Move past the parameter key and the first null byte to get to the value.
b = b[i+1:]

if i = bytes.IndexByte(b, 0x00); i == -1 {
return nil
}
if !bytes.Equal(b[:i], mustBe) {
return errRequiresParameter
return checkRequiredParameter(name, string(b[:i]))
}

func checkRequiredConnParam(k string, v any) error {
if v == nil {
return nil
}
return checkRequiredParameter(k, startupParamValue(v))
}

func checkRequiredParameter(k, v string) error {
switch k {
case "standard_conforming_strings":
if !isStandardConformingStringsOn(v) {
return errRequiresParameter
}
case "client_encoding":
if !isClientEncodingUTF8(v) {
return errRequiresParameter
}
}
return nil
}

func isRequiredParameter(k string) bool {
switch k {
case "standard_conforming_strings", "client_encoding":
return true
default:
return false
}
}

func canonicalRequiredParameterValue(k, v string) string {
// PostgreSQL accepts equivalent spellings, but it reports canonical values.
// Sending canonical values keeps startup validation independent of aliases.
switch k {
case "standard_conforming_strings":
if isStandardConformingStringsOn(v) {
return "on"
}
case "client_encoding":
if isClientEncodingUTF8(v) {
return "UTF8"
}
}
return v
}

func isStandardConformingStringsOn(s string) bool {
switch strings.ToLower(strings.TrimSpace(s)) {
case "on", "true", "yes", "1":
return true
default:
return false
}
}

func isClientEncodingUTF8(s string) bool {
switch strings.ToUpper(strings.ReplaceAll(strings.TrimSpace(s), "-", "")) {
case "UTF8", "UNICODE":
return true
default:
return false
}
}

func startupParamValue(v any) string {
switch v := v.(type) {
case string:
return v
case []byte:
return string(v)
default:
return fmt.Sprint(v)
}
}
190 changes: 190 additions & 0 deletions driver/pgdriver/proto_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
package pgdriver

import (
"bytes"
"context"
"encoding/binary"
"io"
"net"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestWriteStartupIncludesRequiredConnParams(t *testing.T) {
netConn := new(bufferConn)
cn := &Conn{
conf: &Config{
User: "user",
Database: "db",
AppName: "app",
ConnParams: map[string]any{
"client_encoding": "utf-8",
"standard_conforming_strings": true,
"search_path": "public",
"statement_timeout": 1000,
"ignored_param_with_nil_value": nil,
},
},
netConn: netConn,
}

err := writeStartup(context.Background(), cn)
require.NoError(t, err)

b := netConn.Bytes()
require.GreaterOrEqual(t, len(b), 8)
require.Equal(t, len(b), int(binary.BigEndian.Uint32(b[:4])))
require.Equal(t, int32(196608), int32(binary.BigEndian.Uint32(b[4:8])))

params := readStartupParams(t, b[8:])
require.Equal(t, "user", params["user"])
require.Equal(t, "db", params["database"])
require.Equal(t, "app", params["application_name"])
require.Equal(t, "UTF8", params["client_encoding"])
require.Equal(t, "on", params["standard_conforming_strings"])
require.NotContains(t, params, "search_path")
require.NotContains(t, params, "statement_timeout")
require.NotContains(t, params, "ignored_param_with_nil_value")
}

func TestCheckRequiredParameter(t *testing.T) {
tests := []struct {
name string
param string
value string
wantErr bool
}{
{
name: "standard_conforming_strings on",
param: "standard_conforming_strings",
value: "on",
},
{
name: "standard_conforming_strings true",
param: "standard_conforming_strings",
value: "true",
},
{
name: "standard_conforming_strings off",
param: "standard_conforming_strings",
value: "off",
wantErr: true,
},
{
name: "client_encoding UTF8",
param: "client_encoding",
value: "UTF8",
},
{
name: "client_encoding utf-8",
param: "client_encoding",
value: "utf-8",
},
{
name: "client_encoding unicode alias",
param: "client_encoding",
value: "Unicode",
},
{
name: "client_encoding latin1",
param: "client_encoding",
value: "LATIN1",
wantErr: true,
},
{
name: "unrelated parameter",
param: "server_version",
value: "16.0",
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
err := checkRequiredParameter(test.param, test.value)
if test.wantErr {
require.ErrorIs(t, err, errRequiresParameter)
} else {
require.NoError(t, err)
}
})
}
}

func TestCheckParameterStatus(t *testing.T) {
err := checkParameterStatus([]byte("client_encoding\x00LATIN1\x00"))
require.ErrorIs(t, err, errRequiresParameter)

err = checkParameterStatus([]byte("client_encoding\x00UTF8\x00"))
require.NoError(t, err)
}

func readStartupParams(t *testing.T, b []byte) map[string]string {
t.Helper()

params := make(map[string]string)
for {
key := readCString(t, &b)
if key == "" {
require.Empty(t, b)
return params
}

value := readCString(t, &b)
params[key] = value
}
}

func readCString(t *testing.T, b *[]byte) string {
t.Helper()

i := bytes.IndexByte(*b, 0)
require.NotEqual(t, -1, i)

s := string((*b)[:i])
*b = (*b)[i+1:]
return s
}

type bufferConn struct {
bytes.Buffer
}

func (c *bufferConn) Read([]byte) (int, error) {
return 0, io.EOF
}

func (c *bufferConn) Close() error {
return nil
}

func (c *bufferConn) LocalAddr() net.Addr {
return testAddr("local")
}

func (c *bufferConn) RemoteAddr() net.Addr {
return testAddr("remote")
}

func (c *bufferConn) SetDeadline(time.Time) error {
return nil
}

func (c *bufferConn) SetReadDeadline(time.Time) error {
return nil
}

func (c *bufferConn) SetWriteDeadline(time.Time) error {
return nil
}

type testAddr string

func (a testAddr) Network() string {
return string(a)
}

func (a testAddr) String() string {
return string(a)
}
Loading