-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
606 additions
and
67 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
package subcmd | ||
|
||
import ( | ||
"reflect" | ||
|
||
"github.com/pkg/errors" | ||
) | ||
|
||
// Check checks that the type of subcmd.F matches the expectations set by subcmd.Params: | ||
// | ||
// - It must be a function; | ||
// - It must return no more than one value; | ||
// - If it returns a value, that value must be of type error; | ||
// - It must take an initial context.Context parameter; | ||
// - It must take a final []string parameter; | ||
// - The length of subcmd.Params must match the number of parameters subcmd.F takes (not counting the initial context.Context and final []string parameters); | ||
// - Each parameter in subcmd.Params must match the corresponding parameter in subcmd.F. | ||
// | ||
// It also checks that the default value of each parameter in subcmd.Params matches the parameter's type. | ||
func Check(subcmd Subcmd) error { | ||
fv := reflect.ValueOf(subcmd.F) | ||
ft := fv.Type() | ||
|
||
wantFuncTypeErr, wantFuncTypeNoErr := funcTypeForParams(subcmd.Params) | ||
switch ft { | ||
case wantFuncTypeErr, wantFuncTypeNoErr: | ||
// ok | ||
|
||
default: | ||
return FuncTypeErr{Want: wantFuncTypeErr, Got: ft} | ||
} | ||
|
||
for i, param := range subcmd.Params { | ||
if err := checkParam(param); err != nil { | ||
return errors.Wrapf(err, "checking parameter %d", i+1) | ||
} | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func checkParam(param Param) error { | ||
if !reflect.TypeOf(param.Default).AssignableTo(param.Type.reflectType()) { | ||
return ParamDefaultErr{Param: param} | ||
} | ||
|
||
return nil | ||
} | ||
|
||
// CheckMap calls Check on each of the entries in the Map. | ||
func CheckMap(m Map) error { | ||
for name, subcmd := range m { | ||
if err := Check(subcmd); err != nil { | ||
return errors.Wrapf(err, "checking subcommand %s", name) | ||
} | ||
} | ||
return nil | ||
} | ||
|
||
func funcTypeForParams(params []Param) (withErr, withoutErr reflect.Type) { | ||
in := []reflect.Type{ctxType} | ||
for _, param := range params { | ||
in = append(in, param.Type.reflectType()) | ||
} | ||
in = append(in, reflect.TypeOf([]string(nil))) | ||
|
||
withoutErr = reflect.FuncOf(in, nil, false) | ||
|
||
out := []reflect.Type{errType} | ||
|
||
withErr = reflect.FuncOf(in, out, false) | ||
|
||
return withErr, withoutErr | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
package subcmd | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"testing" | ||
"time" | ||
) | ||
|
||
func TestCheckZeroArgs(t *testing.T) { | ||
cases := []struct { | ||
name string | ||
f interface{} | ||
wantErr bool | ||
}{{ | ||
name: "noErr", | ||
f: func(context.Context, []string) {}, | ||
}, { | ||
name: "err", | ||
f: func(context.Context, []string) error { return nil }, | ||
}, { | ||
name: "noContext", | ||
f: func([]string) {}, | ||
wantErr: true, | ||
}, { | ||
name: "noArgs", | ||
f: func(context.Context) {}, | ||
wantErr: true, | ||
}, { | ||
name: "extraArg", | ||
f: func(context.Context, int, []string) {}, | ||
wantErr: true, | ||
}} | ||
|
||
for _, tc := range cases { | ||
t.Run(tc.name, func(t *testing.T) { | ||
err := Check(Subcmd{F: tc.f}) | ||
if gotErr := err != nil; gotErr != tc.wantErr { | ||
t.Errorf("got err %v, wantErr is %v", err, tc.wantErr) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func TestCheckOneArg(t *testing.T) { | ||
for ptyp := Bool; ptyp <= Duration; ptyp++ { | ||
t.Run(ptyp.String(), func(t *testing.T) { | ||
var fOK, fTooMany interface{} | ||
switch ptyp { | ||
case Bool: | ||
fOK = func(context.Context, bool, []string) {} | ||
fTooMany = func(context.Context, bool, bool, []string) {} | ||
case Int: | ||
fOK = func(context.Context, int, []string) {} | ||
fTooMany = func(context.Context, int, int, []string) {} | ||
case Int64: | ||
fOK = func(context.Context, int64, []string) {} | ||
fTooMany = func(context.Context, int64, int64, []string) {} | ||
case Uint: | ||
fOK = func(context.Context, uint, []string) {} | ||
fTooMany = func(context.Context, uint, uint, []string) {} | ||
case Uint64: | ||
fOK = func(context.Context, uint64, []string) {} | ||
fTooMany = func(context.Context, uint64, uint64, []string) {} | ||
case String: | ||
fOK = func(context.Context, string, []string) {} | ||
fTooMany = func(context.Context, string, string, []string) {} | ||
case Float64: | ||
fOK = func(context.Context, float64, []string) {} | ||
fTooMany = func(context.Context, float64, float64, []string) {} | ||
case Duration: | ||
fOK = func(context.Context, time.Duration, []string) {} | ||
fTooMany = func(context.Context, time.Duration, time.Duration, []string) {} | ||
} | ||
|
||
t.Run("one", func(t *testing.T) { | ||
if err := Check(Subcmd{F: fOK, Params: []Param{{Type: ptyp, Default: dflts[ptyp]}}}); err != nil { | ||
t.Error(err) | ||
} | ||
}) | ||
|
||
t.Run("toomany", func(t *testing.T) { | ||
err := Check(Subcmd{F: fTooMany, Params: []Param{{Type: ptyp, Default: dflts[ptyp]}}}) | ||
var e FuncTypeErr | ||
if !errors.As(err, &e) { | ||
t.Errorf("got %v, want FuncTypeErr", err) | ||
} | ||
}) | ||
|
||
t.Run("toofew", func(t *testing.T) { | ||
err := Check(Subcmd{F: func(context.Context, []string) {}, Params: []Param{{Type: ptyp, Default: dflts[ptyp]}}}) | ||
var e FuncTypeErr | ||
if !errors.As(err, &e) { | ||
t.Errorf("got %v, want FuncTypeErr", err) | ||
} | ||
}) | ||
|
||
t.Run("wrongtype", func(t *testing.T) { | ||
for ptyp2 := Bool; ptyp2 <= Duration; ptyp2++ { | ||
if ptyp2 == ptyp { | ||
continue | ||
} | ||
|
||
t.Run(ptyp2.String(), func(t *testing.T) { | ||
err := Check(Subcmd{F: fOK, Params: []Param{{Type: ptyp2, Default: dflts[ptyp2]}}}) | ||
var e FuncTypeErr | ||
if !errors.As(err, &e) { | ||
t.Errorf("got %v, want FuncTypeErr", err) | ||
} | ||
}) | ||
} | ||
}) | ||
}) | ||
} | ||
} | ||
|
||
func TestCheckNotFunc(t *testing.T) { | ||
var e FuncTypeErr | ||
if err := Check(Subcmd{F: 42}); !errors.As(err, &e) { | ||
t.Errorf("got %v, want FuncTypeErr", err) | ||
} | ||
} | ||
|
||
func TestCheckNoContext(t *testing.T) { | ||
var e FuncTypeErr | ||
if err := Check(Subcmd{F: func(int, []string) {}}); !errors.As(err, &e) { | ||
t.Errorf("got %v, want FuncTypeErr", err) | ||
} | ||
} | ||
|
||
func TestCheckNoStringSlice(t *testing.T) { | ||
var e FuncTypeErr | ||
if err := Check(Subcmd{F: func(context.Context, int) {}}); !errors.As(err, &e) { | ||
t.Errorf("got %v, want FuncTypeErr", err) | ||
} | ||
} | ||
|
||
func TestCheckNoError(t *testing.T) { | ||
var e FuncTypeErr | ||
if err := Check(Subcmd{F: func(context.Context, []string) int { return 0 }}); !errors.As(err, &e) { | ||
t.Errorf("got %v, want FuncTypeErr", err) | ||
} | ||
} | ||
|
||
func TestTooManyReturns(t *testing.T) { | ||
var e FuncTypeErr | ||
if err := Check(Subcmd{F: func(context.Context, []string) (int, int) { return 0, 0 }}); !errors.As(err, &e) { | ||
t.Errorf("got %v, want FuncTypeErr", err) | ||
} | ||
} | ||
|
||
func TestCheckParam(t *testing.T) { | ||
for ptyp := Bool; ptyp <= Duration; ptyp++ { | ||
t.Run(ptyp.String(), func(t *testing.T) { | ||
for ptyp2 := Bool; ptyp2 <= Duration; ptyp2++ { | ||
t.Run(ptyp2.String(), func(t *testing.T) { | ||
err := checkParam(Param{Type: ptyp, Default: dflts[ptyp2]}) | ||
if ptyp == ptyp2 { | ||
if err != nil { | ||
t.Error(err) | ||
} | ||
} else { | ||
var e ParamDefaultErr | ||
if !errors.As(err, &e) { | ||
t.Errorf("got %v, want ParamDefaultErr", err) | ||
} | ||
} | ||
}) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
var dflts = map[Type]interface{}{ | ||
Bool: false, | ||
Int: 0, | ||
Int64: int64(0), | ||
Uint: uint(0), | ||
Uint64: uint64(0), | ||
String: "", | ||
Float64: float64(0), | ||
Duration: time.Duration(0), | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
package subcmd | ||
|
||
import ( | ||
"encoding/json" | ||
"testing" | ||
) | ||
|
||
func TestParseEnv(t *testing.T) { | ||
j, err := json.Marshal(map[string]interface{}{ | ||
"foo": "bar", | ||
"baz": 42, | ||
}) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
restoreEnv := testSetenv(EnvVar, string(j)) | ||
defer restoreEnv() | ||
|
||
type s struct { | ||
Foo string `json:"foo"` | ||
Baz int `json:"baz"` | ||
} | ||
|
||
var ( | ||
got s | ||
want = s{ | ||
Foo: "bar", | ||
Baz: 42, | ||
} | ||
) | ||
|
||
if err := ParseEnv(&got); err != nil { | ||
t.Fatal(err) | ||
} | ||
if got != want { | ||
t.Errorf("got %#v, want %#v", got, want) | ||
} | ||
} |
Oops, something went wrong.