Skip to content

Commit

Permalink
domain, executor: make flush privilege propagate via etcd (#27958)
Browse files Browse the repository at this point in the history
  • Loading branch information
morgo authored Sep 30, 2021
1 parent 042f498 commit e52dbd6
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 38 deletions.
14 changes: 7 additions & 7 deletions domain/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -1387,7 +1387,7 @@ const (

// NotifyUpdatePrivilege updates privilege key in etcd, TiDB client that watches
// the key will get notification.
func (do *Domain) NotifyUpdatePrivilege(ctx sessionctx.Context) {
func (do *Domain) NotifyUpdatePrivilege() error {
if do.etcdClient != nil {
row := do.etcdClient.KV
_, err := row.Put(context.Background(), privilegeKey, "")
Expand All @@ -1396,13 +1396,13 @@ func (do *Domain) NotifyUpdatePrivilege(ctx sessionctx.Context) {
}
}
// update locally
exec := ctx.(sqlexec.RestrictedSQLExecutor)
if stmt, err := exec.ParseWithParams(context.Background(), `FLUSH PRIVILEGES`); err == nil {
_, _, err := exec.ExecRestrictedStmt(context.Background(), stmt)
if err != nil {
logutil.BgLogger().Error("unable to update privileges", zap.Error(err))
}
sysSessionPool := do.SysSessionPool()
ctx, err := sysSessionPool.Get()
if err != nil {
return err
}
defer sysSessionPool.Put(ctx)
return do.PrivilegeHandle().Update(ctx.(sessionctx.Context))
}

// NotifyUpdateSysVarCache updates the sysvar cache key in etcd, which other TiDB
Expand Down
3 changes: 1 addition & 2 deletions executor/grant.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,7 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error {
return err
}
isCommit = true
domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx)
return nil
return domain.GetDomain(e.ctx).NotifyUpdatePrivilege()
}

func containsNonDynamicPriv(privList []*ast.PrivElem) bool {
Expand Down
3 changes: 1 addition & 2 deletions executor/revoke.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,7 @@ func (e *RevokeExec) Next(ctx context.Context, req *chunk.Chunk) error {
return err
}
isCommit = true
domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx)
return nil
return domain.GetDomain(e.ctx).NotifyUpdatePrivilege()
}

// Checks that dynamic privileges are only of global scope.
Expand Down
40 changes: 13 additions & 27 deletions executor/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,7 @@ func (e *SimpleExec) executeSetDefaultRole(ctx context.Context, s *ast.SetDefaul
u, h := s.UserList[0].Username, s.UserList[0].Hostname
if u == sessionVars.User.Username && h == sessionVars.User.AuthHostname {
err = e.setDefaultRoleForCurrentUser(s)
domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx)
return
return domain.GetDomain(e.ctx).NotifyUpdatePrivilege()
}
}

Expand All @@ -411,8 +410,7 @@ func (e *SimpleExec) executeSetDefaultRole(ctx context.Context, s *ast.SetDefaul
if err != nil {
return
}
domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx)
return
return domain.GetDomain(e.ctx).NotifyUpdatePrivilege()
}

func (e *SimpleExec) setRoleRegular(s *ast.SetRoleStmt) error {
Expand Down Expand Up @@ -698,8 +696,7 @@ func (e *SimpleExec) executeRevokeRole(ctx context.Context, s *ast.RevokeRoleStm
if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "commit"); err != nil {
return err
}
domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx)
return nil
return domain.GetDomain(e.ctx).NotifyUpdatePrivilege()
}

func (e *SimpleExec) executeCommit(s *ast.CommitStmt) {
Expand Down Expand Up @@ -838,8 +835,7 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm
if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "commit"); err != nil {
return errors.Trace(err)
}
domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx)
return err
return domain.GetDomain(e.ctx).NotifyUpdatePrivilege()
}

func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt) error {
Expand Down Expand Up @@ -970,8 +966,7 @@ func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt)
e.ctx.GetSessionVars().StmtCtx.AppendNote(err)
}
}
domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx)
return nil
return domain.GetDomain(e.ctx).NotifyUpdatePrivilege()
}

func (e *SimpleExec) executeGrantRole(ctx context.Context, s *ast.GrantRoleStmt) error {
Expand Down Expand Up @@ -1031,8 +1026,7 @@ func (e *SimpleExec) executeGrantRole(ctx context.Context, s *ast.GrantRoleStmt)
if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "commit"); err != nil {
return err
}
domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx)
return nil
return domain.GetDomain(e.ctx).NotifyUpdatePrivilege()
}

// Should cover same internal mysql.* tables as DROP USER, so this function is very similar
Expand Down Expand Up @@ -1138,8 +1132,7 @@ func (e *SimpleExec) executeRenameUser(s *ast.RenameUserStmt) error {
}
return ErrCannotUser.GenWithStackByArgs("RENAME USER", failedUser)
}
domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx)
return nil
return domain.GetDomain(e.ctx).NotifyUpdatePrivilege()
}

func renameUserHostInSystemTable(sqlExecutor sqlexec.SQLExecutor, tableName, usernameColumn, hostColumn string, users *ast.UserToUser) error {
Expand Down Expand Up @@ -1301,8 +1294,7 @@ func (e *SimpleExec) executeDropUser(ctx context.Context, s *ast.DropUserStmt) e
}
return ErrCannotUser.GenWithStackByArgs("DROP USER", strings.Join(failedUsers, ","))
}
domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx)
return nil
return domain.GetDomain(e.ctx).NotifyUpdatePrivilege()
}

func userExists(ctx context.Context, sctx sessionctx.Context, name string, host string) (bool, error) {
Expand Down Expand Up @@ -1391,8 +1383,10 @@ func (e *SimpleExec) executeSetPwd(ctx context.Context, s *ast.SetPwdStmt) error
return err
}
_, _, err = exec.ExecRestrictedStmt(ctx, stmt)
domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx)
return err
if err != nil {
return err
}
return domain.GetDomain(e.ctx).NotifyUpdatePrivilege()
}

func (e *SimpleExec) executeKillStmt(ctx context.Context, s *ast.KillStmt) error {
Expand Down Expand Up @@ -1506,16 +1500,8 @@ func (e *SimpleExec) executeFlush(s *ast.FlushStmt) error {
if config.GetGlobalConfig().Security.SkipGrantTable {
return nil
}

dom := domain.GetDomain(e.ctx)
sysSessionPool := dom.SysSessionPool()
ctx, err := sysSessionPool.Get()
if err != nil {
return err
}
defer sysSessionPool.Put(ctx)
err = dom.PrivilegeHandle().Update(ctx.(sessionctx.Context))
return err
return dom.NotifyUpdatePrivilege()
case ast.FlushTiDBPlugin:
dom := domain.GetDomain(e.ctx)
for _, pluginName := range s.Plugins {
Expand Down

0 comments on commit e52dbd6

Please sign in to comment.