From 4fc9551dd13d8cb2dbbe55ae1ff707ce8e3e4637 Mon Sep 17 00:00:00 2001
From: djshow832 <zhangming@pingcap.com>
Date: Wed, 15 Jun 2022 21:02:34 +0800
Subject: [PATCH] sessionctx, types, executor: support encoding and decoding
 user-defined variables (#35343)

close pingcap/tidb#35288
---
 executor/executor_test.go                     |  9 --
 executor/show.go                              | 24 +++++
 executor/simple.go                            | 11 ++-
 session/session.go                            | 11 +++
 sessionctx/context.go                         | 10 ++
 sessionctx/sessionstates/session_states.go    | 27 ++++++
 .../sessionstates/session_states_test.go      | 91 +++++++++++++++++++
 sessionctx/variable/session.go                | 35 +++++++
 types/datum.go                                | 57 ++++++++++++
 types/datum_test.go                           | 55 +++++++++++
 types/mydecimal.go                            | 36 ++++++++
 types/mydecimal_test.go                       | 20 ++++
 types/time.go                                 | 11 +++
 types/time_test.go                            | 12 +++
 util/mock/context.go                          | 11 +++
 15 files changed, 410 insertions(+), 10 deletions(-)
 create mode 100644 sessionctx/sessionstates/session_states.go
 create mode 100644 sessionctx/sessionstates/session_states_test.go

diff --git a/executor/executor_test.go b/executor/executor_test.go
index 1d486bd7ec598..c8e4304f1c22c 100644
--- a/executor/executor_test.go
+++ b/executor/executor_test.go
@@ -6070,12 +6070,3 @@ func TestIsFastPlan(t *testing.T) {
 		require.Equal(t, ca.isFastPlan, ok)
 	}
 }
-
-func TestShowSessionStates(t *testing.T) {
-	store, clean := testkit.CreateMockStore(t)
-	defer clean()
-	tk := testkit.NewTestKit(t, store)
-	tk.MustQuery("show session_states").Check(testkit.Rows())
-	tk.MustExec("set session_states 'x'")
-	tk.MustGetErrCode("set session_states 1", errno.ErrParse)
-}
diff --git a/executor/show.go b/executor/show.go
index 04153d27c0bd2..4b6d35d8ef187 100644
--- a/executor/show.go
+++ b/executor/show.go
@@ -45,6 +45,7 @@ import (
 	"github.com/pingcap/tidb/privilege"
 	"github.com/pingcap/tidb/privilege/privileges"
 	"github.com/pingcap/tidb/sessionctx"
+	"github.com/pingcap/tidb/sessionctx/sessionstates"
 	"github.com/pingcap/tidb/sessionctx/stmtctx"
 	"github.com/pingcap/tidb/sessionctx/variable"
 	"github.com/pingcap/tidb/store/helper"
@@ -1930,6 +1931,29 @@ func (e *ShowExec) fetchShowBuiltins() error {
 }
 
 func (e *ShowExec) fetchShowSessionStates(ctx context.Context) error {
+	sessionStates := &sessionstates.SessionStates{}
+	err := e.ctx.EncodeSessionStates(ctx, e.ctx, sessionStates)
+	if err != nil {
+		return err
+	}
+	stateBytes, err := gjson.Marshal(sessionStates)
+	if err != nil {
+		return errors.Trace(err)
+	}
+	stateJSON := json.BinaryJSON{}
+	if err = stateJSON.UnmarshalJSON(stateBytes); err != nil {
+		return err
+	}
+	// This will be implemented in future PRs.
+	tokenBytes, err := gjson.Marshal("")
+	if err != nil {
+		return errors.Trace(err)
+	}
+	tokenJSON := json.BinaryJSON{}
+	if err = tokenJSON.UnmarshalJSON(tokenBytes); err != nil {
+		return err
+	}
+	e.appendRow([]interface{}{stateJSON, tokenJSON})
 	return nil
 }
 
diff --git a/executor/simple.go b/executor/simple.go
index 45982938af878..cfede41d49057 100644
--- a/executor/simple.go
+++ b/executor/simple.go
@@ -15,7 +15,9 @@
 package executor
 
 import (
+	"bytes"
 	"context"
+	"encoding/json"
 	"fmt"
 	"os"
 	"strings"
@@ -41,6 +43,7 @@ import (
 	"github.com/pingcap/tidb/plugin"
 	"github.com/pingcap/tidb/privilege"
 	"github.com/pingcap/tidb/sessionctx"
+	"github.com/pingcap/tidb/sessionctx/sessionstates"
 	"github.com/pingcap/tidb/sessionctx/variable"
 	"github.com/pingcap/tidb/sessiontxn"
 	"github.com/pingcap/tidb/types"
@@ -1687,7 +1690,13 @@ func asyncDelayShutdown(p *os.Process, delay time.Duration) {
 }
 
 func (e *SimpleExec) executeSetSessionStates(ctx context.Context, s *ast.SetSessionStatesStmt) error {
-	return nil
+	var sessionStates sessionstates.SessionStates
+	decoder := json.NewDecoder(bytes.NewReader([]byte(s.SessionStates)))
+	decoder.UseNumber()
+	if err := decoder.Decode(&sessionStates); err != nil {
+		return errors.Trace(err)
+	}
+	return e.ctx.DecodeSessionStates(ctx, e.ctx, &sessionStates)
 }
 
 func (e *SimpleExec) executeAdmin(s *ast.AdminStmt) error {
diff --git a/session/session.go b/session/session.go
index c6c82da7f8025..d913a2e808071 100644
--- a/session/session.go
+++ b/session/session.go
@@ -47,6 +47,7 @@ import (
 	"github.com/pingcap/tidb/parser/model"
 	"github.com/pingcap/tidb/parser/mysql"
 	"github.com/pingcap/tidb/parser/terror"
+	"github.com/pingcap/tidb/sessionctx/sessionstates"
 	"github.com/pingcap/tidb/sessiontxn"
 	"github.com/pingcap/tidb/sessiontxn/legacy"
 	"github.com/pingcap/tidb/sessiontxn/staleread"
@@ -3500,3 +3501,13 @@ func (s *session) getSnapshotInterceptor() kv.SnapshotInterceptor {
 func (s *session) GetStmtStats() *stmtstats.StatementStats {
 	return s.stmtStats
 }
+
+// EncodeSessionStates implements SessionStatesHandler.EncodeSessionStates interface.
+func (s *session) EncodeSessionStates(ctx context.Context, sctx sessionctx.Context, sessionStates *sessionstates.SessionStates) (err error) {
+	return s.sessionVars.EncodeSessionStates(ctx, sessionStates)
+}
+
+// DecodeSessionStates implements SessionStatesHandler.DecodeSessionStates interface.
+func (s *session) DecodeSessionStates(ctx context.Context, sctx sessionctx.Context, sessionStates *sessionstates.SessionStates) (err error) {
+	return s.sessionVars.DecodeSessionStates(ctx, sessionStates)
+}
diff --git a/sessionctx/context.go b/sessionctx/context.go
index be3329c1de4d3..852253f819a08 100644
--- a/sessionctx/context.go
+++ b/sessionctx/context.go
@@ -24,6 +24,7 @@ import (
 	"github.com/pingcap/tidb/kv"
 	"github.com/pingcap/tidb/metrics"
 	"github.com/pingcap/tidb/parser/model"
+	"github.com/pingcap/tidb/sessionctx/sessionstates"
 	"github.com/pingcap/tidb/sessionctx/variable"
 	"github.com/pingcap/tidb/util"
 	"github.com/pingcap/tidb/util/kvcache"
@@ -41,8 +42,17 @@ type InfoschemaMetaVersion interface {
 	SchemaMetaVersion() int64
 }
 
+// SessionStatesHandler is an interface for encoding and decoding session states.
+type SessionStatesHandler interface {
+	// EncodeSessionStates encodes session states into a JSON.
+	EncodeSessionStates(context.Context, Context, *sessionstates.SessionStates) error
+	// DecodeSessionStates decodes a map into session states.
+	DecodeSessionStates(context.Context, Context, *sessionstates.SessionStates) error
+}
+
 // Context is an interface for transaction and executive args environment.
 type Context interface {
+	SessionStatesHandler
 	// NewTxn creates a new transaction for further execution.
 	// If old transaction is valid, it is committed first.
 	// It's used in BEGIN statement and DDL statements to commit old transaction.
diff --git a/sessionctx/sessionstates/session_states.go b/sessionctx/sessionstates/session_states.go
new file mode 100644
index 0000000000000..43adb554f5758
--- /dev/null
+++ b/sessionctx/sessionstates/session_states.go
@@ -0,0 +1,27 @@
+// Copyright 2022 PingCAP, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sessionstates
+
+import (
+	ptypes "github.com/pingcap/tidb/parser/types"
+	"github.com/pingcap/tidb/types"
+)
+
+// SessionStates contains all the states in the session that should be migrated when the session
+// is migrated to another server. It is shown by `show session_states` and recovered by `set session_states`.
+type SessionStates struct {
+	UserVars     map[string]*types.Datum      `json:"user-var-values,omitempty"`
+	UserVarTypes map[string]*ptypes.FieldType `json:"user-var-types,omitempty"`
+}
diff --git a/sessionctx/sessionstates/session_states_test.go b/sessionctx/sessionstates/session_states_test.go
new file mode 100644
index 0000000000000..61413039f29f1
--- /dev/null
+++ b/sessionctx/sessionstates/session_states_test.go
@@ -0,0 +1,91 @@
+// Copyright 2022 PingCAP, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sessionstates_test
+
+import (
+	"fmt"
+	"strings"
+	"testing"
+
+	"github.com/pingcap/tidb/errno"
+	"github.com/pingcap/tidb/testkit"
+	"github.com/stretchr/testify/require"
+)
+
+func TestGrammar(t *testing.T) {
+	store, clean := testkit.CreateMockStore(t)
+	defer clean()
+	tk := testkit.NewTestKit(t, store)
+	rows := tk.MustQuery("show session_states").Rows()
+	require.Len(t, rows, 1)
+	tk.MustExec("set session_states '{}'")
+	tk.MustGetErrCode("set session_states 1", errno.ErrParse)
+}
+
+func TestUserVars(t *testing.T) {
+	store, clean := testkit.CreateMockStore(t)
+	defer clean()
+	tk := testkit.NewTestKit(t, store)
+	tk.MustExec("create table test.t1(" +
+		"j json, b blob, s varchar(255), st set('red', 'green', 'blue'), en enum('red', 'green', 'blue'))")
+	tk.MustExec("insert into test.t1 values('{\"color:\": \"red\"}', 'red', 'red', 'red,green', 'red')")
+
+	tests := []string{
+		"",
+		"set @%s=null",
+		"set @%s=1",
+		"set @%s=1.0e10",
+		"set @%s=1.0-1",
+		"set @%s=now()",
+		"set @%s=1, @%s=1.0-1",
+		"select @%s:=1+1",
+		// TiDB doesn't support following features.
+		//"select j into @%s from test.t1",
+		//"select j,b,s,st,en into @%s,@%s,@%s,@%s,@%s from test.t1",
+	}
+
+	for _, tt := range tests {
+		tk1 := testkit.NewTestKit(t, store)
+		tk2 := testkit.NewTestKit(t, store)
+		namesNum := strings.Count(tt, "%s")
+		names := make([]any, 0, namesNum)
+		for i := 0; i < namesNum; i++ {
+			names = append(names, fmt.Sprintf("a%d", i))
+		}
+		var sql string
+		if len(tt) > 0 {
+			sql = fmt.Sprintf(tt, names...)
+			tk1.MustExec(sql)
+		}
+		showSessionStatesAndSet(t, tk1, tk2)
+		for _, name := range names {
+			sql := fmt.Sprintf("select @%s", name)
+			msg := fmt.Sprintf("sql: %s, var name: %s", sql, name)
+			value1 := tk1.MustQuery(sql).Rows()[0][0]
+			value2 := tk2.MustQuery(sql).Rows()[0][0]
+			require.Equal(t, value1, value2, msg)
+		}
+	}
+}
+
+func showSessionStatesAndSet(t *testing.T, tk1, tk2 *testkit.TestKit) {
+	rows := tk1.MustQuery("show session_states").Rows()
+	require.Len(t, rows, 1)
+	state := rows[0][0].(string)
+	state = strings.ReplaceAll(state, "\\", "\\\\")
+	state = strings.ReplaceAll(state, "'", "\\'")
+	setSQL := fmt.Sprintf("set session_states '%s'", state)
+	tk2.MustExec(setSQL)
+}
diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go
index edc2942459fa2..735105c57d4db 100644
--- a/sessionctx/variable/session.go
+++ b/sessionctx/variable/session.go
@@ -16,6 +16,7 @@ package variable
 
 import (
 	"bytes"
+	"context"
 	"crypto/tls"
 	"encoding/binary"
 	"fmt"
@@ -41,6 +42,8 @@ import (
 	"github.com/pingcap/tidb/parser/model"
 	"github.com/pingcap/tidb/parser/mysql"
 	"github.com/pingcap/tidb/parser/terror"
+	ptypes "github.com/pingcap/tidb/parser/types"
+	"github.com/pingcap/tidb/sessionctx/sessionstates"
 	"github.com/pingcap/tidb/sessionctx/stmtctx"
 	pumpcli "github.com/pingcap/tidb/tidb-binlog/pump_client"
 	"github.com/pingcap/tidb/types"
@@ -1834,6 +1837,38 @@ func (s *SessionVars) GetTemporaryTable(tblInfo *model.TableInfo) tableutil.Temp
 	return nil
 }
 
+// EncodeSessionStates saves session states into SessionStates.
+func (s *SessionVars) EncodeSessionStates(ctx context.Context, sessionStates *sessionstates.SessionStates) (err error) {
+	// Encode user-defined variables.
+	s.UsersLock.RLock()
+	sessionStates.UserVars = make(map[string]*types.Datum, len(s.Users))
+	for name, userVar := range s.Users {
+		sessionStates.UserVars[name] = userVar.Clone()
+	}
+	sessionStates.UserVarTypes = make(map[string]*ptypes.FieldType, len(s.UserVarTypes))
+	for name, userVarType := range s.UserVarTypes {
+		sessionStates.UserVarTypes[name] = userVarType.Clone()
+	}
+	s.UsersLock.RUnlock()
+	return
+}
+
+// DecodeSessionStates restores session states from SessionStates.
+func (s *SessionVars) DecodeSessionStates(ctx context.Context, sessionStates *sessionstates.SessionStates) (err error) {
+	// Decode user-defined variables.
+	s.UsersLock.Lock()
+	s.Users = make(map[string]types.Datum, len(sessionStates.UserVars))
+	for name, userVar := range sessionStates.UserVars {
+		s.Users[name] = *userVar.Clone()
+	}
+	s.UserVarTypes = make(map[string]*ptypes.FieldType, len(sessionStates.UserVarTypes))
+	for name, userVarType := range sessionStates.UserVarTypes {
+		s.UserVarTypes[name] = userVarType.Clone()
+	}
+	s.UsersLock.Unlock()
+	return
+}
+
 // TableDelta stands for the changed count for one table or partition.
 type TableDelta struct {
 	Delta    int64
diff --git a/types/datum.go b/types/datum.go
index 8bc8fddafe870..5613972c24016 100644
--- a/types/datum.go
+++ b/types/datum.go
@@ -15,6 +15,7 @@
 package types
 
 import (
+	gjson "encoding/json"
 	"fmt"
 	"math"
 	"sort"
@@ -2008,6 +2009,62 @@ func (d *Datum) MemUsage() (sum int64) {
 	return EmptyDatumSize + int64(cap(d.b)) + int64(len(d.collation))
 }
 
+type jsonDatum struct {
+	K         byte       `json:"k"`
+	Decimal   uint16     `json:"decimal,omitempty"`
+	Length    uint32     `json:"length,omitempty"`
+	I         int64      `json:"i,omitempty"`
+	Collation string     `json:"collation,omitempty"`
+	B         []byte     `json:"b,omitempty"`
+	Time      Time       `json:"time,omitempty"`
+	MyDecimal *MyDecimal `json:"mydecimal,omitempty"`
+}
+
+// MarshalJSON implements Marshaler.MarshalJSON interface.
+func (d *Datum) MarshalJSON() ([]byte, error) {
+	jd := &jsonDatum{
+		K:         d.k,
+		Decimal:   d.decimal,
+		Length:    d.length,
+		I:         d.i,
+		Collation: d.collation,
+		B:         d.b,
+	}
+	switch d.k {
+	case KindMysqlTime:
+		jd.Time = d.GetMysqlTime()
+	case KindMysqlDecimal:
+		jd.MyDecimal = d.GetMysqlDecimal()
+	default:
+		if d.x != nil {
+			return nil, errors.New(fmt.Sprintf("unsupported type: %d", d.k))
+		}
+	}
+	return gjson.Marshal(jd)
+}
+
+// UnmarshalJSON implements Unmarshaler.UnmarshalJSON interface.
+func (d *Datum) UnmarshalJSON(data []byte) error {
+	var jd jsonDatum
+	if err := gjson.Unmarshal(data, &jd); err != nil {
+		return err
+	}
+	d.k = jd.K
+	d.decimal = jd.Decimal
+	d.length = jd.Length
+	d.i = jd.I
+	d.collation = jd.Collation
+	d.b = jd.B
+
+	switch jd.K {
+	case KindMysqlTime:
+		d.SetMysqlTime(jd.Time)
+	case KindMysqlDecimal:
+		d.SetMysqlDecimal(jd.MyDecimal)
+	}
+	return nil
+}
+
 func invalidConv(d *Datum, tp byte) (Datum, error) {
 	return Datum{}, errors.Errorf("cannot convert datum from %s to type %s", KindStr(d.Kind()), TypeStr(tp))
 }
diff --git a/types/datum_test.go b/types/datum_test.go
index fbcafa9b29b04..75627791c2342 100644
--- a/types/datum_test.go
+++ b/types/datum_test.go
@@ -15,6 +15,7 @@
 package types
 
 import (
+	gjson "encoding/json"
 	"fmt"
 	"math"
 	"reflect"
@@ -22,6 +23,7 @@ import (
 	"testing"
 	"time"
 
+	"github.com/pingcap/tidb/parser/charset"
 	"github.com/pingcap/tidb/parser/mysql"
 	"github.com/pingcap/tidb/sessionctx/stmtctx"
 	"github.com/pingcap/tidb/types/json"
@@ -546,6 +548,59 @@ func TestStringToMysqlBit(t *testing.T) {
 	}
 }
 
+func TestMarshalDatum(t *testing.T) {
+	e, err := ParseSetValue([]string{"a", "b", "c", "d", "e"}, uint64(1))
+	require.NoError(t, err)
+	tests := []Datum{
+		NewIntDatum(1),
+		NewUintDatum(72),
+		NewFloat32Datum(1.23),
+		NewFloat64Datum(1.23),
+		NewDatum(math.Inf(-1)),
+		NewDecimalDatum(NewDecFromStringForTest("1.2345")),
+		NewStringDatum("abcde"),
+		NewCollationStringDatum("abcde", charset.CollationBin),
+		NewDurationDatum(Duration{Duration: time.Duration(1)}),
+		NewTimeDatum(NewTime(FromGoTime(time.Date(2018, 3, 8, 16, 1, 0, 315313000, time.UTC)), mysql.TypeTimestamp, 6)),
+		NewBytesDatum([]byte("abcde")),
+		NewBinaryLiteralDatum([]byte{0x81}),
+		NewMysqlBitDatum(NewBinaryLiteralFromUint(0x98765432, 4)),
+		NewMysqlEnumDatum(Enum{Name: "a", Value: 1}),
+		NewCollateMysqlEnumDatum(Enum{Name: "a", Value: 1}, charset.CollationASCII),
+		NewMysqlSetDatum(e, charset.CollationGBKBin),
+		NewJSONDatum(json.CreateBinary(int64(1))),
+		MinNotNullDatum(),
+		MaxValueDatum(),
+	}
+	// Marshal the datum and then unmarshal it to see if they are equal.
+	for i, tt := range tests {
+		msg := fmt.Sprintf("failed at %dth test", i)
+		bytes, err := gjson.Marshal(&tt)
+		require.NoError(t, err, msg)
+		var datum Datum
+		err = gjson.Unmarshal(bytes, &datum)
+		require.NoError(t, err, msg)
+		require.Equal(t, tt.k, datum.k, msg)
+		require.Equal(t, tt.decimal, datum.decimal, msg)
+		require.Equal(t, tt.length, datum.length, msg)
+		require.Equal(t, tt.i, datum.i, msg)
+		require.Equal(t, tt.collation, datum.collation, msg)
+		require.Equal(t, tt.b, datum.b, msg)
+		if tt.x == nil {
+			require.Nil(t, datum.x, msg)
+		}
+		require.Equal(t, reflect.TypeOf(tt.x), reflect.TypeOf(datum.x), msg)
+		switch tt.x.(type) {
+		case Time:
+			require.Equal(t, 0, tt.x.(Time).Compare(datum.x.(Time)))
+		case *MyDecimal:
+			require.Equal(t, 0, tt.x.(*MyDecimal).Compare(datum.x.(*MyDecimal)))
+		default:
+			require.EqualValues(t, tt.x, datum.x, msg)
+		}
+	}
+}
+
 func BenchmarkCompareDatum(b *testing.B) {
 	vals, vals1 := prepareCompareDatums()
 	sc := new(stmtctx.StatementContext)
diff --git a/types/mydecimal.go b/types/mydecimal.go
index c1e09808c6d20..c2229462dc8ae 100644
--- a/types/mydecimal.go
+++ b/types/mydecimal.go
@@ -15,6 +15,7 @@
 package types
 
 import (
+	"encoding/json"
 	"math"
 	"strconv"
 	"strings"
@@ -1538,6 +1539,41 @@ func (d *MyDecimal) Compare(to *MyDecimal) int {
 	return 1
 }
 
+// None of ToBin, ToFloat64, or ToString can encode MyDecimal without loss.
+// So we still need a MarshalJSON/UnmarshalJSON function.
+type jsonMyDecimal struct {
+	DigitsInt  int8
+	DigitsFrac int8
+	ResultFrac int8
+	Negative   bool
+	WordBuf    [maxWordBufLen]int32
+}
+
+// MarshalJSON implements Marshaler.MarshalJSON interface.
+func (d *MyDecimal) MarshalJSON() ([]byte, error) {
+	var r jsonMyDecimal
+	r.DigitsInt = d.digitsInt
+	r.DigitsFrac = d.digitsFrac
+	r.ResultFrac = d.resultFrac
+	r.Negative = d.negative
+	r.WordBuf = d.wordBuf
+	return json.Marshal(r)
+}
+
+// UnmarshalJSON implements Unmarshaler.UnmarshalJSON interface.
+func (d *MyDecimal) UnmarshalJSON(data []byte) error {
+	var r jsonMyDecimal
+	err := json.Unmarshal(data, &r)
+	if err == nil {
+		d.digitsInt = r.DigitsInt
+		d.digitsFrac = r.DigitsFrac
+		d.resultFrac = r.ResultFrac
+		d.negative = r.Negative
+		d.wordBuf = r.WordBuf
+	}
+	return err
+}
+
 // DecimalNeg reverses decimal's sign.
 func DecimalNeg(from *MyDecimal) *MyDecimal {
 	to := *from
diff --git a/types/mydecimal_test.go b/types/mydecimal_test.go
index 07e8df28c28b9..584c1a8bd7f08 100644
--- a/types/mydecimal_test.go
+++ b/types/mydecimal_test.go
@@ -15,6 +15,7 @@
 package types
 
 import (
+	"encoding/json"
 	"fmt"
 	"strconv"
 	"strings"
@@ -992,3 +993,22 @@ func TestFromStringMyDecimal(t *testing.T) {
 	// reset
 	wordBufLen = maxWordBufLen
 }
+
+func TestMarshalMyDecimal(t *testing.T) {
+	cases := []string{
+		"12345",
+		"12345.",
+		".00012345000098765",
+		".12345000098765",
+		"-.000000012345000098765",
+		"123E-2",
+	}
+	for _, tt := range cases {
+		var v1, v2 MyDecimal
+		require.NoError(t, v1.FromString([]byte(tt)))
+		j, err := json.Marshal(&v1)
+		require.NoError(t, err)
+		require.NoError(t, json.Unmarshal(j, &v2))
+		require.Equal(t, 0, v1.Compare(&v2))
+	}
+}
diff --git a/types/time.go b/types/time.go
index a9c46aa7912ad..38c2d721a40fa 100644
--- a/types/time.go
+++ b/types/time.go
@@ -16,6 +16,7 @@ package types
 
 import (
 	"bytes"
+	"encoding/json"
 	"fmt"
 	"math"
 	"regexp"
@@ -545,6 +546,16 @@ func (t Time) RoundFrac(sc *stmtctx.StatementContext, fsp int) (Time, error) {
 	return NewTime(nt, t.Type(), fsp), nil
 }
 
+// MarshalJSON implements Marshaler.MarshalJSON interface.
+func (t Time) MarshalJSON() ([]byte, error) {
+	return json.Marshal(t.coreTime)
+}
+
+// UnmarshalJSON implements Unmarshaler.UnmarshalJSON interface.
+func (t *Time) UnmarshalJSON(data []byte) error {
+	return json.Unmarshal(data, &t.coreTime)
+}
+
 // GetFsp gets the fsp of a string.
 func GetFsp(s string) int {
 	index := GetFracIndex(s)
diff --git a/types/time_test.go b/types/time_test.go
index c575fa730c2b7..220cf25899e27 100644
--- a/types/time_test.go
+++ b/types/time_test.go
@@ -15,6 +15,7 @@
 package types_test
 
 import (
+	"encoding/json"
 	"fmt"
 	"math"
 	"testing"
@@ -2035,6 +2036,17 @@ func TestParseWithTimezone(t *testing.T) {
 	}
 }
 
+func TestMarshalTime(t *testing.T) {
+	sc := mock.NewContext().GetSessionVars().StmtCtx
+	v1, err := types.ParseTime(sc, "2017-01-18 01:01:01.123456", mysql.TypeDatetime, types.MaxFsp)
+	require.NoError(t, err)
+	j, err := json.Marshal(v1)
+	require.NoError(t, err)
+	var v2 types.Time
+	require.NoError(t, json.Unmarshal(j, &v2))
+	require.Equal(t, 0, v1.Compare(v2))
+}
+
 func BenchmarkFormat(b *testing.B) {
 	t1 := types.NewTime(types.FromGoTime(time.Now()), mysql.TypeTimestamp, 0)
 	for i := 0; i < b.N; i++ {
diff --git a/util/mock/context.go b/util/mock/context.go
index a0525c3e4ef42..5877b12fa6225 100644
--- a/util/mock/context.go
+++ b/util/mock/context.go
@@ -27,6 +27,7 @@ import (
 	"github.com/pingcap/tidb/parser/model"
 	"github.com/pingcap/tidb/parser/terror"
 	"github.com/pingcap/tidb/sessionctx"
+	"github.com/pingcap/tidb/sessionctx/sessionstates"
 	"github.com/pingcap/tidb/sessionctx/variable"
 	"github.com/pingcap/tidb/util"
 	"github.com/pingcap/tidb/util/disk"
@@ -386,6 +387,16 @@ func (c *Context) ReleaseAllAdvisoryLocks() int {
 	return 0
 }
 
+// EncodeSessionStates implements sessionctx.Context EncodeSessionStates interface.
+func (c *Context) EncodeSessionStates(context.Context, sessionctx.Context, *sessionstates.SessionStates) error {
+	return errors.Errorf("Not Supported")
+}
+
+// DecodeSessionStates implements sessionctx.Context DecodeSessionStates interface.
+func (c *Context) DecodeSessionStates(context.Context, sessionctx.Context, *sessionstates.SessionStates) error {
+	return errors.Errorf("Not Supported")
+}
+
 // Close implements the sessionctx.Context interface.
 func (c *Context) Close() {
 }