diff --git a/cli.go b/cli.go index c3ada92..e304717 100644 --- a/cli.go +++ b/cli.go @@ -1,6 +1,7 @@ package cli import ( + "fmt" "io" "os" "sync" @@ -26,6 +27,8 @@ type CLI struct { // Version of the CLI. Version string + SubcommandChooser func(*CLI) (CommandFactory, error) + // HelpFunc and HelpWriter are used to output help information, if // requested. // @@ -47,6 +50,24 @@ type CLI struct { isVersion bool } +func DefaultSubcommandChooser(c *CLI) (CommandFactory, error) { + helpCommandFactory := func() (Command, error) { + return HelpCommand{c.HelpFunc, c.HelpWriter, c.Commands}, nil + } + versionCommand := OutputTextCommand{c.HelpWriter, c.Version} + versionCommandFactory := func() (Command, error) { + return versionCommand, nil + } + + if commandFunc, ok := c.Commands[c.Subcommand()]; ok { + return commandFunc, nil + } else if c.IsVersion() { + return versionCommandFactory, fmt.Errorf("Failed to find subcommand") + } else { + return helpCommandFactory, fmt.Errorf("Failed to find subcommand") + } +} + // NewClI returns a new CLI instance with sensible defaults. func NewCLI(app, version string) *CLI { return &CLI{ @@ -75,12 +96,6 @@ func (c *CLI) IsVersion() bool { func (c *CLI) Run() (int, error) { c.once.Do(c.init) - // Just show the version and exit if instructed. - if c.IsVersion() && c.Version != "" { - c.HelpWriter.Write([]byte(c.Version + "\n")) - return 1, nil - } - // If there is an invalid flag, then error if len(c.topFlags) > 0 { c.HelpWriter.Write([]byte( @@ -90,25 +105,16 @@ func (c *CLI) Run() (int, error) { return 1, nil } - // Attempt to get the factory function for creating the command - // implementation. If the command is invalid or blank, it is an error. - commandFunc, ok := c.Commands[c.Subcommand()] - if !ok { - c.HelpWriter.Write([]byte(c.HelpFunc(c.Commands) + "\n")) - return 1, nil - } + commandFunc, _ := c.SubcommandChooser(c) command, err := commandFunc() + if err != nil { return 0, err } - - // If we've been instructed to just print the help, then print it if c.IsHelp() { - c.HelpWriter.Write([]byte(command.Help() + "\n")) - return 1, nil + command = OutputTextCommand{c.HelpWriter, command.Help()} } - return command.Run(c.SubcommandArgs()), nil } @@ -140,6 +146,10 @@ func (c *CLI) init() { c.HelpWriter = os.Stderr } + if c.SubcommandChooser == nil { + c.SubcommandChooser = DefaultSubcommandChooser + } + c.processArgs() } diff --git a/cli_test.go b/cli_test.go index d49d8d4..8c24658 100644 --- a/cli_test.go +++ b/cli_test.go @@ -158,8 +158,11 @@ func TestCLIRun_printHelp(t *testing.T) { continue } - if !strings.Contains(buf.String(), helpText) { - t.Errorf("Args: %#v. Text: %v", testCase, buf.String()) + expect := strings.TrimSpace(buf.String()) + got := strings.TrimSpace(helpText) + + if !strings.Contains(expect, got) { + t.Errorf("Args: %#v, expect: %#v, got %#v", testCase, expect, got) } } } @@ -195,8 +198,10 @@ func TestCLIRun_printCommandHelp(t *testing.T) { t.Fatalf("bad exit code: %d", exitCode) } - if buf.String() != (command.HelpText + "\n") { - t.Fatalf("bad: %#v", buf.String()) + expect := strings.TrimSpace(command.HelpText) + got := strings.TrimSpace(buf.String()) + if expect != got { + t.Fatalf("Expect %#v, got %#v.", expect, got) } } } diff --git a/command.go b/command.go index b18d3ef..416cedc 100644 --- a/command.go +++ b/command.go @@ -1,5 +1,7 @@ package cli +import "io" + // A command is a runnable sub-command of a CLI. type Command interface { // Help should return long-form help text that includes the command-line @@ -21,3 +23,21 @@ type Command interface { // We need a factory because we may need to setup some state on the // struct that implements the command itself. type CommandFactory func() (Command, error) + +type OutputTextCommand struct { + writer io.Writer + text string +} + +func (c OutputTextCommand) Help() string { + return c.text +} + +func (c OutputTextCommand) Synopsis() string { + return c.Help() +} + +func (c OutputTextCommand) Run(_ []string) int { + c.writer.Write([]byte(c.Help())) + return 1 +} diff --git a/help.go b/help.go index 67ea8c8..69aa4f3 100644 --- a/help.go +++ b/help.go @@ -3,6 +3,7 @@ package cli import ( "bytes" "fmt" + "io" "log" "sort" "strings" @@ -77,3 +78,20 @@ func FilteredHelpFunc(include []string, f HelpFunc) HelpFunc { return f(filtered) } } + +type HelpCommand struct { + helpFunc HelpFunc + writer io.Writer + subcommands map[string]CommandFactory +} + +func (c HelpCommand) Help() string { + return c.helpFunc(c.subcommands) + "\n" +} +func (c HelpCommand) Synopsis() string { + return c.Help() +} +func (c HelpCommand) Run(_ []string) int { + c.writer.Write([]byte(c.Help())) + return 1 +}