diff --git a/session/session.go b/session/session.go index bb1f15836fd90..3f163146e8796 100644 --- a/session/session.go +++ b/session/session.go @@ -56,6 +56,7 @@ import ( "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/table/temptable" "github.com/pingcap/tidb/util/logutil/consistency" + "github.com/pingcap/tidb/util/sem" "github.com/pingcap/tidb/util/topsql" topsqlstate "github.com/pingcap/tidb/util/topsql/state" "github.com/pingcap/tidb/util/topsql/stmtstats" @@ -3496,10 +3497,45 @@ func (s *session) GetStmtStats() *stmtstats.StatementStats { // EncodeSessionStates implements SessionStatesHandler.EncodeSessionStates interface. func (s *session) EncodeSessionStates(ctx context.Context, sctx sessionctx.Context, sessionStates *sessionstates.SessionStates) (err error) { - return s.sessionVars.EncodeSessionStates(ctx, sessionStates) + if err = s.sessionVars.EncodeSessionStates(ctx, sessionStates); err != nil { + return err + } + + // Encode session variables. We put it here instead of SessionVars to avoid cycle import. + sessionStates.SystemVars = make(map[string]string) + for _, sv := range variable.GetSysVars() { + switch { + case sv.Hidden, sv.HasNoneScope(), sv.HasInstanceScope(), !sv.HasSessionScope(): + // Hidden and none-scoped variables cannot be modified. + // Instance-scoped variables don't need to be encoded. + // Noop variables should also be migrated even if they are noop. + continue + case sv.ReadOnly: + // Skip read-only variables here. We encode them into SessionStates manually. + continue + case sem.IsEnabled() && sem.IsInvisibleSysVar(sv.Name): + // If they are shown, there will be a security issue. + continue + } + // Get all session variables because the default values may change between versions. + if val, keep, err := variable.GetSessionStatesSystemVar(s.sessionVars, sv.Name); err == nil && keep { + sessionStates.SystemVars[sv.Name] = val + } + } + return } // DecodeSessionStates implements SessionStatesHandler.DecodeSessionStates interface. func (s *session) DecodeSessionStates(ctx context.Context, sctx sessionctx.Context, sessionStates *sessionstates.SessionStates) (err error) { - return s.sessionVars.DecodeSessionStates(ctx, sessionStates) + if err = s.sessionVars.DecodeSessionStates(ctx, sessionStates); err != nil { + return err + } + + // Decode session variables. + for name, val := range sessionStates.SystemVars { + if err = variable.SetSessionSystemVar(s.sessionVars, name, val); err != nil { + return err + } + } + return err } diff --git a/sessionctx/sessionstates/session_states.go b/sessionctx/sessionstates/session_states.go index 43adb554f5758..312cf891ec80e 100644 --- a/sessionctx/sessionstates/session_states.go +++ b/sessionctx/sessionstates/session_states.go @@ -24,4 +24,5 @@ import ( type SessionStates struct { UserVars map[string]*types.Datum `json:"user-var-values,omitempty"` UserVarTypes map[string]*ptypes.FieldType `json:"user-var-types,omitempty"` + SystemVars map[string]string `json:"sys-vars,omitempty"` } diff --git a/sessionctx/sessionstates/session_states_test.go b/sessionctx/sessionstates/session_states_test.go index 61413039f29f1..81e4cb6d5285a 100644 --- a/sessionctx/sessionstates/session_states_test.go +++ b/sessionctx/sessionstates/session_states_test.go @@ -16,11 +16,14 @@ package sessionstates_test import ( "fmt" + "strconv" "strings" "testing" "github.com/pingcap/tidb/errno" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/testkit" + "github.com/pingcap/tidb/util/sem" "github.com/stretchr/testify/require" ) @@ -80,12 +83,134 @@ func TestUserVars(t *testing.T) { } } +func TestSystemVars(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + + tests := []struct { + stmts []string + varName string + inSessionStates bool + checkStmt string + expectedValue string + }{ + { + // normal variable + inSessionStates: true, + varName: variable.TiDBMaxTiFlashThreads, + expectedValue: strconv.Itoa(variable.DefTiFlashMaxThreads), + }, + { + // hidden variable + inSessionStates: false, + varName: variable.TiDBTxnReadTS, + }, + { + // none-scoped variable + inSessionStates: false, + varName: variable.DataDir, + expectedValue: "/usr/local/mysql/data/", + }, + { + // instance-scoped variable + inSessionStates: false, + varName: variable.TiDBGeneralLog, + expectedValue: "0", + }, + { + // global-scoped variable + inSessionStates: false, + varName: variable.TiDBAutoAnalyzeStartTime, + expectedValue: variable.DefAutoAnalyzeStartTime, + }, + { + // sem invisible variable + inSessionStates: false, + varName: variable.TiDBAllowRemoveAutoInc, + }, + { + // noop variables + stmts: []string{"set sql_buffer_result=true"}, + inSessionStates: true, + varName: "sql_buffer_result", + expectedValue: "1", + }, + { + stmts: []string{"set transaction isolation level repeatable read"}, + inSessionStates: true, + varName: "tx_isolation_one_shot", + expectedValue: "REPEATABLE-READ", + }, + { + inSessionStates: false, + varName: variable.Timestamp, + }, + { + stmts: []string{"set timestamp=100"}, + inSessionStates: true, + varName: variable.Timestamp, + expectedValue: "100", + }, + { + stmts: []string{"set rand_seed1=10000000, rand_seed2=1000000"}, + inSessionStates: true, + varName: variable.RandSeed1, + checkStmt: "select rand()", + expectedValue: "0.028870999839968048", + }, + { + stmts: []string{"set rand_seed1=10000000, rand_seed2=1000000", "select rand()"}, + inSessionStates: true, + varName: variable.RandSeed1, + checkStmt: "select rand()", + expectedValue: "0.11641535266900002", + }, + } + + sem.Enable() + for _, tt := range tests { + tk1 := testkit.NewTestKit(t, store) + for _, stmt := range tt.stmts { + if strings.HasPrefix(stmt, "select") { + tk1.MustQuery(stmt) + } else { + tk1.MustExec(stmt) + } + } + tk2 := testkit.NewTestKit(t, store) + rows := tk1.MustQuery("show session_states").Rows() + state := rows[0][0].(string) + msg := fmt.Sprintf("var name: '%s', expected value: '%s'", tt.varName, tt.expectedValue) + require.Equal(t, tt.inSessionStates, strings.Contains(state, tt.varName), msg) + state = strconv.Quote(state) + setSQL := fmt.Sprintf("set session_states %s", state) + tk2.MustExec(setSQL) + if len(tt.expectedValue) > 0 { + checkStmt := tt.checkStmt + if len(checkStmt) == 0 { + checkStmt = fmt.Sprintf("select @@%s", tt.varName) + } + tk2.MustQuery(checkStmt).Check(testkit.Rows(tt.expectedValue)) + } + } + + { + // The session value should not change even if the global value changes. + tk1 := testkit.NewTestKit(t, store) + tk1.MustQuery("select @@autocommit").Check(testkit.Rows("1")) + tk2 := testkit.NewTestKit(t, store) + tk2.MustExec("set global autocommit=0") + tk3 := testkit.NewTestKit(t, store) + showSessionStatesAndSet(t, tk1, tk3) + tk3.MustQuery("select @@autocommit").Check(testkit.Rows("1")) + } +} + func showSessionStatesAndSet(t *testing.T, tk1, tk2 *testkit.TestKit) { rows := tk1.MustQuery("show session_states").Rows() require.Len(t, rows, 1) state := rows[0][0].(string) - state = strings.ReplaceAll(state, "\\", "\\\\") - state = strings.ReplaceAll(state, "'", "\\'") - setSQL := fmt.Sprintf("set session_states '%s'", state) + state = strconv.Quote(state) + setSQL := fmt.Sprintf("set session_states %s", state) tk2.MustExec(setSQL) } diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 3605e94625902..ce3cdc66bfc7b 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -77,6 +77,9 @@ var defaultSysVars = []*SysVar{ } timestamp := s.StmtCtx.GetOrStoreStmtCache(stmtctx.StmtNowTsCacheKey, time.Now()).(time.Time) return types.ToString(float64(timestamp.UnixNano()) / float64(time.Second)) + }, GetStateValue: func(s *SessionVars) (string, bool, error) { + timestamp, ok := s.systems[Timestamp] + return timestamp, ok && timestamp != DefTimestamp, nil }}, {Scope: ScopeSession, Name: WarningCount, Value: "0", ReadOnly: true, skipInit: true, GetSession: func(s *SessionVars) (string, error) { return strconv.Itoa(s.SysWarningCount), nil @@ -86,9 +89,13 @@ var defaultSysVars = []*SysVar{ }}, {Scope: ScopeSession, Name: LastInsertID, Value: "", skipInit: true, Type: TypeInt, AllowEmpty: true, MinValue: 0, MaxValue: math.MaxInt64, GetSession: func(s *SessionVars) (string, error) { return strconv.FormatUint(s.StmtCtx.PrevLastInsertID, 10), nil + }, GetStateValue: func(s *SessionVars) (string, bool, error) { + return "", false, nil }}, {Scope: ScopeSession, Name: Identity, Value: "", skipInit: true, Type: TypeInt, AllowEmpty: true, MinValue: 0, MaxValue: math.MaxInt64, GetSession: func(s *SessionVars) (string, error) { return strconv.FormatUint(s.StmtCtx.PrevLastInsertID, 10), nil + }, GetStateValue: func(s *SessionVars) (string, bool, error) { + return "", false, nil }}, /* TiDB specific variables */ // TODO: TiDBTxnScope is hidden because local txn feature is not done. @@ -192,6 +199,11 @@ var defaultSysVars = []*SysVar{ s.txnIsolationLevelOneShot.state = oneShotSet s.txnIsolationLevelOneShot.value = val return nil + }, GetStateValue: func(s *SessionVars) (string, bool, error) { + if s.txnIsolationLevelOneShot.state != oneShotDef { + return s.txnIsolationLevelOneShot.value, true, nil + } + return "", false, nil }}, {Scope: ScopeSession, Name: TiDBOptimizerSelectivityLevel, Value: strconv.Itoa(DefTiDBOptimizerSelectivityLevel), skipInit: true, Type: TypeUnsigned, MinValue: 0, MaxValue: math.MaxInt32, SetSession: func(s *SessionVars, val string) error { s.OptimizerSelectivityLevel = tidbOptPositiveInt32(val, DefTiDBOptimizerSelectivityLevel) @@ -307,12 +319,16 @@ var defaultSysVars = []*SysVar{ return nil }, GetSession: func(s *SessionVars) (string, error) { return "0", nil + }, GetStateValue: func(s *SessionVars) (string, bool, error) { + return strconv.FormatUint(uint64(s.Rng.GetSeed1()), 10), true, nil }}, {Scope: ScopeSession, Name: RandSeed2, Type: TypeInt, Value: "0", skipInit: true, MaxValue: math.MaxInt32, SetSession: func(s *SessionVars, val string) error { s.Rng.SetSeed2(uint32(tidbOptPositiveInt32(val, 0))) return nil }, GetSession: func(s *SessionVars) (string, error) { return "0", nil + }, GetStateValue: func(s *SessionVars) (string, bool, error) { + return strconv.FormatUint(uint64(s.Rng.GetSeed2()), 10), true, nil }}, {Scope: ScopeSession, Name: TiDBReadConsistency, Value: string(ReadConsistencyStrict), Type: TypeStr, Hidden: true, Validation: func(_ *SessionVars, normalized string, _ string, _ ScopeFlag) (string, error) { diff --git a/sessionctx/variable/variable.go b/sessionctx/variable/variable.go index db747819dee42..8a882f1d6e4f2 100644 --- a/sessionctx/variable/variable.go +++ b/sessionctx/variable/variable.go @@ -132,6 +132,9 @@ type SysVar struct { GetSession func(*SessionVars) (string, error) // GetGlobal is a getter function for global scope. GetGlobal func(*SessionVars) (string, error) + // GetStateValue gets the value for session states, which is used for migrating sessions. + // We need a function to override GetSession sometimes, because GetSession may not return the real value. + GetStateValue func(*SessionVars) (string, bool, error) // skipInit defines if the sysvar should be loaded into the session on init. // This is only important to set for sysvars that include session scope, // since global scoped sysvars are not-applicable. diff --git a/sessionctx/variable/varsutil.go b/sessionctx/variable/varsutil.go index 39ec20cbe2fb1..ab878d2bb3054 100644 --- a/sessionctx/variable/varsutil.go +++ b/sessionctx/variable/varsutil.go @@ -193,6 +193,29 @@ func GetSessionOrGlobalSystemVar(s *SessionVars, name string) (string, error) { return sv.GetGlobalFromHook(s) } +// GetSessionStatesSystemVar gets the session variable value for session states. +// It's only used for encoding session states when migrating a session. +// The returned boolean indicates whether to keep this value in the session states. +func GetSessionStatesSystemVar(s *SessionVars, name string) (string, bool, error) { + sv := GetSysVar(name) + if sv == nil { + return "", false, ErrUnknownSystemVar.GenWithStackByArgs(name) + } + // Call GetStateValue first if it exists. Otherwise, call GetSession. + if sv.GetStateValue != nil { + return sv.GetStateValue(s) + } + if sv.GetSession != nil { + val, err := sv.GetSessionFromHook(s) + return val, err == nil, err + } + // Only get the cached value. No need to check the global or default value. + if val, ok := s.systems[sv.Name]; ok { + return val, true, nil + } + return "", false, nil +} + // GetGlobalSystemVar gets a global system variable. func GetGlobalSystemVar(s *SessionVars, name string) (string, error) { sv := GetSysVar(name) diff --git a/sessionctx/variable/varsutil_test.go b/sessionctx/variable/varsutil_test.go index 59cfb4cfca81b..4641a8c2f1e0d 100644 --- a/sessionctx/variable/varsutil_test.go +++ b/sessionctx/variable/varsutil_test.go @@ -673,3 +673,22 @@ func TestStmtVars(t *testing.T) { err = SetStmtVar(vars, MaxExecutionTime, "100") require.NoError(t, err) } + +func TestSessionStatesSystemVar(t *testing.T) { + vars := NewSessionVars() + err := SetSessionSystemVar(vars, "autocommit", "1") + require.NoError(t, err) + val, keep, err := GetSessionStatesSystemVar(vars, "autocommit") + require.NoError(t, err) + require.Equal(t, "ON", val) + require.Equal(t, true, keep) + _, keep, err = GetSessionStatesSystemVar(vars, Timestamp) + require.NoError(t, err) + require.Equal(t, false, keep) + err = SetSessionSystemVar(vars, MaxAllowedPacket, "1024") + require.NoError(t, err) + val, keep, err = GetSessionStatesSystemVar(vars, MaxAllowedPacket) + require.NoError(t, err) + require.Equal(t, "1024", val) + require.Equal(t, true, keep) +} diff --git a/util/mathutil/rand.go b/util/mathutil/rand.go index 6c93588a91129..a58c88281d638 100644 --- a/util/mathutil/rand.go +++ b/util/mathutil/rand.go @@ -67,3 +67,17 @@ func (rng *MysqlRng) SetSeed2(seed uint32) { defer rng.mu.Unlock() rng.seed2 = seed } + +// GetSeed1 is an interface to get seed1. It's only used for getting session states. +func (rng *MysqlRng) GetSeed1() uint32 { + rng.mu.Lock() + defer rng.mu.Unlock() + return rng.seed1 +} + +// GetSeed2 is an interface to get seed2. It's only used for getting session states. +func (rng *MysqlRng) GetSeed2() uint32 { + rng.mu.Lock() + defer rng.mu.Unlock() + return rng.seed2 +} diff --git a/util/mathutil/rand_test.go b/util/mathutil/rand_test.go index d0164f4de201f..0cc026604431c 100644 --- a/util/mathutil/rand_test.go +++ b/util/mathutil/rand_test.go @@ -68,4 +68,6 @@ func TestRandWithSeed1AndSeed2(t *testing.T) { require.Equal(t, rng.Gen(), 0.028870999839968048) require.Equal(t, rng.Gen(), 0.11641535266900002) require.Equal(t, rng.Gen(), 0.49546379455874096) + require.Equal(t, rng.GetSeed1(), uint32(532000198)) + require.Equal(t, rng.GetSeed2(), uint32(689000330)) }