Skip to content

Commit

Permalink
Misc. fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
bobg committed Nov 28, 2023
1 parent 1857801 commit 5004ae8
Show file tree
Hide file tree
Showing 6 changed files with 261 additions and 137 deletions.
48 changes: 46 additions & 2 deletions check.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package subcmd

import (
"flag"
"reflect"

"github.com/pkg/errors"
Expand Down Expand Up @@ -60,8 +61,51 @@ func checkFuncType(ft reflect.Type, params []Param) error {
}

func checkParam(param Param) error {
if !reflect.TypeOf(param.Default).AssignableTo(param.Type.reflectType()) {
return ParamDefaultErr{Param: param}
switch param.Type {
case Bool:
if _, ok := param.Default.(bool); !ok {
return ParamDefaultErr{Param: param}
}

case Int:
if _, err := asInt(param.Default); err != nil {
return ParamDefaultErr{Param: param}
}

case Int64:
if _, err := asInt64(param.Default); err != nil {
return ParamDefaultErr{Param: param}
}

case Uint:
if _, err := asUint(param.Default); err != nil {
return ParamDefaultErr{Param: param}
}

case Uint64:
if _, err := asUint64(param.Default); err != nil {
return ParamDefaultErr{Param: param}
}

case String:
if _, ok := param.Default.(string); !ok {
return ParamDefaultErr{Param: param}
}

case Float64:
if _, err := asFloat64(param.Default); err != nil {
return ParamDefaultErr{Param: param}
}

case Duration:
if _, err := asDuration(param.Default); err != nil {
return ParamDefaultErr{Param: param}
}

case Value:
if _, ok := param.Default.(flag.Value); !ok {
return ParamDefaultErr{Param: param}
}
}

return nil
Expand Down
54 changes: 47 additions & 7 deletions check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,56 @@ func TestCheckParam(t *testing.T) {
t.Run(ptyp.String(), func(t *testing.T) {
for ptyp2 := Bool; ptyp2 <= Duration; ptyp2++ {
t.Run(ptyp2.String(), func(t *testing.T) {
err := checkParam(Param{Type: ptyp, Default: dflts[ptyp2]})
wantErr := true

switch ptyp {
case Int64:
wantErr = ptyp2 != Int

case Uint:
wantErr = ptyp2 != Int

case Uint64:
switch ptyp2 {
case Int, Int64, Uint:
wantErr = false
default:
wantErr = true
}

case Float64:
switch ptyp2 {
case Int, Int64, Uint, Uint64:
wantErr = false
default:
wantErr = true
}

case Duration:
switch ptyp2 {
case Int, Int64:
wantErr = false
default:
wantErr = true
}
}

if ptyp == ptyp2 {
if err != nil {
wantErr = false
}

err := checkParam(Param{Type: ptyp, Default: dflts[ptyp2]})
if err != nil {
if !wantErr {
t.Error(err)
} else {
var e ParamDefaultErr
if !errors.As(err, &e) {
t.Errorf("got %v, want ParamDefaultErr", err)
}
}
} else {
var e ParamDefaultErr
if !errors.As(err, &e) {
t.Errorf("got %v, want ParamDefaultErr", err)
}
} else if wantErr {
t.Errorf("got nil, want ParamDefaultErr")
}
})
}
Expand Down
2 changes: 1 addition & 1 deletion context.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,4 @@ func WithSuppressCheck(ctx context.Context, suppress bool) context.Context {
func SuppressCheck(ctx context.Context) bool {
val, _ := ctx.Value(suppressCheckKey).(bool)
return val
}
}
2 changes: 1 addition & 1 deletion errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,5 +240,5 @@ type ParamDefaultErr struct {
}

func (e ParamDefaultErr) Error() string {
return fmt.Sprintf("default value %v is not of type %v", e.Param.Default, e.Param.Type)
return fmt.Sprintf("default value %v (%T) is not of type %v", e.Param.Default, e.Param.Default, e.Param.Type)
}
Loading

0 comments on commit 5004ae8

Please sign in to comment.