Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try all valid types when type-asserting Param.Default #5

Merged
merged 1 commit into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 145 additions & 60 deletions parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,7 @@ func parseBoolPos(args *[]string, argvals *[]reflect.Value, p Param) error {
}

func parseIntPos(args *[]string, argvals *[]reflect.Value, p Param) error {
var val int64
if dflt, ok := p.Default.(int); ok {
val = int64(dflt)
} else if dflt, ok := p.Default.(int64); ok {
val = dflt
}
val := int64(asInt(p.Default)) // has to be int64 to receive the result of ParseInt below

if len(*args) > 0 {
var err error
Expand All @@ -128,12 +123,7 @@ func parseIntPos(args *[]string, argvals *[]reflect.Value, p Param) error {
}

func parseInt64Pos(args *[]string, argvals *[]reflect.Value, p Param) error {
var val int64
if dflt, ok := p.Default.(int); ok {
val = int64(dflt)
} else if dflt, ok := p.Default.(int64); ok {
val = dflt
}
val := asInt64(p.Default)

if len(*args) > 0 {
var err error
Expand All @@ -148,12 +138,7 @@ func parseInt64Pos(args *[]string, argvals *[]reflect.Value, p Param) error {
}

func parseUintPos(args *[]string, argvals *[]reflect.Value, p Param) error {
var val uint64
if dflt, ok := p.Default.(uint); ok {
val = uint64(dflt)
} else if dflt, ok := p.Default.(uint64); ok {
val = dflt
}
val := uint64(asUint(p.Default)) // has to be uint64 to receive the result of ParseUint below

if len(*args) > 0 {
var err error
Expand All @@ -168,12 +153,7 @@ func parseUintPos(args *[]string, argvals *[]reflect.Value, p Param) error {
}

func parseUint64Pos(args *[]string, argvals *[]reflect.Value, p Param) error {
var val uint64
if dflt, ok := p.Default.(uint); ok {
val = uint64(dflt)
} else if dflt, ok := p.Default.(uint64); ok {
val = dflt
}
val := asUint64(p.Default)

if len(*args) > 0 {
var err error
Expand All @@ -198,21 +178,7 @@ func parseStringPos(args *[]string, argvals *[]reflect.Value, p Param) error {
}

func parseFloat64Pos(args *[]string, argvals *[]reflect.Value, p Param) error {
var val float64
switch dflt := p.Default.(type) {
case int:
val = float64(dflt)
case int64:
val = float64(dflt)
case uint:
val = float64(dflt)
case uint64:
val = float64(dflt)
case float32:
val = float64(dflt)
case float64:
val = dflt
}
val := asFloat64(p.Default)

if len(*args) > 0 {
var err error
Expand All @@ -227,15 +193,7 @@ func parseFloat64Pos(args *[]string, argvals *[]reflect.Value, p Param) error {
}

func parseDurationPos(args *[]string, argvals *[]reflect.Value, p Param) error {
var val time.Duration
switch dflt := p.Default.(type) {
case int:
val = time.Duration(dflt)
case int64:
val = time.Duration(dflt)
case time.Duration:
val = dflt
}
val := asDuration(p.Default)

if len(*args) > 0 {
var err error
Expand Down Expand Up @@ -264,6 +222,139 @@ func parseValuePos(args *[]string, argvals *[]reflect.Value, p Param) error {
return nil
}

func asInt(val interface{}) int {
switch v := val.(type) {
case int:
return v

case int8:
return int(v)
case int16:
return int(v)
case int32:
return int(v)

case uint8:
return int(v)
case uint16:
return int(v)
}
return 0
}

func asInt64(val interface{}) int64 {
switch v := val.(type) {
case int:
return int64(v)
case int64:
return v

case int8:
return int64(v)
case int16:
return int64(v)
case int32:
return int64(v)

case uint8:
return int64(v)
case uint16:
return int64(v)
case uint32:
return int64(v)
}
return 0
}

func asUint(val interface{}) uint {
switch v := val.(type) {
case uint:
return v

case int8:
return uint(v)
case int16:
return uint(v)

case uint8:
return uint(v)
case uint16:
return uint(v)
case uint32:
return uint(v)
}
return 0
}

func asUint64(val interface{}) uint64 {
switch v := val.(type) {
case uint:
return uint64(v)
case uint64:
return v

case int8:
return uint64(v)
case int16:
return uint64(v)
case int32:
return uint64(v)

case uint8:
return uint64(v)
case uint16:
return uint64(v)
case uint32:
return uint64(v)
}
return 0
}

func asFloat64(val interface{}) float64 {
switch v := val.(type) {
case int:
return float64(v)
case uint:
return float64(v)

case int8:
return float64(v)
case int16:
return float64(v)
case int32:
return float64(v)
case int64:
return float64(v)

case uint8:
return float64(v)
case uint16:
return float64(v)
case uint32:
return float64(v)
case uint64:
return float64(v)

case float32:
return float64(v)
case float64:
return v
}
return 0
}

func asDuration(val interface{}) time.Duration {
switch v := val.(type) {
case int:
return time.Duration(v)
case int64:
return time.Duration(v)
case time.Duration:
return v
}
return 0
}

// ToFlagSet takes a slice of [Param] and produces:
//
// - a [flag.FlagSet],
Expand Down Expand Up @@ -291,32 +382,26 @@ func ToFlagSet(params []Param) (fs *flag.FlagSet, ptrs []reflect.Value, position
v = fs.Bool(name, dflt, p.Doc)

case Int:
dflt, _ := p.Default.(int)
v = fs.Int(name, dflt, p.Doc)
v = fs.Int(name, asInt(p.Default), p.Doc)

case Int64:
dflt, _ := p.Default.(int64)
v = fs.Int64(name, dflt, p.Doc)
v = fs.Int64(name, asInt64(p.Default), p.Doc)

case Uint:
dflt, _ := p.Default.(uint)
v = fs.Uint(name, dflt, p.Doc)
v = fs.Uint(name, asUint(p.Default), p.Doc)

case Uint64:
dflt, _ := p.Default.(uint64)
v = fs.Uint64(name, dflt, p.Doc)
v = fs.Uint64(name, asUint64(p.Default), p.Doc)

case String:
dflt, _ := p.Default.(string)
v = fs.String(name, dflt, p.Doc)

case Float64:
dflt, _ := p.Default.(float64)
v = fs.Float64(name, dflt, p.Doc)
v = fs.Float64(name, asFloat64(p.Default), p.Doc)

case Duration:
dflt, _ := p.Default.(time.Duration)
v = fs.Duration(name, dflt, p.Doc)
v = fs.Duration(name, asDuration(p.Default), p.Doc)

case Value:
val, ok := p.Default.(flag.Value)
Expand Down
45 changes: 45 additions & 0 deletions parse_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package subcmd

import (
"flag"
"testing"
)

func TestToFlagSet(t *testing.T) {
cases := []struct {
name string
params []Param
wantfs map[string]interface{}
}{{
name: "float64_with_int_default",
params: []Param{{
Name: "-float64",
Type: Float64,
Default: 1,
}},
wantfs: map[string]interface{}{
"float64": 1.0,
},
}}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
fs, _, _, err := ToFlagSet(c.params)
if err != nil {
t.Fatal(err)
}

for k, v := range c.wantfs {
val := fs.Lookup(k)
if val == nil {
t.Fatalf("flag %s not found", k)
}
getter := val.Value.(flag.Getter)
got := getter.Get()
if got != v {
t.Errorf("got %v (%T), want %v (%T)", got, got, v, v)
}
}
})
}
}
Loading