Skip to content

Commit

Permalink
Add the Copier type (#8)
Browse files Browse the repository at this point in the history
* Checkpoint.

* Make Copy() optional; add Copier, remove ValueType.
  • Loading branch information
bobg authored Jul 3, 2024
1 parent 301b833 commit 2396dda
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 5 deletions.
19 changes: 18 additions & 1 deletion parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ func parseValuePos(args *[]string, argvals *[]reflect.Value, p Param) error {
if !ok {
return ParseErr{Err: fmt.Errorf("param %s is not a flag.Value", p.Name)}
}
if copier, ok := val.(Copier); ok {
val = copier.Copy()
}
if len(*args) > 0 {
if err := val.Set((*args)[0]); err != nil {
return ParseErr{Err: err}
Expand Down Expand Up @@ -406,9 +409,12 @@ func ToFlagSet(params []Param) (fs *flag.FlagSet, ptrs []reflect.Value, position
case Value:
val, ok := p.Default.(flag.Value)
if !ok {
err = fmt.Errorf("param %s has type Value but default value %v is not a flag.Value", p.Name, p.Default)
err = fmt.Errorf("param %s has type Value but default value %v is not a ValueType", p.Name, p.Default)
return
}
if copier, ok := val.(Copier); ok {
val = copier.Copy()
}
fs.Var(val, name, p.Doc)
v = val

Expand All @@ -422,3 +428,14 @@ func ToFlagSet(params []Param) (fs *flag.FlagSet, ptrs []reflect.Value, position

return fs, ptrs, positional, nil
}

// Copier is a [flag.Value] that can copy itself.
// Your type should implement Copier
// if you want to be able to use the same default value for multiple arguments
// without changes to one
// (e.g. when parsing command-line options into it)
// affecting the others.
type Copier interface {
flag.Value
Copy() flag.Value
}
5 changes: 3 additions & 2 deletions subcmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ type Subcmd struct {
// where OPTS stands for a sequence of zero or more additional parameters
// corresponding to the types in Params.
//
// A Param with type Value supplies a flag.Value to the function.
// A Param with type Value supplies a [flag.Value] to the function.
// It's up to the function to type-assert the flag.Value to a more-specific type to read the value it contains.
F interface{}

Expand All @@ -104,7 +104,8 @@ type Param struct {
// Default is a default value for the parameter.
// Its type must be suitable for Type.
// If Type is Value,
// then Default must be a flag.Value.
// then Default must be a [flag.Value].
// It may optionally also be a [Copier], qv.
Default interface{}

// Doc is a docstring for the parameter.
Expand Down
49 changes: 47 additions & 2 deletions value_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package subcmd
import (
"context"
"flag"
"fmt"
"reflect"
"strings"
"testing"
Expand Down Expand Up @@ -32,12 +33,32 @@ func TestValue(t *testing.T) {
}
}

func TestDifferentValues(t *testing.T) {
dflt := &valuetestvalue{result: []string{"dflt"}}
cmd := new(valuetestcmd)
cmd.subcmds = Commands(
"b", cmd.differentValues, "", Params(
"x", Value, dflt, "",
"y", Value, dflt, "",
),
)
if err := Run(context.Background(), cmd, []string{"b", "res1", "res2"}); err != nil {
t.Fatal(err)
}
}

type valuetestcmd struct {
x, y, pos1, pos2 valuetestvalue
rest []string

subcmds Map
}

func (c *valuetestcmd) Subcmds() Map {
if c.subcmds != nil {
return c.subcmds
}

return Commands(
"a", c.a, "", Params(
"-x", Value, &c.x, "",
Expand Down Expand Up @@ -65,12 +86,28 @@ func (c *valuetestcmd) a(_ context.Context, x, y, pos1, pos2 flag.Value, rest []
return nil
}

func (c *valuetestcmd) differentValues(_ context.Context, xv, yv flag.Value, _ []string) error {
x, ok := xv.(*valuetestvalue)
if !ok {
return fmt.Errorf("unexpected type %T for x", xv)
}
y, ok := yv.(*valuetestvalue)
if !ok {
return fmt.Errorf("unexpected type %T for y", yv)
}
if len(x.result) != 1 || x.result[0] != "res1" {
return fmt.Errorf("unexpected x value %v", x.result)
}
if len(y.result) != 1 || y.result[0] != "res2" {
return fmt.Errorf("unexpected y value %v", y.result)
}
return nil
}

type valuetestvalue struct {
result []string
}

var _ flag.Value = &valuetestvalue{}

func (v *valuetestvalue) String() string {
if v == nil {
return ""
Expand All @@ -82,3 +119,11 @@ func (v *valuetestvalue) Set(s string) error {
v.result = strings.Split(s, ",")
return nil
}

func (v *valuetestvalue) Copy() flag.Value {
result := &valuetestvalue{
result: make([]string, len(v.result)),
}
copy(result.result, v.result)
return result
}

0 comments on commit 2396dda

Please sign in to comment.