diff --git a/pkg/cli/cli.go b/pkg/cli/cli.go index 5101d25..25a1ce1 100644 --- a/pkg/cli/cli.go +++ b/pkg/cli/cli.go @@ -44,30 +44,39 @@ func New(binaryName string, shortUsage string, longUsage, examples string) Comma } } -// ParseAndValidateFlags will parse flags registered in this instance of CLI from os.Args -// and then perform validation -func (cl *CommandLineInterface) ParseAndValidateFlags() (map[string]interface{}, error) { +// ParseFlags will parse flags registered in this instance of CLI from os.Args +func (cl *CommandLineInterface) ParseFlags() (map[string]interface{}, error) { cl.setUsageTemplate() // Remove Suite Flags so that args only include Config and Filter Flags - cl.rootCmd.SetArgs(cl.removeIntersectingArgs(cl.suiteFlags)) + cl.rootCmd.SetArgs(removeIntersectingArgs(cl.suiteFlags)) // This parses Config and Filter flags only err := cl.rootCmd.Execute() if err != nil { return nil, err } // Remove Config and Filter flags so that only suite flags are parsed - err = cl.suiteFlags.Parse(cl.removeIntersectingArgs(cl.rootCmd.Flags())) + err = cl.suiteFlags.Parse(removeIntersectingArgs(cl.rootCmd.Flags())) if err != nil { return nil, err } // Add suite flags to rootCmd flagset so that other processing can occur // This has to be done after usage is printed so that the flagsets can be grouped properly when printed cl.rootCmd.Flags().AddFlagSet(cl.suiteFlags) - err = cl.SetUntouchedFlagValuesToNil() + err = cl.setUntouchedFlagValuesToNil() if err != nil { return nil, err } - err = cl.ProcessRangeFilterFlags() + err = cl.processRangeFilterFlags() + if err != nil { + return nil, err + } + return cl.Flags, nil +} + +// ParseAndValidateFlags will parse flags registered in this instance of CLI from os.Args +// and then perform validation +func (cl *CommandLineInterface) ParseAndValidateFlags() (map[string]interface{}, error) { + flags, err := cl.ParseFlags() if err != nil { return nil, err } @@ -75,19 +84,51 @@ func (cl *CommandLineInterface) ParseAndValidateFlags() (map[string]interface{}, if err != nil { return nil, err } - return cl.Flags, nil + return flags, nil } -func (cl *CommandLineInterface) removeIntersectingArgs(flagSet *pflag.FlagSet) []string { - newArgs := os.Args[1:] - for i, arg := range newArgs { - if flagSet.Lookup(strings.Replace(arg, "--", "", 1)) != nil || (len(arg) == 2 && flagSet.ShorthandLookup(strings.Replace(arg, "-", "", 1)) != nil) { - newArgs = append(newArgs[:i], newArgs[i+1:]...) +// ValidateFlags iterates through any registered validators and executes them +func (cl *CommandLineInterface) ValidateFlags() error { + for flagName, validationFn := range cl.validators { + if validationFn == nil { + continue + } + err := validationFn(cl.Flags[flagName]) + if err != nil { + return err } } + return nil +} + +func removeIntersectingArgs(flagSet *pflag.FlagSet) []string { + newArgs := []string{} + skipNext := false + for i, arg := range os.Args { + if skipNext { + skipNext = false + continue + } + arg = strings.Split(arg, "=")[0] + longFlag := strings.Replace(arg, "--", "", 1) + if flagSet.Lookup(longFlag) != nil || shorthandLookup(flagSet, arg) != nil { + if len(os.Args) > i+1 && os.Args[i+1][0] != '-' { + skipNext = true + } + continue + } + newArgs = append(newArgs, os.Args[i]) + } return newArgs } +func shorthandLookup(flagSet *pflag.FlagSet, arg string) *pflag.Flag { + if len(arg) == 2 && arg[0] == '-' && arg[1] != '-' { + return flagSet.ShorthandLookup(strings.Replace(arg, "-", "", 1)) + } + return nil +} + func (cl *CommandLineInterface) setUsageTemplate() { transformedUsage := usageTemplate suiteFlagCount := 0 @@ -104,9 +145,9 @@ func (cl *CommandLineInterface) setUsageTemplate() { cl.rootCmd.Flags().Usage = func() {} } -// SetUntouchedFlagValuesToNil iterates through all flags and sets their value to nil if they were not specifically set by the user +// setUntouchedFlagValuesToNil iterates through all flags and sets their value to nil if they were not specifically set by the user // This allows for a specified value, a negative value (like false or empty string), or an unspecified (nil) entry. -func (cl *CommandLineInterface) SetUntouchedFlagValuesToNil() error { +func (cl *CommandLineInterface) setUntouchedFlagValuesToNil() error { defaultHandlerErrMsg := "Unable to find a default value handler for %v, marking as no default value. This could be an error" defaultHandlerFlags := []string{} @@ -141,8 +182,8 @@ func (cl *CommandLineInterface) SetUntouchedFlagValuesToNil() error { return nil } -// ProcessRangeFilterFlags sets min and max to the appropriate 0 or maxInt bounds based on the 3-tuple that a user specifies for base flag, min, and/or max -func (cl *CommandLineInterface) ProcessRangeFilterFlags() error { +// processRangeFilterFlags sets min and max to the appropriate 0 or maxInt bounds based on the 3-tuple that a user specifies for base flag, min, and/or max +func (cl *CommandLineInterface) processRangeFilterFlags() error { for flagName := range cl.intRangeFlags { rangeHelperMin := fmt.Sprintf("%s-%s", flagName, "min") rangeHelperMax := fmt.Sprintf("%s-%s", flagName, "max") @@ -167,17 +208,3 @@ func (cl *CommandLineInterface) ProcessRangeFilterFlags() error { } return nil } - -// ValidateFlags iterates through any registered validators and executes them -func (cl *CommandLineInterface) ValidateFlags() error { - for flagName, validationFn := range cl.validators { - if validationFn == nil { - continue - } - err := validationFn(cl.Flags[flagName]) - if err != nil { - return err - } - } - return nil -} diff --git a/pkg/cli/cli_internal_test.go b/pkg/cli/cli_internal_test.go new file mode 100644 index 0000000..9d78c9d --- /dev/null +++ b/pkg/cli/cli_internal_test.go @@ -0,0 +1,48 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package cli + +import ( + "os" + "testing" + + h "github.com/aws/amazon-ec2-instance-selector/pkg/test" + "github.com/spf13/pflag" +) + +// Tests + +func TestRemoveIntersectingArgs(t *testing.T) { + flagSet := pflag.NewFlagSet("test-flag-set", pflag.ContinueOnError) + flagSet.Bool("test-bool", false, "test usage") + os.Args = []string{"ec2-instance-selector", "--test-bool", "--this-should-stay"} + newArgs := removeIntersectingArgs(flagSet) + h.Assert(t, len(newArgs) == 2, "NewArgs should only include the bin name and one argument after removing intersections") +} + +func TestRemoveIntersectingArgs_NextArg(t *testing.T) { + flagSet := pflag.NewFlagSet("test-flag-set", pflag.ContinueOnError) + flagSet.String("test-str", "", "test usage") + os.Args = []string{"ec2-instance-selector", "--test-str", "somevalue", "--this-should-stay", "valuetostay"} + newArgs := removeIntersectingArgs(flagSet) + h.Assert(t, len(newArgs) == 3, "NewArgs should only include the bin name and a flag + input after removing intersections") +} + +func TestRemoveIntersectingArgs_ShorthandArg(t *testing.T) { + flagSet := pflag.NewFlagSet("test-flag-set", pflag.ContinueOnError) + flagSet.StringP("test-str", "t", "", "test usage") + os.Args = []string{"ec2-instance-selector", "--test-str", "somevalue", "--this-should-stay", "valuetostay", "-t", "test"} + newArgs := removeIntersectingArgs(flagSet) + h.Assert(t, len(newArgs) == 3, "NewArgs should only include the bin name and a flag + input after removing intersections") +} diff --git a/pkg/cli/cli_test.go b/pkg/cli/cli_test.go new file mode 100644 index 0000000..40e66d1 --- /dev/null +++ b/pkg/cli/cli_test.go @@ -0,0 +1,244 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package cli_test + +import ( + "fmt" + "os" + "testing" + + "github.com/aws/amazon-ec2-instance-selector/pkg/cli" + h "github.com/aws/amazon-ec2-instance-selector/pkg/test" +) + +const ( + maxInt = int(^uint(0) >> 1) +) + +// Helpers + +func getTestCLI() cli.CommandLineInterface { + return cli.New("test", "test short usage", "test long usage", "test examples") +} + +// Tests + +func TestValidateFlags(t *testing.T) { + // Nil validator should succeed validation + cli := getTestCLI() + flagName := "test-flag" + cli.StringFlag(flagName, nil, nil, "Test String w/o validation", nil) + err := cli.ValidateFlags() + h.Ok(t, err) + + // Validator which returns nil error should succeed validation + cli = getTestCLI() + flagName = "test-flag" + cli.StringFlag(flagName, nil, nil, "Test String w/ successful validation", func(val interface{}) error { + return nil + }) + err = cli.ValidateFlags() + h.Ok(t, err) + + // Validator which returns error should fail validation + cli = getTestCLI() + cli.StringFlag(flagName, nil, nil, "Test String w/ validation failure", func(val interface{}) error { + return fmt.Errorf("validation failed") + }) + err = cli.ValidateFlags() + h.Nok(t, err) +} + +func TestParseFlags(t *testing.T) { + cli := getTestCLI() + flagName := "test-flag" + flagArg := fmt.Sprintf("--%s", flagName) + cli.StringFlag(flagName, nil, nil, "Test String w/o validation", nil) + os.Args = []string{"ec2-instance-selector", flagArg, "test"} + flags, err := cli.ParseFlags() + h.Ok(t, err) + flagOutput := flags[flagName].(*string) + h.Assert(t, *flagOutput == "test", "Flag %s should have been parsed", flagArg) +} + +func TestParseFlags_IntRange(t *testing.T) { + flagName := "test-flag" + flagMinArg := fmt.Sprintf("%s-%s", flagName, "min") + flagMaxArg := fmt.Sprintf("%s-%s", flagName, "max") + flagArg := fmt.Sprintf("--%s", flagName) + + // Root set Min and Max to the same val + cli := getTestCLI() + cli.IntMinMaxRangeFlags(flagName, nil, nil, "Test") + os.Args = []string{"ec2-instance-selector", flagArg, "5"} + flags, err := cli.ParseFlags() + h.Ok(t, err) + flagMinOutput := flags[flagMinArg].(*int) + flagMaxOutput := flags[flagMaxArg].(*int) + h.Assert(t, *flagMinOutput == 5 && *flagMaxOutput == 5, "Flag %s min and max should have been parsed to the same number", flagArg) + + // Min is set to a val and max is set to maxInt + cli = getTestCLI() + cli.IntMinMaxRangeFlags(flagName, nil, nil, "Test") + os.Args = []string{"ec2-instance-selector", "--" + flagMinArg, "5"} + flags, err = cli.ParseFlags() + h.Ok(t, err) + flagMinOutput = flags[flagMinArg].(*int) + flagMaxOutput = flags[flagMaxArg].(*int) + h.Assert(t, *flagMinOutput == 5 && *flagMaxOutput == maxInt, "Flag %s min should have been parsed from cmdline and max set to maxInt", flagArg) + + // Max is set to a val and min is set to 0 + cli = getTestCLI() + cli.IntMinMaxRangeFlags(flagName, nil, nil, "Test") + os.Args = []string{"ec2-instance-selector", "--" + flagMaxArg, "50"} + flags, err = cli.ParseFlags() + h.Ok(t, err) + flagMinOutput = flags[flagMinArg].(*int) + flagMaxOutput = flags[flagMaxArg].(*int) + h.Assert(t, *flagMinOutput == 0 && *flagMaxOutput == 50, "Flag %s max should have been parsed from cmdline and min set to 0", flagArg) + + // Min and Max are set to separate values + cli = getTestCLI() + cli.IntMinMaxRangeFlags(flagName, nil, nil, "Test") + os.Args = []string{"ec2-instance-selector", "--" + flagMinArg, "10", "--" + flagMaxArg, "500"} + flags, err = cli.ParseFlags() + h.Ok(t, err) + flagMinOutput = flags[flagMinArg].(*int) + flagMaxOutput = flags[flagMaxArg].(*int) + h.Assert(t, *flagMinOutput == 10 && *flagMaxOutput == 500, "Flag %s max should have been parsed from cmdline and min set to 0", flagArg) +} + +func TestParseFlags_IntRangeErr(t *testing.T) { + cli := getTestCLI() + flagName := "test-flag" + flagArg := fmt.Sprintf("--%s", flagName) + cli.IntMinMaxRangeFlags(flagName, nil, nil, "Test") + os.Args = []string{"ec2-instance-selector", flagArg, "1", flagArg + "-min", "1", flagArg + "-max", "2"} + _, err := cli.ParseFlags() + h.Nok(t, err) +} + +func TestParseFlags_RootErr(t *testing.T) { + cli := getTestCLI() + os.Args = []string{"ec2-instance-selector", "--test", "test"} + _, err := cli.ParseFlags() + h.Nok(t, err) +} + +func TestParseFlags_SuiteErr(t *testing.T) { + cli := getTestCLI() + cli.SuiteBoolFlag("test", nil, nil, "") + os.Args = []string{"ec2-instance-selector", "-----test"} + _, err := cli.ParseFlags() + h.Nok(t, err) +} + +func TestParseFlags_SuiteFlags(t *testing.T) { + cli := getTestCLI() + flagName := "test-flag" + flagArg := fmt.Sprintf("--%s", flagName) + cli.SuiteBoolFlag(flagName, nil, nil, "Test Suite Flag") + os.Args = []string{"ec2-instance-selector", flagArg} + flags, err := cli.ParseFlags() + h.Ok(t, err) + flagOutput := flags[flagName].(*bool) + h.Assert(t, *flagOutput == true, "Suite Flag %s should have been parsed", flagArg) +} + +func TestParseFlags_ConfigFlags(t *testing.T) { + cli := getTestCLI() + flagName := "test-flag" + flagArg := fmt.Sprintf("--%s", flagName) + cli.ConfigBoolFlag(flagName, nil, nil, "Test Config Flag") + os.Args = []string{"ec2-instance-selector", flagArg} + flags, err := cli.ParseFlags() + h.Ok(t, err) + flagOutput := flags[flagName].(*bool) + h.Assert(t, *flagOutput == true, "Config Flag %s should have been parsed", flagArg) +} + +func TestParseFlags_AllTypes(t *testing.T) { + cli := getTestCLI() + flagName := "test-flag" + configName := flagName + "-config" + suiteName := flagName + "-suite" + flagArg := fmt.Sprintf("--%s", flagName) + configArg := fmt.Sprintf("--%s", configName) + suiteArg := fmt.Sprintf("--%s", suiteName) + + cli.BoolFlag(flagName, nil, nil, "Test Filter Flag") + cli.ConfigBoolFlag(configName, nil, nil, "Test Config Flag") + cli.SuiteBoolFlag(suiteName, nil, nil, "Test Suite Flag") + os.Args = []string{"ec2-instance-selector", flagArg, configArg, suiteArg} + flags, err := cli.ParseFlags() + h.Ok(t, err) + flagOutput := flags[flagName].(*bool) + configOutput := flags[configName].(*bool) + suiteOutput := flags[suiteName].(*bool) + h.Assert(t, *flagOutput && *configOutput && *suiteOutput, "Filter, Config, and Sutie Flags %s should have been parsed", flagArg) +} + +func TestParseFlags_UntouchedFlags(t *testing.T) { + cli := getTestCLI() + flagName := "test-flag" + flagArg := fmt.Sprintf("--%s", flagName) + + cli.BoolFlag(flagName, nil, nil, "Test Filter Flag") + os.Args = []string{"ec2-instance-selector"} + flags, err := cli.ParseFlags() + h.Ok(t, err) + val, ok := flags[flagName] + h.Assert(t, ok, "Flag %s should exist in flags map", flagArg) + h.Assert(t, val == nil, "Flag %s should be set to nil when not explicitly set", flagArg) +} + +func TestParseFlags_UntouchedFlagsAllTypes(t *testing.T) { + cli := getTestCLI() + flagName := "test-flag" + ratioName := flagName + "-ratio" + configName := flagName + "-config" + suiteName := flagName + "-suite" + flagArg := fmt.Sprintf("--%s", flagName) + ratioArg := fmt.Sprintf("--%s", ratioName) + configArg := fmt.Sprintf("--%s", configName) + suiteArg := fmt.Sprintf("--%s", suiteName) + + cli.IntFlag(flagName, nil, nil, "Test Filter Flag") + cli.RatioFlag(ratioName, nil, nil, "Test Ratio Flag") + cli.ConfigStringFlag(configName, nil, nil, "Test Config Flag", nil) + cli.SuiteBoolFlag(suiteName, nil, nil, "Test Suite Flag") + + os.Args = []string{"ec2-instance-selector"} + flags, err := cli.ParseFlags() + h.Ok(t, err) + val, ok := flags[flagName] + ratioVal, ratioOk := flags[ratioName] + configVal, configOk := flags[configName] + suiteVal, suiteOk := flags[suiteName] + h.Assert(t, ok && ratioOk && configOk && suiteOk, "Flags %s, %s, %s should exist for all types in flags map", flagArg, ratioArg, configArg, suiteArg) + h.Assert(t, val == nil && ratioVal == nil && configVal == nil && suiteVal == nil, + "Flag %s, %s, %s should be set to nil when not explicitly set", flagArg, ratioArg, configArg, suiteArg) +} + +func TestParseAndValidateFlags_Err(t *testing.T) { + cli := getTestCLI() + flagName := "test-flag" + flagArg := fmt.Sprintf("--%s", flagName) + flagMin := flagArg + "-min" + flagMax := flagArg + "-max" + cli.IntMinMaxRangeFlags(flagName, nil, nil, "Test with validation") + os.Args = []string{"ec2-instance-selector", flagMin, "5", flagMax, "1"} + _, err := cli.ParseAndValidateFlags() + h.Nok(t, err) +} diff --git a/pkg/cli/flags.go b/pkg/cli/flags.go index ebce05e..a202ec1 100644 --- a/pkg/cli/flags.go +++ b/pkg/cli/flags.go @@ -47,10 +47,7 @@ func (cl *CommandLineInterface) RatioFlag(name string, shorthand *string, defaul // IntMinMaxRangeFlags creates and registers a min, max, and helper flag each accepting an Integer func (cl *CommandLineInterface) IntMinMaxRangeFlags(name string, shorthand *string, defaultValue *int, description string) { - cl.IntFlagOnFlagSet(cl.rootCmd.Flags(), name, shorthand, defaultValue, fmt.Sprintf("%s (sets --%s-min and -max to the same value)", description, name)) - cl.IntFlagOnFlagSet(cl.rootCmd.Flags(), name+"-min", nil, nil, fmt.Sprintf("Minimum %s If --%s-max is not specified, the upper bound will be infinity", description, name)) - cl.IntFlagOnFlagSet(cl.rootCmd.Flags(), name+"-max", nil, nil, fmt.Sprintf("Maximum %s If --%s-min is not specified, the lower bound will be 0", description, name)) - cl.intRangeFlags[name] = true + cl.IntMinMaxRangeFlagOnFlagSet(cl.rootCmd.Flags(), name, shorthand, defaultValue, description) } // IntFlag creates and registers a flag accepting an Integer @@ -111,6 +108,19 @@ func (cl *CommandLineInterface) IntMinMaxRangeFlagOnFlagSet(flagSet *pflag.FlagS cl.IntFlagOnFlagSet(flagSet, name, shorthand, defaultValue, fmt.Sprintf("%s (sets --%s-min and -max to the same value)", description, name)) cl.IntFlagOnFlagSet(flagSet, name+"-min", nil, nil, fmt.Sprintf("Minimum %s If --%s-max is not specified, the upper bound will be infinity", description, name)) cl.IntFlagOnFlagSet(flagSet, name+"-max", nil, nil, fmt.Sprintf("Maximum %s If --%s-min is not specified, the lower bound will be 0", description, name)) + cl.validators[name] = func(val interface{}) error { + if cl.Flags[name+"-min"] == nil || cl.Flags[name+"-max"] == nil { + return nil + } + minArg := name + "-min" + maxArg := name + "-max" + minVal := cl.Flags[minArg].(*int) + maxVal := cl.Flags[maxArg].(*int) + if *minVal > *maxVal { + return fmt.Errorf("Invalid input for --%s and --%s. %s must be less than or equal to %s", minArg, maxArg, minArg, maxArg) + } + return nil + } cl.intRangeFlags[name] = true } @@ -136,6 +146,7 @@ func (cl *CommandLineInterface) StringFlagOnFlagSet(flagSet *pflag.FlagSet, name } if shorthand != nil { cl.Flags[name] = flagSet.StringP(name, string(*shorthand), *defaultValue, description) + cl.validators[name] = validationFn return } cl.Flags[name] = flagSet.String(name, *defaultValue, description) diff --git a/pkg/cli/flags_test.go b/pkg/cli/flags_test.go new file mode 100644 index 0000000..d46e56f --- /dev/null +++ b/pkg/cli/flags_test.go @@ -0,0 +1,110 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package cli_test + +import ( + "fmt" + "testing" + + h "github.com/aws/amazon-ec2-instance-selector/pkg/test" +) + +// Tests + +func TestBoolFlag(t *testing.T) { + cli := getTestCLI() + for _, flagFn := range []func(string, *string, *bool, string){cli.BoolFlag, cli.ConfigBoolFlag, cli.SuiteBoolFlag} { + flagName := "test-int" + flagFn(flagName, cli.StringMe("t"), nil, "Test Bool") + _, ok := cli.Flags[flagName] + h.Assert(t, len(cli.Flags) == 1, "Should contain 1 flag") + h.Assert(t, ok, "Should contain %s flag", flagName) + + cli = getTestCLI() + cli.BoolFlag(flagName, nil, nil, "Test Bool") + h.Assert(t, len(cli.Flags) == 1, "Should contain 1 flag w/ no shorthand") + h.Assert(t, ok, "Should contain %s flag w/ no shorthand", flagName) + } +} + +func TestIntFlag(t *testing.T) { + cli := getTestCLI() + for _, flagFn := range []func(string, *string, *int, string){cli.IntFlag, cli.ConfigIntFlag} { + flagName := "test-int" + flagFn(flagName, cli.StringMe("t"), nil, "Test Int") + _, ok := cli.Flags[flagName] + h.Assert(t, len(cli.Flags) == 1, "Should contain 1 flag") + h.Assert(t, ok, "Should contain %s flag", flagName) + + cli = getTestCLI() + cli.IntFlag(flagName, nil, nil, "Test Int") + h.Assert(t, len(cli.Flags) == 1, "Should contain 1 flag w/ no shorthand") + h.Assert(t, ok, "Should contain %s flag w/ no shorthand", flagName) + } +} + +func TestStringFlag(t *testing.T) { + cli := getTestCLI() + for _, flagFn := range []func(string, *string, *string, string, func(interface{}) error){cli.StringFlag, cli.ConfigStringFlag} { + flagName := "test-string" + flagFn(flagName, cli.StringMe("t"), nil, "Test String", nil) + _, ok := cli.Flags[flagName] + h.Assert(t, len(cli.Flags) == 1, "Should contain 1 flag") + h.Assert(t, ok, "Should contain %s flag", flagName) + + cli = getTestCLI() + flagFn(flagName, cli.StringMe("t"), nil, "Test String w/ validation", func(val interface{}) error { + return fmt.Errorf("validation failed") + }) + h.Assert(t, len(cli.Flags) == 1, "Should contain 1 flag") + h.Assert(t, ok, "Should contain %s flag with validation", flagName) + + cli = getTestCLI() + flagFn(flagName, nil, nil, "Test String", nil) + h.Assert(t, len(cli.Flags) == 1, "Should contain 1 flag w/ no shorthand") + h.Assert(t, ok, "Should contain %s flag w/ no shorthand", flagName) + } +} + +func TestRatioFlag(t *testing.T) { + cli := getTestCLI() + flagName := "test-ratio" + cli.RatioFlag(flagName, cli.StringMe("t"), nil, "Test Ratio") + _, ok := cli.Flags[flagName] + h.Assert(t, len(cli.Flags) == 1, "Should contain 1 flag") + h.Assert(t, ok, "Should contain %s flag", flagName) + + cli = getTestCLI() + cli.RatioFlag(flagName, nil, nil, "Test Ratio") + h.Assert(t, len(cli.Flags) == 1, "Should contain 1 flag w/ no shorthand") + h.Assert(t, ok, "Should contain %s flag w/ no shorthand", flagName) +} + +func TestIntMinMaxRangeFlags(t *testing.T) { + cli := getTestCLI() + flagName := "test-int-min-max-range" + cli.IntMinMaxRangeFlags(flagName, cli.StringMe("t"), nil, "Test Min Max Range") + _, ok := cli.Flags[flagName] + _, minOk := cli.Flags[flagName+"-min"] + _, maxOk := cli.Flags[flagName+"-max"] + h.Assert(t, len(cli.Flags) == 3, "Should contain 3 flags") + h.Assert(t, ok, "Should contain %s flag", flagName) + h.Assert(t, minOk, "Should contain %s flag", flagName) + h.Assert(t, maxOk, "Should contain %s flag", flagName) + + cli = getTestCLI() + cli.IntMinMaxRangeFlags(flagName, nil, nil, "Test Min Max Range") + h.Assert(t, len(cli.Flags) == 3, "Should contain 3 flags w/ no shorthand") + h.Assert(t, ok, "Should contain %s flag w/ no shorthand", flagName) +} diff --git a/pkg/cli/types.go b/pkg/cli/types.go index 85f5f36..1804e1f 100644 --- a/pkg/cli/types.go +++ b/pkg/cli/types.go @@ -46,6 +46,7 @@ Global Flags: {{end}}` ) +// validator defines the function for providing validation on a flag type validator = func(val interface{}) error // CommandLineInterface is a type to group CLI funcs and state diff --git a/pkg/cli/types_test.go b/pkg/cli/types_test.go new file mode 100644 index 0000000..84fbfe2 --- /dev/null +++ b/pkg/cli/types_test.go @@ -0,0 +1,88 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package cli_test + +import ( + "testing" + + "github.com/aws/amazon-ec2-instance-selector/pkg/selector" + h "github.com/aws/amazon-ec2-instance-selector/pkg/test" +) + +// Tests + +func TestBoolMe(t *testing.T) { + cli := getTestCLI() + boolTrue := true + val := cli.BoolMe(boolTrue) + h.Assert(t, *val == true, "Should return true from passed in value bool") + val = cli.BoolMe(&boolTrue) + h.Assert(t, *val == true, "Should return true from passed in pointer bool") + val = cli.BoolMe(7) + h.Assert(t, val == nil, "Should return nil from other data type passed in") + val = cli.BoolMe(nil) + h.Assert(t, val == nil, "Should return nil if nil is passed in") +} + +func TestStringMe(t *testing.T) { + cli := getTestCLI() + stringVal := "test" + val := cli.StringMe(stringVal) + h.Assert(t, *val == stringVal, "Should return %s from passed in string value", stringVal) + val = cli.StringMe(&stringVal) + h.Assert(t, *val == stringVal, "Should return %s from passed in string pointer", stringVal) + val = cli.StringMe(7) + h.Assert(t, val == nil, "Should return nil from other data type passed in") + val = cli.StringMe(nil) + h.Assert(t, val == nil, "Should return nil if nil is passed in") +} + +func TestIntMe(t *testing.T) { + cli := getTestCLI() + intVal := 10 + val := cli.IntMe(intVal) + h.Assert(t, *val == intVal, "Should return %s from passed in int value", intVal) + val = cli.IntMe(&intVal) + h.Assert(t, *val == intVal, "Should return %s from passed in int pointer", intVal) + val = cli.IntMe(true) + h.Assert(t, val == nil, "Should return nil from other data type passed in") + val = cli.IntMe(nil) + h.Assert(t, val == nil, "Should return nil if nil is passed in") +} + +func TestFloat64Me(t *testing.T) { + cli := getTestCLI() + fVal := 10.01 + val := cli.Float64Me(fVal) + h.Assert(t, *val == fVal, "Should return %s from passed in float64 value", fVal) + val = cli.Float64Me(&fVal) + h.Assert(t, *val == fVal, "Should return %s from passed in float64 pointer", fVal) + val = cli.Float64Me(true) + h.Assert(t, val == nil, "Should return nil from other data type passed in") + val = cli.Float64Me(nil) + h.Assert(t, val == nil, "Should return nil if nil is passed in") +} + +func TestIntRangeMe(t *testing.T) { + cli := getTestCLI() + intRangeVal := selector.IntRangeFilter{LowerBound: 1, UpperBound: 2} + val := cli.IntRangeMe(intRangeVal) + h.Assert(t, *val == intRangeVal, "Should return %s from passed in int range value", intRangeVal) + val = cli.IntRangeMe(&intRangeVal) + h.Assert(t, *val == intRangeVal, "Should return %s from passed in range pointer", intRangeVal) + val = cli.IntRangeMe(true) + h.Assert(t, val == nil, "Should return nil from other data type passed in") + val = cli.IntRangeMe(nil) + h.Assert(t, val == nil, "Should return nil if nil is passed in") +} diff --git a/pkg/test/helpers.go b/pkg/test/helpers.go index ef1ed5b..f5fd95e 100644 --- a/pkg/test/helpers.go +++ b/pkg/test/helpers.go @@ -39,6 +39,15 @@ func Ok(tb testing.TB, err error) { } } +// Nok fails the test if an err is nil. +func Nok(tb testing.TB, err error) { + if err == nil { + _, file, line, _ := runtime.Caller(1) + fmt.Printf("\033[31m%s:%d: unexpected success \033[39m\n\n", filepath.Base(file), line) + tb.FailNow() + } +} + // Equals fails the test if exp is not equal to act. func Equals(tb testing.TB, exp, act interface{}) { if !reflect.DeepEqual(exp, act) {