diff --git a/check.go b/check.go index 4310fb9..713fa1a 100644 --- a/check.go +++ b/check.go @@ -1,6 +1,7 @@ package subcmd import ( + "flag" "reflect" "github.com/pkg/errors" @@ -60,8 +61,51 @@ func checkFuncType(ft reflect.Type, params []Param) error { } func checkParam(param Param) error { - if !reflect.TypeOf(param.Default).AssignableTo(param.Type.reflectType()) { - return ParamDefaultErr{Param: param} + switch param.Type { + case Bool: + if _, ok := param.Default.(bool); !ok { + return ParamDefaultErr{Param: param} + } + + case Int: + if _, err := asInt(param.Default); err != nil { + return ParamDefaultErr{Param: param} + } + + case Int64: + if _, err := asInt64(param.Default); err != nil { + return ParamDefaultErr{Param: param} + } + + case Uint: + if _, err := asUint(param.Default); err != nil { + return ParamDefaultErr{Param: param} + } + + case Uint64: + if _, err := asUint64(param.Default); err != nil { + return ParamDefaultErr{Param: param} + } + + case String: + if _, ok := param.Default.(string); !ok { + return ParamDefaultErr{Param: param} + } + + case Float64: + if _, err := asFloat64(param.Default); err != nil { + return ParamDefaultErr{Param: param} + } + + case Duration: + if _, err := asDuration(param.Default); err != nil { + return ParamDefaultErr{Param: param} + } + + case Value: + if _, ok := param.Default.(flag.Value); !ok { + return ParamDefaultErr{Param: param} + } } return nil diff --git a/check_test.go b/check_test.go index 68ecf15..366062d 100644 --- a/check_test.go +++ b/check_test.go @@ -154,16 +154,56 @@ func TestCheckParam(t *testing.T) { t.Run(ptyp.String(), func(t *testing.T) { for ptyp2 := Bool; ptyp2 <= Duration; ptyp2++ { t.Run(ptyp2.String(), func(t *testing.T) { - err := checkParam(Param{Type: ptyp, Default: dflts[ptyp2]}) + wantErr := true + + switch ptyp { + case Int64: + wantErr = ptyp2 != Int + + case Uint: + wantErr = ptyp2 != Int + + case Uint64: + switch ptyp2 { + case Int, Int64, Uint: + wantErr = false + default: + wantErr = true + } + + case Float64: + switch ptyp2 { + case Int, Int64, Uint, Uint64: + wantErr = false + default: + wantErr = true + } + + case Duration: + switch ptyp2 { + case Int, Int64: + wantErr = false + default: + wantErr = true + } + } + if ptyp == ptyp2 { - if err != nil { + wantErr = false + } + + err := checkParam(Param{Type: ptyp, Default: dflts[ptyp2]}) + if err != nil { + if !wantErr { t.Error(err) + } else { + var e ParamDefaultErr + if !errors.As(err, &e) { + t.Errorf("got %v, want ParamDefaultErr", err) + } } - } else { - var e ParamDefaultErr - if !errors.As(err, &e) { - t.Errorf("got %v, want ParamDefaultErr", err) - } + } else if wantErr { + t.Errorf("got nil, want ParamDefaultErr") } }) } diff --git a/context.go b/context.go index 49d2b23..4885660 100644 --- a/context.go +++ b/context.go @@ -61,4 +61,4 @@ func WithSuppressCheck(ctx context.Context, suppress bool) context.Context { func SuppressCheck(ctx context.Context) bool { val, _ := ctx.Value(suppressCheckKey).(bool) return val -} \ No newline at end of file +} diff --git a/errors.go b/errors.go index b8be3a1..f11adf6 100644 --- a/errors.go +++ b/errors.go @@ -240,5 +240,5 @@ type ParamDefaultErr struct { } func (e ParamDefaultErr) Error() string { - return fmt.Sprintf("default value %v is not of type %v", e.Param.Default, e.Param.Type) + return fmt.Sprintf("default value %v (%T) is not of type %v", e.Param.Default, e.Param.Default, e.Param.Type) } diff --git a/parse.go b/parse.go index a3eca29..add811c 100644 --- a/parse.go +++ b/parse.go @@ -94,7 +94,10 @@ func parsePositionalArg(p Param, args *[]string, argvals *[]reflect.Value) error } func parseBoolPos(args *[]string, argvals *[]reflect.Value, p Param) error { - val, _ := p.Default.(bool) + val, ok := p.Default.(bool) + if !ok { + return ParseErr{Err: fmt.Errorf("param %s has type Bool but default value %v has type %T", p.Name, p.Default, p.Default)} + } if len(*args) > 0 { var err error val, err = strconv.ParseBool((*args)[0]) @@ -108,10 +111,13 @@ func parseBoolPos(args *[]string, argvals *[]reflect.Value, p Param) error { } func parseIntPos(args *[]string, argvals *[]reflect.Value, p Param) error { - val := int64(asInt(p.Default)) // has to be int64 to receive the result of ParseInt below + valInt, err := asInt(p.Default) + if err != nil { + return ParseErr{Err: err} + } + val := int64(valInt) // has to be int64 to receive the result of ParseInt below if len(*args) > 0 { - var err error val, err = strconv.ParseInt((*args)[0], 10, 32) if err != nil { return ParseErr{Err: err} @@ -123,10 +129,12 @@ func parseIntPos(args *[]string, argvals *[]reflect.Value, p Param) error { } func parseInt64Pos(args *[]string, argvals *[]reflect.Value, p Param) error { - val := asInt64(p.Default) + val, err := asInt64(p.Default) + if err != nil { + return ParseErr{Err: err} + } if len(*args) > 0 { - var err error val, err = strconv.ParseInt((*args)[0], 10, 64) if err != nil { return ParseErr{Err: err} @@ -138,10 +146,13 @@ func parseInt64Pos(args *[]string, argvals *[]reflect.Value, p Param) error { } func parseUintPos(args *[]string, argvals *[]reflect.Value, p Param) error { - val := uint64(asUint(p.Default)) // has to be uint64 to receive the result of ParseUint below + valUint, err := asUint(p.Default) + if err != nil { + return ParseErr{Err: err} + } + val := uint64(valUint) // has to be uint64 to receive the result of ParseUint below if len(*args) > 0 { - var err error val, err = strconv.ParseUint((*args)[0], 10, 32) if err != nil { return ParseErr{Err: err} @@ -153,10 +164,12 @@ func parseUintPos(args *[]string, argvals *[]reflect.Value, p Param) error { } func parseUint64Pos(args *[]string, argvals *[]reflect.Value, p Param) error { - val := asUint64(p.Default) + val, err := asUint64(p.Default) + if err != nil { + return ParseErr{Err: err} + } if len(*args) > 0 { - var err error val, err = strconv.ParseUint((*args)[0], 10, 64) if err != nil { return ParseErr{Err: err} @@ -168,7 +181,10 @@ func parseUint64Pos(args *[]string, argvals *[]reflect.Value, p Param) error { } func parseStringPos(args *[]string, argvals *[]reflect.Value, p Param) error { - val, _ := p.Default.(string) + val, ok := p.Default.(string) + if !ok { + return ParseErr{Err: fmt.Errorf("param %s has type String but default value %v has type %T", p.Name, p.Default, p.Default)} + } if len(*args) > 0 { val = (*args)[0] *args = (*args)[1:] @@ -178,10 +194,12 @@ func parseStringPos(args *[]string, argvals *[]reflect.Value, p Param) error { } func parseFloat64Pos(args *[]string, argvals *[]reflect.Value, p Param) error { - val := asFloat64(p.Default) + val, err := asFloat64(p.Default) + if err != nil { + return ParseErr{Err: err} + } if len(*args) > 0 { - var err error val, err = strconv.ParseFloat((*args)[0], 64) if err != nil { return ParseErr{Err: err} @@ -193,10 +211,12 @@ func parseFloat64Pos(args *[]string, argvals *[]reflect.Value, p Param) error { } func parseDurationPos(args *[]string, argvals *[]reflect.Value, p Param) error { - val := asDuration(p.Default) + val, err := asDuration(p.Default) + if err != nil { + return ParseErr{Err: err} + } if len(*args) > 0 { - var err error val, err = time.ParseDuration((*args)[0]) if err != nil { return ParseErr{Err: err} @@ -222,137 +242,175 @@ func parseValuePos(args *[]string, argvals *[]reflect.Value, p Param) error { return nil } -func asInt(val interface{}) int { +func asInt(val interface{}) (int, error) { switch v := val.(type) { case int: - return v + return v, nil case int8: - return int(v) + return int(v), nil case int16: - return int(v) + return int(v), nil case int32: - return int(v) + return int(v), nil case uint8: - return int(v) + return int(v), nil case uint16: - return int(v) + return int(v), nil } - return 0 + + return 0, fmt.Errorf("cannot convert %v (type %T) to int", val, val) } -func asInt64(val interface{}) int64 { +func asInt64(val interface{}) (int64, error) { switch v := val.(type) { case int: - return int64(v) + return int64(v), nil case int64: - return v + return v, nil case int8: - return int64(v) + return int64(v), nil case int16: - return int64(v) + return int64(v), nil case int32: - return int64(v) + return int64(v), nil case uint8: - return int64(v) + return int64(v), nil case uint16: - return int64(v) + return int64(v), nil case uint32: - return int64(v) + return int64(v), nil } - return 0 + + return 0, fmt.Errorf("cannot convert %v (type %T) to int64", val, val) } -func asUint(val interface{}) uint { +func asUint(val interface{}) (uint, error) { switch v := val.(type) { + case int: + if v >= 0 { + return uint(v), nil + } + case uint: - return v + return v, nil case int8: - return uint(v) + if v >= 0 { + return uint(v), nil + } + case int16: - return uint(v) + if v >= 0 { + return uint(v), nil + } + + case int32: + if v >= 0 { + return uint(v), nil + } case uint8: - return uint(v) + return uint(v), nil case uint16: - return uint(v) + return uint(v), nil case uint32: - return uint(v) + return uint(v), nil } - return 0 + + return 0, fmt.Errorf("cannot convert %v (type %T) to uint", val, val) } -func asUint64(val interface{}) uint64 { +func asUint64(val interface{}) (uint64, error) { switch v := val.(type) { + case int: + if v >= 0 { + return uint64(v), nil + } + case uint: - return uint64(v) + return uint64(v), nil case uint64: - return v + return v, nil case int8: - return uint64(v) + if v >= 0 { + return uint64(v), nil + } + case int16: - return uint64(v) + if v >= 0 { + return uint64(v), nil + } + case int32: - return uint64(v) + if v >= 0 { + return uint64(v), nil + } + + case int64: + if v >= 0 { + return uint64(v), nil + } case uint8: - return uint64(v) + return uint64(v), nil case uint16: - return uint64(v) + return uint64(v), nil case uint32: - return uint64(v) + return uint64(v), nil } - return 0 + + return 0, fmt.Errorf("cannot convert %v (type %T) to uint64", val, val) } -func asFloat64(val interface{}) float64 { +func asFloat64(val interface{}) (float64, error) { switch v := val.(type) { case int: - return float64(v) + return float64(v), nil case uint: - return float64(v) + return float64(v), nil case int8: - return float64(v) + return float64(v), nil case int16: - return float64(v) + return float64(v), nil case int32: - return float64(v) + return float64(v), nil case int64: - return float64(v) + return float64(v), nil case uint8: - return float64(v) + return float64(v), nil case uint16: - return float64(v) + return float64(v), nil case uint32: - return float64(v) + return float64(v), nil case uint64: - return float64(v) + return float64(v), nil case float32: - return float64(v) + return float64(v), nil case float64: - return v + return v, nil } - return 0 + + return 0, fmt.Errorf("cannot convert %v (type %T) to float64", val, val) } -func asDuration(val interface{}) time.Duration { - switch v := val.(type) { - case int: - return time.Duration(v) - case int64: - return time.Duration(v) - case time.Duration: - return v +func asDuration(val interface{}) (time.Duration, error) { + if v, ok := val.(time.Duration); ok { + return v, nil } - return 0 + + v, err := asInt64(val) + if err != nil { + return 0, fmt.Errorf("cannot convert %v (type %T) to time.Duration", val, val) + } + return time.Duration(v), nil } // ToFlagSet takes a slice of [Param] and produces: @@ -382,31 +440,58 @@ func ToFlagSet(params []Param) (fs *flag.FlagSet, ptrs []reflect.Value, position v = fs.Bool(name, dflt, p.Doc) case Int: - v = fs.Int(name, asInt(p.Default), p.Doc) + dflt, err := asInt(p.Default) + if err != nil { + return nil, nil, nil, err + } + v = fs.Int(name, dflt, p.Doc) case Int64: - v = fs.Int64(name, asInt64(p.Default), p.Doc) + dflt, err := asInt64(p.Default) + if err != nil { + return nil, nil, nil, err + } + v = fs.Int64(name, dflt, p.Doc) case Uint: - v = fs.Uint(name, asUint(p.Default), p.Doc) + dflt, err := asUint(p.Default) + if err != nil { + return nil, nil, nil, err + } + v = fs.Uint(name, dflt, p.Doc) case Uint64: - v = fs.Uint64(name, asUint64(p.Default), p.Doc) + dflt, err := asUint64(p.Default) + if err != nil { + return nil, nil, nil, err + } + v = fs.Uint64(name, dflt, p.Doc) case String: - dflt, _ := p.Default.(string) + dflt, ok := p.Default.(string) + if !ok { + return nil, nil, nil, fmt.Errorf("param %s has type String but default value %v has type %T", p.Name, p.Default, p.Default) + } v = fs.String(name, dflt, p.Doc) case Float64: - v = fs.Float64(name, asFloat64(p.Default), p.Doc) + dflt, err := asFloat64(p.Default) + if err != nil { + return nil, nil, nil, err + } + v = fs.Float64(name, dflt, p.Doc) case Duration: - v = fs.Duration(name, asDuration(p.Default), p.Doc) + dflt, err := asDuration(p.Default) + if err != nil { + return nil, nil, nil, err + } + v = fs.Duration(name, dflt, p.Doc) case Value: val, ok := p.Default.(flag.Value) if !ok { - err = fmt.Errorf("param %s has type Value but default value %v is not a flag.Value", p.Name, p.Default) + err = fmt.Errorf("param %s has type Value but default value %v has type %T", p.Name, p.Default, p.Default) return } fs.Var(val, name, p.Doc) diff --git a/parse_test.go b/parse_test.go deleted file mode 100644 index eb78351..0000000 --- a/parse_test.go +++ /dev/null @@ -1,45 +0,0 @@ -package subcmd - -import ( - "flag" - "testing" -) - -func TestToFlagSet(t *testing.T) { - cases := []struct { - name string - params []Param - wantfs map[string]interface{} - }{{ - name: "float64_with_int_default", - params: []Param{{ - Name: "-float64", - Type: Float64, - Default: 1, - }}, - wantfs: map[string]interface{}{ - "float64": 1.0, - }, - }} - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - fs, _, _, err := ToFlagSet(c.params) - if err != nil { - t.Fatal(err) - } - - for k, v := range c.wantfs { - val := fs.Lookup(k) - if val == nil { - t.Fatalf("flag %s not found", k) - } - getter := val.Value.(flag.Getter) - got := getter.Get() - if got != v { - t.Errorf("got %v (%T), want %v (%T)", got, got, v, v) - } - } - }) - } -}