Skip to content

Commit

Permalink
move SuperFlag to Ristretto (#246)
Browse files Browse the repository at this point in the history
  • Loading branch information
karlmcguire authored Feb 5, 2021
1 parent 426327c commit 1fb8d28
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 0 deletions.
118 changes: 118 additions & 0 deletions z/flags.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package z

import (
"fmt"
"log"
"strconv"
"strings"

"github.com/pkg/errors"
)

func parseFlag(flag string) map[string]string {
kvm := make(map[string]string)
for _, kv := range strings.Split(flag, ";") {
if strings.TrimSpace(kv) == "" {
continue
}
splits := strings.SplitN(kv, "=", 2)
k := strings.TrimSpace(splits[0])
k = strings.ToLower(k)
k = strings.ReplaceAll(k, "_", "-")
kvm[k] = strings.TrimSpace(splits[1])
}
return kvm
}

type SuperFlag struct {
m map[string]string
}

func NewSuperFlag(flag string) *SuperFlag {
return &SuperFlag{
m: parseFlag(flag),
}
}
func (sf *SuperFlag) String() string {
if sf == nil {
return ""
}
var kvs []string
for k, v := range sf.m {
kvs = append(kvs, fmt.Sprintf("%s=%s", k, v))
}
return strings.Join(kvs, "; ")
}
func (sf *SuperFlag) MergeAndCheckDefault(flag string) *SuperFlag {
if sf == nil {
sf = &SuperFlag{
m: parseFlag(flag),
}
return sf
}
numKeys := len(sf.m)
src := parseFlag(flag)
for k := range src {
if _, ok := sf.m[k]; ok {
numKeys--
}
}
if numKeys != 0 {
msg := fmt.Sprintf("Found invalid options in %s. Valid options: %v", sf, flag)
panic(msg)
}
for k, v := range src {
if _, ok := sf.m[k]; !ok {
sf.m[k] = v
}
}
return sf
}
func (sf *SuperFlag) GetBool(opt string) bool {
val := sf.GetString(opt)
if val == "" {
return false
}
b, err := strconv.ParseBool(val)
if err != nil {
err = errors.Wrapf(err,
"Unable to parse %s as bool for key: %s. Options: %s\n",
val, opt, sf)
log.Fatalf("%+v", err)
}
return b
}
func (sf *SuperFlag) GetUint64(opt string) uint64 {
val := sf.GetString(opt)
if val == "" {
return 0
}
u, err := strconv.ParseUint(val, 0, 64)
if err != nil {
err = errors.Wrapf(err,
"Unable to parse %s as uint64 for key: %s. Options: %s\n",
val, opt, sf)
log.Fatalf("%+v", err)
}
return u
}
func (sf *SuperFlag) GetUint32(opt string) uint32 {
val := sf.GetString(opt)
if val == "" {
return 0
}
u, err := strconv.ParseUint(val, 0, 32)
if err != nil {
err = errors.Wrapf(err,
"Unable to parse %s as uint32 for key: %s. Options: %s\n",
val, opt, sf)
log.Fatalf("%+v", err)
}
return uint32(u)
}
func (sf *SuperFlag) GetString(opt string) string {
if sf == nil {
return ""
}
return sf.m[opt]
}
28 changes: 28 additions & 0 deletions z/flags_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package z

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestFlag(t *testing.T) {
opt := `bool_key=true; int-key=5; float-key=0.05; string_key=value; ;`
sf := NewSuperFlag(opt)
t.Logf("Got SuperFlag: %s\n", sf)

def := `bool_key=false; int-key=0; float-key=1.0; string-key=; other-key=5`

// bool-key and int-key should not be overwritten. Only other-key should be set.
sf.MergeAndCheckDefault(def)

c := func() {
// Has a typo.
NewSuperFlag("boolo-key=true").MergeAndCheckDefault(def)
}
require.Panics(t, c)
require.Equal(t, true, sf.GetBool("bool-key"))
require.Equal(t, uint64(5), sf.GetUint64("int-key"))
require.Equal(t, "value", sf.GetString("string-key"))
require.Equal(t, uint64(5), sf.GetUint64("other-key"))
}

0 comments on commit 1fb8d28

Please sign in to comment.