Skip to content

Commit

Permalink
Add TLS support to increment command (dgraph-io#3257)
Browse files Browse the repository at this point in the history
* Fix dgraph increment options.

* Show error if TLS options are incomplete.

* Use TLS connection with dgraph increment if requested.

* Replace hard-coded alpha port in tests with z.SockAddr.

* Remove unnecessary comment.
  • Loading branch information
codexnull authored and dna2github committed Jul 19, 2019
1 parent d0598b3 commit f6804d5
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 14 deletions.
17 changes: 11 additions & 6 deletions dgraph/cmd/counter/increment.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import (
"github.com/dgraph-io/dgraph/x"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"google.golang.org/grpc"
)

var Increment x.SubCommand
Expand All @@ -49,12 +48,15 @@ func init() {
flag := Increment.Cmd.Flags()
flag.String("addr", "localhost:9080", "Address of Dgraph alpha.")
flag.Int("num", 1, "How many times to run.")
flag.Bool("ro", false, "Only read the counter value, don't update it.")
flag.Bool("be", false, "Read counter value without retrieving timestamp from Zero.")
flag.Duration("wait", 0*time.Second, "How long to wait.")
flag.String("pred", "counter.val", "Predicate to use for storing the counter.")
flag.String("user", "", "Username if login is required.")
flag.String("password", "", "Password of the user.")
flag.String("pred", "counter.val",
"Predicate to use for storing the counter.")
flag.Bool("ro", false,
"Read-only. Read the counter value without updating it.")
flag.Bool("be", false,
"Best-effort. Read counter value without retrieving timestamp from Zero.")
// TLS configuration
x.RegisterClientTLSFlags(flag)
}
Expand Down Expand Up @@ -135,14 +137,17 @@ func run(conf *viper.Viper) {
waitDur := conf.GetDuration("wait")
num := conf.GetInt("num")

conn, err := grpc.Dial(addr, grpc.WithInsecure())
tlsCfg, err := x.LoadClientTLSConfig(conf)
x.CheckfNoTrace(err)

conn, err := x.SetupConnection(addr, tlsCfg, false)
if err != nil {
log.Fatal(err)
}
dc := api.NewDgraphClient(conn)
dg := dgo.NewDgraphClient(dc)
if user := conf.GetString("user"); len(user) > 0 {
x.Check(dg.Login(context.Background(), user, conf.GetString("password")))
x.CheckfNoTrace(dg.Login(context.Background(), user, conf.GetString("password")))
}

for num > 0 {
Expand Down
3 changes: 2 additions & 1 deletion tlstest/acl/acl_over_tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/dgraph-io/dgo"
"github.com/dgraph-io/dgo/protos/api"
"github.com/dgraph-io/dgraph/z"
"github.com/golang/glog"
"github.com/spf13/viper"
"google.golang.org/grpc"
Expand Down Expand Up @@ -98,7 +99,7 @@ func ExampleLoginOverTLS() {
conf.Set("tls_cacert", "../tls/ca.crt")
conf.Set("tls_server_name", "node")

dg, err := dgraphClientWithCerts(":9180", conf)
dg, err := dgraphClientWithCerts(z.SockAddr, conf)
if err != nil {
glog.Fatalf("Unable to get dgraph client: %v", err)
}
Expand Down
4 changes: 2 additions & 2 deletions tlstest/certrequest/certrequest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
)

func TestAccessOverPlaintext(t *testing.T) {
dg := z.DgraphClient(":9180")
dg := z.DgraphClient(z.SockAddr)
err := dg.Alter(context.Background(), &api.Operation{DropAll: true})
require.Error(t, err, "The authentication handshake should have failed")
}
Expand All @@ -21,7 +21,7 @@ func TestAccessWithCaCert(t *testing.T) {
conf.Set("tls_cacert", "../tls/ca.crt")
conf.Set("tls_server_name", "node")

dg, err := z.DgraphClientWithCerts(":9180", conf)
dg, err := z.DgraphClientWithCerts(z.SockAddr, conf)
require.NoError(t, err, "Unable to get dgraph client: %v", err)
err = dg.Alter(context.Background(), &api.Operation{DropAll: true})
require.NoError(t, err, "Unable to perform dropall: %v", err)
Expand Down
4 changes: 2 additions & 2 deletions tlstest/certrequireandverify/certrequireandverify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func TestAccessWithoutClientCert(t *testing.T) {
conf.Set("tls_cacert", "../tls/ca.crt")
conf.Set("tls_server_name", "node")

dg, err := z.DgraphClientWithCerts(":9180", conf)
dg, err := z.DgraphClientWithCerts(z.SockAddr, conf)
require.NoError(t, err, "Unable to get dgraph client: %v", err)
err = dg.Alter(context.Background(), &api.Operation{DropAll: true})
require.Error(t, err, "The authentication handshake should have failed")
Expand All @@ -28,7 +28,7 @@ func TestAccessWithClientCert(t *testing.T) {
conf.Set("tls_cert", "../tls/client.acl.crt")
conf.Set("tls_key", "../tls/client.acl.key")

dg, err := z.DgraphClientWithCerts(":9180", conf)
dg, err := z.DgraphClientWithCerts(z.SockAddr, conf)
require.NoError(t, err, "Unable to get dgraph client: %v", err)
err = dg.Alter(context.Background(), &api.Operation{DropAll: true})
require.NoError(t, err, "Unable to perform dropall: %v", err)
Expand Down
4 changes: 2 additions & 2 deletions tlstest/certverifyifgiven/certverifyifgiven_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func TestAccessWithoutClientCert(t *testing.T) {
conf.Set("tls_cacert", "../tls/ca.crt")
conf.Set("tls_server_name", "node")

dg, err := z.DgraphClientWithCerts(":9180", conf)
dg, err := z.DgraphClientWithCerts(z.SockAddr, conf)
require.NoError(t, err, "Unable to get dgraph client: %v", err)
err = dg.Alter(context.Background(), &api.Operation{DropAll: true})
require.NoError(t, err, "Unable to perform dropall: %v", err)
Expand All @@ -28,7 +28,7 @@ func TestAccessWithClientCert(t *testing.T) {
conf.Set("tls_cert", "../tls/client.acl.crt")
conf.Set("tls_key", "../tls/client.acl.key")

dg, err := z.DgraphClientWithCerts(":9180", conf)
dg, err := z.DgraphClientWithCerts(z.SockAddr, conf)
require.NoError(t, err, "Unable to get dgraph client: %v", err)
err = dg.Alter(context.Background(), &api.Operation{DropAll: true})
require.NoError(t, err, "Unable to perform dropall: %v", err)
Expand Down
11 changes: 10 additions & 1 deletion x/tls_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ type TLSHelperConfig struct {
}

func RegisterClientTLSFlags(flag *pflag.FlagSet) {
flag.String("tls_cacert", "", "The CA Cert file used to verify server certificates.")
flag.String("tls_cacert", "",
"The CA Cert file used to verify server certificates. Required for enabling TLS.")
flag.Bool("tls_use_system_ca", true, "Include System CA into CA Certs.")
flag.String("tls_server_name", "", "Used to verify the server hostname.")
flag.String("tls_cert", "", "(optional) The Cert file provided by the client to the server.")
Expand Down Expand Up @@ -107,6 +108,14 @@ func LoadClientTLSConfig(v *viper.Viper) (*tls.Config, error) {
}

return &tlsCfg, nil
} else
// Attempt to determine if user specified *any* TLS option. Unfortunately and contrary to
// Viper's own documentation, there's no way to tell whether an option value came from a
// command-line option or a built-it default.
if v.GetString("tls_server_name") != "" ||
v.GetString("tls_cert") != "" ||
v.GetString("tls_key") != "" {
return nil, fmt.Errorf("--tls_cacert is required for enabling TLS")
}
return nil, nil
}
Expand Down

0 comments on commit f6804d5

Please sign in to comment.