From 9213f8d27edfc236619c426d59b2e8803c290664 Mon Sep 17 00:00:00 2001 From: Vitalii Solodilov Date: Wed, 28 Sep 2022 01:44:00 +0300 Subject: [PATCH] Fully overrides default slice fields of config (#4001) mapstructure library doesn't override full slice during unmarshalling. Origin issue: https://github.com/mitchellh/mapstructure/issues/74#issuecomment-279886492 To address this we zeroes every slice before unmarshalling unless user provided slice is nil. --- .chloggen/config-slice-fields.yaml | 5 ++ confmap/confmap.go | 18 +++++ confmap/confmap_test.go | 119 +++++++++++++++++++++++++++++ 3 files changed, 142 insertions(+) create mode 100644 .chloggen/config-slice-fields.yaml diff --git a/.chloggen/config-slice-fields.yaml b/.chloggen/config-slice-fields.yaml new file mode 100644 index 00000000000..09ce945ce83 --- /dev/null +++ b/.chloggen/config-slice-fields.yaml @@ -0,0 +1,5 @@ +change_type: "bug_fix" +component: "config" +note: "Fully overrides default slice fields of config" +issues: [4001] +subtext: diff --git a/confmap/confmap.go b/confmap/confmap.go index caf52a57c07..e15fec0fd22 100644 --- a/confmap/confmap.go +++ b/confmap/confmap.go @@ -168,6 +168,7 @@ func decodeConfig(m *Conf, result interface{}, errorUnused bool) error { mapstructure.StringToTimeDurationHookFunc(), mapstructure.TextUnmarshallerHookFunc(), unmarshalerHookFunc(result), + zeroesSliceHookFunc(), ), } decoder, err := mapstructure.NewDecoder(dc) @@ -320,6 +321,23 @@ func marshalerHookFunc(orig interface{}) mapstructure.DecodeHookFuncValue { } } +// mapstructure library doesn't override full slice during unmarshalling. +// Origin issue: https://github.com/mitchellh/mapstructure/issues/74#issuecomment-279886492 +// To address this we zeroes every slice before unmarshalling unless user provided slice is nil. +func zeroesSliceHookFunc() mapstructure.DecodeHookFuncValue { + return func(from reflect.Value, to reflect.Value) (interface{}, error) { + if from.Kind() == reflect.Slice && from.IsNil() { + return from.Interface(), nil + } + + if to.CanSet() && to.Kind() == reflect.Slice { + to.Set(reflect.MakeSlice(reflect.SliceOf(to.Type().Elem()), 0, 0)) + } + + return from.Interface(), nil + } +} + // Unmarshaler interface may be implemented by types to customize their behavior when being unmarshaled from a Conf. type Unmarshaler interface { // Unmarshal a Conf into the struct in a custom way. diff --git a/confmap/confmap_test.go b/confmap/confmap_test.go index 2ed44035fdc..4f8c9c6093f 100644 --- a/confmap/confmap_test.go +++ b/confmap/confmap_test.go @@ -400,3 +400,122 @@ func TestUnmarshalerErr(t *testing.T) { assert.EqualError(t, cfgMap.Unmarshal(tc), expectErr) assert.Empty(t, tc.Err.Foo) } + +func TestUnmarshalerSlices(t *testing.T) { + type innerStructWithSlices struct { + Ints []int `mapstructure:"ints"` + } + type structWithSlices struct { + Strings []string `mapstructure:"strings"` + Nested innerStructWithSlices `mapstructure:"nested"` + } + + tests := []struct { + name string + cfg map[string]any + provided any + expected any + }{ + { + name: "overridden by slice", + cfg: map[string]any{ + "strings": []string{"111"}, + }, + provided: &structWithSlices{ + Strings: []string{"xxx", "yyyy", "zzzz"}, + }, + expected: &structWithSlices{ + Strings: []string{"111"}, + }, + }, + { + name: "overridden by zero slice", + cfg: map[string]any{ + "strings": []string{}, + }, + provided: &structWithSlices{ + Strings: []string{"xxx", "yyyy"}, + }, + expected: &structWithSlices{ + Strings: []string{}, + }, + }, + { + name: "not overridden by nil slice", + cfg: map[string]any{ + "strings": []string(nil), + }, + provided: &structWithSlices{ + Strings: []string{"xxx", "yyyy"}, + }, + expected: &structWithSlices{ + Strings: []string{"xxx", "yyyy"}, + }, + }, + { + name: "not overridden by nil", + cfg: map[string]any{ + "strings": nil, + }, + provided: &structWithSlices{ + Strings: []string{"xxx", "yyyy"}, + }, + expected: &structWithSlices{ + Strings: []string{"xxx", "yyyy"}, + }, + }, + { + name: "not overridden by missing value", + cfg: map[string]any{}, + provided: &structWithSlices{ + Strings: []string{"xxx", "yyyy"}, + }, + expected: &structWithSlices{ + Strings: []string{"xxx", "yyyy"}, + }, + }, + { + name: "overridden by nested slice", + cfg: map[string]any{ + "nested": map[string]any{ + "ints": []int{777}, + }, + }, + provided: &structWithSlices{ + Nested: innerStructWithSlices{ + Ints: []int{1, 2, 3}, + }, + }, + expected: &structWithSlices{ + Nested: innerStructWithSlices{ + Ints: []int{777}, + }, + }, + }, + { + name: "overridden by weakly typed input", + cfg: map[string]any{ + "strings": "111", + }, + provided: &structWithSlices{ + Strings: []string{"xxx", "yyyy", "zzzz"}, + }, + expected: &structWithSlices{ + Strings: []string{"111"}, + }, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + cfg := NewFromStringMap(tt.cfg) + + err := cfg.UnmarshalExact(tt.provided) + if assert.NoError(t, err) { + assert.Equal(t, tt.expected, tt.provided) + } + }) + } +}