Skip to content

Commit

Permalink
Check and external subcommands (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
bobg authored Jul 1, 2023
1 parent b38597b commit a47d0d3
Show file tree
Hide file tree
Showing 11 changed files with 606 additions and 67 deletions.
2 changes: 2 additions & 0 deletions Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ command -globalopt subcommand -subopt1 FOO -subopt2 ARG1 ARG2
```

Subcommands may have sub-subcommands and so on.
Subcommands may also be implemented as separate executables.
(See the documentation for [Prefixer](https://pkg.go.dev/github.com/bobg/subcmd/v2#Prefixer).)

This is a layer on top of the standard Go `flag` package.

Expand Down
74 changes: 74 additions & 0 deletions check.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package subcmd

import (
"reflect"

"github.com/pkg/errors"
)

// Check checks that the type of subcmd.F matches the expectations set by subcmd.Params:
//
// - It must be a function;
// - It must return no more than one value;
// - If it returns a value, that value must be of type error;
// - It must take an initial context.Context parameter;
// - It must take a final []string parameter;
// - The length of subcmd.Params must match the number of parameters subcmd.F takes (not counting the initial context.Context and final []string parameters);
// - Each parameter in subcmd.Params must match the corresponding parameter in subcmd.F.
//
// It also checks that the default value of each parameter in subcmd.Params matches the parameter's type.
func Check(subcmd Subcmd) error {
fv := reflect.ValueOf(subcmd.F)
ft := fv.Type()

wantFuncTypeErr, wantFuncTypeNoErr := funcTypeForParams(subcmd.Params)
switch ft {
case wantFuncTypeErr, wantFuncTypeNoErr:
// ok

default:
return FuncTypeErr{Want: wantFuncTypeErr, Got: ft}
}

for i, param := range subcmd.Params {
if err := checkParam(param); err != nil {
return errors.Wrapf(err, "checking parameter %d", i+1)
}
}

return nil
}

func checkParam(param Param) error {
if !reflect.TypeOf(param.Default).AssignableTo(param.Type.reflectType()) {
return ParamDefaultErr{Param: param}
}

return nil
}

// CheckMap calls Check on each of the entries in the Map.
func CheckMap(m Map) error {
for name, subcmd := range m {
if err := Check(subcmd); err != nil {
return errors.Wrapf(err, "checking subcommand %s", name)
}
}
return nil
}

func funcTypeForParams(params []Param) (withErr, withoutErr reflect.Type) {
in := []reflect.Type{ctxType}
for _, param := range params {
in = append(in, param.Type.reflectType())
}
in = append(in, reflect.TypeOf([]string(nil)))

withoutErr = reflect.FuncOf(in, nil, false)

out := []reflect.Type{errType}

withErr = reflect.FuncOf(in, out, false)

return withErr, withoutErr
}
183 changes: 183 additions & 0 deletions check_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
package subcmd

import (
"context"
"errors"
"testing"
"time"
)

func TestCheckZeroArgs(t *testing.T) {
cases := []struct {
name string
f interface{}
wantErr bool
}{{
name: "noErr",
f: func(context.Context, []string) {},
}, {
name: "err",
f: func(context.Context, []string) error { return nil },
}, {
name: "noContext",
f: func([]string) {},
wantErr: true,
}, {
name: "noArgs",
f: func(context.Context) {},
wantErr: true,
}, {
name: "extraArg",
f: func(context.Context, int, []string) {},
wantErr: true,
}}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
err := Check(Subcmd{F: tc.f})
if gotErr := err != nil; gotErr != tc.wantErr {
t.Errorf("got err %v, wantErr is %v", err, tc.wantErr)
}
})
}
}

func TestCheckOneArg(t *testing.T) {
for ptyp := Bool; ptyp <= Duration; ptyp++ {
t.Run(ptyp.String(), func(t *testing.T) {
var fOK, fTooMany interface{}
switch ptyp {
case Bool:
fOK = func(context.Context, bool, []string) {}
fTooMany = func(context.Context, bool, bool, []string) {}
case Int:
fOK = func(context.Context, int, []string) {}
fTooMany = func(context.Context, int, int, []string) {}
case Int64:
fOK = func(context.Context, int64, []string) {}
fTooMany = func(context.Context, int64, int64, []string) {}
case Uint:
fOK = func(context.Context, uint, []string) {}
fTooMany = func(context.Context, uint, uint, []string) {}
case Uint64:
fOK = func(context.Context, uint64, []string) {}
fTooMany = func(context.Context, uint64, uint64, []string) {}
case String:
fOK = func(context.Context, string, []string) {}
fTooMany = func(context.Context, string, string, []string) {}
case Float64:
fOK = func(context.Context, float64, []string) {}
fTooMany = func(context.Context, float64, float64, []string) {}
case Duration:
fOK = func(context.Context, time.Duration, []string) {}
fTooMany = func(context.Context, time.Duration, time.Duration, []string) {}
}

t.Run("one", func(t *testing.T) {
if err := Check(Subcmd{F: fOK, Params: []Param{{Type: ptyp, Default: dflts[ptyp]}}}); err != nil {
t.Error(err)
}
})

t.Run("toomany", func(t *testing.T) {
err := Check(Subcmd{F: fTooMany, Params: []Param{{Type: ptyp, Default: dflts[ptyp]}}})
var e FuncTypeErr
if !errors.As(err, &e) {
t.Errorf("got %v, want FuncTypeErr", err)
}
})

t.Run("toofew", func(t *testing.T) {
err := Check(Subcmd{F: func(context.Context, []string) {}, Params: []Param{{Type: ptyp, Default: dflts[ptyp]}}})
var e FuncTypeErr
if !errors.As(err, &e) {
t.Errorf("got %v, want FuncTypeErr", err)
}
})

t.Run("wrongtype", func(t *testing.T) {
for ptyp2 := Bool; ptyp2 <= Duration; ptyp2++ {
if ptyp2 == ptyp {
continue
}

t.Run(ptyp2.String(), func(t *testing.T) {
err := Check(Subcmd{F: fOK, Params: []Param{{Type: ptyp2, Default: dflts[ptyp2]}}})
var e FuncTypeErr
if !errors.As(err, &e) {
t.Errorf("got %v, want FuncTypeErr", err)
}
})
}
})
})
}
}

func TestCheckNotFunc(t *testing.T) {
var e FuncTypeErr
if err := Check(Subcmd{F: 42}); !errors.As(err, &e) {
t.Errorf("got %v, want FuncTypeErr", err)
}
}

func TestCheckNoContext(t *testing.T) {
var e FuncTypeErr
if err := Check(Subcmd{F: func(int, []string) {}}); !errors.As(err, &e) {
t.Errorf("got %v, want FuncTypeErr", err)
}
}

func TestCheckNoStringSlice(t *testing.T) {
var e FuncTypeErr
if err := Check(Subcmd{F: func(context.Context, int) {}}); !errors.As(err, &e) {
t.Errorf("got %v, want FuncTypeErr", err)
}
}

func TestCheckNoError(t *testing.T) {
var e FuncTypeErr
if err := Check(Subcmd{F: func(context.Context, []string) int { return 0 }}); !errors.As(err, &e) {
t.Errorf("got %v, want FuncTypeErr", err)
}
}

func TestTooManyReturns(t *testing.T) {
var e FuncTypeErr
if err := Check(Subcmd{F: func(context.Context, []string) (int, int) { return 0, 0 }}); !errors.As(err, &e) {
t.Errorf("got %v, want FuncTypeErr", err)
}
}

func TestCheckParam(t *testing.T) {
for ptyp := Bool; ptyp <= Duration; ptyp++ {
t.Run(ptyp.String(), func(t *testing.T) {
for ptyp2 := Bool; ptyp2 <= Duration; ptyp2++ {
t.Run(ptyp2.String(), func(t *testing.T) {
err := checkParam(Param{Type: ptyp, Default: dflts[ptyp2]})
if ptyp == ptyp2 {
if err != nil {
t.Error(err)
}
} else {
var e ParamDefaultErr
if !errors.As(err, &e) {
t.Errorf("got %v, want ParamDefaultErr", err)
}
}
})
}
})
}
}

var dflts = map[Type]interface{}{
Bool: false,
Int: 0,
Int64: int64(0),
Uint: uint(0),
Uint64: uint64(0),
String: "",
Float64: float64(0),
Duration: time.Duration(0),
}
26 changes: 26 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"flag"
"fmt"
"os"
"reflect"
"strings"
)

Expand Down Expand Up @@ -214,3 +215,28 @@ func missingUnknownSubcmd(line1 string, cmd Cmd) string {
}
return b.String()
}

// FuncTypeErr means a Subcmd's F field has a type that does not match the function signature implied by its Params field.
type FuncTypeErr struct {
// Got is the type of the F field.
Got reflect.Type

// Want is the expected function type implied by the Params field.
// Note: for simplicity, this includes the optional error return,
// even if the type in Got does not
// (which is not, in itself, an error).
Want reflect.Type
}

func (e FuncTypeErr) Error() string {
return fmt.Sprintf("function has type %v, want %v", e.Got, e.Want)
}

// ParamDefaultErr is the error when a Param has a default value that is not of the correct type.
type ParamDefaultErr struct {
Param Param
}

func (e ParamDefaultErr) Error() string {
return fmt.Sprintf("default value %v is not of type %v", e.Param.Default, e.Param.Type)
}
1 change: 1 addition & 0 deletions parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/pkg/errors"
)

// The length of the resulting slice is len(params)+2.
func parseArgs(ctx context.Context, params []Param, args []string) ([]reflect.Value, error) {
fs, ptrs, positional, err := ToFlagSet(params)
if err != nil {
Expand Down
39 changes: 39 additions & 0 deletions parseenv_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package subcmd

import (
"encoding/json"
"testing"
)

func TestParseEnv(t *testing.T) {
j, err := json.Marshal(map[string]interface{}{
"foo": "bar",
"baz": 42,
})
if err != nil {
t.Fatal(err)
}

restoreEnv := testSetenv(EnvVar, string(j))
defer restoreEnv()

type s struct {
Foo string `json:"foo"`
Baz int `json:"baz"`
}

var (
got s
want = s{
Foo: "bar",
Baz: 42,
}
)

if err := ParseEnv(&got); err != nil {
t.Fatal(err)
}
if got != want {
t.Errorf("got %#v, want %#v", got, want)
}
}
Loading

0 comments on commit a47d0d3

Please sign in to comment.