diff --git a/env.go b/env.go index 66e89882..beada81d 100644 --- a/env.go +++ b/env.go @@ -660,39 +660,31 @@ func getOr(key, defaultValue string, defExists bool, envs map[string]string) (va } func set(field reflect.Value, sf reflect.StructField, value string, funcMap map[reflect.Type]ParserFunc) error { - if tm := asTextUnmarshaler(field); tm != nil { - if err := tm.UnmarshalText([]byte(value)); err != nil { + if ok, err := customSet(field, value, funcMap); ok { + if err != nil { return newParseError(sf, err) } return nil } - typee := sf.Type - fieldee := field - if typee.Kind() == reflect.Ptr { - typee = typee.Elem() - fieldee = field.Elem() - } - - parserFunc, ok := funcMap[typee] - if ok { - val, err := parserFunc(value) + if ok, err := textUnmarshalerSet(field, value); ok { if err != nil { return newParseError(sf, err) } - - fieldee.Set(reflect.ValueOf(val)) return nil } - parserFunc, ok = defaultBuiltInParsers[typee.Kind()] + typee := sf.Type + if typee.Kind() == reflect.Ptr { + typee = typee.Elem() + } + + parserFunc, ok := defaultBuiltInParsers[typee.Kind()] if ok { - val, err := parserFunc(value) + err := parseAndSet(field, value, parserFunc) if err != nil { return newParseError(sf, err) } - - fieldee.Set(reflect.ValueOf(val).Convert(typee)) return nil } @@ -706,6 +698,37 @@ func set(field reflect.Value, sf reflect.StructField, value string, funcMap map[ return newNoParserError(sf) } +func customSet(field reflect.Value, value string, funcMap map[reflect.Type]ParserFunc) (ok bool, err error) { + typee := field.Type() + if typee.Kind() == reflect.Ptr { + typee = typee.Elem() + } + parserFunc, ok := funcMap[typee] + if !ok { + return false, nil + } + return true, parseAndSet(field, value, parserFunc) +} + +func parseAndSet(field reflect.Value, value string, parserFunc ParserFunc) error { + val, err := parserFunc(value) + if err != nil { + return err + } + + rVal := reflect.ValueOf(val) + typee := field.Type() + if typee.Kind() == reflect.Ptr && !rVal.CanAddr() { + typee = typee.Elem() + ptr := reflect.New(typee) + ptr.Elem().Set(rVal.Convert(typee)) + field.Set(ptr) + } else { + field.Set(rVal.Convert(typee)) + } + return nil +} + func handleSlice(field reflect.Value, value string, sf reflect.StructField, funcMap map[reflect.Type]ParserFunc) error { separator := sf.Tag.Get("envSeparator") if separator == "" { @@ -800,20 +823,31 @@ func handleMap(field reflect.Value, value string, sf reflect.StructField, funcMa return nil } -func asTextUnmarshaler(field reflect.Value) encoding.TextUnmarshaler { +func textUnmarshalerSet(field reflect.Value, value interface{}) (bool, error) { + fv := field + allocated := false if field.Kind() == reflect.Ptr { if field.IsNil() { - field.Set(reflect.New(field.Type().Elem())) + // TextUnmarshaler requires a zero value pointer receiver. + allocated = true + fv = reflect.New(field.Type().Elem()) } - } else if field.CanAddr() { - field = field.Addr() + } else if fv.CanAddr() { + fv = fv.Addr() } - tm, ok := field.Interface().(encoding.TextUnmarshaler) + tm, ok := fv.Interface().(encoding.TextUnmarshaler) if !ok { - return nil + return false, nil + } + err := tm.UnmarshalText([]byte(value.(string))) + if err != nil { + return true, err + } + if allocated { + field.Set(fv) } - return tm + return true, nil } func parseTextUnmarshalers(field reflect.Value, data []string, sf reflect.StructField) error { diff --git a/env_test.go b/env_test.go index 4d93ff6f..a751aaf5 100644 --- a/env_test.go +++ b/env_test.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "log/slog" "net/http" "net/url" "os" @@ -1300,6 +1301,30 @@ func TestCustomParserNotCalledForNonAlias(t *testing.T) { isEqual(t, U(44), cfg.Other) } +func TestCustomParserCalledForTextMarshallImpl(t *testing.T) { + type config struct { + Val slog.Level `env:"" envDefault:"TRACE"` + } + + tParser := func(s string) (interface{}, error) { + l := slog.LevelInfo + if s == "TRACE" { + return slog.LevelDebug - 4, nil + } + err := l.UnmarshalText([]byte(s)) + return l, err + } + + cfg := config{} + + err := ParseWithOptions(&cfg, Options{FuncMap: map[reflect.Type]ParserFunc{ + reflect.TypeOf(slog.Level(0)): tParser, + }}) + + isNoErr(t, err) + isEqual(t, slog.LevelDebug-4, cfg.Val) +} + func TestCustomParserBasicUnsupported(t *testing.T) { type ConstT struct { A int