Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 59 additions & 25 deletions env.go
Original file line number Diff line number Diff line change
Expand Up @@ -660,39 +660,31 @@ func getOr(key, defaultValue string, defExists bool, envs map[string]string) (va
}

func set(field reflect.Value, sf reflect.StructField, value string, funcMap map[reflect.Type]ParserFunc) error {
if tm := asTextUnmarshaler(field); tm != nil {
if err := tm.UnmarshalText([]byte(value)); err != nil {
if ok, err := customSet(field, value, funcMap); ok {
if err != nil {
return newParseError(sf, err)
}
return nil
}

typee := sf.Type
fieldee := field
if typee.Kind() == reflect.Ptr {
typee = typee.Elem()
fieldee = field.Elem()
}

parserFunc, ok := funcMap[typee]
if ok {
val, err := parserFunc(value)
if ok, err := textUnmarshalerSet(field, value); ok {
if err != nil {
return newParseError(sf, err)
}

fieldee.Set(reflect.ValueOf(val))
return nil
}

parserFunc, ok = defaultBuiltInParsers[typee.Kind()]
typee := sf.Type
if typee.Kind() == reflect.Ptr {
typee = typee.Elem()
}

parserFunc, ok := defaultBuiltInParsers[typee.Kind()]
if ok {
val, err := parserFunc(value)
err := parseAndSet(field, value, parserFunc)
if err != nil {
return newParseError(sf, err)
}

fieldee.Set(reflect.ValueOf(val).Convert(typee))
return nil
}

Expand All @@ -706,6 +698,37 @@ func set(field reflect.Value, sf reflect.StructField, value string, funcMap map[
return newNoParserError(sf)
}

func customSet(field reflect.Value, value string, funcMap map[reflect.Type]ParserFunc) (ok bool, err error) {
typee := field.Type()
if typee.Kind() == reflect.Ptr {
typee = typee.Elem()
}
parserFunc, ok := funcMap[typee]
if !ok {
return false, nil
}
return true, parseAndSet(field, value, parserFunc)
}

func parseAndSet(field reflect.Value, value string, parserFunc ParserFunc) error {
val, err := parserFunc(value)
if err != nil {
return err
}

rVal := reflect.ValueOf(val)
typee := field.Type()
if typee.Kind() == reflect.Ptr && !rVal.CanAddr() {
typee = typee.Elem()
ptr := reflect.New(typee)
ptr.Elem().Set(rVal.Convert(typee))
field.Set(ptr)
} else {
field.Set(rVal.Convert(typee))
}
return nil
}

func handleSlice(field reflect.Value, value string, sf reflect.StructField, funcMap map[reflect.Type]ParserFunc) error {
separator := sf.Tag.Get("envSeparator")
if separator == "" {
Expand Down Expand Up @@ -800,20 +823,31 @@ func handleMap(field reflect.Value, value string, sf reflect.StructField, funcMa
return nil
}

func asTextUnmarshaler(field reflect.Value) encoding.TextUnmarshaler {
func textUnmarshalerSet(field reflect.Value, value interface{}) (bool, error) {
fv := field
allocated := false
if field.Kind() == reflect.Ptr {
if field.IsNil() {
field.Set(reflect.New(field.Type().Elem()))
// TextUnmarshaler requires a zero value pointer receiver.
allocated = true
fv = reflect.New(field.Type().Elem())
}
} else if field.CanAddr() {
field = field.Addr()
} else if fv.CanAddr() {
fv = fv.Addr()
}

tm, ok := field.Interface().(encoding.TextUnmarshaler)
tm, ok := fv.Interface().(encoding.TextUnmarshaler)
if !ok {
return nil
return false, nil
}
err := tm.UnmarshalText([]byte(value.(string)))
if err != nil {
return true, err
}
if allocated {
field.Set(fv)
}
return tm
return true, nil
}

func parseTextUnmarshalers(field reflect.Value, data []string, sf reflect.StructField) error {
Expand Down
25 changes: 25 additions & 0 deletions env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"net/url"
"os"
Expand Down Expand Up @@ -1300,6 +1301,30 @@ func TestCustomParserNotCalledForNonAlias(t *testing.T) {
isEqual(t, U(44), cfg.Other)
}

func TestCustomParserCalledForTextMarshallImpl(t *testing.T) {
type config struct {
Val slog.Level `env:"" envDefault:"TRACE"`
}

tParser := func(s string) (interface{}, error) {
l := slog.LevelInfo
if s == "TRACE" {
return slog.LevelDebug - 4, nil
}
err := l.UnmarshalText([]byte(s))
return l, err
}

cfg := config{}

err := ParseWithOptions(&cfg, Options{FuncMap: map[reflect.Type]ParserFunc{
reflect.TypeOf(slog.Level(0)): tParser,
}})

isNoErr(t, err)
isEqual(t, slog.LevelDebug-4, cfg.Val)
}

func TestCustomParserBasicUnsupported(t *testing.T) {
type ConstT struct {
A int
Expand Down