From ea2159bb4cf4d321c19eff8dced9f7e7d6ee794f Mon Sep 17 00:00:00 2001 From: Lucas Wang Date: Thu, 6 Dec 2018 14:15:07 -0800 Subject: [PATCH] Adding the acl subcommand to support acl features (#2795) --- dgraph/cmd/alpha/run.go | 33 ++++- dgraph/cmd/live/run.go | 41 +----- dgraph/cmd/root.go | 3 +- edgraph/access.go | 34 +++++ edgraph/access_ee.go | 240 +++++++++++++++++++++++++++++++ edgraph/config.go | 5 + edgraph/server.go | 5 + ee/acl/acl_test.go | 122 ++++++++++++++++ ee/acl/cmd/groups.go | 239 ++++++++++++++++++++++++++++++ ee/acl/cmd/groups_test.go | 63 ++++++++ ee/acl/cmd/run.go | 33 +++++ ee/acl/cmd/run_ee.go | 296 ++++++++++++++++++++++++++++++++++++++ ee/acl/cmd/users.go | 280 ++++++++++++++++++++++++++++++++++++ ee/acl/utils.go | 94 ++++++++++++ query/mutation.go | 3 + worker/backup.go | 10 +- worker/groups.go | 40 +++++- x/tls_helper.go | 8 +- x/x.go | 50 ++++++- 19 files changed, 1538 insertions(+), 61 deletions(-) create mode 100644 edgraph/access.go create mode 100644 edgraph/access_ee.go create mode 100644 ee/acl/acl_test.go create mode 100644 ee/acl/cmd/groups.go create mode 100644 ee/acl/cmd/groups_test.go create mode 100644 ee/acl/cmd/run.go create mode 100644 ee/acl/cmd/run_ee.go create mode 100644 ee/acl/cmd/users.go create mode 100644 ee/acl/utils.go diff --git a/dgraph/cmd/alpha/run.go b/dgraph/cmd/alpha/run.go index 872ec83140e..f73d6a8185c 100644 --- a/dgraph/cmd/alpha/run.go +++ b/dgraph/cmd/alpha/run.go @@ -21,6 +21,7 @@ import ( "crypto/tls" "errors" "fmt" + "io/ioutil" "log" "net" "net/http" @@ -54,6 +55,11 @@ import ( hapi "google.golang.org/grpc/health/grpc_health_v1" ) +const ( + tlsNodeCert = "node.crt" + tlsNodeKey = "node.key" +) + var ( bindall bool tlsConf x.TLSHelperConfig @@ -120,6 +126,12 @@ they form a Raft group and provide synchronous replication. "If set, all Alter requests to Dgraph would need to have this token."+ " The token can be passed as follows: For HTTP requests, in X-Dgraph-AuthToken header."+ " For Grpc, in auth-token key in the context.") + flag.String("hmac_secret_file", "", "The file storing the HMAC secret"+ + " that is used for signing the JWT. Enterprise feature.") + flag.Duration("access_jwt_ttl", 6*time.Hour, "The TTL for the access jwt. "+ + "Enterprise feature.") + flag.Duration("refresh_jwt_ttl", 30*24*time.Hour, "The TTL for the refresh jwt. "+ + "Enterprise feature.") flag.Float64P("lru_mb", "l", -1, "Estimated memory the LRU cache can take. "+ "Actual usage by the process would be more than specified here.") @@ -380,7 +392,7 @@ var shutdownCh chan struct{} func run() { bindall = Alpha.Conf.GetBool("bindall") - edgraph.SetConfiguration(edgraph.Options{ + opts := edgraph.Options{ BadgerTables: Alpha.Conf.GetString("badger.tables"), BadgerVlog: Alpha.Conf.GetString("badger.vlog"), @@ -390,7 +402,22 @@ func run() { Nomutations: Alpha.Conf.GetBool("nomutations"), AuthToken: Alpha.Conf.GetString("auth_token"), AllottedMemory: Alpha.Conf.GetFloat64("lru_mb"), - }) + } + + secretFile := Alpha.Conf.GetString("hmac_secret_file") + if secretFile != "" { + hmacSecret, err := ioutil.ReadFile(secretFile) + if err != nil { + glog.Fatalf("Unable to read HMAC secret from file: %v", secretFile) + } + + opts.HmacSecret = hmacSecret + opts.AccessJwtTtl = Alpha.Conf.GetDuration("access_jwt_ttl") + opts.RefreshJwtTtl = Alpha.Conf.GetDuration("refresh_jwt_ttl") + + glog.Info("HMAC secret loaded successfully.") + } + edgraph.SetConfiguration(opts) ips, err := parseIPsFromString(Alpha.Conf.GetString("whitelist")) x.Check(err) @@ -406,7 +433,7 @@ func run() { MaxRetries: Alpha.Conf.GetInt("max_retries"), } - x.LoadTLSConfig(&tlsConf, Alpha.Conf) + x.LoadTLSConfig(&tlsConf, Alpha.Conf, tlsNodeCert, tlsNodeKey) tlsConf.ClientAuth = Alpha.Conf.GetString("tls_client_auth") setupCustomTokenizers() diff --git a/dgraph/cmd/live/run.go b/dgraph/cmd/live/run.go index 8db0f5c0733..8a06777ac67 100644 --- a/dgraph/cmd/live/run.go +++ b/dgraph/cmd/live/run.go @@ -34,8 +34,6 @@ import ( "strings" "time" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials" "google.golang.org/grpc/metadata" "github.com/dgraph-io/badger" @@ -48,11 +46,6 @@ import ( "github.com/spf13/cobra" ) -const ( - tlsLiveCert = "client.live.crt" - tlsLiveKey = "client.live.key" -) - type options struct { files string schemaFile string @@ -239,34 +232,6 @@ func (l *loader) processFile(ctx context.Context, file string) error { return nil } -func setupConnection(host string, insecure bool) (*grpc.ClientConn, error) { - if insecure { - return grpc.Dial(host, - grpc.WithDefaultCallOptions( - grpc.MaxCallRecvMsgSize(x.GrpcMaxSize), - grpc.MaxCallSendMsgSize(x.GrpcMaxSize)), - grpc.WithInsecure(), - grpc.WithBlock(), - grpc.WithTimeout(10*time.Second)) - } - - tlsConf.ConfigType = x.TLSClientConfig - tlsConf.Cert = filepath.Join(tlsConf.CertDir, tlsLiveCert) - tlsConf.Key = filepath.Join(tlsConf.CertDir, tlsLiveKey) - tlsCfg, _, err := x.GenerateTLSConfig(tlsConf) - if err != nil { - return nil, err - } - - return grpc.Dial(host, - grpc.WithDefaultCallOptions( - grpc.MaxCallRecvMsgSize(x.GrpcMaxSize), - grpc.MaxCallSendMsgSize(x.GrpcMaxSize)), - grpc.WithTransportCredentials(credentials.NewTLS(tlsCfg)), - grpc.WithBlock(), - grpc.WithTimeout(10*time.Second)) -} - func fileList(files string) []string { if len(files) == 0 { return []string{} @@ -285,7 +250,7 @@ func setup(opts batchMutationOptions, dc *dgo.Dgraph) *loader { kv, err := badger.Open(o) x.Checkf(err, "Error while creating badger KV posting store") - connzero, err := setupConnection(opt.zero, true) + connzero, err := x.SetupConnection(opt.zero, &tlsConf) x.Checkf(err, "Unable to connect to zero, Is it running at %s?", opt.zero) alloc := xidmap.New( @@ -329,7 +294,7 @@ func run() error { ignoreIndexConflict: Live.Conf.GetBool("ignore_index_conflict"), authToken: Live.Conf.GetString("auth_token"), } - x.LoadTLSConfig(&tlsConf, Live.Conf) + x.LoadTLSConfig(&tlsConf, Live.Conf, x.TlsClientCert, x.TlsClientKey) tlsConf.ServerName = Live.Conf.GetString("tls_server_name") go http.ListenAndServe("localhost:6060", nil) @@ -345,7 +310,7 @@ func run() error { ds := strings.Split(opt.dgraph, ",") var clients []api.DgraphClient for _, d := range ds { - conn, err := setupConnection(d, !tlsConf.CertRequired) + conn, err := x.SetupConnection(d, &tlsConf) x.Checkf(err, "While trying to setup connection to Dgraph alpha.") defer conn.Close() diff --git a/dgraph/cmd/root.go b/dgraph/cmd/root.go index 1a288a0592b..c392f9e3020 100644 --- a/dgraph/cmd/root.go +++ b/dgraph/cmd/root.go @@ -29,6 +29,7 @@ import ( "github.com/dgraph-io/dgraph/dgraph/cmd/live" "github.com/dgraph-io/dgraph/dgraph/cmd/version" "github.com/dgraph-io/dgraph/dgraph/cmd/zero" + "github.com/dgraph-io/dgraph/ee/acl/cmd" "github.com/dgraph-io/dgraph/x" "github.com/spf13/cobra" flag "github.com/spf13/pflag" @@ -86,7 +87,7 @@ func init() { var subcommands = []*x.SubCommand{ &bulk.Bulk, &cert.Cert, &conv.Conv, &live.Live, &alpha.Alpha, &zero.Zero, - &version.Version, &debug.Debug, + &version.Version, &debug.Debug, &acl.CmdAcl, } for _, sc := range subcommands { RootCmd.AddCommand(sc.Cmd) diff --git a/edgraph/access.go b/edgraph/access.go new file mode 100644 index 00000000000..f347f489294 --- /dev/null +++ b/edgraph/access.go @@ -0,0 +1,34 @@ +// +build oss + +/* + * Copyright 2018 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package edgraph + +import ( + "context" + + "github.com/dgraph-io/dgo/protos/api" + "github.com/dgraph-io/dgraph/x" + "github.com/golang/glog" +) + +func (s *Server) Login(ctx context.Context, + request *api.LoginRequest) (*api.Response, error) { + + glog.Warningf("Login failed: %s", x.ErrNotSupported) + return &api.Response{}, x.ErrNotSupported +} diff --git a/edgraph/access_ee.go b/edgraph/access_ee.go new file mode 100644 index 00000000000..c3a6e8e7e9b --- /dev/null +++ b/edgraph/access_ee.go @@ -0,0 +1,240 @@ +// +build !oss + +/* + * Copyright 2018 Dgraph Labs, Inc. All rights reserved. + * + * Licensed under the Dgraph Community License (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * https://github.com/dgraph-io/dgraph/blob/master/licenses/DCL.txt + */ + +package edgraph + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "time" + + "github.com/dgraph-io/dgo/protos/api" + "github.com/dgraph-io/dgraph/ee/acl" + "github.com/dgrijalva/jwt-go" + "github.com/golang/glog" + "google.golang.org/grpc/peer" + + otrace "go.opencensus.io/trace" +) + +func (s *Server) Login(ctx context.Context, + request *api.LoginRequest) (*api.Response, error) { + ctx, span := otrace.StartSpan(ctx, "server.Login") + defer span.End() + + // record the client ip for this login request + var addr string + if ip, ok := peer.FromContext(ctx); ok { + addr = ip.Addr.String() + glog.Infof("Login request from: %s", addr) + span.Annotate([]otrace.Attribute{ + otrace.StringAttribute("client_ip", addr), + }, "client ip for login") + } + + user, err := s.authenticate(ctx, request) + if err != nil { + errMsg := fmt.Sprintf("authentication from address %s failed: %v", addr, err) + glog.Errorf(errMsg) + return nil, fmt.Errorf(errMsg) + } + + resp := &api.Response{} + accessJwt, err := getAccessJwt(request.Userid, user.Groups) + if err != nil { + errMsg := fmt.Sprintf("unable to get access jwt (userid=%s,addr=%s):%v", + request.Userid, addr, err) + glog.Errorf(errMsg) + return nil, fmt.Errorf(errMsg) + } + refreshJwt, err := getRefreshJwt(request.Userid) + if err != nil { + errMsg := fmt.Sprintf("unable to get refresh jwt (userid=%s,addr=%s):%v", + request.Userid, addr, err) + glog.Errorf(errMsg) + return nil, fmt.Errorf(errMsg) + } + + loginJwt := api.Jwt{ + AccessJwt: accessJwt, + RefreshJwt: refreshJwt, + } + + jwtBytes, err := loginJwt.Marshal() + if err != nil { + errMsg := fmt.Sprintf("unable to marshal jwt (userid=%s,addr=%s):%v", + request.Userid, addr, err) + glog.Errorf(errMsg) + return nil, fmt.Errorf(errMsg) + } + resp.Json = jwtBytes + return resp, nil +} + +func (s *Server) authenticate(ctx context.Context, request *api.LoginRequest) (*acl.User, error) { + if err := validateLoginRequest(request); err != nil { + return nil, fmt.Errorf("invalid login request: %v", err) + } + + var user *acl.User + if len(request.RefreshToken) > 0 { + userId, err := authenticateRefreshToken(request.RefreshToken) + if err != nil { + return nil, fmt.Errorf("unable to authenticate the refresh token %v: %v", + request.RefreshToken, err) + } + + user, err = s.queryUser(ctx, userId, "") + if err != nil { + return nil, fmt.Errorf("error while querying user with id: %v", + request.Userid) + } + + if user == nil { + return nil, fmt.Errorf("user not found for id %v", request.Userid) + } + } else { + var err error + user, err = s.queryUser(ctx, request.Userid, request.Password) + if err != nil { + return nil, fmt.Errorf("error while querying user with id: %v", + request.Userid) + } + + if user == nil { + return nil, fmt.Errorf("user not found for id %v", request.Userid) + } + if !user.PasswordMatch { + return nil, fmt.Errorf("password mismatch for user: %v", request.Userid) + } + } + + return user, nil +} + +func authenticateRefreshToken(refreshToken string) (string, error) { + token, err := jwt.Parse(refreshToken, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return Config.HmacSecret, nil + }) + + if err != nil { + return "", fmt.Errorf("unable to parse refresh token:%v", err) + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok || !token.Valid { + return "", fmt.Errorf("claims in refresh token is not map claims:%v", refreshToken) + } + + // by default, the MapClaims.Valid will return true if the exp field is not set + // here we enforce the checking to make sure that the refresh token has not expired + now := time.Now().Unix() + if !claims.VerifyExpiresAt(now, true) { + return "", fmt.Errorf("refresh token has expired: %v", refreshToken) + } + + userId, ok := claims["userid"].(string) + if !ok { + return "", fmt.Errorf("userid in claims is not a string:%v", userId) + } + return userId, nil +} + +func validateLoginRequest(request *api.LoginRequest) error { + if request == nil { + return fmt.Errorf("the request should not be nil") + } + // we will use the refresh token for authentication if it's set + if len(request.RefreshToken) > 0 { + return nil + } + + // otherwise make sure both userid and password are set + if len(request.Userid) == 0 { + return fmt.Errorf("the userid should not be empty") + } + if len(request.Password) == 0 { + return fmt.Errorf("the password should not be empty") + } + return nil +} + +func getAccessJwt(userId string, groups []acl.Group) (string, error) { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "userid": userId, + "groups": acl.GetGroupIDs(groups), + // set the jwt exp according to the ttl + "exp": json.Number( + strconv.FormatInt(time.Now().Add(Config.AccessJwtTtl).Unix(), 10)), + }) + + jwtString, err := token.SignedString(Config.HmacSecret) + if err != nil { + return "", fmt.Errorf("unable to encode jwt to string: %v", err) + } + return jwtString, nil +} + +func getRefreshJwt(userId string) (string, error) { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "userid": userId, + // set the jwt exp according to the ttl + "exp": json.Number( + strconv.FormatInt(time.Now().Add(Config.RefreshJwtTtl).Unix(), 10)), + }) + + jwtString, err := token.SignedString(Config.HmacSecret) + if err != nil { + return "", fmt.Errorf("unable to encode jwt to string: %v", err) + } + return jwtString, nil +} + +const queryUser = ` + query search($userid: string, $password: string){ + user(func: eq(dgraph.xid, $userid)) { + uid + password_match: checkpwd(dgraph.password, $password) + dgraph.user.group { + uid + dgraph.xid + } + } + }` + +func (s *Server) queryUser(ctx context.Context, userid string, password string) (user *acl.User, + err error) { + queryVars := map[string]string{ + "$userid": userid, + "$password": password, + } + queryRequest := api.Request{ + Query: queryUser, + Vars: queryVars, + } + + queryResp, err := s.Query(ctx, &queryRequest) + if err != nil { + glog.Errorf("Error while query user with id %s: %v", userid, err) + return nil, err + } + user, err = acl.UnmarshalUser(queryResp, "user") + if err != nil { + return nil, err + } + return user, nil +} diff --git a/edgraph/config.go b/edgraph/config.go index 13a83fd70fd..6ba4e9ad7de 100644 --- a/edgraph/config.go +++ b/edgraph/config.go @@ -19,6 +19,7 @@ package edgraph import ( "expvar" "path/filepath" + "time" "github.com/dgraph-io/dgraph/posting" "github.com/dgraph-io/dgraph/worker" @@ -34,6 +35,10 @@ type Options struct { AuthToken string AllottedMemory float64 + + HmacSecret []byte + AccessJwtTtl time.Duration + RefreshJwtTtl time.Duration } var Config Options diff --git a/edgraph/server.go b/edgraph/server.go index 3cb878681c2..425e22f68b3 100644 --- a/edgraph/server.go +++ b/edgraph/server.go @@ -598,6 +598,11 @@ func parseNQuads(b []byte) ([]*api.NQuad, error) { return nqs, nil } +// parseMutationObject tries to consolidate fields of the api.Mutation into the +// corresponding field of the returned gql.Mutation. For example, the 3 fields, +// api.Mutation#SetJson, api.Mutation#SetNquads and api.Mutation#Set are consolidated into the +// gql.Mutation.Set field. Similarly the 3 fields api.Mutation#DeleteJson, api.Mutation#DelNquads +// and api.Mutation#Del are merged into the gql.Mutation#Del field. func parseMutationObject(mu *api.Mutation) (*gql.Mutation, error) { res := &gql.Mutation{} if len(mu.SetJson) > 0 { diff --git a/ee/acl/acl_test.go b/ee/acl/acl_test.go new file mode 100644 index 00000000000..806c8fcc78c --- /dev/null +++ b/ee/acl/acl_test.go @@ -0,0 +1,122 @@ +// +build !oss + +/* + * Copyright 2018 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Dgraph Community License (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * https://github.com/dgraph-io/dgraph/blob/master/licenses/DCL.txt + */ + +package acl + +import ( + "os/exec" + "testing" +) + +const ( + userid = "alice" + userpassword = "simplepassword" + dgraphEndpoint = "localhost:9180" +) + +func TestAcl(t *testing.T) { + t.Run("create user", CreateAndDeleteUsers) + // t.Run("login", LogIn) +} + +func checkOutput(t *testing.T, cmd *exec.Cmd, shouldFail bool) string { + out, err := cmd.CombinedOutput() + if (!shouldFail && err != nil) || (shouldFail && err == nil) { + t.Errorf("Error output from command:%v", string(out)) + t.Fatal(err) + } + + return string(out) +} + +func CreateAndDeleteUsers(t *testing.T) { + createUserCmd1 := exec.Command("dgraph", "acl", "useradd", "-d", dgraphEndpoint, "-u", userid, + "-p", userpassword) + createUserOutput1 := checkOutput(t, createUserCmd1, false) + t.Logf("Got output when creating user:%v", createUserOutput1) + + createUserCmd2 := exec.Command("dgraph", "acl", "useradd", "-d", dgraphEndpoint, "-u", userid, + "-p", userpassword) + + // create the user again should fail + createUserOutput2 := checkOutput(t, createUserCmd2, true) + t.Logf("Got output when creating user:%v", createUserOutput2) + + // delete the user + deleteUserCmd := exec.Command("dgraph", "acl", "userdel", "-d", dgraphEndpoint, "-u", userid) + deleteUserOutput := checkOutput(t, deleteUserCmd, false) + t.Logf("Got output when deleting user:%v", deleteUserOutput) + + // now we should be able to create the user again + createUserCmd3 := exec.Command("dgraph", "acl", "useradd", "-d", dgraphEndpoint, "-u", userid, + "-p", userpassword) + createUserOutput3 := checkOutput(t, createUserCmd3, false) + t.Logf("Got output when creating user:%v", createUserOutput3) +} + +// TODO(gitlw): Finish this later. +// func LogIn(t *testing.T) { +// delete and recreate the user to ensure a clean state +/* + deleteUserCmd := exec.Command("dgraph", "acl", "userdel", "-d", dgraphEndpoint, "-u", "lucas") + deleteUserOutput := checkOutput(t, deleteUserCmd, false) + createUserCmd := exec.Command("dgraph", "acl", "useradd", "-d", dgraphEndpoint, "-u", "lucas", + "-p", "haha") + createUserOutput := checkOutput(t, createUserCmd, false) +*/ + +// now try to login with the wrong password + +//loginWithWrongPassword(t, ctx, adminClient) +//loginWithCorrectPassword(t, ctx, adminClient) +// } + +/* +func loginWithCorrectPassword(t *testing.T, ctx context.Context, + adminClient api.DgraphAccessClient) { + loginRequest := api.LogInRequest{ + Userid: userid, + Password: userpassword, + } + response2, err := adminClient.LogIn(ctx, &loginRequest) + require.NoError(t, err) + if response2.Code != api.AclResponseCode_OK { + t.Errorf("Login with the correct password should result in the code %v", + api.AclResponseCode_OK) + } + jwt := acl.Jwt{} + jwt.DecodeString(response2.Context.Jwt, false, nil) + if jwt.Payload.Userid != userid { + t.Errorf("the jwt token should have the user id encoded") + } + jwtTime := time.Unix(jwt.Payload.Exp, 0) + jwtValidDays := jwtTime.Sub(time.Now()).Round(time.Hour).Hours() / 24 + if jwtValidDays != 30.0 { + t.Errorf("The jwt token should be valid for 30 days, received %v days", jwtValidDays) + } +} + +func loginWithWrongPassword(t *testing.T, ctx context.Context, + adminClient api.DgraphAccessClient) { + loginRequestWithWrongPassword := api.LogInRequest{ + Userid: userid, + Password: userpassword + "123", + } + + response, err := adminClient.LogIn(ctx, &loginRequestWithWrongPassword) + require.NoError(t, err) + if response.Code != api.AclResponseCode_UNAUTHENTICATED { + t.Errorf("Login with the wrong password should result in the code %v", api.AclResponseCode_UNAUTHENTICATED) + } +} + +*/ diff --git a/ee/acl/cmd/groups.go b/ee/acl/cmd/groups.go new file mode 100644 index 00000000000..ea6b9ab21b9 --- /dev/null +++ b/ee/acl/cmd/groups.go @@ -0,0 +1,239 @@ +// +build !oss + +/* + * Copyright 2018 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Dgraph Community License (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * https://github.com/dgraph-io/dgraph/blob/master/licenses/DCL.txt + */ +package acl + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/dgraph-io/dgo" + "github.com/dgraph-io/dgo/protos/api" + "github.com/dgraph-io/dgraph/ee/acl" + "github.com/dgraph-io/dgraph/x" + "github.com/golang/glog" + "github.com/spf13/viper" +) + +func groupAdd(conf *viper.Viper) error { + groupId := conf.GetString("group") + if len(groupId) == 0 { + return fmt.Errorf("the group id should not be empty") + } + + dc, close := getDgraphClient(conf) + defer close() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + txn := dc.NewTxn() + defer func() { + if err := txn.Discard(ctx); err != nil { + glog.Errorf("Unable to discard transaction:%v", err) + } + }() + + group, err := queryGroup(ctx, txn, groupId) + if err != nil { + return fmt.Errorf("error while querying group:%v", err) + } + if group != nil { + return fmt.Errorf("the group with id %v already exists", groupId) + } + + createGroupNQuads := []*api.NQuad{ + { + Subject: "_:newgroup", + Predicate: "dgraph.xid", + ObjectValue: &api.Value{Val: &api.Value_StrVal{StrVal: groupId}}, + }, + } + + mu := &api.Mutation{ + CommitNow: true, + Set: createGroupNQuads, + } + if _, err = txn.Mutate(ctx, mu); err != nil { + return fmt.Errorf("unable to create group: %v", err) + } + + glog.Infof("Created new group with id %v", groupId) + return nil +} + +func groupDel(conf *viper.Viper) error { + groupId := conf.GetString("group") + if len(groupId) == 0 { + return fmt.Errorf("the group id should not be empty") + } + + dc, close := getDgraphClient(conf) + defer close() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + txn := dc.NewTxn() + defer func() { + if err := txn.Discard(ctx); err != nil { + glog.Errorf("Unable to discard transaction:%v", err) + } + }() + + group, err := queryGroup(ctx, txn, groupId) + if err != nil { + return fmt.Errorf("error while querying group:%v", err) + } + if group == nil || len(group.Uid) == 0 { + return fmt.Errorf("unable to delete group because it does not exist: %v", groupId) + } + + deleteGroupNQuads := []*api.NQuad{ + { + Subject: group.Uid, + Predicate: x.Star, + ObjectValue: &api.Value{Val: &api.Value_DefaultVal{DefaultVal: x.Star}}, + }, + } + mu := &api.Mutation{ + CommitNow: true, + Del: deleteGroupNQuads, + } + if _, err := txn.Mutate(ctx, mu); err != nil { + return fmt.Errorf("unable to delete group: %v", err) + } + + glog.Infof("Deleted group with id %v", groupId) + return nil +} + +func queryGroup(ctx context.Context, txn *dgo.Txn, groupid string, + fields ...string) (group *acl.Group, err error) { + + // write query header + query := fmt.Sprintf(`query search($groupid: string){ + group(func: eq(dgraph.xid, $groupid)) { + uid + %s }}`, strings.Join(fields, ", ")) + + queryVars := map[string]string{ + "$groupid": groupid, + } + + queryResp, err := txn.QueryWithVars(ctx, query, queryVars) + if err != nil { + glog.Errorf("Error while query group with id %s: %v", groupid, err) + return nil, err + } + group, err = acl.UnmarshalGroup(queryResp.GetJson(), "group") + if err != nil { + return nil, err + } + return group, nil +} + +type Acl struct { + Predicate string `json:"predicate"` + Perm int32 `json:"perm"` +} + +func chMod(conf *viper.Viper) error { + groupId := conf.GetString("group") + predicate := conf.GetString("pred") + perm := conf.GetInt("perm") + if len(groupId) == 0 { + return fmt.Errorf("the groupid must not be empty") + } + if len(predicate) == 0 { + return fmt.Errorf("the predicate must not be empty") + } + + dc, close := getDgraphClient(conf) + defer close() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + txn := dc.NewTxn() + defer func() { + if err := txn.Discard(ctx); err != nil { + glog.Errorf("Unable to discard transaction:%v", err) + } + }() + + group, err := queryGroup(ctx, txn, groupId, "dgraph.group.acl") + if err != nil { + return fmt.Errorf("error while querying group:%v", err) + } + if group == nil || len(group.Uid) == 0 { + return fmt.Errorf("unable to change permission for group because it does not exist: %v", + groupId) + } + + var currentAcls []Acl + if len(group.Acls) != 0 { + if err := json.Unmarshal([]byte(group.Acls), ¤tAcls); err != nil { + return fmt.Errorf("unable to unmarshal the acls associated with the group %v:%v", + groupId, err) + } + } + + newAcls, updated := updateAcl(currentAcls, Acl{ + Predicate: predicate, + Perm: int32(perm), + }) + if !updated { + glog.Infof("Nothing needs to be changed for the permission of group:%v", groupId) + return nil + } + + newAclBytes, err := json.Marshal(newAcls) + if err != nil { + return fmt.Errorf("unable to marshal the updated acls:%v", err) + } + + chModNQuads := &api.NQuad{ + Subject: group.Uid, + Predicate: "dgraph.group.acl", + ObjectValue: &api.Value{Val: &api.Value_BytesVal{BytesVal: newAclBytes}}, + } + mu := &api.Mutation{ + CommitNow: true, + Set: []*api.NQuad{chModNQuads}, + } + + if _, err = txn.Mutate(ctx, mu); err != nil { + return fmt.Errorf("unable to change mutations for the group %v on predicate %v: %v", + groupId, predicate, err) + } + glog.Infof("Successfully changed permission for group %v on predicate %v to %v", + groupId, predicate, perm) + return nil +} + +// returns whether the existing acls slice is changed +func updateAcl(acls []Acl, newAcl Acl) ([]Acl, bool) { + for idx, aclEntry := range acls { + if aclEntry.Predicate == newAcl.Predicate { + if aclEntry.Perm == newAcl.Perm { + return acls, false + } + if newAcl.Perm < 0 { + // remove the current aclEntry from the array + copy(acls[idx:], acls[idx+1:]) + return acls[:len(acls)-1], true + } + acls[idx].Perm = newAcl.Perm + return acls, true + } + } + + // we do not find any existing aclEntry matching the newAcl predicate + return append(acls, newAcl), true +} diff --git a/ee/acl/cmd/groups_test.go b/ee/acl/cmd/groups_test.go new file mode 100644 index 00000000000..c4b7592d38a --- /dev/null +++ b/ee/acl/cmd/groups_test.go @@ -0,0 +1,63 @@ +// +build !oss + +/* + * Copyright 2018 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Dgraph Community License (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * https://github.com/dgraph-io/dgraph/blob/master/licenses/DCL.txt + */ +package acl + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUpdateAcl(t *testing.T) { + var currenAcls []Acl + newAcl := Acl{ + Predicate: "friend", + Perm: 4, + } + updatedAcls1, changed := updateAcl(currenAcls, newAcl) + require.True(t, changed, "the acl list should be changed") + require.Equal(t, 1, len(updatedAcls1), "the updated acl list should have 1 element") + + // trying to update the acl list again with the exactly same acl won't change it + updatedAcls2, changed := updateAcl(updatedAcls1, newAcl) + require.False(t, changed, "the acl list should not be changed through update with "+ + "an existing element") + require.Equal(t, 1, len(updatedAcls2), "the updated acl list should still have 1 element") + require.Equal(t, int32(4), updatedAcls2[0].Perm, "the perm should still have the value of 4") + + newAcl.Perm = 6 + updatedAcls3, changed := updateAcl(updatedAcls1, newAcl) + require.True(t, changed, "the acl list should be changed through update "+ + "with element of new perm") + require.Equal(t, 1, len(updatedAcls3), "the updated acl list should still have 1 element") + require.Equal(t, int32(6), updatedAcls3[0].Perm, "the updated perm should be 6 now") + + newAcl = Acl{ + Predicate: "buddy", + Perm: 6, + } + updatedAcls4, changed := updateAcl(updatedAcls3, newAcl) + require.True(t, changed, "the acl should be changed through update "+ + "with element of new predicate") + require.Equal(t, 2, len(updatedAcls4), "the acl list should have 2 elements now") + + newAcl = Acl{ + Predicate: "buddy", + Perm: -3, + } + updatedAcls5, changed := updateAcl(updatedAcls4, newAcl) + require.True(t, changed, "the acl should be changed through update "+ + "with element of negative predicate") + require.Equal(t, 1, len(updatedAcls5), "the acl list should have 1 element now") + require.Equal(t, "friend", updatedAcls5[0].Predicate, "the left acl should have the original "+ + "first predicate") +} diff --git a/ee/acl/cmd/run.go b/ee/acl/cmd/run.go new file mode 100644 index 00000000000..efb614b9b72 --- /dev/null +++ b/ee/acl/cmd/run.go @@ -0,0 +1,33 @@ +// +build oss + +/* + * Copyright 2018 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package acl + +import ( + "github.com/dgraph-io/dgraph/x" + "github.com/spf13/cobra" +) + +var CmdAcl x.SubCommand + +func init() { + CmdAcl.Cmd = &cobra.Command{ + Use: "acl", + Short: "Enterprise feature. Not supported in oss version", + } +} diff --git a/ee/acl/cmd/run_ee.go b/ee/acl/cmd/run_ee.go new file mode 100644 index 00000000000..203e96ff949 --- /dev/null +++ b/ee/acl/cmd/run_ee.go @@ -0,0 +1,296 @@ +// +build !oss + +/* + * Copyright 2018 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Dgraph Community License (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * https://github.com/dgraph-io/dgraph/blob/master/licenses/DCL.txt + */ + +package acl + +import ( + "context" + "encoding/json" + "fmt" + "os" + "strings" + "time" + + "github.com/dgraph-io/dgo" + "github.com/dgraph-io/dgo/protos/api" + "github.com/dgraph-io/dgraph/x" + "github.com/golang/glog" + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +type options struct { + dgraph string +} + +var opt options +var tlsConf x.TLSHelperConfig + +var CmdAcl x.SubCommand + +func init() { + CmdAcl.Cmd = &cobra.Command{ + Use: "acl", + Short: "Run the Dgraph acl tool", + } + + flag := CmdAcl.Cmd.PersistentFlags() + flag.StringP("dgraph", "d", "127.0.0.1:9080", "Dgraph gRPC server address") + + // TLS configuration + x.RegisterTLSFlags(flag) + flag.String("tls_server_name", "", "Used to verify the server hostname.") + + subcommands := initSubcommands() + for _, sc := range subcommands { + CmdAcl.Cmd.AddCommand(sc.Cmd) + sc.Conf = viper.New() + if err := sc.Conf.BindPFlags(sc.Cmd.Flags()); err != nil { + glog.Fatalf("Unable to bind flags for command %v:%v", sc, err) + } + if err := sc.Conf.BindPFlags(CmdAcl.Cmd.PersistentFlags()); err != nil { + glog.Fatalf("Unable to bind persistent flags from acl for command %v:%v", sc, err) + } + sc.Conf.SetEnvPrefix(sc.EnvPrefix) + } +} + +func initSubcommands() []*x.SubCommand { + // user creation command + var cmdUserAdd x.SubCommand + cmdUserAdd.Cmd = &cobra.Command{ + Use: "useradd", + Short: "Run Dgraph acl tool to add a user", + Run: func(cmd *cobra.Command, args []string) { + if err := userAdd(cmdUserAdd.Conf); err != nil { + glog.Errorf("Unable to add user:%v", err) + os.Exit(1) + } + }, + } + userAddFlags := cmdUserAdd.Cmd.Flags() + userAddFlags.StringP("user", "u", "", "The user id to be created") + userAddFlags.StringP("password", "p", "", "The password for the user") + + // user deletion command + var cmdUserDel x.SubCommand + cmdUserDel.Cmd = &cobra.Command{ + Use: "userdel", + Short: "Run Dgraph acl tool to delete a user", + Run: func(cmd *cobra.Command, args []string) { + if err := userDel(cmdUserDel.Conf); err != nil { + glog.Errorf("Unable to delete the user:%v", err) + os.Exit(1) + } + }, + } + userDelFlags := cmdUserDel.Cmd.Flags() + userDelFlags.StringP("user", "u", "", "The user id to be deleted") + + // login command + var cmdLogIn x.SubCommand + cmdLogIn.Cmd = &cobra.Command{ + Use: "login", + Short: "Login to dgraph in order to get a jwt token", + Run: func(cmd *cobra.Command, args []string) { + if err := userLogin(cmdLogIn.Conf); err != nil { + glog.Errorf("Unable to login:%v", err) + os.Exit(1) + } + }, + } + loginFlags := cmdLogIn.Cmd.Flags() + loginFlags.StringP("user", "u", "", "The user id to be created") + loginFlags.StringP("password", "p", "", "The password for the user") + + // group creation command + var cmdGroupAdd x.SubCommand + cmdGroupAdd.Cmd = &cobra.Command{ + Use: "groupadd", + Short: "Run Dgraph acl tool to add a group", + Run: func(cmd *cobra.Command, args []string) { + if err := groupAdd(cmdGroupAdd.Conf); err != nil { + glog.Errorf("Unable to add group:%v", err) + os.Exit(1) + } + }, + } + groupAddFlags := cmdGroupAdd.Cmd.Flags() + groupAddFlags.StringP("group", "g", "", "The group id to be created") + + // group deletion command + var cmdGroupDel x.SubCommand + cmdGroupDel.Cmd = &cobra.Command{ + Use: "groupdel", + Short: "Run Dgraph acl tool to delete a group", + Run: func(cmd *cobra.Command, args []string) { + if err := groupDel(cmdGroupDel.Conf); err != nil { + glog.Errorf("Unable to delete group:%v", err) + os.Exit(1) + } + }, + } + groupDelFlags := cmdGroupDel.Cmd.Flags() + groupDelFlags.StringP("group", "g", "", "The group id to be deleted") + + // the usermod command used to set a user's groups + var cmdUserMod x.SubCommand + cmdUserMod.Cmd = &cobra.Command{ + Use: "usermod", + Short: "Run Dgraph acl tool to change a user's groups", + Run: func(cmd *cobra.Command, args []string) { + if err := userMod(cmdUserMod.Conf); err != nil { + glog.Errorf("Unable to modify user:%v", err) + os.Exit(1) + } + }, + } + userModFlags := cmdUserMod.Cmd.Flags() + userModFlags.StringP("user", "u", "", "The user id to be changed") + userModFlags.StringP("groups", "g", "", "The groups to be set for the user") + + // the chmod command is used to change a group's permissions + var cmdChMod x.SubCommand + cmdChMod.Cmd = &cobra.Command{ + Use: "chmod", + Short: "Run Dgraph acl tool to change a group's permissions", + Run: func(cmd *cobra.Command, args []string) { + if err := chMod(cmdChMod.Conf); err != nil { + glog.Errorf("Unable to change permisson for group:%v", err) + os.Exit(1) + } + }, + } + chModFlags := cmdChMod.Cmd.Flags() + chModFlags.StringP("group", "g", "", "The group whose permission "+ + "is to be changed") + chModFlags.StringP("pred", "p", "", "The predicates whose acls"+ + " are to be changed") + chModFlags.IntP("perm", "P", 0, "The acl represented using "+ + "an integer, 4 for read-only, 2 for write-only, and 1 for modify-only") + + var cmdInfo x.SubCommand + cmdInfo.Cmd = &cobra.Command{ + Use: "info", + Short: "Show info about a user or group", + Run: func(cmd *cobra.Command, args []string) { + if err := info(cmdInfo.Conf); err != nil { + glog.Errorf("Unable to show info:%v", err) + os.Exit(1) + } + }, + } + infoFlags := cmdInfo.Cmd.Flags() + infoFlags.StringP("user", "u", "", "The user to be shown") + infoFlags.StringP("group", "g", "", "The group to be shown") + return []*x.SubCommand{ + &cmdUserAdd, &cmdUserDel, &cmdLogIn, &cmdGroupAdd, &cmdGroupDel, &cmdUserMod, + &cmdChMod, &cmdInfo, + } +} + +type CloseFunc func() + +func getDgraphClient(conf *viper.Viper) (*dgo.Dgraph, CloseFunc) { + opt = options{ + dgraph: conf.GetString("dgraph"), + } + glog.Infof("Running transaction with dgraph endpoint: %v", opt.dgraph) + + if len(opt.dgraph) == 0 { + glog.Fatalf("The --dgraph option must be set in order to connect to dgraph") + } + + x.LoadTLSConfig(&tlsConf, CmdAcl.Conf, x.TlsClientCert, x.TlsClientKey) + tlsConf.ServerName = CmdAcl.Conf.GetString("tls_server_name") + + conn, err := x.SetupConnection(opt.dgraph, &tlsConf) + x.Checkf(err, "While trying to setup connection to Dgraph alpha.") + + dc := api.NewDgraphClient(conn) + return dgo.NewDgraphClient(dc), func() { + if err := conn.Close(); err != nil { + glog.Errorf("Error while closing connection:%v", err) + } + } +} + +func info(conf *viper.Viper) error { + userId := conf.GetString("user") + groupId := conf.GetString("group") + if (len(userId) == 0 && len(groupId) == 0) || + (len(userId) != 0 && len(groupId) != 0) { + return fmt.Errorf("either the user or group should be specified, not both") + } + + dc, close := getDgraphClient(conf) + defer close() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + txn := dc.NewTxn() + defer func() { + if err := txn.Discard(ctx); err != nil { + glog.Errorf("Unable to discard transaction:%v", err) + } + }() + + if len(userId) != 0 { + user, err := queryUser(ctx, txn, userId) + if err != nil { + return err + } + + var userBuf strings.Builder + userBuf.WriteString(fmt.Sprintf("user %v:\n", userId)) + userBuf.WriteString(fmt.Sprintf("uid:%v\nid:%v\n", user.Uid, user.UserID)) + var groupNames []string + for _, group := range user.Groups { + groupNames = append(groupNames, group.GroupID) + } + userBuf.WriteString(fmt.Sprintf("groups:%v\n", strings.Join(groupNames, " "))) + glog.Infof(userBuf.String()) + } + + if len(groupId) != 0 { + group, err := queryGroup(ctx, txn, groupId, "dgraph.xid", "~dgraph.user.group{dgraph.xid}", + "dgraph.group.acl") + if err != nil { + return err + } + // build the info string for group + var groupSB strings.Builder + groupSB.WriteString(fmt.Sprintf("group %v:\n", groupId)) + groupSB.WriteString(fmt.Sprintf("uid:%v\nid:%v\n", group.Uid, group.GroupID)) + + var userNames []string + for _, user := range group.Users { + userNames = append(userNames, user.UserID) + } + groupSB.WriteString(fmt.Sprintf("users:%v\n", strings.Join(userNames, " "))) + + var aclStrs []string + var acls []Acl + if err := json.Unmarshal([]byte(group.Acls), &acls); err != nil { + return fmt.Errorf("unable to unmarshal the acls associated with the group %v:%v", + groupId, err) + } + + for _, acl := range acls { + aclStrs = append(aclStrs, fmt.Sprintf("(predicate:%v,perm:%v)", acl.Predicate, acl.Perm)) + } + groupSB.WriteString(fmt.Sprintf("acls:%v\n", strings.Join(aclStrs, " "))) + + glog.Infof(groupSB.String()) + } + + return nil +} diff --git a/ee/acl/cmd/users.go b/ee/acl/cmd/users.go new file mode 100644 index 00000000000..b71c72cdc8b --- /dev/null +++ b/ee/acl/cmd/users.go @@ -0,0 +1,280 @@ +// +build !oss + +/* + * Copyright 2018 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Dgraph Community License (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * https://github.com/dgraph-io/dgraph/blob/master/licenses/DCL.txt + */ + +package acl + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/dgraph-io/dgo" + "github.com/dgraph-io/dgo/protos/api" + "github.com/dgraph-io/dgraph/ee/acl" + "github.com/dgraph-io/dgraph/x" + "github.com/golang/glog" + "github.com/spf13/viper" +) + +func userAdd(conf *viper.Viper) error { + userid := conf.GetString("user") + password := conf.GetString("password") + + if len(userid) == 0 { + return fmt.Errorf("the user must not be empty") + } + if len(password) == 0 { + return fmt.Errorf("the password must not be empty") + } + + dc, close := getDgraphClient(conf) + defer close() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + txn := dc.NewTxn() + defer func() { + if err := txn.Discard(ctx); err != nil { + glog.Errorf("Unable to discard transaction:%v", err) + } + }() + + user, err := queryUser(ctx, txn, userid) + if err != nil { + return fmt.Errorf("error while querying user:%v", err) + } + if user != nil { + return fmt.Errorf("unable to create user because of conflict: %v", userid) + } + + createUserNQuads := []*api.NQuad{ + { + Subject: "_:newuser", + Predicate: "dgraph.xid", + ObjectValue: &api.Value{Val: &api.Value_StrVal{StrVal: userid}}, + }, + { + Subject: "_:newuser", + Predicate: "dgraph.password", + ObjectValue: &api.Value{Val: &api.Value_StrVal{StrVal: password}}, + }} + + mu := &api.Mutation{ + CommitNow: true, + Set: createUserNQuads, + } + + if _, err := txn.Mutate(ctx, mu); err != nil { + return fmt.Errorf("unable to create user: %v", err) + } + + glog.Infof("Created new user with id %v", userid) + return nil +} + +func userDel(conf *viper.Viper) error { + userid := conf.GetString("user") + // validate the userid + if len(userid) == 0 { + return fmt.Errorf("the user id should not be empty") + } + + dc, close := getDgraphClient(conf) + defer close() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + txn := dc.NewTxn() + defer func() { + if err := txn.Discard(ctx); err != nil { + glog.Errorf("Unable to discard transaction:%v", err) + } + }() + + user, err := queryUser(ctx, txn, userid) + if err != nil { + return fmt.Errorf("error while querying user:%v", err) + } + + if user == nil || len(user.Uid) == 0 { + return fmt.Errorf("unable to delete user because it does not exist: %v", userid) + } + + deleteUserNQuads := []*api.NQuad{ + { + Subject: user.Uid, + Predicate: x.Star, + ObjectValue: &api.Value{Val: &api.Value_DefaultVal{DefaultVal: x.Star}}, + }} + + mu := &api.Mutation{ + CommitNow: true, + Del: deleteUserNQuads, + } + + if _, err = txn.Mutate(ctx, mu); err != nil { + return fmt.Errorf("unable to delete user: %v", err) + } + + glog.Infof("Deleted user with id %v", userid) + return nil +} + +func userLogin(conf *viper.Viper) error { + userid := conf.GetString("user") + password := conf.GetString("password") + + if len(userid) == 0 { + return fmt.Errorf("the user must not be empty") + } + if len(password) == 0 { + return fmt.Errorf("the password must not be empty") + } + + dc, close := getDgraphClient(conf) + defer close() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + txn := dc.NewTxn() + defer func() { + if err := txn.Discard(ctx); err != nil { + glog.Errorf("Unable to discard transaction:%v", err) + } + }() + + if err := dc.Login(ctx, userid, password); err != nil { + return fmt.Errorf("unable to login:%v", err) + } + updatedContext := dc.GetContext(ctx) + glog.Infof("Login successfully.\naccess jwt:\n%v\nrefresh jwt:\n%v", + updatedContext.Value("accessJwt"), updatedContext.Value("refreshJwt")) + return nil +} + +func queryUser(ctx context.Context, txn *dgo.Txn, userid string) (user *acl.User, err error) { + query := ` + query search($userid: string){ + user(func: eq(dgraph.xid, $userid)) { + uid + dgraph.xid + dgraph.user.group { + uid + dgraph.xid + } + } + }` + + queryVars := make(map[string]string) + queryVars["$userid"] = userid + + queryResp, err := txn.QueryWithVars(ctx, query, queryVars) + if err != nil { + return nil, fmt.Errorf("error while query user with id %s: %v", userid, err) + } + user, err = acl.UnmarshalUser(queryResp, "user") + if err != nil { + return nil, err + } + return user, nil +} + +func userMod(conf *viper.Viper) error { + userId := conf.GetString("user") + groups := conf.GetString("groups") + if len(userId) == 0 { + return fmt.Errorf("the user must not be empty") + } + + dc, close := getDgraphClient(conf) + defer close() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + txn := dc.NewTxn() + defer func() { + if err := txn.Discard(ctx); err != nil { + glog.Errorf("Unable to discard transaction:%v", err) + } + }() + + user, err := queryUser(ctx, txn, userId) + if err != nil { + return fmt.Errorf("error while querying user:%v", err) + } + if user == nil { + return fmt.Errorf("the user does not exist: %v", userId) + } + + targetGroupsMap := make(map[string]struct{}) + if len(groups) > 0 { + for _, g := range strings.Split(groups, ",") { + targetGroupsMap[g] = struct{}{} + } + } + + existingGroupsMap := make(map[string]struct{}) + for _, g := range user.Groups { + existingGroupsMap[g.GroupID] = struct{}{} + } + newGroups, groupsToBeDeleted := x.Diff(targetGroupsMap, existingGroupsMap) + + mu := &api.Mutation{ + CommitNow: true, + Set: []*api.NQuad{}, + Del: []*api.NQuad{}, + } + + for _, g := range newGroups { + glog.Infof("Adding user %v to group %v", userId, g) + nquad, err := getUserModNQuad(ctx, txn, user.Uid, g) + if err != nil { + return fmt.Errorf("error while getting the user mod nquad:%v", err) + } + mu.Set = append(mu.Set, nquad) + } + + for _, g := range groupsToBeDeleted { + glog.Infof("Deleting user %v from group %v", userId, g) + nquad, err := getUserModNQuad(ctx, txn, user.Uid, g) + if err != nil { + return fmt.Errorf("error while getting the user mod nquad:%v", err) + } + mu.Del = append(mu.Del, nquad) + } + if len(mu.Del) == 0 && len(mu.Set) == 0 { + glog.Infof("Nothing needs to be changed for the groups of user:%v", userId) + return nil + } + + if _, err := txn.Mutate(ctx, mu); err != nil { + return fmt.Errorf("error while mutating the group:%+v", err) + } + glog.Infof("Successfully modified groups for user %v", userId) + return nil +} + +func getUserModNQuad(ctx context.Context, txn *dgo.Txn, userId string, + groupId string) (*api.NQuad, error) { + group, err := queryGroup(ctx, txn, groupId) + if err != nil { + return nil, err + } + if group == nil { + return nil, fmt.Errorf("the group does not exist:%v", groupId) + } + + createUserGroupNQuads := &api.NQuad{ + Subject: userId, + Predicate: "dgraph.user.group", + ObjectId: group.Uid, + } + + return createUserGroupNQuads, nil +} diff --git a/ee/acl/utils.go b/ee/acl/utils.go new file mode 100644 index 00000000000..173ab28e673 --- /dev/null +++ b/ee/acl/utils.go @@ -0,0 +1,94 @@ +// +build !oss + +/* + * Copyright 2018 Dgraph Labs, Inc. All rights reserved. + * + * Licensed under the Dgraph Community License (the "License"); you + * may not use this file except in compliance with the License. You + * may obtain a copy of the License at + * + * https://github.com/dgraph-io/dgraph/blob/master/licenses/DCL.txt + */ + +package acl + +import ( + "encoding/json" + "fmt" + + "github.com/dgraph-io/dgo/protos/api" + "github.com/dgraph-io/dgraph/x" + "github.com/golang/glog" +) + +func GetGroupIDs(groups []Group) []string { + if len(groups) == 0 { + // the user does not have any groups + return nil + } + + jwtGroups := make([]string, 0, len(groups)) + for _, g := range groups { + jwtGroups = append(jwtGroups, g.GroupID) + } + return jwtGroups +} + +type User struct { + Uid string `json:"uid"` + UserID string `json:"dgraph.xid"` + Password string `json:"dgraph.password"` + PasswordMatch bool `json:"password_match"` + Groups []Group `json:"dgraph.user.group"` +} + +// Extract the first User pointed by the userKey in the query response +func UnmarshalUser(resp *api.Response, userKey string) (user *User, err error) { + m := make(map[string][]User) + + err = json.Unmarshal(resp.GetJson(), &m) + if err != nil { + return nil, fmt.Errorf("Unable to unmarshal the query user response for user:%v", err) + } + users := m[userKey] + if len(users) == 0 { + // the user does not exist + return nil, nil + } + if len(users) > 1 { + return nil, x.Errorf("Found multiple users: %s", resp.GetJson()) + } + return &users[0], nil +} + +// parse the response and check existing of the uid +type Group struct { + Uid string `json:"uid"` + GroupID string `json:"dgraph.xid"` + Users []User `json:"~dgraph.user.group"` + Acls string `json:"dgraph.group.acl"` +} + +// Extract the first User pointed by the userKey in the query response +func UnmarshalGroup(input []byte, groupKey string) (group *Group, err error) { + m := make(map[string][]Group) + + err = json.Unmarshal(input, &m) + if err != nil { + glog.Errorf("Unable to unmarshal the query group response:%v", err) + return nil, err + } + groups := m[groupKey] + if len(groups) == 0 { + // the group does not exist + return nil, nil + } + if len(groups) > 1 { + return nil, x.Errorf("Found multiple groups: %s", input) + } + return &groups[0], nil +} + +type JwtGroup struct { + Group string +} diff --git a/query/mutation.go b/query/mutation.go index a0c97f53091..3f47142386d 100644 --- a/query/mutation.go +++ b/query/mutation.go @@ -120,6 +120,9 @@ func verifyUid(ctx context.Context, uid uint64) error { return nil } +// AssignUids tries to assign unique ids to each identity in the subjects and objects in the +// format of _:xxx. An identity, e.g. _:a, will only be assigned one uid regardless how many times +// it shows up in the subjects or objects func AssignUids(ctx context.Context, nquads []*api.NQuad) (map[string]uint64, error) { newUids := make(map[string]uint64) num := &pb.Num{} diff --git a/worker/backup.go b/worker/backup.go index c4abb9d8ece..10f073e7f6a 100644 --- a/worker/backup.go +++ b/worker/backup.go @@ -19,23 +19,19 @@ package worker import ( - "errors" - "github.com/dgraph-io/dgraph/protos/pb" "github.com/dgraph-io/dgraph/x" "github.com/golang/glog" "golang.org/x/net/context" ) -var errNotSupported = errors.New("Feature available only in Dgraph Enterprise Edition.") - // Backup implements the Worker interface. func (w *grpcWorker) Backup(ctx context.Context, req *pb.BackupRequest) (*pb.Status, error) { - glog.Infof("Backup failed: %s", errNotSupported) - return &pb.Status{}, nil + glog.Warningf("Backup failed: %v", x.ErrNotSupported) + return &pb.Status{}, x.ErrNotSupported } // BackupOverNetwork handles a request coming from an HTTP client. func BackupOverNetwork(pctx context.Context, target string) error { - return x.Errorf("Backup failed: %s", errNotSupported) + return x.ErrNotSupported } diff --git a/worker/groups.go b/worker/groups.go index bde5fc55c10..a05fb5a3a94 100644 --- a/worker/groups.go +++ b/worker/groups.go @@ -139,11 +139,43 @@ func StartRaftNodes(walStore *badger.DB, bindall bool) { } func (g *groupi) proposeInitialSchema() { + // propose the schema for _predicate_ if !Config.ExpandEdge { return } + g.upsertSchema(&pb.SchemaUpdate{ + Predicate: x.PredicateListAttr, + ValueType: pb.Posting_STRING, + List: true, + }) + + // propose the schema update for acl predicates + g.upsertSchema(&pb.SchemaUpdate{ + Predicate: "dgraph.xid", + ValueType: pb.Posting_STRING, + Directive: pb.SchemaUpdate_INDEX, + Tokenizer: []string{"exact"}, + }) + + g.upsertSchema(&pb.SchemaUpdate{ + Predicate: "dgraph.password", + ValueType: pb.Posting_PASSWORD, + }) + + g.upsertSchema(&pb.SchemaUpdate{ + Predicate: "dgraph.user.group", + Directive: pb.SchemaUpdate_REVERSE, + ValueType: pb.Posting_UID, + }) + g.upsertSchema(&pb.SchemaUpdate{ + Predicate: "dgraph.group.acl", + ValueType: pb.Posting_STRING, + }) +} + +func (g *groupi) upsertSchema(schema *pb.SchemaUpdate) { g.RLock() - _, ok := g.tablets[x.PredicateListAttr] + _, ok := g.tablets[schema.Predicate] g.RUnlock() if ok { return @@ -153,11 +185,7 @@ func (g *groupi) proposeInitialSchema() { var m pb.Mutations // schema for _predicate_ is not changed once set. m.StartTs = 1 - m.Schema = append(m.Schema, &pb.SchemaUpdate{ - Predicate: x.PredicateListAttr, - ValueType: pb.Posting_STRING, - List: true, - }) + m.Schema = append(m.Schema, schema) // This would propose the schema mutation and make sure some node serves this predicate // and has the schema defined above. diff --git a/x/tls_helper.go b/x/tls_helper.go index ff5b10124a3..f556022ab8a 100644 --- a/x/tls_helper.go +++ b/x/tls_helper.go @@ -39,8 +39,6 @@ const ( const ( tlsRootCert = "ca.crt" - tlsNodeCert = "node.crt" - tlsNodeKey = "node.key" ) // TLSHelperConfig define params used to create a tls.Config @@ -61,13 +59,13 @@ func RegisterTLSFlags(flag *pflag.FlagSet) { flag.Bool("tls_use_system_ca", true, "Include System CA into CA Certs.") } -func LoadTLSConfig(conf *TLSHelperConfig, v *viper.Viper) { +func LoadTLSConfig(conf *TLSHelperConfig, v *viper.Viper, tlsCertFile string, tlsKeyFile string) { conf.CertDir = v.GetString("tls_dir") if conf.CertDir != "" { conf.CertRequired = true conf.RootCACert = path.Join(conf.CertDir, tlsRootCert) - conf.Cert = path.Join(conf.CertDir, tlsNodeCert) - conf.Key = path.Join(conf.CertDir, tlsNodeKey) + conf.Cert = path.Join(conf.CertDir, tlsCertFile) + conf.Key = path.Join(conf.CertDir, tlsKeyFile) conf.ClientAuth = v.GetString("tls_client_auth") } conf.UseSystemCACerts = v.GetBool("tls_use_system_ca") diff --git a/x/x.go b/x/x.go index c082d8eecd5..27bfca7a924 100644 --- a/x/x.go +++ b/x/x.go @@ -31,9 +31,14 @@ import ( "time" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" ) // Error constants representing different types of errors. +var ( + ErrNotSupported = fmt.Errorf("Feature available only in Dgraph Enterprise Edition.") +) + const ( Success = "Success" ErrorUnauthorized = "ErrorUnauthorized" @@ -46,7 +51,8 @@ const ( ErrorNoPermission = "ErrorNoPermission" ErrorInvalidMutation = "ErrorInvalidMutation" ErrorServiceUnavailable = "ErrorServiceUnavailable" - ValidHostnameRegex = "^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\\-]*[a-zA-Z0-9])\\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\\-]*[A-Za-z0-9])$" + + ValidHostnameRegex = "^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\\-]*[a-zA-Z0-9])\\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\\-]*[A-Za-z0-9])$" // When changing this value also remember to change in in client/client.go:DeleteEdges. Star = "_STAR_ALL" @@ -66,6 +72,9 @@ const ( // If the difference between AppliedUntil - TxnMarks.DoneUntil() is greater than this, we // start aborting old transactions. ForceAbortDifference = 5000 + + TlsClientCert = "client.crt" + TlsClientKey = "client.key" ) var ( @@ -413,3 +422,42 @@ func DivideAndRule(num int) (numGo, width int) { } return } + +func SetupConnection(host string, tlsConf *TLSHelperConfig) (*grpc.ClientConn, error) { + opts := append([]grpc.DialOption{}, + grpc.WithDefaultCallOptions( + grpc.MaxCallRecvMsgSize(GrpcMaxSize), + grpc.MaxCallSendMsgSize(GrpcMaxSize)), + grpc.WithBlock(), + grpc.WithTimeout(10*time.Second)) + + if tlsConf.CertRequired { + tlsConf.ConfigType = TLSClientConfig + tlsCfg, _, err := GenerateTLSConfig(*tlsConf) + if err != nil { + return nil, err + } + opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(tlsCfg))) + } else { + opts = append(opts, grpc.WithInsecure()) + } + return grpc.Dial(host, opts...) +} + +func Diff(targetMap map[string]struct{}, existingMap map[string]struct{}) ([]string, []string) { + var newGroups []string + var groupsToBeDeleted []string + + for g := range targetMap { + if _, ok := existingMap[g]; !ok { + newGroups = append(newGroups, g) + } + } + for g := range existingMap { + if _, ok := targetMap[g]; !ok { + groupsToBeDeleted = append(groupsToBeDeleted, g) + } + } + + return newGroups, groupsToBeDeleted +}