diff --git a/flag_bool_with_inverse.go b/flag_bool_with_inverse.go index 7b6e18c7a7..0e256b3e8d 100644 --- a/flag_bool_with_inverse.go +++ b/flag_bool_with_inverse.go @@ -90,7 +90,7 @@ func (parent *BoolWithInverseFlag) initialize() { parent.negativeFlag = &BoolFlag{ Category: child.Category, DefaultText: child.DefaultText, - Sources: ValueSourceChain{Chain: append([]ValueSource{}, child.Sources.Chain...)}, + Sources: NewValueSourceChain(child.Sources.Chain...), Usage: child.Usage, Required: child.Required, Hidden: child.Hidden, @@ -109,12 +109,13 @@ func (parent *BoolWithInverseFlag) initialize() { parent.negativeFlag.Name = parent.inverseName() parent.negativeFlag.Aliases = parent.inverseAliases() - if len(child.Sources.Chain) > 0 { - parent.negativeFlag.Sources = ValueSourceChain{Chain: make([]ValueSource, len(child.Sources.Chain))} + if len(child.Sources.EnvKeys()) > 0 { + sources := []ValueSource{} - for idx, envVar := range child.GetEnvVars() { - parent.negativeFlag.Sources.Chain[idx] = &envVarValueSource{Key: strings.ToUpper(parent.InversePrefix) + envVar} + for _, envVar := range child.GetEnvVars() { + sources = append(sources, &envVarValueSource{Key: strings.ToUpper(parent.InversePrefix) + envVar}) } + parent.negativeFlag.Sources = NewValueSourceChain(sources...) } } diff --git a/flag_impl.go b/flag_impl.go index 014a2a97dc..3478744807 100644 --- a/flag_impl.go +++ b/flag_impl.go @@ -238,15 +238,7 @@ func (f *FlagBase[T, C, V]) GetUsage() string { // GetEnvVars returns the env vars for this flag func (f *FlagBase[T, C, V]) GetEnvVars() []string { - vals := []string{} - - for _, src := range f.Sources.Chain { - if v, ok := src.(*envVarValueSource); ok { - vals = append(vals, v.Key) - } - } - - return vals + return f.Sources.EnvKeys() } // TakesValue returns true if the flag takes a value, otherwise false diff --git a/godoc-current.txt b/godoc-current.txt index bb6f4c2e9d..ece68dd4d6 100644 --- a/godoc-current.txt +++ b/godoc-current.txt @@ -967,6 +967,8 @@ type ValueSource interface { ValueSource is a source which can be used to look up a value, typically for use with a cli.Flag +func EnvVar(key string) ValueSource + type ValueSourceChain struct { Chain []ValueSource } @@ -981,6 +983,12 @@ func Files(paths ...string) ValueSourceChain Files is a helper function to encapsulate a number of fileValueSource together as a ValueSourceChain +func NewValueSourceChain(src ...ValueSource) ValueSourceChain + +func (vsc *ValueSourceChain) Append(other ValueSourceChain) + +func (vsc *ValueSourceChain) EnvKeys() []string + func (vsc *ValueSourceChain) GoString() string func (vsc *ValueSourceChain) Lookup() (string, bool) diff --git a/testdata/godoc-v3.x.txt b/testdata/godoc-v3.x.txt index bb6f4c2e9d..ece68dd4d6 100644 --- a/testdata/godoc-v3.x.txt +++ b/testdata/godoc-v3.x.txt @@ -967,6 +967,8 @@ type ValueSource interface { ValueSource is a source which can be used to look up a value, typically for use with a cli.Flag +func EnvVar(key string) ValueSource + type ValueSourceChain struct { Chain []ValueSource } @@ -981,6 +983,12 @@ func Files(paths ...string) ValueSourceChain Files is a helper function to encapsulate a number of fileValueSource together as a ValueSourceChain +func NewValueSourceChain(src ...ValueSource) ValueSourceChain + +func (vsc *ValueSourceChain) Append(other ValueSourceChain) + +func (vsc *ValueSourceChain) EnvKeys() []string + func (vsc *ValueSourceChain) GoString() string func (vsc *ValueSourceChain) Lookup() (string, bool) diff --git a/value_source.go b/value_source.go index 84445a8655..266f5e54bf 100644 --- a/value_source.go +++ b/value_source.go @@ -24,6 +24,28 @@ type ValueSourceChain struct { Chain []ValueSource } +func NewValueSourceChain(src ...ValueSource) ValueSourceChain { + return ValueSourceChain{ + Chain: src, + } +} + +func (vsc *ValueSourceChain) Append(other ValueSourceChain) { + vsc.Chain = append(vsc.Chain, other.Chain...) +} + +func (vsc *ValueSourceChain) EnvKeys() []string { + vals := []string{} + + for _, src := range vsc.Chain { + if v, ok := src.(*envVarValueSource); ok { + vals = append(vals, v.Key) + } + } + + return vals +} + func (vsc *ValueSourceChain) String() string { s := []string{} @@ -73,6 +95,12 @@ func (e *envVarValueSource) GoString() string { return fmt.Sprintf("&envVarValueSource{Key:%[1]q}", e.Key) } +func EnvVar(key string) ValueSource { + return &envVarValueSource{ + Key: key, + } +} + // EnvVars is a helper function to encapsulate a number of // envVarValueSource together as a ValueSourceChain func EnvVars(keys ...string) ValueSourceChain { diff --git a/value_source_test.go b/value_source_test.go index a5ca4e02f1..3e44f965d4 100644 --- a/value_source_test.go +++ b/value_source_test.go @@ -12,13 +12,13 @@ import ( func TestEnvVarValueSource(t *testing.T) { t.Run("implements ValueSource", func(t *testing.T) { - src := &envVarValueSource{Key: "foo"} + src := EnvVar("foo") require.Implements(t, (*ValueSource)(nil), src) t.Run("not found", func(t *testing.T) { t.Setenv("foo", "bar") - src := &envVarValueSource{Key: "foo_1"} + src := EnvVar("foo_1") _, ok := src.Lookup() require.False(t, ok) }) @@ -27,7 +27,7 @@ func TestEnvVarValueSource(t *testing.T) { t.Setenv("foo", "bar") r := require.New(t) - src := &envVarValueSource{Key: "foo"} + src := EnvVar("foo") str, ok := src.Lookup() r.True(ok) @@ -37,7 +37,7 @@ func TestEnvVarValueSource(t *testing.T) { }) t.Run("implements fmt.Stringer", func(t *testing.T) { - src := &envVarValueSource{Key: "foo"} + src := EnvVar("foo") r := require.New(t) r.Implements((*fmt.Stringer)(nil), src) @@ -45,7 +45,7 @@ func TestEnvVarValueSource(t *testing.T) { }) t.Run("implements fmt.GoStringer", func(t *testing.T) { - src := &envVarValueSource{Key: "foo"} + src := EnvVar("foo") r := require.New(t) r.Implements((*fmt.GoStringer)(nil), src) @@ -122,6 +122,16 @@ func TestFilePaths(t *testing.T) { r.Contains(src.String(), fmt.Sprintf("%[1]q", fileName)) } +func TestValueSourceChainEnvKeys(t *testing.T) { + chain := NewValueSourceChain( + &staticValueSource{"hello"}, + ) + chain.Append(EnvVars("foo", "bar")) + + r := require.New(t) + r.Equal([]string{"foo", "bar"}, chain.EnvKeys()) +} + func TestValueSourceChain(t *testing.T) { t.Run("implements ValueSource", func(t *testing.T) { vsc := &ValueSourceChain{} @@ -140,11 +150,10 @@ func TestValueSourceChain(t *testing.T) { r.Implements((*fmt.GoStringer)(nil), vsc) r.Equal("&ValueSourceChain{Chain:{}}", vsc.GoString()) - vsc.Chain = []ValueSource{ - &staticValueSource{v: "yahtzee"}, + vsc1 := NewValueSourceChain(&staticValueSource{v: "yahtzee"}, &staticValueSource{v: "matzoh"}, - } - r.Equal("&ValueSourceChain{Chain:{&staticValueSource{v:\"yahtzee\"},&staticValueSource{v:\"matzoh\"}}}", vsc.GoString()) + ) + r.Equal("&ValueSourceChain{Chain:{&staticValueSource{v:\"yahtzee\"},&staticValueSource{v:\"matzoh\"}}}", vsc1.GoString()) }) t.Run("implements fmt.Stringer", func(t *testing.T) { @@ -154,12 +163,12 @@ func TestValueSourceChain(t *testing.T) { r.Implements((*fmt.Stringer)(nil), vsc) r.Equal("", vsc.String()) - vsc.Chain = []ValueSource{ + vsc1 := NewValueSourceChain( &staticValueSource{v: "soup"}, &staticValueSource{v: "salad"}, &staticValueSource{v: "pumpkins"}, - } - r.Equal("soup,salad,pumpkins", vsc.String()) + ) + r.Equal("soup,salad,pumpkins", vsc1.String()) }) }