Skip to content

Commit

Permalink
fix some cli parsing bugs and add unit test coverage (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
bwagner5 authored May 19, 2020
1 parent 96b4a84 commit c05edf4
Show file tree
Hide file tree
Showing 8 changed files with 573 additions and 35 deletions.
89 changes: 58 additions & 31 deletions pkg/cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,50 +44,91 @@ func New(binaryName string, shortUsage string, longUsage, examples string) Comma
}
}

// ParseAndValidateFlags will parse flags registered in this instance of CLI from os.Args
// and then perform validation
func (cl *CommandLineInterface) ParseAndValidateFlags() (map[string]interface{}, error) {
// ParseFlags will parse flags registered in this instance of CLI from os.Args
func (cl *CommandLineInterface) ParseFlags() (map[string]interface{}, error) {
cl.setUsageTemplate()
// Remove Suite Flags so that args only include Config and Filter Flags
cl.rootCmd.SetArgs(cl.removeIntersectingArgs(cl.suiteFlags))
cl.rootCmd.SetArgs(removeIntersectingArgs(cl.suiteFlags))
// This parses Config and Filter flags only
err := cl.rootCmd.Execute()
if err != nil {
return nil, err
}
// Remove Config and Filter flags so that only suite flags are parsed
err = cl.suiteFlags.Parse(cl.removeIntersectingArgs(cl.rootCmd.Flags()))
err = cl.suiteFlags.Parse(removeIntersectingArgs(cl.rootCmd.Flags()))
if err != nil {
return nil, err
}
// Add suite flags to rootCmd flagset so that other processing can occur
// This has to be done after usage is printed so that the flagsets can be grouped properly when printed
cl.rootCmd.Flags().AddFlagSet(cl.suiteFlags)
err = cl.SetUntouchedFlagValuesToNil()
err = cl.setUntouchedFlagValuesToNil()
if err != nil {
return nil, err
}
err = cl.ProcessRangeFilterFlags()
err = cl.processRangeFilterFlags()
if err != nil {
return nil, err
}
return cl.Flags, nil
}

// ParseAndValidateFlags will parse flags registered in this instance of CLI from os.Args
// and then perform validation
func (cl *CommandLineInterface) ParseAndValidateFlags() (map[string]interface{}, error) {
flags, err := cl.ParseFlags()
if err != nil {
return nil, err
}
err = cl.ValidateFlags()
if err != nil {
return nil, err
}
return cl.Flags, nil
return flags, nil
}

func (cl *CommandLineInterface) removeIntersectingArgs(flagSet *pflag.FlagSet) []string {
newArgs := os.Args[1:]
for i, arg := range newArgs {
if flagSet.Lookup(strings.Replace(arg, "--", "", 1)) != nil || (len(arg) == 2 && flagSet.ShorthandLookup(strings.Replace(arg, "-", "", 1)) != nil) {
newArgs = append(newArgs[:i], newArgs[i+1:]...)
// ValidateFlags iterates through any registered validators and executes them
func (cl *CommandLineInterface) ValidateFlags() error {
for flagName, validationFn := range cl.validators {
if validationFn == nil {
continue
}
err := validationFn(cl.Flags[flagName])
if err != nil {
return err
}
}
return nil
}

func removeIntersectingArgs(flagSet *pflag.FlagSet) []string {
newArgs := []string{}
skipNext := false
for i, arg := range os.Args {
if skipNext {
skipNext = false
continue
}
arg = strings.Split(arg, "=")[0]
longFlag := strings.Replace(arg, "--", "", 1)
if flagSet.Lookup(longFlag) != nil || shorthandLookup(flagSet, arg) != nil {
if len(os.Args) > i+1 && os.Args[i+1][0] != '-' {
skipNext = true
}
continue
}
newArgs = append(newArgs, os.Args[i])
}
return newArgs
}

func shorthandLookup(flagSet *pflag.FlagSet, arg string) *pflag.Flag {
if len(arg) == 2 && arg[0] == '-' && arg[1] != '-' {
return flagSet.ShorthandLookup(strings.Replace(arg, "-", "", 1))
}
return nil
}

func (cl *CommandLineInterface) setUsageTemplate() {
transformedUsage := usageTemplate
suiteFlagCount := 0
Expand All @@ -104,9 +145,9 @@ func (cl *CommandLineInterface) setUsageTemplate() {
cl.rootCmd.Flags().Usage = func() {}
}

// SetUntouchedFlagValuesToNil iterates through all flags and sets their value to nil if they were not specifically set by the user
// setUntouchedFlagValuesToNil iterates through all flags and sets their value to nil if they were not specifically set by the user
// This allows for a specified value, a negative value (like false or empty string), or an unspecified (nil) entry.
func (cl *CommandLineInterface) SetUntouchedFlagValuesToNil() error {
func (cl *CommandLineInterface) setUntouchedFlagValuesToNil() error {
defaultHandlerErrMsg := "Unable to find a default value handler for %v, marking as no default value. This could be an error"
defaultHandlerFlags := []string{}

Expand Down Expand Up @@ -141,8 +182,8 @@ func (cl *CommandLineInterface) SetUntouchedFlagValuesToNil() error {
return nil
}

// ProcessRangeFilterFlags sets min and max to the appropriate 0 or maxInt bounds based on the 3-tuple that a user specifies for base flag, min, and/or max
func (cl *CommandLineInterface) ProcessRangeFilterFlags() error {
// processRangeFilterFlags sets min and max to the appropriate 0 or maxInt bounds based on the 3-tuple that a user specifies for base flag, min, and/or max
func (cl *CommandLineInterface) processRangeFilterFlags() error {
for flagName := range cl.intRangeFlags {
rangeHelperMin := fmt.Sprintf("%s-%s", flagName, "min")
rangeHelperMax := fmt.Sprintf("%s-%s", flagName, "max")
Expand All @@ -167,17 +208,3 @@ func (cl *CommandLineInterface) ProcessRangeFilterFlags() error {
}
return nil
}

// ValidateFlags iterates through any registered validators and executes them
func (cl *CommandLineInterface) ValidateFlags() error {
for flagName, validationFn := range cl.validators {
if validationFn == nil {
continue
}
err := validationFn(cl.Flags[flagName])
if err != nil {
return err
}
}
return nil
}
48 changes: 48 additions & 0 deletions pkg/cli/cli_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.

package cli

import (
"os"
"testing"

h "github.com/aws/amazon-ec2-instance-selector/pkg/test"
"github.com/spf13/pflag"
)

// Tests

func TestRemoveIntersectingArgs(t *testing.T) {
flagSet := pflag.NewFlagSet("test-flag-set", pflag.ContinueOnError)
flagSet.Bool("test-bool", false, "test usage")
os.Args = []string{"ec2-instance-selector", "--test-bool", "--this-should-stay"}
newArgs := removeIntersectingArgs(flagSet)
h.Assert(t, len(newArgs) == 2, "NewArgs should only include the bin name and one argument after removing intersections")
}

func TestRemoveIntersectingArgs_NextArg(t *testing.T) {
flagSet := pflag.NewFlagSet("test-flag-set", pflag.ContinueOnError)
flagSet.String("test-str", "", "test usage")
os.Args = []string{"ec2-instance-selector", "--test-str", "somevalue", "--this-should-stay", "valuetostay"}
newArgs := removeIntersectingArgs(flagSet)
h.Assert(t, len(newArgs) == 3, "NewArgs should only include the bin name and a flag + input after removing intersections")
}

func TestRemoveIntersectingArgs_ShorthandArg(t *testing.T) {
flagSet := pflag.NewFlagSet("test-flag-set", pflag.ContinueOnError)
flagSet.StringP("test-str", "t", "", "test usage")
os.Args = []string{"ec2-instance-selector", "--test-str", "somevalue", "--this-should-stay", "valuetostay", "-t", "test"}
newArgs := removeIntersectingArgs(flagSet)
h.Assert(t, len(newArgs) == 3, "NewArgs should only include the bin name and a flag + input after removing intersections")
}
Loading

0 comments on commit c05edf4

Please sign in to comment.