Skip to content

Commit

Permalink
feat: support map kind env (#244)
Browse files Browse the repository at this point in the history
  • Loading branch information
ken8203 authored Feb 7, 2023
1 parent f434e98 commit 266f68b
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 1 deletion.
53 changes: 52 additions & 1 deletion env.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,11 @@ func set(field reflect.Value, sf reflect.StructField, value string, funcMap map[
return nil
}

if field.Kind() == reflect.Slice {
switch field.Kind() {
case reflect.Slice:
return handleSlice(field, value, sf, funcMap)
case reflect.Map:
return handleMap(field, value, sf, funcMap)
}

return newNoParserError(sf)
Expand Down Expand Up @@ -413,6 +416,54 @@ func handleSlice(field reflect.Value, value string, sf reflect.StructField, func
return nil
}

func handleMap(field reflect.Value, value string, sf reflect.StructField, funcMap map[reflect.Type]ParserFunc) error {
keyType := sf.Type.Key()
keyParserFunc, ok := funcMap[keyType]
if !ok {
keyParserFunc, ok = defaultBuiltInParsers[keyType.Kind()]
if !ok {
return newNoParserError(sf)
}
}

elemType := sf.Type.Elem()
elemParserFunc, ok := funcMap[elemType]
if !ok {
elemParserFunc, ok = defaultBuiltInParsers[elemType.Kind()]
if !ok {
return newNoParserError(sf)
}
}

separator := sf.Tag.Get("envSeparator")
if separator == "" {
separator = ","
}

result := reflect.MakeMap(sf.Type)
for _, part := range strings.Split(value, separator) {
pairs := strings.Split(part, ":")
if len(pairs) != 2 {
return fmt.Errorf("map pair: want 2 got %d", len(pairs))
}

key, err := keyParserFunc(pairs[0])
if err != nil {
return newParseError(sf, err)
}

elem, err := elemParserFunc(pairs[1])
if err != nil {
return newParseError(sf, err)
}

result.SetMapIndex(reflect.ValueOf(key).Convert(keyType), reflect.ValueOf(elem).Convert(elemType))
}

field.Set(result)
return nil
}

func asTextUnmarshaler(field reflect.Value) encoding.TextUnmarshaler {
if reflect.Ptr == field.Kind() {
if field.IsNil() {
Expand Down
33 changes: 33 additions & 0 deletions env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,39 @@ func TestParsesEnv(t *testing.T) {
isEqual(t, cfg.unexported, "")
}

func TestParsesEnv_Map(t *testing.T) {
type config struct {
MapStringString map[string]string `env:"MAP_STRING_STRING" envSeparator:","`
MapStringInt64 map[string]int64 `env:"MAP_STRING_INT64"`
MapStringBool map[string]bool `env:"MAP_STRING_BOOL" envSeparator:";"`
}

mss := map[string]string{
"k1": "v1",
"k2": "v2",
}
setEnv(t, "MAP_STRING_STRING", "k1:v1,k2:v2")

msi := map[string]int64{
"k1": 1,
"k2": 2,
}
setEnv(t, "MAP_STRING_INT64", "k1:1,k2:2")

msb := map[string]bool{
"k1": true,
"k2": false,
}
setEnv(t, "MAP_STRING_BOOL", "k1:true;k2:false")

var cfg config
isNoErr(t, Parse(&cfg))

isEqual(t, mss, cfg.MapStringString)
isEqual(t, msi, cfg.MapStringInt64)
isEqual(t, msb, cfg.MapStringBool)
}

func TestSetEnvAndTagOptsChain(t *testing.T) {
type config struct {
Key1 string `mytag:"KEY1,required"`
Expand Down

0 comments on commit 266f68b

Please sign in to comment.