diff --git a/domain/domain.go b/domain/domain.go index e472ae096639e..6e8ce2abb3fc3 100644 --- a/domain/domain.go +++ b/domain/domain.go @@ -1387,7 +1387,7 @@ const ( // NotifyUpdatePrivilege updates privilege key in etcd, TiDB client that watches // the key will get notification. -func (do *Domain) NotifyUpdatePrivilege(ctx sessionctx.Context) { +func (do *Domain) NotifyUpdatePrivilege() error { if do.etcdClient != nil { row := do.etcdClient.KV _, err := row.Put(context.Background(), privilegeKey, "") @@ -1396,13 +1396,13 @@ func (do *Domain) NotifyUpdatePrivilege(ctx sessionctx.Context) { } } // update locally - exec := ctx.(sqlexec.RestrictedSQLExecutor) - if stmt, err := exec.ParseWithParams(context.Background(), `FLUSH PRIVILEGES`); err == nil { - _, _, err := exec.ExecRestrictedStmt(context.Background(), stmt) - if err != nil { - logutil.BgLogger().Error("unable to update privileges", zap.Error(err)) - } + sysSessionPool := do.SysSessionPool() + ctx, err := sysSessionPool.Get() + if err != nil { + return err } + defer sysSessionPool.Put(ctx) + return do.PrivilegeHandle().Update(ctx.(sessionctx.Context)) } // NotifyUpdateSysVarCache updates the sysvar cache key in etcd, which other TiDB diff --git a/executor/grant.go b/executor/grant.go index 913cbf52c2333..8523d8a915297 100644 --- a/executor/grant.go +++ b/executor/grant.go @@ -214,8 +214,7 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error { return err } isCommit = true - domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx) - return nil + return domain.GetDomain(e.ctx).NotifyUpdatePrivilege() } func containsNonDynamicPriv(privList []*ast.PrivElem) bool { diff --git a/executor/revoke.go b/executor/revoke.go index 606f9ea95785f..a00bedd375193 100644 --- a/executor/revoke.go +++ b/executor/revoke.go @@ -119,8 +119,7 @@ func (e *RevokeExec) Next(ctx context.Context, req *chunk.Chunk) error { return err } isCommit = true - domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx) - return nil + return domain.GetDomain(e.ctx).NotifyUpdatePrivilege() } // Checks that dynamic privileges are only of global scope. diff --git a/executor/simple.go b/executor/simple.go index f6b560843d86f..f4fb4715e3a90 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -388,8 +388,7 @@ func (e *SimpleExec) executeSetDefaultRole(ctx context.Context, s *ast.SetDefaul u, h := s.UserList[0].Username, s.UserList[0].Hostname if u == sessionVars.User.Username && h == sessionVars.User.AuthHostname { err = e.setDefaultRoleForCurrentUser(s) - domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx) - return + return domain.GetDomain(e.ctx).NotifyUpdatePrivilege() } } @@ -411,8 +410,7 @@ func (e *SimpleExec) executeSetDefaultRole(ctx context.Context, s *ast.SetDefaul if err != nil { return } - domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx) - return + return domain.GetDomain(e.ctx).NotifyUpdatePrivilege() } func (e *SimpleExec) setRoleRegular(s *ast.SetRoleStmt) error { @@ -698,8 +696,7 @@ func (e *SimpleExec) executeRevokeRole(ctx context.Context, s *ast.RevokeRoleStm if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "commit"); err != nil { return err } - domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx) - return nil + return domain.GetDomain(e.ctx).NotifyUpdatePrivilege() } func (e *SimpleExec) executeCommit(s *ast.CommitStmt) { @@ -838,8 +835,7 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "commit"); err != nil { return errors.Trace(err) } - domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx) - return err + return domain.GetDomain(e.ctx).NotifyUpdatePrivilege() } func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt) error { @@ -970,8 +966,7 @@ func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt) e.ctx.GetSessionVars().StmtCtx.AppendNote(err) } } - domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx) - return nil + return domain.GetDomain(e.ctx).NotifyUpdatePrivilege() } func (e *SimpleExec) executeGrantRole(ctx context.Context, s *ast.GrantRoleStmt) error { @@ -1031,8 +1026,7 @@ func (e *SimpleExec) executeGrantRole(ctx context.Context, s *ast.GrantRoleStmt) if _, err := sqlExecutor.ExecuteInternal(context.TODO(), "commit"); err != nil { return err } - domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx) - return nil + return domain.GetDomain(e.ctx).NotifyUpdatePrivilege() } // Should cover same internal mysql.* tables as DROP USER, so this function is very similar @@ -1138,8 +1132,7 @@ func (e *SimpleExec) executeRenameUser(s *ast.RenameUserStmt) error { } return ErrCannotUser.GenWithStackByArgs("RENAME USER", failedUser) } - domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx) - return nil + return domain.GetDomain(e.ctx).NotifyUpdatePrivilege() } func renameUserHostInSystemTable(sqlExecutor sqlexec.SQLExecutor, tableName, usernameColumn, hostColumn string, users *ast.UserToUser) error { @@ -1301,8 +1294,7 @@ func (e *SimpleExec) executeDropUser(ctx context.Context, s *ast.DropUserStmt) e } return ErrCannotUser.GenWithStackByArgs("DROP USER", strings.Join(failedUsers, ",")) } - domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx) - return nil + return domain.GetDomain(e.ctx).NotifyUpdatePrivilege() } func userExists(ctx context.Context, sctx sessionctx.Context, name string, host string) (bool, error) { @@ -1391,8 +1383,10 @@ func (e *SimpleExec) executeSetPwd(ctx context.Context, s *ast.SetPwdStmt) error return err } _, _, err = exec.ExecRestrictedStmt(ctx, stmt) - domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx) - return err + if err != nil { + return err + } + return domain.GetDomain(e.ctx).NotifyUpdatePrivilege() } func (e *SimpleExec) executeKillStmt(ctx context.Context, s *ast.KillStmt) error { @@ -1506,16 +1500,8 @@ func (e *SimpleExec) executeFlush(s *ast.FlushStmt) error { if config.GetGlobalConfig().Security.SkipGrantTable { return nil } - dom := domain.GetDomain(e.ctx) - sysSessionPool := dom.SysSessionPool() - ctx, err := sysSessionPool.Get() - if err != nil { - return err - } - defer sysSessionPool.Put(ctx) - err = dom.PrivilegeHandle().Update(ctx.(sessionctx.Context)) - return err + return dom.NotifyUpdatePrivilege() case ast.FlushTiDBPlugin: dom := domain.GetDomain(e.ctx) for _, pluginName := range s.Plugins {