diff --git a/contrib/integration/increment/.gitignore b/dgraph/cmd/counter/.gitignore similarity index 100% rename from contrib/integration/increment/.gitignore rename to dgraph/cmd/counter/.gitignore diff --git a/contrib/integration/increment/main.go b/dgraph/cmd/counter/increment.go similarity index 66% rename from contrib/integration/increment/main.go rename to dgraph/cmd/counter/increment.go index 1c9d994fd34..7cd67eaf70e 100644 --- a/contrib/integration/increment/main.go +++ b/dgraph/cmd/counter/increment.go @@ -17,28 +17,41 @@ // This binary would retrieve a value for UID=0x01, and increment it by 1. If // successful, it would print out the incremented value. It assumes that it has // access to UID=0x01, and that `val` predicate is of type int. -package main +package counter import ( "context" "encoding/json" - "flag" "fmt" "log" "time" "github.com/dgraph-io/dgo" "github.com/dgraph-io/dgo/protos/api" + "github.com/dgraph-io/dgraph/x" + "github.com/spf13/cobra" + "github.com/spf13/viper" "google.golang.org/grpc" ) -var ( - addr = flag.String("addr", "localhost:9080", "Address of Dgraph alpha.") - num = flag.Int("num", 1, "How many times to run.") - ro = flag.Bool("ro", false, "Only read the counter value, don't update it.") - wait = flag.String("wait", "0", "How long to wait.") - pred = flag.String("pred", "counter.val", "Predicate to use for storing the counter.") -) +var Increment x.SubCommand + +func init() { + Increment.Cmd = &cobra.Command{ + Use: "increment", + Short: "Increment a counter transactionally", + Run: func(cmd *cobra.Command, args []string) { + run(Increment.Conf) + }, + } + + flag := Increment.Cmd.Flags() + flag.String("addr", "localhost:9080", "Address of Dgraph alpha.") + flag.Int("num", 1, "How many times to run.") + flag.Bool("ro", false, "Only read the counter value, don't update it.") + flag.Duration("wait", 0*time.Second, "How long to wait.") + flag.String("pred", "counter.val", "Predicate to use for storing the counter.") +} type Counter struct { Uid string `json:"uid"` @@ -47,12 +60,12 @@ type Counter struct { startTs uint64 // Only used for internal testing. } -func queryCounter(txn *dgo.Txn) (Counter, error) { +func queryCounter(txn *dgo.Txn, pred string) (Counter, error) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() var counter Counter - query := fmt.Sprintf("{ q(func: has(%s)) { uid, val: %s }}", *pred, *pred) + query := fmt.Sprintf("{ q(func: has(%s)) { uid, val: %s }}", pred, pred) resp, err := txn.Query(ctx, query) if err != nil { return counter, fmt.Errorf("Query error: %v", err) @@ -72,15 +85,15 @@ func queryCounter(txn *dgo.Txn) (Counter, error) { return counter, nil } -func process(dg *dgo.Dgraph, readOnly bool) (Counter, error) { +func process(dg *dgo.Dgraph, readOnly bool, pred string) (Counter, error) { if readOnly { txn := dg.NewReadOnlyTxn() defer txn.Discard(nil) - return queryCounter(txn) + return queryCounter(txn, pred) } txn := dg.NewTxn() - counter, err := queryCounter(txn) + counter, err := queryCounter(txn, pred) if err != nil { return Counter{}, err } @@ -90,7 +103,7 @@ func process(dg *dgo.Dgraph, readOnly bool) (Counter, error) { if len(counter.Uid) == 0 { counter.Uid = "_:new" } - mu.SetNquads = []byte(fmt.Sprintf(`<%s> <%s> "%d"^^ .`, counter.Uid, *pred, counter.Val)) + mu.SetNquads = []byte(fmt.Sprintf(`<%s> <%s> "%d"^^ .`, counter.Uid, pred, counter.Val)) // Don't put any timeout for mutation. _, err = txn.Mutate(context.Background(), &mu) @@ -100,23 +113,21 @@ func process(dg *dgo.Dgraph, readOnly bool) (Counter, error) { return counter, txn.Commit(context.Background()) } -func main() { - flag.Parse() - - conn, err := grpc.Dial(*addr, grpc.WithInsecure()) +func run(conf *viper.Viper) { + addr := conf.GetString("addr") + waitDur := conf.GetDuration("wait") + num := conf.GetInt("num") + ro := conf.GetBool("ro") + pred := conf.GetString("pred") + conn, err := grpc.Dial(addr, grpc.WithInsecure()) if err != nil { log.Fatal(err) } dc := api.NewDgraphClient(conn) dg := dgo.NewDgraphClient(dc) - waitDur, err := time.ParseDuration(*wait) - if err != nil { - log.Fatal(err) - } - - for *num > 0 { - cnt, err := process(dg, *ro) + for num > 0 { + cnt, err := process(dg, ro, pred) now := time.Now().UTC().Format("0102 03:04:05.999") if err != nil { fmt.Printf("%-17s While trying to process counter: %v. Retrying...\n", now, err) @@ -124,7 +135,7 @@ func main() { continue } fmt.Printf("%-17s Counter VAL: %d [ Ts: %d ]\n", now, cnt.Val, cnt.startTs) - *num-- + num-- time.Sleep(waitDur) } } diff --git a/contrib/integration/increment/main_test.go b/dgraph/cmd/counter/increment_test.go similarity index 93% rename from contrib/integration/increment/main_test.go rename to dgraph/cmd/counter/increment_test.go index 29bdbe7a502..c6b08e45ab9 100644 --- a/contrib/integration/increment/main_test.go +++ b/dgraph/cmd/counter/increment_test.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package main +package counter import ( "context" @@ -32,6 +32,7 @@ import ( ) const N = 10 +const pred = "counter" func increment(t *testing.T, dg *dgo.Dgraph) int { var max int @@ -51,7 +52,7 @@ func increment(t *testing.T, dg *dgo.Dgraph) int { go func() { defer wg.Done() for i := 0; i < N; i++ { - cnt, err := process(dg, false) + cnt, err := process(dg, false, pred) if err != nil { if strings.Index(err.Error(), "Transaction has been aborted") >= 0 { // pass @@ -69,7 +70,7 @@ func increment(t *testing.T, dg *dgo.Dgraph) int { } func read(t *testing.T, dg *dgo.Dgraph, expected int) { - cnt, err := process(dg, true) + cnt, err := process(dg, true, pred) require.NoError(t, err) ts := cnt.startTs t.Logf("Readonly stage counter: %+v\n", cnt) @@ -80,7 +81,7 @@ func read(t *testing.T, dg *dgo.Dgraph, expected int) { go func() { defer wg.Done() for i := 0; i < N; i++ { - cnt, err := process(dg, true) + cnt, err := process(dg, true, pred) if err != nil { t.Logf("Error while reading: %v\n", err) } else { @@ -110,7 +111,7 @@ func TestIncrement(t *testing.T) { ctx := metadata.NewOutgoingContext(context.Background(), md) x.Check(dg.Alter(ctx, &op)) - cnt, err := process(dg, false) + cnt, err := process(dg, false, pred) if err != nil { t.Logf("Error while reading: %v\n", err) } else { diff --git a/dgraph/cmd/root.go b/dgraph/cmd/root.go index b235e085970..a32f78e5d5d 100644 --- a/dgraph/cmd/root.go +++ b/dgraph/cmd/root.go @@ -25,6 +25,7 @@ import ( "github.com/dgraph-io/dgraph/dgraph/cmd/bulk" "github.com/dgraph-io/dgraph/dgraph/cmd/cert" "github.com/dgraph-io/dgraph/dgraph/cmd/conv" + "github.com/dgraph-io/dgraph/dgraph/cmd/counter" "github.com/dgraph-io/dgraph/dgraph/cmd/debug" "github.com/dgraph-io/dgraph/dgraph/cmd/live" "github.com/dgraph-io/dgraph/dgraph/cmd/version" @@ -66,7 +67,7 @@ var rootConf = viper.New() // subcommands initially contains all default sub-commands. var subcommands = []*x.SubCommand{ &bulk.Bulk, &cert.Cert, &conv.Conv, &live.Live, &alpha.Alpha, &zero.Zero, &version.Version, - &debug.Debug, + &debug.Debug, &counter.Increment, } func initCmds() {