diff --git a/providers/basicflag/basicflag.go b/providers/basicflag/basicflag.go index 6dbf9bac..fb11fd2a 100644 --- a/providers/basicflag/basicflag.go +++ b/providers/basicflag/basicflag.go @@ -9,11 +9,24 @@ import ( "github.com/knadh/koanf/maps" ) +// Opt represents optional options (yup) passed to the provider. +type Opt struct { + KeyMap KoanfIntf +} + +// KoanfIntf is an interface that represents a small subset of methods +// used by this package from Koanf{}. When using this package, a live +// instance of Koanf{} should be passed. +type KoanfIntf interface { + Exists(string) bool +} + // Pflag implements a pflag command line provider. type Pflag struct { delim string flagset *flag.FlagSet cb func(key string, value string) (string, interface{}) + opt *Opt } // Provider returns a commandline flags provider that returns @@ -21,18 +34,42 @@ type Pflag struct { // nesting hierarchy of keys are defined by delim. For instance, the // delim "." will convert the key `parent.child.key: 1` // to `{parent: {child: {key: 1}}}`. -func Provider(f *flag.FlagSet, delim string) *Pflag { - return &Pflag{ +// +// It takes an optional Opt{} (but recommended) argument with a Koanf instance (opt.KeyMap) to see if the +// the flags defined have been set from other providers, for instance, +// a config file. If they are not, then the default values of the flags +// are merged. If they do exist, the flag values are not merged but only +// the values that have been explicitly set in the command line are merged. +// It is a variadic function as a hack to ensure backwards compatibility with the +// function definition. +// See https://github.com/knadh/koanf/issues/255 +func Provider(f *flag.FlagSet, delim string, opt ...*Opt) *Pflag { + pf := &Pflag{ flagset: f, delim: delim, } + + if len(opt) > 0 { + pf.opt = opt[0] + } + + return pf } // ProviderWithValue works exactly the same as Provider except the callback // takes a (key, value) with the variable name and value and allows you // to modify both. This is useful for cases where you may want to return // other types like a string slice instead of just a string. -func ProviderWithValue(f *flag.FlagSet, delim string, cb func(key string, value string) (string, interface{})) *Pflag { +// +// It takes an optional Opt{} (but recommended) argument with a Koanf instance (opt.KeyMap) to see if the +// the flags defined have been set from other providers, for instance, +// a config file. If they are not, then the default values of the flags +// are merged. If they do exist, the flag values are not merged but only +// the values that have been explicitly set in the command line are merged. +// It is a variadic function as a hack to ensure backwards compatibility with the +// function definition. +// See https://github.com/knadh/koanf/issues/255 +func ProviderWithValue(f *flag.FlagSet, delim string, cb func(key string, value string) (string, interface{}), ko ...KoanfIntf) *Pflag { return &Pflag{ flagset: f, delim: delim, @@ -42,18 +79,54 @@ func ProviderWithValue(f *flag.FlagSet, delim string, cb func(key string, value // Read reads the flag variables and returns a nested conf map. func (p *Pflag) Read() (map[string]interface{}, error) { + var changed map[string]struct{} + + // Prepare a map of flags that have been explicitly set by the user as aa KeyMap instance of Koanf + // has been provided. + if p.opt != nil && p.opt.KeyMap != nil { + changed = map[string]struct{}{} + + p.flagset.Visit(func(f *flag.Flag) { + key := f.Name + if p.cb != nil { + key, _ = p.cb(f.Name, "") + } + if key == "" { + return + } + + changed[key] = struct{}{} + }) + } + mp := make(map[string]interface{}) p.flagset.VisitAll(func(f *flag.Flag) { + var ( + key = f.Name + val interface{} = f.Value.String() + ) if p.cb != nil { - key, value := p.cb(f.Name, f.Value.String()) + k, v := p.cb(f.Name, f.Value.String()) // If the callback blanked the key, it should be omitted - if key == "" { + if k == "" { return } - mp[key] = value - } else { - mp[f.Name] = f.Value.String() + + key = k + val = v } + + // If the default value of the flag was never changed by the user, + // it should not override the value in the conf map (if it exists in the first place). + if changed != nil { + if _, ok := changed[key]; !ok { + if p.opt.KeyMap.Exists(key) { + return + } + } + } + + mp[key] = val }) return maps.Unflatten(mp, p.delim), nil } diff --git a/tests/koanf_test.go b/tests/koanf_test.go index 4f87ea58..edf5eeb8 100644 --- a/tests/koanf_test.go +++ b/tests/koanf_test.go @@ -587,11 +587,9 @@ func TestLoadMerge(t *testing.T) { func TestFlags(t *testing.T) { var ( assert = assert.New(t) - k = koanf.New(delim) + def = koanf.New(delim) ) - assert.Nil(k.Load(file.Provider(mockJSON), json.Parser()), "error loading file") - k2 := k.Copy() - k3 := k.Copy() + assert.Nil(def.Load(file.Provider(mockJSON), json.Parser()), "error loading file") // Override with the posflag provider. f := pflag.NewFlagSet("test", pflag.ContinueOnError) @@ -610,41 +608,79 @@ func TestFlags(t *testing.T) { // Initialize the provider with the Koanf instance passed where default values // will merge if the keys are not present in the conf map. - assert.Nil(k.Load(posflag.Provider(f, ".", k), nil), "error loading posflag") - assert.Equal("flag", k.String("parent1.child1.type"), "types don't match") - assert.Equal("flag", k.String("flagkey"), "value doesn't match") - assert.NotEqual("flag", k.String("parent1.name"), "value doesn't match") - assert.Equal([]string{"a", "b", "c"}, k.Strings("stringslice"), "value doesn't match") - assert.Equal([]int{1, 2, 3}, k.Ints("intslice"), "value doesn't match") + { + k := def.Copy() + assert.Nil(k.Load(posflag.Provider(f, ".", k), nil), "error loading posflag") + assert.Equal("flag", k.String("parent1.child1.type"), "types don't match") + assert.Equal("flag", k.String("flagkey"), "value doesn't match") + assert.NotEqual("flag", k.String("parent1.name"), "value doesn't match") + assert.Equal([]string{"a", "b", "c"}, k.Strings("stringslice"), "value doesn't match") + assert.Equal([]int{1, 2, 3}, k.Ints("intslice"), "value doesn't match") + } // Test the posflag provider can mutate the value to upper case - assert.Nil(k3.Load(posflag.ProviderWithValue(f, ".", nil, func(k string, v string) (string, interface{}) { - return strings.Replace(strings.ToLower(k), "prefix_", "", -1), strings.ToUpper(v) - }), nil), "error loading posflag") - assert.Equal("FLAG", k3.String("parent1.child1.type"), "types don't match") + { + k := def.Copy() + assert.Nil(k.Load(posflag.ProviderWithValue(f, ".", nil, func(k string, v string) (string, interface{}) { + return strings.Replace(strings.ToLower(k), "prefix_", "", -1), strings.ToUpper(v) + }), nil), "error loading posflag") + assert.Equal("FLAG", k.String("parent1.child1.type"), "types don't match") + } // Test without passing the Koanf instance where default values will not merge. - assert.Nil(k2.Load(posflag.Provider(f, ".", nil), nil), "error loading posflag") - assert.Equal("flag", k2.String("parent1.child1.type"), "types don't match") - assert.Equal("", k2.String("flagkey"), "value doesn't match") - assert.NotEqual("", k2.String("parent1.name"), "value doesn't match") - - // Override with the flag provider. - bf := flag.NewFlagSet("test", flag.ContinueOnError) - bf.String("parent1.child1.type", "flag", "") - bf.Set("parent1.child1.type", "basicflag") - assert.Nil(k.Load(basicflag.Provider(bf, "."), nil), "error loading basicflag") - assert.Equal("basicflag", k.String("parent1.child1.type"), "types don't match") + { + k := def.Copy() + assert.Nil(k.Load(posflag.Provider(f, ".", nil), nil), "error loading posflag") + assert.Equal("flag", k.String("parent1.child1.type"), "types don't match") + assert.Equal("", k.String("flagkey"), "value doesn't match") + assert.NotEqual("", k.String("parent1.name"), "value doesn't match") + } - // Test the basicflag provider can mutate the value to upper case - bf2 := flag.NewFlagSet("test", flag.ContinueOnError) - bf2.String("parent1.child1.type", "flag", "") - bf2.Set("parent1.child1.type", "basicflag") - assert.Nil(k.Load(basicflag.ProviderWithValue(bf2, ".", func(k string, v string) (string, interface{}) { - return strings.Replace(strings.ToLower(k), "prefix_", "", -1), strings.ToUpper(v) - }), nil), "error loading basicflag") - assert.Equal("BASICFLAG", k.String("parent1.child1.type"), "types don't match") + // Override with the basicflag provider. + { + k := def.Copy() + bf := flag.NewFlagSet("test", flag.ContinueOnError) + bf.String("parent1.child1.type", "flag", "") + bf.String("parent2.child2.name", "override-default", "") + bf.Set("parent1.child1.type", "basicflag") + assert.Nil(k.Load(basicflag.Provider(bf, "."), nil), "error loading basicflag") + assert.Equal("basicflag", k.String("parent1.child1.type"), "types don't match") + assert.Equal("override-default", k.String("parent2.child2.name"), "basicflag default value override failed") + } + + // No defualt-value override behaviour. + { + k := def.Copy() + bf := flag.NewFlagSet("test", flag.ContinueOnError) + bf.String("parent1.child1.name", "override-default", "") + bf.String("parent2.child2.name", "override-default", "") + bf.Set("parent2.child2.name", "custom") + assert.Nil(k.Load(basicflag.Provider(bf, ".", &basicflag.Opt{KeyMap: def}), nil), "error loading basicflag") + assert.Equal("child1", k.String("parent1.child1.name"), "basicflag default overwrote") + assert.Equal("custom", k.String("parent2.child2.name"), "basicflag set failed") + } + // Override with the basicflag provider. + { + k := def.Copy() + bf := flag.NewFlagSet("test", flag.ContinueOnError) + bf.String("parent1.child1.type", "flag", "") + bf.Set("parent1.child1.type", "basicflag") + assert.Nil(k.Load(basicflag.Provider(bf, "."), nil), "error loading basicflag") + assert.Equal("basicflag", k.String("parent1.child1.type"), "types don't match") + } + + // Test the basicflag provider can mutate the value to upper case + { + k := def.Copy() + bf := flag.NewFlagSet("test", flag.ContinueOnError) + bf.String("parent1.child1.type", "flag", "") + bf.Set("parent1.child1.type", "basicflag") + assert.Nil(k.Load(basicflag.ProviderWithValue(bf, ".", func(k string, v string) (string, interface{}) { + return strings.Replace(strings.ToLower(k), "prefix_", "", -1), strings.ToUpper(v) + }), nil), "error loading basicflag") + assert.Equal("BASICFLAG", k.String("parent1.child1.type"), "types don't match") + } } func TestConfMapValues(t *testing.T) {