Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 64 additions & 5 deletions pgconn/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ type (
GetSSLPasswordFunc func(ctx context.Context) string
)

// LookupSRVFunc is a function that resolves DNS SRV records. It has the same
// signature as [net.Resolver.LookupSRV].
type LookupSRVFunc func(ctx context.Context, service, proto, name string) (cname string, addrs []*net.SRV, err error)

// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by [ParseConfig]. A
// manually initialized Config will cause ConnectConfig to panic.
type Config struct {
Expand All @@ -40,10 +44,26 @@ type Config struct {
Password string
TLSConfig *tls.Config // nil disables TLS
ConnectTimeout time.Duration
DialFunc DialFunc // e.g. net.Dialer.DialContext
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
DialFunc DialFunc // e.g. net.Dialer.DialContext
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
LookupSRVFunc LookupSRVFunc // e.g. net.Resolver.LookupSRV; used when SRVHost is set
BuildFrontend BuildFrontendFunc

// SRVHost, when non-empty, causes the connection to be established via DNS
// SRV discovery. ParseConfig looks up _postgresql._tcp.{SRVHost} and uses
// the returned targets in priority/weight order (RFC 2782) as the list of
// hosts to try. It is set automatically when the connection string uses the
// postgres+srv:// or postgresql+srv:// URI scheme, or the srvhost= keyword.
// SRVHost is mutually exclusive with specifying hosts directly in the
// connection string.
SRVHost string

// srvMakeTLS generates per-target TLS configurations for SRV-resolved
// hostnames. It is a closure that captures the TLS settings from ParseConfig
// so that each SRV target gets the correct SNI ServerName. Nil when SRV
// mode is inactive.
srvMakeTLS func(target string) ([]*tls.Config, error)

// BuildContextWatcherHandler is called to create a ContextWatcherHandler for a connection. The handler is called
// when a context passed to a PgConn method is canceled.
BuildContextWatcherHandler func(*PgConn) ctxwatch.Handler
Expand Down Expand Up @@ -278,12 +298,27 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
connStringSettings := make(map[string]string)
if connString != "" {
var err error
// connString may be a database URL or in PostgreSQL keyword/value format
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
connStringSettings, err = parseURLSettings(connString)
// connString may be a database URL or in PostgreSQL keyword/value format.
// postgres+srv:// and postgresql+srv:// are SRV-discovery variants: the
// host in the URL names the cluster, not an individual server.
isSRVScheme := strings.HasPrefix(connString, "postgres+srv://") || strings.HasPrefix(connString, "postgresql+srv://")
normalizedConnString := connString
if isSRVScheme {
normalizedConnString = strings.Replace(connString, "+srv", "", 1)
}
if strings.HasPrefix(normalizedConnString, "postgres://") || strings.HasPrefix(normalizedConnString, "postgresql://") {
connStringSettings, err = parseURLSettings(normalizedConnString)
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as URL", err: err}
}
if isSRVScheme {
// Move the parsed host to srvhost; port comes from SRV records.
if host, ok := connStringSettings["host"]; ok && host != "" {
connStringSettings["srvhost"] = host
delete(connStringSettings, "host")
}
delete(connStringSettings, "port")
}
} else {
connStringSettings, err = parseKeywordValueSettings(connString)
if err != nil {
Expand Down Expand Up @@ -360,6 +395,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
"min_protocol_version": {},
"max_protocol_version": {},
"channel_binding": {},
"srvhost": {},
}

// Adding kerberos configuration
Expand All @@ -377,6 +413,29 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
config.RuntimeParams[k] = v
}

// SRV discovery mode: srvhost names the cluster, individual servers come
// from DNS at connection time.
if srvhost := settings["srvhost"]; srvhost != "" {
// Error only when host was supplied explicitly in the connection string
// itself (not when it came from PGHOST or the built-in default).
if _, hostInConnString := connStringSettings["host"]; hostInConnString {
return nil, &ParseConfigError{ConnString: connString, msg: "srvhost and host are mutually exclusive"}
}
config.SRVHost = srvhost
config.LookupSRVFunc = net.DefaultResolver.LookupSRV
// Capture TLS settings in a closure so buildConnectOneConfigs can build
// per-target TLS configs with the correct SNI ServerName for each
// SRV-resolved hostname.
capturedSettings := settings
capturedOptions := options
config.srvMakeTLS = func(target string) ([]*tls.Config, error) {
return configTLS(capturedSettings, target, capturedOptions)
}
// Provide a dummy host so passfile lookup and other host-dependent
// defaults have something to work with.
settings["host"] = srvhost
}

fallbacks := []*FallbackConfig{}

hosts := strings.Split(settings["host"], ",")
Expand Down
29 changes: 29 additions & 0 deletions pgconn/export_test.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,32 @@
// File export_test exports some methods for better testing.

package pgconn

import "context"

// BuildConnectOneConfigsFromSRV exposes the internal SRV resolution logic for
// white-box testing. It returns the ordered list of (network, address,
// originalHostname) tuples that pgconn would attempt to connect to, without
// actually opening any TCP connections.
func BuildConnectOneConfigsFromSRV(ctx context.Context, config *Config) ([]ResolvedSRVTarget, error) {
configs, errs := buildConnectOneConfigsFromSRV(ctx, config)
if len(errs) > 0 {
return nil, errs[0]
}
targets := make([]ResolvedSRVTarget, len(configs))
for i, c := range configs {
targets[i] = ResolvedSRVTarget{
Network: c.network,
Address: c.address,
OriginalHostname: c.originalHostname,
}
}
return targets, nil
}

// ResolvedSRVTarget holds the resolved address information for one SRV target.
type ResolvedSRVTarget struct {
Network string // "tcp"
Address string // "host:port"
OriginalHostname string // SRV target after trimming trailing dot
}
51 changes: 51 additions & 0 deletions pgconn/pgconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ func ConnectConfig(ctx context.Context, config *Config) (*PgConn, error) {
// slice of successfully resolved connectOneConfigs and a slice of errors. It is possible for both slices to contain
// values if some hosts were successfully resolved and others were not.
func buildConnectOneConfigs(ctx context.Context, config *Config) ([]*connectOneConfig, []error) {
if config.SRVHost != "" {
return buildConnectOneConfigsFromSRV(ctx, config)
}

// Simplify usage by treating primary config and fallbacks the same.
fallbackConfigs := []*FallbackConfig{
{
Expand Down Expand Up @@ -240,6 +244,53 @@ func buildConnectOneConfigs(ctx context.Context, config *Config) ([]*connectOneC
return configs, allErrors
}

// buildConnectOneConfigsFromSRV resolves _postgresql._tcp.{SRVHost} and builds
// connectOneConfigs for each returned target, in priority/weight order per
// RFC 2782. net.DefaultResolver.LookupSRV already handles the sorting and
// weight-based randomisation within a priority group.
func buildConnectOneConfigsFromSRV(ctx context.Context, config *Config) ([]*connectOneConfig, []error) {
_, srvs, err := config.LookupSRVFunc(ctx, "postgresql", "tcp", config.SRVHost)
if err != nil {
return nil, []error{fmt.Errorf("SRV lookup for %s: %w", config.SRVHost, err)}
}
if len(srvs) == 0 {
return nil, []error{fmt.Errorf("SRV lookup for %s returned no records", config.SRVHost)}
}

var configs []*connectOneConfig

for _, srv := range srvs {
// LookupSRV returns FQDNs with a trailing dot; trim it so TLS SNI and
// error messages look like normal hostnames.
target := strings.TrimSuffix(srv.Target, ".")
port := srv.Port

var tlsConfigs []*tls.Config
if config.srvMakeTLS != nil {
tlsConfigs, err = config.srvMakeTLS(target)
if err != nil {
return nil, []error{fmt.Errorf("TLS config for SRV target %s: %w", target, err)}
}
} else {
// Fallback when Config was constructed without ParseConfig (e.g. tests
// that set LookupSRVFunc directly without going through ParseConfig).
tlsConfigs = []*tls.Config{config.TLSConfig}
}

for _, tlsConfig := range tlsConfigs {
network, address := NetworkAddress(target, port)
configs = append(configs, &connectOneConfig{
network: network,
address: address,
originalHostname: target,
tlsConfig: tlsConfig,
})
}
}

return configs, nil
}

// connectPreferred attempts to connect to the preferred host from connectOneConfigs. The connections are attempted in
// order. If a connection is successful it is returned. If no connection is successful then all errors are returned. If
// a connection attempt returns a [NotPreferredError], then that host will be used if no other hosts are successful.
Expand Down
Loading
Loading