Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression,planner: support non-deterministic functions (e.g., now) in the prepared plan cache #8105

Merged
merged 1 commit into from
Nov 1, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
expression,planner: support non-deterministic functions (e.g., now) i…
…n the plan cache
dbjoa committed Oct 25, 2018
commit 1db4a0ef4e456aafc1d6729241cc2d2ae2ba5d1f
2 changes: 2 additions & 0 deletions executor/executor.go
Original file line number Diff line number Diff line change
@@ -1173,6 +1173,8 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
sc := new(stmtctx.StatementContext)
sc.TimeZone = vars.Location()
sc.MemTracker = memory.NewTracker(s.Text(), vars.MemQuotaQuery)
sc.NowTs = time.Time{}
sc.SysTs = time.Time{}
switch config.GetGlobalConfig().OOMAction {
case config.OOMActionCancel:
sc.MemTracker.SetActionOnExceed(&memory.PanicOnExceed{})
66 changes: 51 additions & 15 deletions expression/builtin_time.go
Original file line number Diff line number Diff line change
@@ -1971,7 +1971,11 @@ func (b *builtinCurrentDateSig) Clone() builtinFunc {
// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_curdate
func (b *builtinCurrentDateSig) evalTime(row chunk.Row) (d types.Time, isNull bool, err error) {
tz := b.ctx.GetSessionVars().Location()
year, month, day := time.Now().In(tz).Date()
var nowTs = &b.ctx.GetSessionVars().StmtCtx.NowTs
if nowTs.Equal(time.Time{}) {
*nowTs = time.Now()
}
year, month, day := nowTs.In(tz).Date()
result := types.Time{
Time: types.FromDate(year, int(month), day, 0, 0, 0, 0),
Type: mysql.TypeDate,
@@ -2026,7 +2030,11 @@ func (b *builtinCurrentTime0ArgSig) Clone() builtinFunc {

func (b *builtinCurrentTime0ArgSig) evalDuration(row chunk.Row) (types.Duration, bool, error) {
tz := b.ctx.GetSessionVars().Location()
dur := time.Now().In(tz).Format(types.TimeFormat)
var nowTs = &b.ctx.GetSessionVars().StmtCtx.NowTs
if nowTs.Equal(time.Time{}) {
*nowTs = time.Now()
}
dur := nowTs.In(tz).Format(types.TimeFormat)
res, err := types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, dur, types.MinFsp)
if err != nil {
return types.Duration{}, true, errors.Trace(err)
@@ -2050,7 +2058,11 @@ func (b *builtinCurrentTime1ArgSig) evalDuration(row chunk.Row) (types.Duration,
return types.Duration{}, true, errors.Trace(err)
}
tz := b.ctx.GetSessionVars().Location()
dur := time.Now().In(tz).Format(types.TimeFSPFormat)
var nowTs = &b.ctx.GetSessionVars().StmtCtx.NowTs
if nowTs.Equal(time.Time{}) {
*nowTs = time.Now()
}
dur := nowTs.In(tz).Format(types.TimeFSPFormat)
res, err := types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, dur, int(fsp))
if err != nil {
return types.Duration{}, true, errors.Trace(err)
@@ -2188,7 +2200,11 @@ func (b *builtinUTCDateSig) Clone() builtinFunc {
// evalTime evals UTC_DATE, UTC_DATE().
// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_utc-date
func (b *builtinUTCDateSig) evalTime(row chunk.Row) (types.Time, bool, error) {
year, month, day := time.Now().UTC().Date()
var nowTs = &b.ctx.GetSessionVars().StmtCtx.NowTs
if nowTs.Equal(time.Time{}) {
*nowTs = time.Now()
}
year, month, day := nowTs.UTC().Date()
result := types.Time{
Time: types.FromGoTime(time.Date(year, month, day, 0, 0, 0, 0, time.UTC)),
Type: mysql.TypeDate,
@@ -2244,8 +2260,12 @@ func (c *utcTimestampFunctionClass) getFunction(ctx sessionctx.Context, args []E
return sig, nil
}

func evalUTCTimestampWithFsp(fsp int) (types.Time, bool, error) {
result, err := convertTimeToMysqlTime(time.Now().UTC(), fsp)
func evalUTCTimestampWithFsp(ctx sessionctx.Context, fsp int) (types.Time, bool, error) {
var nowTs = &ctx.GetSessionVars().StmtCtx.NowTs
if nowTs.Equal(time.Time{}) {
*nowTs = time.Now()
}
result, err := convertTimeToMysqlTime(nowTs.UTC(), fsp)
if err != nil {
return types.Time{}, true, errors.Trace(err)
}
@@ -2277,7 +2297,7 @@ func (b *builtinUTCTimestampWithArgSig) evalTime(row chunk.Row) (types.Time, boo
return types.Time{}, true, errors.Errorf("Invalid negative %d specified, must in [0, 6].", num)
}

result, isNull, err := evalUTCTimestampWithFsp(int(num))
result, isNull, err := evalUTCTimestampWithFsp(b.ctx, int(num))
return result, isNull, errors.Trace(err)
}

@@ -2294,7 +2314,7 @@ func (b *builtinUTCTimestampWithoutArgSig) Clone() builtinFunc {
// evalTime evals UTC_TIMESTAMP().
// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_utc-timestamp
func (b *builtinUTCTimestampWithoutArgSig) evalTime(row chunk.Row) (types.Time, bool, error) {
result, isNull, err := evalUTCTimestampWithFsp(0)
result, isNull, err := evalUTCTimestampWithFsp(b.ctx, 0)
return result, isNull, errors.Trace(err)
}

@@ -2328,12 +2348,16 @@ func (c *nowFunctionClass) getFunction(ctx sessionctx.Context, args []Expression
}

func evalNowWithFsp(ctx sessionctx.Context, fsp int) (types.Time, bool, error) {
sysTs, err := getSystemTimestamp(ctx)
if err != nil {
return types.Time{}, true, errors.Trace(err)
var sysTs = &ctx.GetSessionVars().StmtCtx.SysTs
if sysTs.Equal(time.Time{}) {
var err error
*sysTs, err = getSystemTimestamp(ctx)
if err != nil {
return types.Time{}, true, errors.Trace(err)
}
}

result, err := convertTimeToMysqlTime(sysTs, fsp)
result, err := convertTimeToMysqlTime(*sysTs, fsp)
if err != nil {
return types.Time{}, true, errors.Trace(err)
}
@@ -3557,7 +3581,11 @@ func (b *builtinUnixTimestampCurrentSig) Clone() builtinFunc {
// evalInt evals a UNIX_TIMESTAMP().
// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_unix-timestamp
func (b *builtinUnixTimestampCurrentSig) evalInt(row chunk.Row) (int64, bool, error) {
dec, err := goTimeToMysqlUnixTimestamp(time.Now(), 1)
var nowTs = &b.ctx.GetSessionVars().StmtCtx.NowTs
if nowTs.Equal(time.Time{}) {
*nowTs = time.Now()
}
dec, err := goTimeToMysqlUnixTimestamp(*nowTs, 1)
if err != nil {
return 0, true, errors.Trace(err)
}
@@ -5497,7 +5525,11 @@ func (b *builtinUTCTimeWithoutArgSig) Clone() builtinFunc {
// evalDuration evals a builtinUTCTimeWithoutArgSig.
// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_utc-time
func (b *builtinUTCTimeWithoutArgSig) evalDuration(row chunk.Row) (types.Duration, bool, error) {
v, err := types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, time.Now().UTC().Format(types.TimeFormat), 0)
var nowTs = &b.ctx.GetSessionVars().StmtCtx.NowTs
if nowTs.Equal(time.Time{}) {
*nowTs = time.Now()
}
v, err := types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, nowTs.UTC().Format(types.TimeFormat), 0)
return v, false, err
}

@@ -5524,7 +5556,11 @@ func (b *builtinUTCTimeWithArgSig) evalDuration(row chunk.Row) (types.Duration,
if fsp < int64(types.MinFsp) {
return types.Duration{}, true, errors.Errorf("Invalid negative %d specified, must in [0, 6].", fsp)
}
v, err := types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, time.Now().UTC().Format(types.TimeFSPFormat), int(fsp))
var nowTs = &b.ctx.GetSessionVars().StmtCtx.NowTs
if nowTs.Equal(time.Time{}) {
*nowTs = time.Now()
}
v, err := types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, nowTs.UTC().Format(types.TimeFSPFormat), int(fsp))
return v, false, err
}

24 changes: 24 additions & 0 deletions expression/builtin_time_test.go
Original file line number Diff line number Diff line change
@@ -762,6 +762,11 @@ func (s *testEvaluatorSuite) TestTime(c *C) {
c.Assert(err, IsNil)
}

func resetStmtContext(ctx sessionctx.Context) {
ctx.GetSessionVars().StmtCtx.NowTs = time.Time{}
ctx.GetSessionVars().StmtCtx.SysTs = time.Time{}
}

func (s *testEvaluatorSuite) TestNowAndUTCTimestamp(c *C) {
defer testleak.AfterTest(c)()

@@ -778,6 +783,7 @@ func (s *testEvaluatorSuite) TestNowAndUTCTimestamp(c *C) {
{funcs[ast.Now], func() time.Time { return time.Now() }},
{funcs[ast.UTCTimestamp], func() time.Time { return time.Now().UTC() }},
} {
resetStmtContext(s.ctx)
f, err := x.fc.getFunction(s.ctx, s.datumsToConstants(nil))
c.Assert(err, IsNil)
v, err := evalBuiltinFunc(f, chunk.Row{})
@@ -789,6 +795,7 @@ func (s *testEvaluatorSuite) TestNowAndUTCTimestamp(c *C) {
c.Assert(strings.Contains(t.String(), "."), IsFalse)
c.Assert(ts.Sub(gotime(t, ts.Location())), LessEqual, time.Second)

resetStmtContext(s.ctx)
f, err = x.fc.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(6)))
c.Assert(err, IsNil)
v, err = evalBuiltinFunc(f, chunk.Row{})
@@ -798,11 +805,13 @@ func (s *testEvaluatorSuite) TestNowAndUTCTimestamp(c *C) {
c.Assert(strings.Contains(t.String(), "."), IsTrue)
c.Assert(ts.Sub(gotime(t, ts.Location())), LessEqual, time.Millisecond)

resetStmtContext(s.ctx)
f, err = x.fc.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(8)))
c.Assert(err, IsNil)
_, err = evalBuiltinFunc(f, chunk.Row{})
c.Assert(err, NotNil)

resetStmtContext(s.ctx)
f, err = x.fc.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(-2)))
c.Assert(err, IsNil)
_, err = evalBuiltinFunc(f, chunk.Row{})
@@ -813,6 +822,7 @@ func (s *testEvaluatorSuite) TestNowAndUTCTimestamp(c *C) {
variable.SetSessionSystemVar(s.ctx.GetSessionVars(), "time_zone", types.NewDatum("+00:00"))
variable.SetSessionSystemVar(s.ctx.GetSessionVars(), "timestamp", types.NewDatum(1234))
fc := funcs[ast.Now]
resetStmtContext(s.ctx)
f, err := fc.getFunction(s.ctx, s.datumsToConstants(nil))
c.Assert(err, IsNil)
v, err := evalBuiltinFunc(f, chunk.Row{})
@@ -877,6 +887,7 @@ func (s *testEvaluatorSuite) TestAddTimeSig(c *C) {

// This is a test for issue 7334
du := newDateArighmeticalUtil()
resetStmtContext(s.ctx)
now, _, err := evalNowWithFsp(s.ctx, 0)
c.Assert(err, IsNil)
res, _, err := du.add(s.ctx, now, "1", "MICROSECOND")
@@ -1203,6 +1214,7 @@ func (s *testEvaluatorSuite) TestUTCTime(c *C) {
}{{0, 8}, {3, 12}, {6, 15}, {-1, 0}, {7, 0}}

for _, test := range tests {
resetStmtContext(s.ctx)
f, err := fc.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(test.param)))
c.Assert(err, IsNil)
v, err := evalBuiltinFunc(f, chunk.Row{})
@@ -1229,6 +1241,7 @@ func (s *testEvaluatorSuite) TestUTCDate(c *C) {
defer testleak.AfterTest(c)()
last := time.Now().UTC()
fc := funcs[ast.UTCDate]
resetStmtContext(mock.NewContext())
f, err := fc.getFunction(mock.NewContext(), s.datumsToConstants(nil))
c.Assert(err, IsNil)
v, err := evalBuiltinFunc(f, chunk.Row{})
@@ -1500,6 +1513,7 @@ func (s *testEvaluatorSuite) TestTimestampDiff(c *C) {
types.NewStringDatum(test.t1),
types.NewStringDatum(test.t2),
}
resetStmtContext(s.ctx)
f, err := fc.getFunction(s.ctx, s.datumsToConstants(args))
c.Assert(err, IsNil)
d, err := evalBuiltinFunc(f, chunk.Row{})
@@ -1509,6 +1523,7 @@ func (s *testEvaluatorSuite) TestTimestampDiff(c *C) {
sc := s.ctx.GetSessionVars().StmtCtx
sc.IgnoreTruncate = true
sc.IgnoreZeroInDate = true
resetStmtContext(s.ctx)
f, err := fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{types.NewStringDatum("DAY"),
types.NewStringDatum("2017-01-00"),
types.NewStringDatum("2017-01-01")}))
@@ -1517,6 +1532,7 @@ func (s *testEvaluatorSuite) TestTimestampDiff(c *C) {
c.Assert(err, IsNil)
c.Assert(d.Kind(), Equals, types.KindNull)

resetStmtContext(s.ctx)
f, err = fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{types.NewStringDatum("DAY"),
{}, types.NewStringDatum("2017-01-01")}))
c.Assert(err, IsNil)
@@ -1528,6 +1544,7 @@ func (s *testEvaluatorSuite) TestTimestampDiff(c *C) {
func (s *testEvaluatorSuite) TestUnixTimestamp(c *C) {
// Test UNIX_TIMESTAMP().
fc := funcs[ast.UnixTimestamp]
resetStmtContext(s.ctx)
f, err := fc.getFunction(s.ctx, nil)
c.Assert(err, IsNil)
d, err := evalBuiltinFunc(f, chunk.Row{})
@@ -1537,12 +1554,14 @@ func (s *testEvaluatorSuite) TestUnixTimestamp(c *C) {

// https://github.com/pingcap/tidb/issues/2496
// Test UNIX_TIMESTAMP(NOW()).
resetStmtContext(s.ctx)
now, isNull, err := evalNowWithFsp(s.ctx, 0)
c.Assert(err, IsNil)
c.Assert(isNull, IsFalse)
n := types.Datum{}
n.SetMysqlTime(now)
args := []types.Datum{n}
resetStmtContext(s.ctx)
f, err = fc.getFunction(s.ctx, s.datumsToConstants(args))
c.Assert(err, IsNil)
d, err = evalBuiltinFunc(f, chunk.Row{})
@@ -1554,6 +1573,7 @@ func (s *testEvaluatorSuite) TestUnixTimestamp(c *C) {
// https://github.com/pingcap/tidb/issues/2852
// Test UNIX_TIMESTAMP(NULL).
args = []types.Datum{types.NewDatum(nil)}
resetStmtContext(s.ctx)
f, err = fc.getFunction(s.ctx, s.datumsToConstants(args))
c.Assert(err, IsNil)
d, err = evalBuiltinFunc(f, chunk.Row{})
@@ -1598,6 +1618,7 @@ func (s *testEvaluatorSuite) TestUnixTimestamp(c *C) {
fmt.Printf("Begin Test %v\n", test)
expr := s.datumsToConstants([]types.Datum{test.input})
expr[0].GetType().Decimal = test.inputDecimal
resetStmtContext(s.ctx)
f, err := fc.getFunction(s.ctx, expr)
c.Assert(err, IsNil, Commentf("%+v", test))
d, err := evalBuiltinFunc(f, chunk.Row{})
@@ -1681,6 +1702,7 @@ func (s *testEvaluatorSuite) TestTimestamp(c *C) {
}
fc := funcs[ast.Timestamp]
for _, test := range tests {
resetStmtContext(s.ctx)
f, err := fc.getFunction(s.ctx, s.datumsToConstants(test.t))
c.Assert(err, IsNil)
d, err := evalBuiltinFunc(f, chunk.Row{})
@@ -1690,6 +1712,7 @@ func (s *testEvaluatorSuite) TestTimestamp(c *C) {
}

nilDatum := types.NewDatum(nil)
resetStmtContext(s.ctx)
f, err := fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{nilDatum}))
c.Assert(err, IsNil)
d, err := evalBuiltinFunc(f, chunk.Row{})
@@ -2357,6 +2380,7 @@ func (s *testEvaluatorSuite) TestWithTimeZone(c *C) {

for _, t := range tests {
now := time.Now().In(sv.TimeZone)
resetStmtContext(s.ctx)
f, err := funcs[t.method].getFunction(s.ctx, s.datumsToConstants(t.Input))
c.Assert(err, IsNil)
d, err := evalBuiltinFunc(f, chunk.Row{})
40 changes: 23 additions & 17 deletions expression/function_traits.go
Original file line number Diff line number Diff line change
@@ -19,23 +19,12 @@ import (

// UnCacheableFunctions stores functions which can not be cached to plan cache.
var UnCacheableFunctions = map[string]struct{}{
ast.Now: {},
ast.CurrentTimestamp: {},
ast.UTCTime: {},
ast.Curtime: {},
ast.CurrentTime: {},
ast.UTCTimestamp: {},
ast.UnixTimestamp: {},
ast.Sysdate: {},
ast.Curdate: {},
ast.CurrentDate: {},
ast.UTCDate: {},
ast.Database: {},
ast.CurrentUser: {},
ast.User: {},
ast.ConnectionID: {},
ast.LastInsertId: {},
ast.Version: {},
ast.Database: {},
ast.CurrentUser: {},
ast.User: {},
ast.ConnectionID: {},
ast.LastInsertId: {},
ast.Version: {},
}

// unFoldableFunctions stores functions which can not be folded duration constant folding stage.
@@ -52,6 +41,23 @@ var unFoldableFunctions = map[string]struct{}{
ast.GetParam: {},
}

// DeferredFunctions stores non-deterministic functions, which can be deferred only when the plan cache is enabled.
var DeferredFunctions = map[string]struct{}{
ast.Now: {},
ast.CurrentTimestamp: {},
ast.UTCTime: {},
ast.Curtime: {},
ast.CurrentTime: {},
ast.UTCTimestamp: {},
ast.UnixTimestamp: {},
ast.Sysdate: {},
ast.Curdate: {},
ast.CurrentDate: {},
ast.UTCDate: {},
ast.Rand: {},
ast.UUID: {},
}

// inequalFunctions stores functions which cannot be propagated from column equal condition.
var inequalFunctions = map[string]struct{}{
ast.IsNull: {},
19 changes: 16 additions & 3 deletions expression/scalar_function.go
Original file line number Diff line number Diff line change
@@ -70,8 +70,8 @@ func (sf *ScalarFunction) MarshalJSON() ([]byte, error) {
return []byte(fmt.Sprintf("\"%s\"", sf)), nil
}

// NewFunction creates a new scalar function or constant.
func NewFunction(ctx sessionctx.Context, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) {
// newFunctionImpl creates a new scalar function or constant.
func newFunctionImpl(ctx sessionctx.Context, fold bool, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) {
if retType == nil {
return nil, errors.Errorf("RetType cannot be nil for ScalarFunction.")
}
@@ -96,7 +96,20 @@ func NewFunction(ctx sessionctx.Context, funcName string, retType *types.FieldTy
RetType: retType,
Function: f,
}
return FoldConstant(sf), nil
if fold {
return FoldConstant(sf), nil
}
return sf, nil
}

// NewFunction creates a new scalar function or constant via a constant folding.
func NewFunction(ctx sessionctx.Context, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) {
return newFunctionImpl(ctx, true, funcName, retType, args...)
}

// NewFunctionBase creates a new scalar function with no constant folding.
func NewFunctionBase(ctx sessionctx.Context, funcName string, retType *types.FieldType, args ...Expression) (Expression, error) {
return newFunctionImpl(ctx, false, funcName, retType, args...)
}

// NewFunctionInternal is similar to NewFunction, but do not returns error, should only be used internally.
42 changes: 22 additions & 20 deletions planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
@@ -757,13 +757,7 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok
value := &expression.Constant{Value: v.Datum, RetType: &v.Type}
er.ctxStack = append(er.ctxStack, value)
case *driver.ParamMarkerExpr:
tp := types.NewFieldType(mysql.TypeUnspecified)
types.DefaultParamTypeForValue(v.GetValue(), tp)
value := &expression.Constant{Value: v.Datum, RetType: tp}
if er.useCache() {
value.DeferredExpr = er.getParamExpression(v)
}
er.ctxStack = append(er.ctxStack, value)
er.paramToExpression(v)
case *ast.VariableExpr:
er.rewriteVariable(v)
case *ast.FuncCallExpr:
@@ -820,17 +814,18 @@ func datumToConstant(d types.Datum, tp byte) *expression.Constant {
return &expression.Constant{Value: d, RetType: types.NewFieldType(tp)}
}

func (er *expressionRewriter) getParamExpression(v *driver.ParamMarkerExpr) expression.Expression {
f, err := expression.NewFunction(er.ctx,
ast.GetParam,
&v.Type,
datumToConstant(types.NewIntDatum(int64(v.Order)), mysql.TypeLonglong))
if err != nil {
er.err = errors.Trace(err)
return nil
}
f.GetType().Tp = v.Type.Tp
return f
func (er *expressionRewriter) paramToExpression(v *driver.ParamMarkerExpr) {
tp := types.NewFieldType(mysql.TypeUnspecified)
types.DefaultParamTypeForValue(v.GetValue(), tp)
value := &expression.Constant{Value: v.Datum, RetType: tp}
if er.useCache() {
var f expression.Expression
f, er.err = expression.NewFunctionBase(er.ctx, ast.GetParam, &v.Type,
datumToConstant(types.NewIntDatum(int64(v.Order)), mysql.TypeLonglong))
f.GetType().Tp = v.Type.Tp
value.DeferredExpr = f
}
er.ctxStack = append(er.ctxStack, value)
}

func (er *expressionRewriter) rewriteVariable(v *ast.VariableExpr) {
@@ -1220,9 +1215,16 @@ func (er *expressionRewriter) funcCallToExpression(v *ast.FuncCallExpr) {
return
}
var function expression.Expression
function, er.err = expression.NewFunction(er.ctx, v.FnName.L, &v.Type, args...)
er.ctxStack = er.ctxStack[:stackLen-len(v.Args)]
er.ctxStack = append(er.ctxStack, function)
if _, ok := expression.DeferredFunctions[v.FnName.L]; er.useCache() && ok {
function, er.err = expression.NewFunctionBase(er.ctx, v.FnName.L, &v.Type, args...)
c := &expression.Constant{Value: types.NewDatum(nil), RetType: &v.Type, DeferredExpr: function}
c.GetType().Tp = function.GetType().Tp
er.ctxStack = append(er.ctxStack, c)
} else {
function, er.err = expression.NewFunction(er.ctx, v.FnName.L, &v.Type, args...)
er.ctxStack = append(er.ctxStack, function)
}
}

func (er *expressionRewriter) toColumn(v *ast.ColumnName) {
96 changes: 96 additions & 0 deletions planner/core/prepare_test.go
Original file line number Diff line number Diff line change
@@ -14,10 +14,16 @@
package core_test

import (
"time"

. "github.com/pingcap/check"
"github.com/pingcap/tidb/executor"
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/metrics"
"github.com/pingcap/tidb/planner/core"
"github.com/pingcap/tidb/util/testkit"
"github.com/pingcap/tidb/util/testleak"
dto "github.com/prometheus/client_model/go"
)

var _ = Suite(&testPrepareSuite{})
@@ -89,3 +95,93 @@ func (s *testPrepareSuite) TestPrepareCacheIndexScan(c *C) {
tk.MustQuery("execute stmt1 using @a, @b").Check(testkit.Rows("1 3", "1 3"))
tk.MustQuery("execute stmt1 using @a, @b").Check(testkit.Rows("1 3", "1 3"))
}

func (s *testPlanSuite) TestPrepareCacheDeferredFunction(c *C) {
defer testleak.AfterTest(c)()
store, dom, err := newStoreWithBootstrap()
c.Assert(err, IsNil)
tk := testkit.NewTestKit(c, store)
orgEnable := core.PreparedPlanCacheEnabled()
orgCapacity := core.PreparedPlanCacheCapacity
defer func() {
dom.Close()
store.Close()
core.SetPreparedPlanCache(orgEnable)
core.PreparedPlanCacheCapacity = orgCapacity
}()
core.SetPreparedPlanCache(true)
core.PreparedPlanCacheCapacity = 100

defer testleak.AfterTest(c)()

tk.MustExec("use test")
tk.MustExec("drop table if exists t1")
tk.MustExec("create table t1 (id int PRIMARY KEY, c1 TIMESTAMP(3) NOT NULL DEFAULT '0000-00-00 00:00:00', KEY idx1 (c1))")
tk.MustExec("prepare sel1 from 'select id, c1 from t1 where c1 < now(3)'")

sql1 := "execute sel1"
expectedPattern := `IndexReader\(Index\(t1.idx1\)\[\[-inf,[0-9]{4}-(0[1-9]|1[0-2])-(0[1-9]|[1-2][0-9]|3[0-1]) (2[0-3]|[01][0-9]):[0-5][0-9]:[0-5][0-9].000\)\]\)`

var cnt [2]float64
var planStr [2]string
metrics.PlanCacheCounter.Reset()
counter := metrics.PlanCacheCounter.WithLabelValues("prepare")
for i := 0; i < 2; i++ {
stmt, err := s.ParseOneStmt(sql1, "", "")
c.Check(err, IsNil)
is := tk.Se.GetSessionVars().TxnCtx.InfoSchema.(infoschema.InfoSchema)
builder := core.NewPlanBuilder(tk.Se, is)
p, err := builder.Build(stmt)
c.Check(err, IsNil)
execPlan, ok := p.(*core.Execute)
c.Check(ok, IsTrue)
executor.ResetContextOfStmt(tk.Se, stmt)
err = execPlan.OptimizePreparedPlan(tk.Se, is)
c.Check(err, IsNil)
planStr[i] = core.ToString(execPlan.Plan)
c.Check(planStr[i], Matches, expectedPattern, Commentf("for %s", sql1))
pb := &dto.Metric{}
counter.Write(pb)
cnt[i] = pb.GetCounter().GetValue()
c.Check(cnt[i], Equals, float64(i))
time.Sleep(time.Second * 1)
}
c.Assert(planStr[0] < planStr[1], IsTrue)
}

func (s *testPrepareSuite) TestPrepareCacheNow(c *C) {
defer testleak.AfterTest(c)()
store, dom, err := newStoreWithBootstrap()
c.Assert(err, IsNil)
tk := testkit.NewTestKit(c, store)
orgEnable := core.PreparedPlanCacheEnabled()
orgCapacity := core.PreparedPlanCacheCapacity
defer func() {
dom.Close()
store.Close()
core.SetPreparedPlanCache(orgEnable)
core.PreparedPlanCacheCapacity = orgCapacity
}()
core.SetPreparedPlanCache(true)
core.PreparedPlanCacheCapacity = 100
tk.MustExec("use test")
tk.MustExec(`prepare stmt1 from "select now(), sleep(1), now()"`)
// When executing one statement at the first time, we don't use cache, so we need to execute it at least twice to test the cache.
rs := tk.MustQuery("execute stmt1").Rows()
c.Assert(rs[0][0].(string), Equals, rs[0][2].(string))

tk.MustExec(`prepare stmt2 from "select current_timestamp(), sleep(1), current_timestamp()"`)
// When executing one statement at the first time, we don't use cache, so we need to execute it at least twice to test the cache.
rs = tk.MustQuery("execute stmt2").Rows()
c.Assert(rs[0][0].(string), Equals, rs[0][2].(string))

tk.MustExec(`prepare stmt3 from "select utc_timestamp(), sleep(1), utc_timestamp()"`)
// When executing one statement at the first time, we don't use cache, so we need to execute it at least twice to test the cache.
rs = tk.MustQuery("execute stmt3").Rows()
c.Assert(rs[0][0].(string), Equals, rs[0][2].(string))

tk.MustExec(`prepare stmt4 from "select unix_timestamp(), sleep(1), unix_timestamp()"`)
// When executing one statement at the first time, we don't use cache, so we need to execute it at least twice to test the cache.
rs = tk.MustQuery("execute stmt4").Rows()
c.Assert(rs[0][0].(string), Equals, rs[0][2].(string))
}
2 changes: 2 additions & 0 deletions sessionctx/stmtctx/stmtctx.go
Original file line number Diff line number Diff line change
@@ -80,6 +80,8 @@ type StatementContext struct {
RuntimeStatsColl *execdetails.RuntimeStatsColl
TableIDs []int64
IndexIDs []int64
NowTs time.Time
SysTs time.Time
}

// AddAffectedRows adds affected rows.