diff --git a/pkg/executor/aggfuncs/builder.go b/pkg/executor/aggfuncs/builder.go index 72ac5fadc08a5..9965976eda5b2 100644 --- a/pkg/executor/aggfuncs/builder.go +++ b/pkg/executor/aggfuncs/builder.go @@ -92,9 +92,9 @@ func BuildWindowFunctions(ctx sessionctx.Context, windowFuncDesc *aggregation.Ag case ast.WindowFuncCumeDist: return buildCumeDist(ordinal, orderByCols) case ast.WindowFuncNthValue: - return buildNthValue(windowFuncDesc, ordinal) + return buildNthValue(ctx, windowFuncDesc, ordinal) case ast.WindowFuncNtile: - return buildNtile(windowFuncDesc, ordinal) + return buildNtile(ctx, windowFuncDesc, ordinal) case ast.WindowFuncPercentRank: return buildPercentRank(ordinal, orderByCols) case ast.WindowFuncLead: @@ -668,22 +668,22 @@ func buildCumeDist(ordinal int, orderByCols []*expression.Column) AggFunc { return r } -func buildNthValue(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { +func buildNthValue(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { base := baseAggFunc{ args: aggFuncDesc.Args, ordinal: ordinal, } // Already checked when building the function description. - nth, _, _ := expression.GetUint64FromConstant(aggFuncDesc.Args[1]) + nth, _, _ := expression.GetUint64FromConstant(ctx, aggFuncDesc.Args[1]) return &nthValue{baseAggFunc: base, tp: aggFuncDesc.RetTp, nth: nth} } -func buildNtile(aggFuncDes *aggregation.AggFuncDesc, ordinal int) AggFunc { +func buildNtile(ctx sessionctx.Context, aggFuncDes *aggregation.AggFuncDesc, ordinal int) AggFunc { base := baseAggFunc{ args: aggFuncDes.Args, ordinal: ordinal, } - n, _, _ := expression.GetUint64FromConstant(aggFuncDes.Args[0]) + n, _, _ := expression.GetUint64FromConstant(ctx, aggFuncDes.Args[0]) return &ntile{baseAggFunc: base, n: n} } @@ -697,7 +697,7 @@ func buildPercentRank(ordinal int, orderByCols []*expression.Column) AggFunc { func buildLeadLag(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordinal int) baseLeadLag { offset := uint64(1) if len(aggFuncDesc.Args) >= 2 { - offset, _, _ = expression.GetUint64FromConstant(aggFuncDesc.Args[1]) + offset, _, _ = expression.GetUint64FromConstant(ctx, aggFuncDesc.Args[1]) } var defaultExpr expression.Expression defaultExpr = expression.NewNull() diff --git a/pkg/executor/distsql.go b/pkg/executor/distsql.go index 925049d40d61d..12555d386cc37 100644 --- a/pkg/executor/distsql.go +++ b/pkg/executor/distsql.go @@ -151,7 +151,7 @@ func closeAll(objs ...Closeable) error { func rebuildIndexRanges(ctx sessionctx.Context, is *plannercore.PhysicalIndexScan, idxCols []*expression.Column, colLens []int) (ranges []*ranger.Range, err error) { access := make([]expression.Expression, 0, len(is.AccessCondition)) for _, cond := range is.AccessCondition { - newCond, err1 := expression.SubstituteCorCol2Constant(cond) + newCond, err1 := expression.SubstituteCorCol2Constant(ctx, cond) if err1 != nil { return nil, err1 } diff --git a/pkg/expression/aggregation/window_func.go b/pkg/expression/aggregation/window_func.go index 1ce06b72d6d2a..ac6ff1f5dfdbc 100644 --- a/pkg/expression/aggregation/window_func.go +++ b/pkg/expression/aggregation/window_func.go @@ -36,13 +36,13 @@ func NewWindowFuncDesc(ctx sessionctx.Context, name string, args []expression.Ex if !skipCheckArgs { switch strings.ToLower(name) { case ast.WindowFuncNthValue: - val, isNull, ok := expression.GetUint64FromConstant(args[1]) + val, isNull, ok := expression.GetUint64FromConstant(ctx, args[1]) // nth_value does not allow `0`, but allows `null`. if !ok || (val == 0 && !isNull) { return nil, nil } case ast.WindowFuncNtile: - val, isNull, ok := expression.GetUint64FromConstant(args[0]) + val, isNull, ok := expression.GetUint64FromConstant(ctx, args[0]) // ntile does not allow `0`, but allows `null`. if !ok || (val == 0 && !isNull) { return nil, nil @@ -51,7 +51,7 @@ func NewWindowFuncDesc(ctx sessionctx.Context, name string, args []expression.Ex if len(args) < 2 { break } - _, isNull, ok := expression.GetUint64FromConstant(args[1]) + _, isNull, ok := expression.GetUint64FromConstant(ctx, args[1]) if !ok || isNull { return nil, nil } diff --git a/pkg/expression/builtin_cast.go b/pkg/expression/builtin_cast.go index e9b81745b2aee..d5ec84df6a735 100644 --- a/pkg/expression/builtin_cast.go +++ b/pkg/expression/builtin_cast.go @@ -2153,7 +2153,7 @@ func BuildCastFunctionWithCheck(ctx sessionctx.Context, expr Expression, tp *typ // since we may reset the flag of the field type of CastAsJson later which // would affect the evaluation of it. if tp.EvalType() != types.ETJson && err == nil { - res = FoldConstant(res) + res = FoldConstant(ctx, res) } return res, err } diff --git a/pkg/expression/builtin_convert_charset.go b/pkg/expression/builtin_convert_charset.go index c515c19d20807..7ecfd8d0fa5e3 100644 --- a/pkg/expression/builtin_convert_charset.go +++ b/pkg/expression/builtin_convert_charset.go @@ -230,7 +230,7 @@ func BuildToBinaryFunction(ctx sessionctx.Context, expr Expression) (res Express Function: f, ctx: ctx, } - return FoldConstant(res) + return FoldConstant(ctx, res) } // BuildFromBinaryFunction builds from_binary function. @@ -246,7 +246,7 @@ func BuildFromBinaryFunction(ctx sessionctx.Context, expr Expression, tp *types. Function: f, ctx: ctx, } - return FoldConstant(res) + return FoldConstant(ctx, res) } type funcProp int8 diff --git a/pkg/expression/constant_fold.go b/pkg/expression/constant_fold.go index 70bd26577d37e..aae42545248f5 100644 --- a/pkg/expression/constant_fold.go +++ b/pkg/expression/constant_fold.go @@ -17,6 +17,7 @@ package expression import ( "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/logutil" @@ -24,10 +25,10 @@ import ( ) // specialFoldHandler stores functions for special UDF to constant fold -var specialFoldHandler = map[string]func(*ScalarFunction) (Expression, bool){} +var specialFoldHandler = map[string]func(sessionctx.Context, *ScalarFunction) (Expression, bool){} func init() { - specialFoldHandler = map[string]func(*ScalarFunction) (Expression, bool){ + specialFoldHandler = map[string]func(sessionctx.Context, *ScalarFunction) (Expression, bool){ ast.If: ifFoldHandler, ast.Ifnull: ifNullFoldHandler, ast.Case: caseWhenHandler, @@ -36,8 +37,8 @@ func init() { } // FoldConstant does constant folding optimization on an expression excluding deferred ones. -func FoldConstant(expr Expression) Expression { - e, _ := foldConstant(expr) +func FoldConstant(ctx sessionctx.Context, expr Expression) Expression { + e, _ := foldConstant(ctx, expr) // keep the original coercibility, charset, collation and repertoire values after folding e.SetCoercibility(expr.Coercibility()) @@ -48,11 +49,11 @@ func FoldConstant(expr Expression) Expression { return e } -func isNullHandler(expr *ScalarFunction) (Expression, bool) { +func isNullHandler(ctx sessionctx.Context, expr *ScalarFunction) (Expression, bool) { arg0 := expr.GetArgs()[0] if constArg, isConst := arg0.(*Constant); isConst { isDeferredConst := constArg.DeferredExpr != nil || constArg.ParamMarker != nil - value, err := expr.EvalWithInnerCtx(chunk.Row{}) + value, err := expr.Eval(ctx, chunk.Row{}) if err != nil { // Failed to fold this expr to a constant, print the DEBUG log and // return the original expression to let the error to be evaluated @@ -71,11 +72,11 @@ func isNullHandler(expr *ScalarFunction) (Expression, bool) { return expr, false } -func ifFoldHandler(expr *ScalarFunction) (Expression, bool) { +func ifFoldHandler(ctx sessionctx.Context, expr *ScalarFunction) (Expression, bool) { args := expr.GetArgs() - foldedArg0, _ := foldConstant(args[0]) + foldedArg0, _ := foldConstant(ctx, args[0]) if constArg, isConst := foldedArg0.(*Constant); isConst { - arg0, isNull0, err := constArg.EvalInt(expr.GetCtx(), chunk.Row{}) + arg0, isNull0, err := constArg.EvalInt(ctx, chunk.Row{}) if err != nil { // Failed to fold this expr to a constant, print the DEBUG log and // return the original expression to let the error to be evaluated @@ -84,23 +85,23 @@ func ifFoldHandler(expr *ScalarFunction) (Expression, bool) { return expr, false } if !isNull0 && arg0 != 0 { - return foldConstant(args[1]) + return foldConstant(ctx, args[1]) } - return foldConstant(args[2]) + return foldConstant(ctx, args[2]) } // if the condition is not const, which branch is unknown to run, so directly return. return expr, false } -func ifNullFoldHandler(expr *ScalarFunction) (Expression, bool) { +func ifNullFoldHandler(ctx sessionctx.Context, expr *ScalarFunction) (Expression, bool) { args := expr.GetArgs() - foldedArg0, isDeferred := foldConstant(args[0]) + foldedArg0, isDeferred := foldConstant(ctx, args[0]) if constArg, isConst := foldedArg0.(*Constant); isConst { // Only check constArg.Value here. Because deferred expression is // evaluated to constArg.Value after foldConstant(args[0]), it's not // needed to be checked. if constArg.Value.IsNull() { - return foldConstant(args[1]) + return foldConstant(ctx, args[1]) } return constArg, isDeferred } @@ -108,11 +109,11 @@ func ifNullFoldHandler(expr *ScalarFunction) (Expression, bool) { return expr, false } -func caseWhenHandler(expr *ScalarFunction) (Expression, bool) { +func caseWhenHandler(ctx sessionctx.Context, expr *ScalarFunction) (Expression, bool) { args, l := expr.GetArgs(), len(expr.GetArgs()) var isDeferred, isDeferredConst bool for i := 0; i < l-1; i += 2 { - expr.GetArgs()[i], isDeferred = foldConstant(args[i]) + expr.GetArgs()[i], isDeferred = foldConstant(ctx, args[i]) isDeferredConst = isDeferredConst || isDeferred if _, isConst := expr.GetArgs()[i].(*Constant); !isConst { // for no-const, here should return directly, because the following branches are unknown to be run or not @@ -121,12 +122,12 @@ func caseWhenHandler(expr *ScalarFunction) (Expression, bool) { // If the condition is const and true, and the previous conditions // has no expr, then the folded execution body is returned, otherwise // the arguments of the casewhen are folded and replaced. - val, isNull, err := args[i].EvalInt(expr.GetCtx(), chunk.Row{}) + val, isNull, err := args[i].EvalInt(ctx, chunk.Row{}) if err != nil { return expr, false } if val != 0 && !isNull { - foldedExpr, isDeferred := foldConstant(args[i+1]) + foldedExpr, isDeferred := foldConstant(ctx, args[i+1]) isDeferredConst = isDeferredConst || isDeferred if _, isConst := foldedExpr.(*Constant); isConst { foldedExpr.GetType().SetDecimal(expr.GetType().GetDecimal()) @@ -139,7 +140,7 @@ func caseWhenHandler(expr *ScalarFunction) (Expression, bool) { // is false, then the folded else execution body is returned. otherwise // the execution body of the else are folded and replaced. if l%2 == 1 { - foldedExpr, isDeferred := foldConstant(args[l-1]) + foldedExpr, isDeferred := foldConstant(ctx, args[l-1]) isDeferredConst = isDeferredConst || isDeferred if _, isConst := foldedExpr.(*Constant); isConst { foldedExpr.GetType().SetDecimal(expr.GetType().GetDecimal()) @@ -150,18 +151,18 @@ func caseWhenHandler(expr *ScalarFunction) (Expression, bool) { return expr, isDeferredConst } -func foldConstant(expr Expression) (Expression, bool) { +func foldConstant(ctx sessionctx.Context, expr Expression) (Expression, bool) { switch x := expr.(type) { case *ScalarFunction: if _, ok := unFoldableFunctions[x.FuncName.L]; ok { return expr, false } - if function := specialFoldHandler[x.FuncName.L]; function != nil && !MaybeOverOptimized4PlanCache(x.GetCtx(), []Expression{expr}) { - return function(x) + if function := specialFoldHandler[x.FuncName.L]; function != nil && !MaybeOverOptimized4PlanCache(ctx, []Expression{expr}) { + return function(ctx, x) } args := x.GetArgs() - sc := x.GetCtx().GetSessionVars().StmtCtx + sc := ctx.GetSessionVars().StmtCtx argIsConst := make([]bool, len(args)) hasNullArg := false allConstArg := true @@ -193,11 +194,11 @@ func foldConstant(expr Expression) (Expression, bool) { constArgs[i] = NewOne() } } - dummyScalarFunc, err := NewFunctionBase(x.GetCtx(), x.FuncName.L, x.GetType(), constArgs...) + dummyScalarFunc, err := NewFunctionBase(ctx, x.FuncName.L, x.GetType(), constArgs...) if err != nil { return expr, isDeferredConst } - value, err := dummyScalarFunc.EvalWithInnerCtx(chunk.Row{}) + value, err := dummyScalarFunc.Eval(ctx, chunk.Row{}) if err != nil { return expr, isDeferredConst } @@ -217,7 +218,7 @@ func foldConstant(expr Expression) (Expression, bool) { } return expr, isDeferredConst } - value, err := x.EvalWithInnerCtx(chunk.Row{}) + value, err := x.Eval(ctx, chunk.Row{}) retType := x.RetType.Clone() if !hasNullArg { // set right not null flag for constant value @@ -245,7 +246,7 @@ func foldConstant(expr Expression) (Expression, bool) { ParamMarker: x.ParamMarker, }, true } else if x.DeferredExpr != nil { - value, err := x.DeferredExpr.EvalWithInnerCtx(chunk.Row{}) + value, err := x.DeferredExpr.Eval(ctx, chunk.Row{}) if err != nil { logutil.BgLogger().Debug("fold expression to constant", zap.String("expression", x.ExplainInfo()), zap.Error(err)) return expr, true diff --git a/pkg/expression/constant_propagation.go b/pkg/expression/constant_propagation.go index c69b2d3867717..0e6b843df0d7c 100644 --- a/pkg/expression/constant_propagation.go +++ b/pkg/expression/constant_propagation.go @@ -212,7 +212,7 @@ func (s *propConstSolver) propagateConstantEQ() { } for i, cond := range s.conditions { if !visited[i] { - s.conditions[i] = ColumnSubstitute(cond, NewSchema(cols...), cons) + s.conditions[i] = ColumnSubstitute(s.ctx, cond, NewSchema(cols...), cons) } } } @@ -470,7 +470,7 @@ func (s *propOuterJoinConstSolver) propagateConstantEQ() { } for i, cond := range s.joinConds { if !visited[i+lenFilters] { - s.joinConds[i] = ColumnSubstitute(cond, NewSchema(cols...), cons) + s.joinConds[i] = ColumnSubstitute(s.ctx, cond, NewSchema(cols...), cons) } } } diff --git a/pkg/expression/constant_test.go b/pkg/expression/constant_test.go index e3cee448792c3..9dc8f2f9aede9 100644 --- a/pkg/expression/constant_test.go +++ b/pkg/expression/constant_test.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" "github.com/pingcap/tidb/pkg/util/mock" @@ -54,12 +55,16 @@ func newString(value string, collation string) *Constant { } } -func newFunction(funcName string, args ...Expression) Expression { - return newFunctionWithType(funcName, types.NewFieldType(mysql.TypeLonglong), args...) +func newFunctionWithMockCtx(funcName string, args ...Expression) Expression { + return newFunction(mock.NewContext(), funcName, args...) } -func newFunctionWithType(funcName string, tp *types.FieldType, args ...Expression) Expression { - return NewFunctionInternal(mock.NewContext(), funcName, tp, args...) +func newFunction(ctx sessionctx.Context, funcName string, args ...Expression) Expression { + return newFunctionWithType(ctx, funcName, types.NewFieldType(mysql.TypeLonglong), args...) +} + +func newFunctionWithType(ctx sessionctx.Context, funcName string, tp *types.FieldType, args ...Expression) Expression { + return NewFunctionInternal(ctx, funcName, tp, args...) } func TestConstantPropagation(t *testing.T) { @@ -71,59 +76,59 @@ func TestConstantPropagation(t *testing.T) { { solver: []PropagateConstantSolver{newPropConstSolver()}, conditions: []Expression{ - newFunction(ast.EQ, newColumn(0), newColumn(1)), - newFunction(ast.EQ, newColumn(1), newColumn(2)), - newFunction(ast.EQ, newColumn(2), newColumn(3)), - newFunction(ast.EQ, newColumn(3), newLonglong(1)), - newFunction(ast.LogicOr, newLonglong(1), newColumn(0)), + newFunctionWithMockCtx(ast.EQ, newColumn(0), newColumn(1)), + newFunctionWithMockCtx(ast.EQ, newColumn(1), newColumn(2)), + newFunctionWithMockCtx(ast.EQ, newColumn(2), newColumn(3)), + newFunctionWithMockCtx(ast.EQ, newColumn(3), newLonglong(1)), + newFunctionWithMockCtx(ast.LogicOr, newLonglong(1), newColumn(0)), }, result: "1, eq(Column#0, 1), eq(Column#1, 1), eq(Column#2, 1), eq(Column#3, 1)", }, { solver: []PropagateConstantSolver{newPropConstSolver()}, conditions: []Expression{ - newFunction(ast.EQ, newColumn(0), newColumn(1)), - newFunction(ast.EQ, newColumn(1), newLonglong(1)), - newFunction(ast.NE, newColumn(2), newLonglong(2)), + newFunctionWithMockCtx(ast.EQ, newColumn(0), newColumn(1)), + newFunctionWithMockCtx(ast.EQ, newColumn(1), newLonglong(1)), + newFunctionWithMockCtx(ast.NE, newColumn(2), newLonglong(2)), }, result: "eq(Column#0, 1), eq(Column#1, 1), ne(Column#2, 2)", }, { solver: []PropagateConstantSolver{newPropConstSolver()}, conditions: []Expression{ - newFunction(ast.EQ, newColumn(0), newColumn(1)), - newFunction(ast.EQ, newColumn(1), newLonglong(1)), - newFunction(ast.EQ, newColumn(2), newColumn(3)), - newFunction(ast.GE, newColumn(2), newLonglong(2)), - newFunction(ast.NE, newColumn(2), newLonglong(4)), - newFunction(ast.NE, newColumn(3), newLonglong(5)), + newFunctionWithMockCtx(ast.EQ, newColumn(0), newColumn(1)), + newFunctionWithMockCtx(ast.EQ, newColumn(1), newLonglong(1)), + newFunctionWithMockCtx(ast.EQ, newColumn(2), newColumn(3)), + newFunctionWithMockCtx(ast.GE, newColumn(2), newLonglong(2)), + newFunctionWithMockCtx(ast.NE, newColumn(2), newLonglong(4)), + newFunctionWithMockCtx(ast.NE, newColumn(3), newLonglong(5)), }, result: "eq(Column#0, 1), eq(Column#1, 1), eq(Column#2, Column#3), ge(Column#2, 2), ge(Column#3, 2), ne(Column#2, 4), ne(Column#2, 5), ne(Column#3, 4), ne(Column#3, 5)", }, { solver: []PropagateConstantSolver{newPropConstSolver()}, conditions: []Expression{ - newFunction(ast.EQ, newColumn(0), newColumn(1)), - newFunction(ast.EQ, newColumn(0), newColumn(2)), - newFunction(ast.GE, newColumn(1), newLonglong(0)), + newFunctionWithMockCtx(ast.EQ, newColumn(0), newColumn(1)), + newFunctionWithMockCtx(ast.EQ, newColumn(0), newColumn(2)), + newFunctionWithMockCtx(ast.GE, newColumn(1), newLonglong(0)), }, result: "eq(Column#0, Column#1), eq(Column#0, Column#2), ge(Column#0, 0), ge(Column#1, 0), ge(Column#2, 0)", }, { solver: []PropagateConstantSolver{newPropConstSolver()}, conditions: []Expression{ - newFunction(ast.EQ, newColumn(0), newColumn(1)), - newFunction(ast.GT, newColumn(0), newLonglong(2)), - newFunction(ast.GT, newColumn(1), newLonglong(3)), - newFunction(ast.LT, newColumn(0), newLonglong(1)), - newFunction(ast.GT, newLonglong(2), newColumn(1)), + newFunctionWithMockCtx(ast.EQ, newColumn(0), newColumn(1)), + newFunctionWithMockCtx(ast.GT, newColumn(0), newLonglong(2)), + newFunctionWithMockCtx(ast.GT, newColumn(1), newLonglong(3)), + newFunctionWithMockCtx(ast.LT, newColumn(0), newLonglong(1)), + newFunctionWithMockCtx(ast.GT, newLonglong(2), newColumn(1)), }, result: "eq(Column#0, Column#1), gt(2, Column#0), gt(2, Column#1), gt(Column#0, 2), gt(Column#0, 3), gt(Column#1, 2), gt(Column#1, 3), lt(Column#0, 1), lt(Column#1, 1)", }, { solver: []PropagateConstantSolver{newPropConstSolver()}, conditions: []Expression{ - newFunction(ast.EQ, newLonglong(1), newColumn(0)), + newFunctionWithMockCtx(ast.EQ, newLonglong(1), newColumn(0)), newLonglong(0), }, result: "0", @@ -131,41 +136,41 @@ func TestConstantPropagation(t *testing.T) { { solver: []PropagateConstantSolver{newPropConstSolver()}, conditions: []Expression{ - newFunction(ast.EQ, newColumn(0), newColumn(1)), - newFunction(ast.In, newColumn(0), newLonglong(1), newLonglong(2)), - newFunction(ast.In, newColumn(1), newLonglong(3), newLonglong(4)), + newFunctionWithMockCtx(ast.EQ, newColumn(0), newColumn(1)), + newFunctionWithMockCtx(ast.In, newColumn(0), newLonglong(1), newLonglong(2)), + newFunctionWithMockCtx(ast.In, newColumn(1), newLonglong(3), newLonglong(4)), }, result: "eq(Column#0, Column#1), in(Column#0, 1, 2), in(Column#0, 3, 4), in(Column#1, 1, 2), in(Column#1, 3, 4)", }, { solver: []PropagateConstantSolver{newPropConstSolver()}, conditions: []Expression{ - newFunction(ast.EQ, newColumn(0), newColumn(1)), - newFunction(ast.EQ, newColumn(0), newFunction(ast.BitLength, newColumn(2))), + newFunctionWithMockCtx(ast.EQ, newColumn(0), newColumn(1)), + newFunctionWithMockCtx(ast.EQ, newColumn(0), newFunctionWithMockCtx(ast.BitLength, newColumn(2))), }, result: "eq(Column#0, Column#1), eq(Column#0, bit_length(cast(Column#2, var_string(20)))), eq(Column#1, bit_length(cast(Column#2, var_string(20))))", }, { solver: []PropagateConstantSolver{newPropConstSolver()}, conditions: []Expression{ - newFunction(ast.EQ, newColumn(0), newColumn(1)), - newFunction(ast.LE, newFunction(ast.Mul, newColumn(0), newColumn(0)), newLonglong(50)), + newFunctionWithMockCtx(ast.EQ, newColumn(0), newColumn(1)), + newFunctionWithMockCtx(ast.LE, newFunctionWithMockCtx(ast.Mul, newColumn(0), newColumn(0)), newLonglong(50)), }, result: "eq(Column#0, Column#1), le(mul(Column#0, Column#0), 50), le(mul(Column#1, Column#1), 50)", }, { solver: []PropagateConstantSolver{newPropConstSolver()}, conditions: []Expression{ - newFunction(ast.EQ, newColumn(0), newColumn(1)), - newFunction(ast.LE, newColumn(0), newFunction(ast.Plus, newColumn(1), newLonglong(1))), + newFunctionWithMockCtx(ast.EQ, newColumn(0), newColumn(1)), + newFunctionWithMockCtx(ast.LE, newColumn(0), newFunctionWithMockCtx(ast.Plus, newColumn(1), newLonglong(1))), }, result: "eq(Column#0, Column#1), le(Column#0, plus(Column#0, 1)), le(Column#0, plus(Column#1, 1)), le(Column#1, plus(Column#1, 1))", }, { solver: []PropagateConstantSolver{newPropConstSolver()}, conditions: []Expression{ - newFunction(ast.EQ, newColumn(0), newColumn(1)), - newFunction(ast.LE, newColumn(0), newFunction(ast.Rand)), + newFunctionWithMockCtx(ast.EQ, newColumn(0), newColumn(1)), + newFunctionWithMockCtx(ast.LE, newColumn(0), newFunctionWithMockCtx(ast.Rand)), }, result: "eq(Column#0, Column#1), le(cast(Column#0, double BINARY), rand())", }, @@ -175,7 +180,7 @@ func TestConstantPropagation(t *testing.T) { ctx := mock.NewContext() conds := make([]Expression, 0, len(tt.conditions)) for _, cd := range tt.conditions { - conds = append(conds, FoldConstant(cd)) + conds = append(conds, FoldConstant(ctx, cd)) } newConds := solver.PropagateConstant(ctx, conds) var result []string @@ -190,75 +195,93 @@ func TestConstantPropagation(t *testing.T) { func TestConstantFolding(t *testing.T) { tests := []struct { - condition Expression + condition func(ctx sessionctx.Context) Expression result string }{ { - condition: newFunction(ast.LT, newColumn(0), newFunction(ast.Plus, newLonglong(1), newLonglong(2))), - result: "lt(Column#0, 3)", + condition: func(ctx sessionctx.Context) Expression { + return newFunction(ctx, ast.LT, newColumn(0), newFunction(ctx, ast.Plus, newLonglong(1), newLonglong(2))) + }, + result: "lt(Column#0, 3)", }, { - condition: newFunction(ast.LT, newColumn(0), newFunction(ast.Greatest, newLonglong(1), newLonglong(2))), - result: "lt(Column#0, 2)", + condition: func(ctx sessionctx.Context) Expression { + return newFunction(ctx, ast.LT, newColumn(0), newFunction(ctx, ast.Greatest, newLonglong(1), newLonglong(2))) + }, + result: "lt(Column#0, 2)", }, { - condition: newFunction(ast.EQ, newColumn(0), newFunction(ast.Rand)), - result: "eq(cast(Column#0, double BINARY), rand())", + condition: func(ctx sessionctx.Context) Expression { + return newFunction(ctx, ast.EQ, newColumn(0), newFunction(ctx, ast.Rand)) + }, + result: "eq(cast(Column#0, double BINARY), rand())", }, { - condition: newFunction(ast.IsNull, newLonglong(1)), - result: "0", + condition: func(ctx sessionctx.Context) Expression { + return newFunction(ctx, ast.IsNull, newLonglong(1)) + }, + result: "0", }, { - condition: newFunction(ast.EQ, newColumn(0), newFunction(ast.UnaryNot, newFunction(ast.Plus, newLonglong(1), newLonglong(1)))), - result: "eq(Column#0, 0)", + condition: func(ctx sessionctx.Context) Expression { + return newFunction(ctx, ast.EQ, newColumn(0), newFunction(ctx, ast.UnaryNot, newFunctionWithMockCtx(ast.Plus, newLonglong(1), newLonglong(1)))) + }, + result: "eq(Column#0, 0)", }, { - condition: newFunction(ast.LT, newColumn(0), newFunction(ast.Plus, newColumn(1), newFunction(ast.Plus, newLonglong(2), newLonglong(1)))), - result: "lt(Column#0, plus(Column#1, 3))", + condition: func(ctx sessionctx.Context) Expression { + return newFunction(ctx, ast.LT, newColumn(0), newFunction(ctx, ast.Plus, newColumn(1), newFunctionWithMockCtx(ast.Plus, newLonglong(2), newLonglong(1)))) + }, + result: "lt(Column#0, plus(Column#1, 3))", }, { - condition: func() Expression { - expr := newFunction(ast.ConcatWS, newColumn(0), NewNull()) - function := expr.(*ScalarFunction) - function.GetCtx().GetSessionVars().StmtCtx.InNullRejectCheck = true - return function - }(), + condition: func(ctx sessionctx.Context) Expression { + expr := newFunction(ctx, ast.ConcatWS, newColumn(0), NewNull()) + ctx.GetSessionVars().StmtCtx.InNullRejectCheck = true + return expr + }, result: "concat_ws(cast(Column#0, var_string(20)), )", }, } for _, tt := range tests { - newConds := FoldConstant(tt.condition) + ctx := mock.NewContext() + expr := tt.condition(ctx) + newConds := FoldConstant(ctx, expr) require.Equalf(t, tt.result, newConds.String(), "different for expr %s", tt.condition) } } func TestConstantFoldingCharsetConvert(t *testing.T) { + ctx := mock.NewContext() tests := []struct { condition Expression result string }{ { - condition: newFunction(ast.Length, newFunctionWithType( + condition: newFunction(ctx, ast.Length, newFunctionWithType( + ctx, InternalFuncToBinary, types.NewFieldType(mysql.TypeVarchar), newString("中文", "gbk_bin"))), result: "4", }, { - condition: newFunction(ast.Length, newFunctionWithType( + condition: newFunction(ctx, ast.Length, newFunctionWithType( + ctx, InternalFuncToBinary, types.NewFieldType(mysql.TypeVarchar), newString("中文", "utf8mb4_bin"))), result: "6", }, { - condition: newFunction(ast.Concat, newFunctionWithType( + condition: newFunction(ctx, ast.Concat, newFunctionWithType( + ctx, InternalFuncFromBinary, types.NewFieldType(mysql.TypeVarchar), newString("中文", "binary"))), result: "中文", }, { - condition: newFunction(ast.Concat, + condition: newFunction(ctx, ast.Concat, newFunctionWithType( + ctx, InternalFuncFromBinary, types.NewFieldTypeWithCollation(mysql.TypeVarchar, "gbk_bin", -1), newString("\xd2\xbb", "binary")), newString("中文", "gbk_bin"), @@ -266,9 +289,10 @@ func TestConstantFoldingCharsetConvert(t *testing.T) { result: "一中文", }, { - condition: newFunction(ast.Concat, + condition: newFunction(ctx, ast.Concat, newString("中文", "gbk_bin"), newFunctionWithType( + ctx, InternalFuncFromBinary, types.NewFieldTypeWithCollation(mysql.TypeVarchar, "gbk_bin", -1), newString("\xd2\xbb", "binary")), ), @@ -276,7 +300,7 @@ func TestConstantFoldingCharsetConvert(t *testing.T) { }, // The result is binary charset, so gbk constant will convert to binary which is \xd6\xd0\xce\xc4. { - condition: newFunction(ast.Concat, + condition: newFunction(ctx, ast.Concat, newString("中文", "gbk_bin"), newString("\xd2\xbb", "binary"), ), @@ -284,7 +308,7 @@ func TestConstantFoldingCharsetConvert(t *testing.T) { }, } for _, tt := range tests { - newConds := FoldConstant(tt.condition) + newConds := FoldConstant(ctx, tt.condition) require.Equalf(t, tt.result, newConds.String(), "different for expr %s", tt.condition) } } diff --git a/pkg/expression/expression.go b/pkg/expression/expression.go index ff48cae555dff..7e844660abc91 100644 --- a/pkg/expression/expression.go +++ b/pkg/expression/expression.go @@ -857,7 +857,7 @@ func evaluateExprWithNull(ctx sessionctx.Context, schema *Schema, expr Expressio return &Constant{Value: types.Datum{}, RetType: types.NewFieldType(mysql.TypeNull)} case *Constant: if x.DeferredExpr != nil { - return FoldConstant(x) + return FoldConstant(ctx, x) } } return expr @@ -922,7 +922,7 @@ func evaluateExprWithNullInNullRejectCheck(ctx sessionctx.Context, schema *Schem return &Constant{Value: types.Datum{}, RetType: types.NewFieldType(mysql.TypeNull)}, true case *Constant: if x.DeferredExpr != nil { - return FoldConstant(x), false + return FoldConstant(ctx, x), false } } return expr, false @@ -1526,7 +1526,7 @@ func wrapWithIsTrue(ctx sessionctx.Context, keepNull bool, arg Expression, wrapF if keepNull { sf.FuncName = model.NewCIStr(ast.IsTruthWithNull) } - return FoldConstant(sf), nil + return FoldConstant(ctx, sf), nil } // PropagateType propagates the type information to the `expr`. diff --git a/pkg/expression/expression_test.go b/pkg/expression/expression_test.go index 11c12539a1b13..28c1974dfaea2 100644 --- a/pkg/expression/expression_test.go +++ b/pkg/expression/expression_test.go @@ -135,19 +135,19 @@ func TestIsBinaryLiteral(t *testing.T) { func TestConstItem(t *testing.T) { ctx := createContext(t) - sf := newFunction(ast.Rand) + sf := newFunctionWithMockCtx(ast.Rand) require.False(t, sf.ConstItem(ctx.GetSessionVars().StmtCtx)) - sf = newFunction(ast.UUID) + sf = newFunctionWithMockCtx(ast.UUID) require.False(t, sf.ConstItem(ctx.GetSessionVars().StmtCtx)) - sf = newFunction(ast.GetParam, NewOne()) + sf = newFunctionWithMockCtx(ast.GetParam, NewOne()) require.False(t, sf.ConstItem(ctx.GetSessionVars().StmtCtx)) - sf = newFunction(ast.Abs, NewOne()) + sf = newFunctionWithMockCtx(ast.Abs, NewOne()) require.True(t, sf.ConstItem(ctx.GetSessionVars().StmtCtx)) } func TestVectorizable(t *testing.T) { exprs := make([]Expression, 0, 4) - sf := newFunction(ast.Rand) + sf := newFunctionWithMockCtx(ast.Rand) column := &Column{ UniqueID: 0, RetType: types.NewFieldType(mysql.TypeLonglong), @@ -171,21 +171,21 @@ func TestVectorizable(t *testing.T) { RetType: types.NewFieldType(mysql.TypeLonglong), } exprs = exprs[:0] - sf = newFunction(ast.SetVar, column0, column1) + sf = newFunctionWithMockCtx(ast.SetVar, column0, column1) exprs = append(exprs, sf) require.False(t, Vectorizable(exprs)) exprs = exprs[:0] - sf = newFunction(ast.GetVar, column0) + sf = newFunctionWithMockCtx(ast.GetVar, column0) exprs = append(exprs, sf) require.False(t, Vectorizable(exprs)) exprs = exprs[:0] - sf = newFunction(ast.NextVal, column0) + sf = newFunctionWithMockCtx(ast.NextVal, column0) exprs = append(exprs, sf) - sf = newFunction(ast.LastVal, column0) + sf = newFunctionWithMockCtx(ast.LastVal, column0) exprs = append(exprs, sf) - sf = newFunction(ast.SetVal, column1, column2) + sf = newFunctionWithMockCtx(ast.SetVal, column1, column2) exprs = append(exprs, sf) require.False(t, Vectorizable(exprs)) } diff --git a/pkg/expression/grouping_sets_test.go b/pkg/expression/grouping_sets_test.go index 68bf202860220..f61f00ccd032c 100644 --- a/pkg/expression/grouping_sets_test.go +++ b/pkg/expression/grouping_sets_test.go @@ -114,25 +114,25 @@ func TestGroupSetsTargetOneCompoundArgs(t *testing.T) { require.Equal(t, offset, 0) // default // mock normal agg count(d+1) - normalAggArgs = newFunction(ast.Plus, d, newLonglong(1)) + normalAggArgs = newFunctionWithMockCtx(ast.Plus, d, newLonglong(1)) offset = newGroupingSets.TargetOne([]Expression{normalAggArgs}) require.NotEqual(t, offset, -1) require.Equal(t, offset, 0) // default // mock normal agg count(d+c) - normalAggArgs = newFunction(ast.Plus, d, c) + normalAggArgs = newFunctionWithMockCtx(ast.Plus, d, c) offset = newGroupingSets.TargetOne([]Expression{normalAggArgs}) require.NotEqual(t, offset, -1) require.Equal(t, offset, 1) // only {c} can supply d and c // mock normal agg count(d+a) - normalAggArgs = newFunction(ast.Plus, d, a) + normalAggArgs = newFunctionWithMockCtx(ast.Plus, d, a) offset = newGroupingSets.TargetOne([]Expression{normalAggArgs}) require.NotEqual(t, offset, -1) require.Equal(t, offset, 0) // only {a,b} can supply d and a // mock normal agg count(d+a+c) - normalAggArgs = newFunction(ast.Plus, d, newFunction(ast.Plus, a, c)) + normalAggArgs = newFunctionWithMockCtx(ast.Plus, d, newFunctionWithMockCtx(ast.Plus, a, c)) offset = newGroupingSets.TargetOne([]Expression{normalAggArgs}) require.Equal(t, offset, -1) // couldn't find a group that supply d, a and c simultaneously. } diff --git a/pkg/expression/scalar_function.go b/pkg/expression/scalar_function.go index 27281648cb4a6..d155654f30823 100644 --- a/pkg/expression/scalar_function.go +++ b/pkg/expression/scalar_function.go @@ -262,12 +262,12 @@ func newFunctionImpl(ctx sessionctx.Context, fold int, funcName string, retType ctx: ctx, } if fold == 1 { - return FoldConstant(sf), nil + return FoldConstant(ctx, sf), nil } else if fold == -1 { // try to fold constants, and return the original function if errors/warnings occur sc := ctx.GetSessionVars().StmtCtx beforeWarns := sc.WarningCount() - newSf := FoldConstant(sf) + newSf := FoldConstant(ctx, sf) afterWarns := sc.WarningCount() if afterWarns > beforeWarns { sc.TruncateWarnings(int(beforeWarns)) diff --git a/pkg/expression/scalar_function_test.go b/pkg/expression/scalar_function_test.go index 97fc0897479d5..10f48eadd6063 100644 --- a/pkg/expression/scalar_function_test.go +++ b/pkg/expression/scalar_function_test.go @@ -36,57 +36,57 @@ func TestExpressionSemanticEqual(t *testing.T) { } // order sensitive cases // a < b; b > a - sf1 := newFunction(ast.LT, a, b) - sf2 := newFunction(ast.GT, b, a) + sf1 := newFunctionWithMockCtx(ast.LT, a, b) + sf2 := newFunctionWithMockCtx(ast.GT, b, a) require.True(t, ExpressionsSemanticEqual(sf1, sf2)) // a > b; b < a - sf3 := newFunction(ast.GT, a, b) - sf4 := newFunction(ast.LT, b, a) + sf3 := newFunctionWithMockCtx(ast.GT, a, b) + sf4 := newFunctionWithMockCtx(ast.LT, b, a) require.True(t, ExpressionsSemanticEqual(sf3, sf4)) // a<=b; b>=a - sf5 := newFunction(ast.LE, a, b) - sf6 := newFunction(ast.GE, b, a) + sf5 := newFunctionWithMockCtx(ast.LE, a, b) + sf6 := newFunctionWithMockCtx(ast.GE, b, a) require.True(t, ExpressionsSemanticEqual(sf5, sf6)) // a>=b; b<=a - sf7 := newFunction(ast.GE, a, b) - sf8 := newFunction(ast.LE, b, a) + sf7 := newFunctionWithMockCtx(ast.GE, a, b) + sf8 := newFunctionWithMockCtx(ast.LE, b, a) require.True(t, ExpressionsSemanticEqual(sf7, sf8)) // not(a= b - sf9 := newFunction(ast.UnaryNot, sf1) + sf9 := newFunctionWithMockCtx(ast.UnaryNot, sf1) require.True(t, ExpressionsSemanticEqual(sf9, sf7)) // a < b; not(a>=b) - sf10 := newFunction(ast.UnaryNot, sf7) + sf10 := newFunctionWithMockCtx(ast.UnaryNot, sf7) require.True(t, ExpressionsSemanticEqual(sf1, sf10)) // order insensitive cases // a + b; b + a - p1 := newFunction(ast.Plus, a, b) - p2 := newFunction(ast.Plus, b, a) + p1 := newFunctionWithMockCtx(ast.Plus, a, b) + p2 := newFunctionWithMockCtx(ast.Plus, b, a) require.True(t, ExpressionsSemanticEqual(p1, p2)) // a * b; b * a - m1 := newFunction(ast.Mul, a, b) - m2 := newFunction(ast.Mul, b, a) + m1 := newFunctionWithMockCtx(ast.Mul, a, b) + m2 := newFunctionWithMockCtx(ast.Mul, b, a) require.True(t, ExpressionsSemanticEqual(m1, m2)) // a = b; b = a - e1 := newFunction(ast.EQ, a, b) - e2 := newFunction(ast.EQ, b, a) + e1 := newFunctionWithMockCtx(ast.EQ, a, b) + e2 := newFunctionWithMockCtx(ast.EQ, b, a) require.True(t, ExpressionsSemanticEqual(e1, e2)) // a = b AND b + a; a + b AND b = a - a1 := newFunction(ast.LogicAnd, e1, p2) - a2 := newFunction(ast.LogicAnd, p1, e2) + a1 := newFunctionWithMockCtx(ast.LogicAnd, e1, p2) + a2 := newFunctionWithMockCtx(ast.LogicAnd, p1, e2) require.True(t, ExpressionsSemanticEqual(a1, a2)) // a * b OR a + b; b + a OR b * a - o1 := newFunction(ast.LogicOr, m1, p1) - o2 := newFunction(ast.LogicOr, p2, m2) + o1 := newFunctionWithMockCtx(ast.LogicOr, m1, p1) + o2 := newFunctionWithMockCtx(ast.LogicOr, p2, m2) require.True(t, ExpressionsSemanticEqual(o1, o2)) } @@ -96,7 +96,8 @@ func TestScalarFunction(t *testing.T) { UniqueID: 1, RetType: types.NewFieldType(mysql.TypeDouble), } - sf := newFunction(ast.LT, a, NewOne()) + + sf := newFunctionWithMockCtx(ast.LT, a, NewOne()) res, err := sf.MarshalJSON() require.NoError(t, err) require.EqualValues(t, []byte{0x22, 0x6c, 0x74, 0x28, 0x43, 0x6f, 0x6c, 0x75, 0x6d, 0x6e, 0x23, 0x31, 0x2c, 0x20, 0x31, 0x29, 0x22}, res) @@ -125,7 +126,7 @@ func TestIssue23309(t *testing.T) { a.RetType.SetFlag(a.RetType.GetFlag() | mysql.NotNullFlag) null := NewNull() null.RetType = types.NewFieldType(mysql.TypeNull) - sf, _ := newFunction(ast.NE, a, null).(*ScalarFunction) + sf, _ := newFunctionWithMockCtx(ast.NE, a, null).(*ScalarFunction) v, err := sf.GetArgs()[1].Eval(mock.NewContext(), chunk.Row{}) require.NoError(t, err) require.True(t, v.IsNull()) @@ -138,8 +139,8 @@ func TestScalarFuncs2Exprs(t *testing.T) { UniqueID: 1, RetType: types.NewFieldType(mysql.TypeDouble), } - sf0, _ := newFunction(ast.LT, a, NewZero()).(*ScalarFunction) - sf1, _ := newFunction(ast.LT, a, NewOne()).(*ScalarFunction) + sf0, _ := newFunctionWithMockCtx(ast.LT, a, NewZero()).(*ScalarFunction) + sf1, _ := newFunctionWithMockCtx(ast.LT, a, NewOne()).(*ScalarFunction) funcs := []*ScalarFunction{sf0, sf1} exprs := ScalarFuncs2Exprs(funcs) diff --git a/pkg/expression/util.go b/pkg/expression/util.go index d8e47d5ef1f9d..b8d08b1d84f5b 100644 --- a/pkg/expression/util.go +++ b/pkg/expression/util.go @@ -409,8 +409,8 @@ func SetExprColumnInOperand(expr Expression) Expression { // ColumnSubstitute substitutes the columns in filter to expressions in select fields. // e.g. select * from (select b as a from t) k where a < 10 => select * from (select b as a from t where b < 10) k. -func ColumnSubstitute(expr Expression, schema *Schema, newExprs []Expression) Expression { - _, _, resExpr := ColumnSubstituteImpl(expr, schema, newExprs, false) +func ColumnSubstitute(ctx sessionctx.Context, expr Expression, schema *Schema, newExprs []Expression) Expression { + _, _, resExpr := ColumnSubstituteImpl(ctx, expr, schema, newExprs, false) return resExpr } @@ -419,8 +419,8 @@ func ColumnSubstitute(expr Expression, schema *Schema, newExprs []Expression) Ex // // 1: substitute them all once find col in schema. // 2: nothing in expr can be substituted. -func ColumnSubstituteAll(expr Expression, schema *Schema, newExprs []Expression) (bool, Expression) { - _, hasFail, resExpr := ColumnSubstituteImpl(expr, schema, newExprs, true) +func ColumnSubstituteAll(ctx sessionctx.Context, expr Expression, schema *Schema, newExprs []Expression) (bool, Expression) { + _, hasFail, resExpr := ColumnSubstituteImpl(ctx, expr, schema, newExprs, true) return hasFail, resExpr } @@ -429,7 +429,7 @@ func ColumnSubstituteAll(expr Expression, schema *Schema, newExprs []Expression) // @return bool means whether the expr has changed. // @return bool means whether the expr should change (has the dependency in schema, while the corresponding expr has some compatibility), but finally fallback. // @return Expression, the original expr or the changed expr, it depends on the first @return bool. -func ColumnSubstituteImpl(expr Expression, schema *Schema, newExprs []Expression, fail1Return bool) (bool, bool, Expression) { +func ColumnSubstituteImpl(ctx sessionctx.Context, expr Expression, schema *Schema, newExprs []Expression, fail1Return bool) (bool, bool, Expression) { switch v := expr.(type) { case *Column: id := schema.ColumnIndex(v) @@ -446,7 +446,7 @@ func ColumnSubstituteImpl(expr Expression, schema *Schema, newExprs []Expression hasFail := false if v.FuncName.L == ast.Cast || v.FuncName.L == ast.Grouping { var newArg Expression - substituted, hasFail, newArg = ColumnSubstituteImpl(v.GetArgs()[0], schema, newExprs, fail1Return) + substituted, hasFail, newArg = ColumnSubstituteImpl(ctx, v.GetArgs()[0], schema, newExprs, fail1Return) if fail1Return && hasFail { return substituted, hasFail, v } @@ -454,7 +454,7 @@ func ColumnSubstituteImpl(expr Expression, schema *Schema, newExprs []Expression flag := v.RetType.GetFlag() var e Expression if v.FuncName.L == ast.Cast { - e = BuildCastFunction(v.GetCtx(), newArg, v.RetType) + e = BuildCastFunction(ctx, newArg, v.RetType) } else { // for grouping function recreation, use clone (meta included) instead of newFunction e = v.Clone() @@ -469,7 +469,7 @@ func ColumnSubstituteImpl(expr Expression, schema *Schema, newExprs []Expression // cowExprRef is a copy-on-write util, args array allocation happens only // when expr in args is changed refExprArr := cowExprRef{v.GetArgs(), nil} - oldCollEt, err := CheckAndDeriveCollationFromExprs(v.GetCtx(), v.FuncName.L, v.RetType.EvalType(), v.GetArgs()...) + oldCollEt, err := CheckAndDeriveCollationFromExprs(ctx, v.FuncName.L, v.RetType.EvalType(), v.GetArgs()...) if err != nil { logutil.BgLogger().Error("Unexpected error happened during ColumnSubstitution", zap.Stack("stack")) return false, false, v @@ -479,7 +479,7 @@ func ColumnSubstituteImpl(expr Expression, schema *Schema, newExprs []Expression tmpArgForCollCheck = make([]Expression, len(v.GetArgs())) } for idx, arg := range v.GetArgs() { - changed, failed, newFuncExpr := ColumnSubstituteImpl(arg, schema, newExprs, fail1Return) + changed, failed, newFuncExpr := ColumnSubstituteImpl(ctx, arg, schema, newExprs, fail1Return) if fail1Return && failed { return changed, failed, v } @@ -489,7 +489,7 @@ func ColumnSubstituteImpl(expr Expression, schema *Schema, newExprs []Expression changed = false copy(tmpArgForCollCheck, refExprArr.Result()) tmpArgForCollCheck[idx] = newFuncExpr - newCollEt, err := CheckAndDeriveCollationFromExprs(v.GetCtx(), v.FuncName.L, v.RetType.EvalType(), tmpArgForCollCheck...) + newCollEt, err := CheckAndDeriveCollationFromExprs(ctx, v.FuncName.L, v.RetType.EvalType(), tmpArgForCollCheck...) if err != nil { logutil.BgLogger().Error("Unexpected error happened during ColumnSubstitution", zap.Stack("stack")) return false, failed, v @@ -518,7 +518,7 @@ func ColumnSubstituteImpl(expr Expression, schema *Schema, newExprs []Expression } } if substituted { - return true, hasFail, NewFunctionInternal(v.GetCtx(), v.FuncName.L, v.RetType, refExprArr.Result()...) + return true, hasFail, NewFunctionInternal(ctx, v.FuncName.L, v.RetType, refExprArr.Result()...) } } return false, false, expr @@ -585,13 +585,13 @@ Loop: // SubstituteCorCol2Constant will substitute correlated column to constant value which it contains. // If the args of one scalar function are all constant, we will substitute it to constant. -func SubstituteCorCol2Constant(expr Expression) (Expression, error) { +func SubstituteCorCol2Constant(ctx sessionctx.Context, expr Expression) (Expression, error) { switch x := expr.(type) { case *ScalarFunction: allConstant := true newArgs := make([]Expression, 0, len(x.GetArgs())) for _, arg := range x.GetArgs() { - newArg, err := SubstituteCorCol2Constant(arg) + newArg, err := SubstituteCorCol2Constant(ctx, arg) if err != nil { return nil, err } @@ -600,7 +600,7 @@ func SubstituteCorCol2Constant(expr Expression) (Expression, error) { allConstant = allConstant && ok } if allConstant { - val, err := x.EvalWithInnerCtx(chunk.Row{}) + val, err := x.Eval(ctx, chunk.Row{}) if err != nil { return nil, err } @@ -611,19 +611,19 @@ func SubstituteCorCol2Constant(expr Expression) (Expression, error) { newSf Expression ) if x.FuncName.L == ast.Cast { - newSf = BuildCastFunction(x.GetCtx(), newArgs[0], x.RetType) + newSf = BuildCastFunction(ctx, newArgs[0], x.RetType) } else if x.FuncName.L == ast.Grouping { newSf = x.Clone() newSf.(*ScalarFunction).GetArgs()[0] = newArgs[0] } else { - newSf, err = NewFunction(x.GetCtx(), x.FuncName.L, x.GetType(), newArgs...) + newSf, err = NewFunction(ctx, x.FuncName.L, x.GetType(), newArgs...) } return newSf, err case *CorrelatedColumn: return &Constant{Value: *x.Data, RetType: x.GetType()}, nil case *Constant: if x.DeferredExpr != nil { - newExpr := FoldConstant(x) + newExpr := FoldConstant(ctx, x) return &Constant{Value: newExpr.(*Constant).Value, RetType: x.GetType()}, nil } } @@ -879,20 +879,20 @@ func pushNotAcrossExpr(ctx sessionctx.Context, expr Expression, not bool) (_ Exp return expr, false } var childExpr Expression - childExpr, changed = pushNotAcrossExpr(f.GetCtx(), child, !not) + childExpr, changed = pushNotAcrossExpr(ctx, child, !not) if !changed && !not { return expr, false } return childExpr, true case ast.LT, ast.GE, ast.GT, ast.LE, ast.EQ, ast.NE: if not { - return NewFunctionInternal(f.GetCtx(), oppositeOp[f.FuncName.L], f.GetType(), f.GetArgs()...), true + return NewFunctionInternal(ctx, oppositeOp[f.FuncName.L], f.GetType(), f.GetArgs()...), true } - newArgs, changed := pushNotAcrossArgs(f.GetCtx(), f.GetArgs(), false) + newArgs, changed := pushNotAcrossArgs(ctx, f.GetArgs(), false) if !changed { return f, false } - return NewFunctionInternal(f.GetCtx(), f.FuncName.L, f.GetType(), newArgs...), true + return NewFunctionInternal(ctx, f.FuncName.L, f.GetType(), newArgs...), true case ast.LogicAnd, ast.LogicOr: var ( newArgs []Expression @@ -900,16 +900,16 @@ func pushNotAcrossExpr(ctx sessionctx.Context, expr Expression, not bool) (_ Exp ) funcName := f.FuncName.L if not { - newArgs, _ = pushNotAcrossArgs(f.GetCtx(), f.GetArgs(), true) + newArgs, _ = pushNotAcrossArgs(ctx, f.GetArgs(), true) funcName = oppositeOp[f.FuncName.L] changed = true } else { - newArgs, changed = pushNotAcrossArgs(f.GetCtx(), f.GetArgs(), false) + newArgs, changed = pushNotAcrossArgs(ctx, f.GetArgs(), false) } if !changed { return f, false } - return NewFunctionInternal(f.GetCtx(), funcName, f.GetType(), newArgs...), true + return NewFunctionInternal(ctx, funcName, f.GetType(), newArgs...), true } } if not { @@ -1081,12 +1081,11 @@ func extractFiltersFromDNF(ctx sessionctx.Context, dnfFunc *ScalarFunction) ([]E // the original expression must satisfy the derived expression. Return nil when the derived expression is universal set. // A running example is: for schema of t1, `(t1.a=1 and t2.a=1) or (t1.a=2 and t2.a=2)` would be derived as // `t1.a=1 or t1.a=2`, while `t1.a=1 or t2.a=1` would get nil. -func DeriveRelaxedFiltersFromDNF(expr Expression, schema *Schema) Expression { +func DeriveRelaxedFiltersFromDNF(ctx sessionctx.Context, expr Expression, schema *Schema) Expression { sf, ok := expr.(*ScalarFunction) if !ok || sf.FuncName.L != ast.LogicOr { return nil } - ctx := sf.GetCtx() dnfItems := FlattenDNFConditions(sf) newDNFItems := make([]Expression, 0, len(dnfItems)) for _, dnfItem := range dnfItems { @@ -1094,7 +1093,7 @@ func DeriveRelaxedFiltersFromDNF(expr Expression, schema *Schema) Expression { newCNFItems := make([]Expression, 0, len(cnfItems)) for _, cnfItem := range cnfItems { if itemSF, ok := cnfItem.(*ScalarFunction); ok && itemSF.FuncName.L == ast.LogicOr { - relaxedCNFItem := DeriveRelaxedFiltersFromDNF(cnfItem, schema) + relaxedCNFItem := DeriveRelaxedFiltersFromDNF(ctx, cnfItem, schema) if relaxedCNFItem != nil { newCNFItems = append(newCNFItems, relaxedCNFItem) } @@ -1392,7 +1391,7 @@ func RemoveDupExprs(exprs []Expression) []Expression { } // GetUint64FromConstant gets a uint64 from constant expression. -func GetUint64FromConstant(expr Expression) (uint64, bool, bool) { +func GetUint64FromConstant(ctx sessionctx.Context, expr Expression) (uint64, bool, bool) { con, ok := expr.(*Constant) if !ok { logutil.BgLogger().Warn("not a constant expression", zap.String("expression", expr.ExplainInfo())) @@ -1403,7 +1402,7 @@ func GetUint64FromConstant(expr Expression) (uint64, bool, bool) { dt = con.ParamMarker.GetUserVar() } else if con.DeferredExpr != nil { var err error - dt, err = con.DeferredExpr.EvalWithInnerCtx(chunk.Row{}) + dt, err = con.DeferredExpr.Eval(ctx, chunk.Row{}) if err != nil { logutil.BgLogger().Warn("eval deferred expr failed", zap.Error(err)) return 0, false, false diff --git a/pkg/expression/util_test.go b/pkg/expression/util_test.go index 4a7198caf4a96..8de7c7b1d56b8 100644 --- a/pkg/expression/util_test.go +++ b/pkg/expression/util_test.go @@ -154,37 +154,37 @@ func TestClone(t *testing.T) { } func TestGetUint64FromConstant(t *testing.T) { + ctx := mock.NewContext() con := &Constant{ Value: types.NewDatum(nil), } - _, isNull, ok := GetUint64FromConstant(con) + _, isNull, ok := GetUint64FromConstant(ctx, con) require.True(t, ok) require.True(t, isNull) con = &Constant{ Value: types.NewIntDatum(-1), } - _, _, ok = GetUint64FromConstant(con) + _, _, ok = GetUint64FromConstant(ctx, con) require.False(t, ok) con.Value = types.NewIntDatum(1) - num, isNull, ok := GetUint64FromConstant(con) + num, isNull, ok := GetUint64FromConstant(ctx, con) require.True(t, ok) require.False(t, isNull) require.Equal(t, uint64(1), num) con.Value = types.NewUintDatum(1) - num, _, _ = GetUint64FromConstant(con) + num, _, _ = GetUint64FromConstant(ctx, con) require.Equal(t, uint64(1), num) con.DeferredExpr = &Constant{Value: types.NewIntDatum(1)} - num, _, _ = GetUint64FromConstant(con) + num, _, _ = GetUint64FromConstant(ctx, con) require.Equal(t, uint64(1), num) - ctx := mock.NewContext() ctx.GetSessionVars().PlanCacheParams.Append(types.NewUintDatum(100)) con.ParamMarker = &ParamMarker{order: 0, ctx: ctx} - num, _, _ = GetUint64FromConstant(con) + num, _, _ = GetUint64FromConstant(ctx, con) require.Equal(t, uint64(100), num) } @@ -245,21 +245,21 @@ func TestSubstituteCorCol2Constant(t *testing.T) { corCol2 := &CorrelatedColumn{Data: &NewOne().Value} corCol2.RetType = types.NewFieldType(mysql.TypeLonglong) cast := BuildCastFunction(ctx, corCol1, types.NewFieldType(mysql.TypeLonglong)) - plus := newFunction(ast.Plus, cast, corCol2) - plus2 := newFunction(ast.Plus, plus, NewOne()) + plus := newFunctionWithMockCtx(ast.Plus, cast, corCol2) + plus2 := newFunctionWithMockCtx(ast.Plus, plus, NewOne()) ans1 := &Constant{Value: types.NewIntDatum(3), RetType: types.NewFieldType(mysql.TypeLonglong)} - ret, err := SubstituteCorCol2Constant(plus2) + ret, err := SubstituteCorCol2Constant(ctx, plus2) require.NoError(t, err) require.True(t, ret.Equal(ctx, ans1)) col1 := &Column{Index: 1, RetType: types.NewFieldType(mysql.TypeLonglong)} - ret, err = SubstituteCorCol2Constant(col1) + ret, err = SubstituteCorCol2Constant(ctx, col1) require.NoError(t, err) ans2 := col1 require.True(t, ret.Equal(ctx, ans2)) - plus3 := newFunction(ast.Plus, plus2, col1) - ret, err = SubstituteCorCol2Constant(plus3) + plus3 := newFunctionWithMockCtx(ast.Plus, plus2, col1) + ret, err = SubstituteCorCol2Constant(ctx, plus3) require.NoError(t, err) - ans3 := newFunction(ast.Plus, ans1, col1) + ans3 := newFunctionWithMockCtx(ast.Plus, ans1, col1) require.True(t, ret.Equal(ctx, ans3)) } @@ -267,14 +267,14 @@ func TestPushDownNot(t *testing.T) { ctx := mock.NewContext() col := &Column{Index: 1, RetType: types.NewFieldType(mysql.TypeLonglong)} // !((a=1||a=1)&&a=1) - eqFunc := newFunction(ast.EQ, col, NewOne()) - orFunc := newFunction(ast.LogicOr, eqFunc, eqFunc) - andFunc := newFunction(ast.LogicAnd, orFunc, eqFunc) - notFunc := newFunction(ast.UnaryNot, andFunc) + eqFunc := newFunctionWithMockCtx(ast.EQ, col, NewOne()) + orFunc := newFunctionWithMockCtx(ast.LogicOr, eqFunc, eqFunc) + andFunc := newFunctionWithMockCtx(ast.LogicAnd, orFunc, eqFunc) + notFunc := newFunctionWithMockCtx(ast.UnaryNot, andFunc) // (a!=1&&a!=1)||a=1 - neFunc := newFunction(ast.NE, col, NewOne()) - andFunc2 := newFunction(ast.LogicAnd, neFunc, neFunc) - orFunc2 := newFunction(ast.LogicOr, andFunc2, neFunc) + neFunc := newFunctionWithMockCtx(ast.NE, col, NewOne()) + andFunc2 := newFunctionWithMockCtx(ast.LogicAnd, neFunc, neFunc) + orFunc2 := newFunctionWithMockCtx(ast.LogicOr, andFunc2, neFunc) notFuncCopy := notFunc.Clone() ret := PushDownNot(ctx, notFunc) require.True(t, ret.Equal(ctx, orFunc2)) @@ -282,37 +282,37 @@ func TestPushDownNot(t *testing.T) { // issue 15725 // (not not a) should be optimized to (a is true) - notFunc = newFunction(ast.UnaryNot, col) - notFunc = newFunction(ast.UnaryNot, notFunc) + notFunc = newFunctionWithMockCtx(ast.UnaryNot, col) + notFunc = newFunctionWithMockCtx(ast.UnaryNot, notFunc) ret = PushDownNot(ctx, notFunc) - require.True(t, ret.Equal(ctx, newFunction(ast.IsTruthWithNull, col))) + require.True(t, ret.Equal(ctx, newFunctionWithMockCtx(ast.IsTruthWithNull, col))) // (not not (a+1)) should be optimized to (a+1 is true) - plusFunc := newFunction(ast.Plus, col, NewOne()) - notFunc = newFunction(ast.UnaryNot, plusFunc) - notFunc = newFunction(ast.UnaryNot, notFunc) + plusFunc := newFunctionWithMockCtx(ast.Plus, col, NewOne()) + notFunc = newFunctionWithMockCtx(ast.UnaryNot, plusFunc) + notFunc = newFunctionWithMockCtx(ast.UnaryNot, notFunc) ret = PushDownNot(ctx, notFunc) - require.True(t, ret.Equal(ctx, newFunction(ast.IsTruthWithNull, plusFunc))) + require.True(t, ret.Equal(ctx, newFunctionWithMockCtx(ast.IsTruthWithNull, plusFunc))) // (not not not a) should be optimized to (not (a is true)) - notFunc = newFunction(ast.UnaryNot, col) - notFunc = newFunction(ast.UnaryNot, notFunc) - notFunc = newFunction(ast.UnaryNot, notFunc) + notFunc = newFunctionWithMockCtx(ast.UnaryNot, col) + notFunc = newFunctionWithMockCtx(ast.UnaryNot, notFunc) + notFunc = newFunctionWithMockCtx(ast.UnaryNot, notFunc) ret = PushDownNot(ctx, notFunc) - require.True(t, ret.Equal(ctx, newFunction(ast.UnaryNot, newFunction(ast.IsTruthWithNull, col)))) + require.True(t, ret.Equal(ctx, newFunctionWithMockCtx(ast.UnaryNot, newFunctionWithMockCtx(ast.IsTruthWithNull, col)))) // (not not not not a) should be optimized to (a is true) - notFunc = newFunction(ast.UnaryNot, col) - notFunc = newFunction(ast.UnaryNot, notFunc) - notFunc = newFunction(ast.UnaryNot, notFunc) - notFunc = newFunction(ast.UnaryNot, notFunc) + notFunc = newFunctionWithMockCtx(ast.UnaryNot, col) + notFunc = newFunctionWithMockCtx(ast.UnaryNot, notFunc) + notFunc = newFunctionWithMockCtx(ast.UnaryNot, notFunc) + notFunc = newFunctionWithMockCtx(ast.UnaryNot, notFunc) ret = PushDownNot(ctx, notFunc) - require.True(t, ret.Equal(ctx, newFunction(ast.IsTruthWithNull, col))) + require.True(t, ret.Equal(ctx, newFunctionWithMockCtx(ast.IsTruthWithNull, col))) } func TestFilter(t *testing.T) { conditions := []Expression{ - newFunction(ast.EQ, newColumn(0), newColumn(1)), - newFunction(ast.EQ, newColumn(1), newColumn(2)), - newFunction(ast.LogicOr, newLonglong(1), newColumn(0)), + newFunctionWithMockCtx(ast.EQ, newColumn(0), newColumn(1)), + newFunctionWithMockCtx(ast.EQ, newColumn(1), newColumn(2)), + newFunctionWithMockCtx(ast.LogicOr, newLonglong(1), newColumn(0)), } result := make([]Expression, 0, 5) result = Filter(result, conditions, isLogicOrFunction) @@ -321,9 +321,9 @@ func TestFilter(t *testing.T) { func TestFilterOutInPlace(t *testing.T) { conditions := []Expression{ - newFunction(ast.EQ, newColumn(0), newColumn(1)), - newFunction(ast.EQ, newColumn(1), newColumn(2)), - newFunction(ast.LogicOr, newLonglong(1), newColumn(0)), + newFunctionWithMockCtx(ast.EQ, newColumn(0), newColumn(1)), + newFunctionWithMockCtx(ast.EQ, newColumn(1), newColumn(2)), + newFunctionWithMockCtx(ast.LogicOr, newLonglong(1), newColumn(0)), } remained, filtered := FilterOutInPlace(conditions, isLogicOrFunction) require.Equal(t, 2, len(remained)) @@ -459,11 +459,11 @@ func TestSQLDigestTextRetriever(t *testing.T) { func BenchmarkExtractColumns(b *testing.B) { conditions := []Expression{ - newFunction(ast.EQ, newColumn(0), newColumn(1)), - newFunction(ast.EQ, newColumn(1), newColumn(2)), - newFunction(ast.EQ, newColumn(2), newColumn(3)), - newFunction(ast.EQ, newColumn(3), newLonglong(1)), - newFunction(ast.LogicOr, newLonglong(1), newColumn(0)), + newFunctionWithMockCtx(ast.EQ, newColumn(0), newColumn(1)), + newFunctionWithMockCtx(ast.EQ, newColumn(1), newColumn(2)), + newFunctionWithMockCtx(ast.EQ, newColumn(2), newColumn(3)), + newFunctionWithMockCtx(ast.EQ, newColumn(3), newLonglong(1)), + newFunctionWithMockCtx(ast.LogicOr, newLonglong(1), newColumn(0)), } expr := ComposeCNFCondition(mock.NewContext(), conditions...) @@ -476,11 +476,11 @@ func BenchmarkExtractColumns(b *testing.B) { func BenchmarkExprFromSchema(b *testing.B) { conditions := []Expression{ - newFunction(ast.EQ, newColumn(0), newColumn(1)), - newFunction(ast.EQ, newColumn(1), newColumn(2)), - newFunction(ast.EQ, newColumn(2), newColumn(3)), - newFunction(ast.EQ, newColumn(3), newLonglong(1)), - newFunction(ast.LogicOr, newLonglong(1), newColumn(0)), + newFunctionWithMockCtx(ast.EQ, newColumn(0), newColumn(1)), + newFunctionWithMockCtx(ast.EQ, newColumn(1), newColumn(2)), + newFunctionWithMockCtx(ast.EQ, newColumn(2), newColumn(3)), + newFunctionWithMockCtx(ast.EQ, newColumn(3), newLonglong(1)), + newFunctionWithMockCtx(ast.LogicOr, newLonglong(1), newColumn(0)), } expr := ComposeCNFCondition(mock.NewContext(), conditions...) schema := &Schema{Columns: ExtractColumns(expr)} diff --git a/pkg/planner/cascades/transformation_rules.go b/pkg/planner/cascades/transformation_rules.go index c3d6360cee8a9..c2c0d9de0149b 100644 --- a/pkg/planner/cascades/transformation_rules.go +++ b/pkg/planner/cascades/transformation_rules.go @@ -549,8 +549,9 @@ func (*PushSelDownProjection) OnTransform(old *memo.ExprIter) (newExprs []*memo. } canBePushed := make([]expression.Expression, 0, len(sel.Conditions)) canNotBePushed := make([]expression.Expression, 0, len(sel.Conditions)) + ctx := sel.SCtx() for _, cond := range sel.Conditions { - substituted, hasFailed, newFilter := expression.ColumnSubstituteImpl(cond, projSchema, proj.Exprs, true) + substituted, hasFailed, newFilter := expression.ColumnSubstituteImpl(ctx, cond, projSchema, proj.Exprs, true) if substituted && !hasFailed && !expression.HasGetSetVarFunc(newFilter) { canBePushed = append(canBePushed, newFilter) } else { @@ -1303,15 +1304,16 @@ func (*PushTopNDownProjection) OnTransform(old *memo.ExprIter) (newExprs []*memo proj := old.Children[0].GetExpr().ExprNode.(*plannercore.LogicalProjection) childGroup := old.Children[0].GetExpr().Children[0] + ctx := topN.SCtx() newTopN := plannercore.LogicalTopN{ Offset: topN.Offset, Count: topN.Count, - }.Init(topN.SCtx(), topN.SelectBlockOffset()) + }.Init(ctx, topN.SelectBlockOffset()) newTopN.ByItems = make([]*util.ByItems, 0, len(topN.ByItems)) for _, by := range topN.ByItems { newTopN.ByItems = append(newTopN.ByItems, &util.ByItems{ - Expr: expression.ColumnSubstitute(by.Expr, old.Children[0].Group.Prop.Schema, proj.Exprs), + Expr: expression.ColumnSubstitute(ctx, by.Expr, old.Children[0].Group.Prop.Schema, proj.Exprs), Desc: by.Desc, }) } @@ -1522,9 +1524,10 @@ func (*MergeAggregationProjection) OnTransform(old *memo.ExprIter) (newExprs []* proj := old.Children[0].GetExpr().ExprNode.(*plannercore.LogicalProjection) projSchema := old.Children[0].GetExpr().Schema() + ctx := oldAgg.SCtx() groupByItems := make([]expression.Expression, len(oldAgg.GroupByItems)) for i, item := range oldAgg.GroupByItems { - groupByItems[i] = expression.ColumnSubstitute(item, projSchema, proj.Exprs) + groupByItems[i] = expression.ColumnSubstitute(ctx, item, projSchema, proj.Exprs) } aggFuncs := make([]*aggregation.AggFuncDesc, len(oldAgg.AggFuncs)) @@ -1532,7 +1535,7 @@ func (*MergeAggregationProjection) OnTransform(old *memo.ExprIter) (newExprs []* aggFuncs[i] = aggFunc.Clone() newArgs := make([]expression.Expression, len(aggFunc.Args)) for j, arg := range aggFunc.Args { - newArgs[j] = expression.ColumnSubstitute(arg, projSchema, proj.Exprs) + newArgs[j] = expression.ColumnSubstitute(ctx, arg, projSchema, proj.Exprs) } aggFuncs[i].Args = newArgs } @@ -1540,7 +1543,7 @@ func (*MergeAggregationProjection) OnTransform(old *memo.ExprIter) (newExprs []* newAgg := plannercore.LogicalAggregation{ GroupByItems: groupByItems, AggFuncs: aggFuncs, - }.Init(oldAgg.SCtx(), oldAgg.SelectBlockOffset()) + }.Init(ctx, oldAgg.SelectBlockOffset()) newAggExpr := memo.NewGroupExpr(newAgg) newAggExpr.SetChildren(old.Children[0].GetExpr().Children...) diff --git a/pkg/planner/core/logical_plan_builder.go b/pkg/planner/core/logical_plan_builder.go index edc04269de2cf..36e0178a4bad5 100644 --- a/pkg/planner/core/logical_plan_builder.go +++ b/pkg/planner/core/logical_plan_builder.go @@ -626,6 +626,7 @@ func (p *LogicalJoin) ExtractOnCondition( deriveLeft bool, deriveRight bool) (eqCond []*expression.ScalarFunction, leftCond []expression.Expression, rightCond []expression.Expression, otherCond []expression.Expression) { + ctx := p.SCtx() for _, expr := range conditions { // For queries like `select a in (select a from s where s.b = t.b) from t`, // if subquery is empty caused by `s.b = t.b`, the result should always be @@ -638,7 +639,6 @@ func (p *LogicalJoin) ExtractOnCondition( } binop, ok := expr.(*expression.ScalarFunction) if ok && len(binop.GetArgs()) == 2 { - ctx := binop.GetCtx() arg0, lOK := binop.GetArgs()[0].(*expression.Column) arg1, rOK := binop.GetArgs()[1].(*expression.Column) if lOK && rOK { @@ -695,13 +695,13 @@ func (p *LogicalJoin) ExtractOnCondition( // `expr AND leftRelaxedCond AND rightRelaxedCond`. Motivation is to push filters down to // children as much as possible. if deriveLeft { - leftRelaxedCond := expression.DeriveRelaxedFiltersFromDNF(expr, leftSchema) + leftRelaxedCond := expression.DeriveRelaxedFiltersFromDNF(ctx, expr, leftSchema) if leftRelaxedCond != nil { leftCond = append(leftCond, leftRelaxedCond) } } if deriveRight { - rightRelaxedCond := expression.DeriveRelaxedFiltersFromDNF(expr, rightSchema) + rightRelaxedCond := expression.DeriveRelaxedFiltersFromDNF(ctx, expr, rightSchema) if rightRelaxedCond != nil { rightCond = append(rightCond, rightRelaxedCond) } diff --git a/pkg/planner/core/logical_plans.go b/pkg/planner/core/logical_plans.go index da3ad0f3bf44f..30d028cab5739 100644 --- a/pkg/planner/core/logical_plans.go +++ b/pkg/planner/core/logical_plans.go @@ -452,28 +452,29 @@ func (p *LogicalJoin) columnSubstituteAll(schema *expression.Schema, exprs []exp copy(cpOtherConditions, p.OtherConditions) copy(cpEqualConditions, p.EqualConditions) + ctx := p.SCtx() // try to substitute columns in these condition. for i, cond := range cpLeftConditions { - if hasFail, cpLeftConditions[i] = expression.ColumnSubstituteAll(cond, schema, exprs); hasFail { + if hasFail, cpLeftConditions[i] = expression.ColumnSubstituteAll(ctx, cond, schema, exprs); hasFail { return } } for i, cond := range cpRightConditions { - if hasFail, cpRightConditions[i] = expression.ColumnSubstituteAll(cond, schema, exprs); hasFail { + if hasFail, cpRightConditions[i] = expression.ColumnSubstituteAll(ctx, cond, schema, exprs); hasFail { return } } for i, cond := range cpOtherConditions { - if hasFail, cpOtherConditions[i] = expression.ColumnSubstituteAll(cond, schema, exprs); hasFail { + if hasFail, cpOtherConditions[i] = expression.ColumnSubstituteAll(ctx, cond, schema, exprs); hasFail { return } } for i, cond := range cpEqualConditions { var tmp expression.Expression - if hasFail, tmp = expression.ColumnSubstituteAll(cond, schema, exprs); hasFail { + if hasFail, tmp = expression.ColumnSubstituteAll(ctx, cond, schema, exprs); hasFail { return } cpEqualConditions[i] = tmp.(*expression.ScalarFunction) diff --git a/pkg/planner/core/optimizer_test.go b/pkg/planner/core/optimizer_test.go index 5e876d29caecc..2b1088ded1ef6 100644 --- a/pkg/planner/core/optimizer_test.go +++ b/pkg/planner/core/optimizer_test.go @@ -431,6 +431,7 @@ func TestPrunePhysicalColumns(t *testing.T) { ExchangeType: tipb.ExchangeType_PassThrough, } hashJoin := &PhysicalHashJoin{} + hashJoin = hashJoin.Init(sctx, nil, 0) recv := &PhysicalExchangeReceiver{} recv1 := &PhysicalExchangeReceiver{} hashSender := &PhysicalExchangeSender{ diff --git a/pkg/planner/core/physical_plans.go b/pkg/planner/core/physical_plans.go index c40ce39b04524..8626ef7e94b3c 100644 --- a/pkg/planner/core/physical_plans.go +++ b/pkg/planner/core/physical_plans.go @@ -937,18 +937,19 @@ func (ts *PhysicalTableScan) IsPartition() (bool, int64) { // mem usage when rebuilding ranges during the execution phase. func (ts *PhysicalTableScan) ResolveCorrelatedColumns() ([]*ranger.Range, error) { access := ts.AccessCondition + ctx := ts.SCtx() if ts.Table.IsCommonHandle { pkIdx := tables.FindPrimaryIndex(ts.Table) idxCols, idxColLens := expression.IndexInfo2PrefixCols(ts.Columns, ts.Schema().Columns, pkIdx) for _, cond := range access { - newCond, err := expression.SubstituteCorCol2Constant(cond) + newCond, err := expression.SubstituteCorCol2Constant(ctx, cond) if err != nil { return nil, err } access = append(access, newCond) } // All of access conditions must be used to build ranges, so we don't limit range memory usage. - res, err := ranger.DetachCondAndBuildRangeForIndex(ts.SCtx(), access, idxCols, idxColLens, 0) + res, err := ranger.DetachCondAndBuildRangeForIndex(ctx, access, idxCols, idxColLens, 0) if err != nil { return nil, err } @@ -957,7 +958,7 @@ func (ts *PhysicalTableScan) ResolveCorrelatedColumns() ([]*ranger.Range, error) var err error pkTP := ts.Table.GetPkColInfo().FieldType // All of access conditions must be used to build ranges, so we don't limit range memory usage. - ts.Ranges, _, _, err = ranger.BuildTableRange(access, ts.SCtx(), &pkTP, 0) + ts.Ranges, _, _, err = ranger.BuildTableRange(access, ctx, &pkTP, 0) if err != nil { return nil, err } diff --git a/pkg/planner/core/resolve_indices.go b/pkg/planner/core/resolve_indices.go index 77e4e8d163171..e40a107a88612 100644 --- a/pkg/planner/core/resolve_indices.go +++ b/pkg/planner/core/resolve_indices.go @@ -83,6 +83,7 @@ func refine4NeighbourProj(p, childProj *PhysicalProjection) { func (p *PhysicalHashJoin) ResolveIndicesItself() (err error) { lSchema := p.children[0].Schema() rSchema := p.children[1].Schema() + ctx := p.SCtx() for i, fun := range p.EqualConditions { lArg, err := fun.GetArgs()[0].ResolveIndices(lSchema) if err != nil { @@ -94,7 +95,7 @@ func (p *PhysicalHashJoin) ResolveIndicesItself() (err error) { return err } p.RightJoinKeys[i] = rArg.(*expression.Column) - p.EqualConditions[i] = expression.NewFunctionInternal(fun.GetCtx(), fun.FuncName.L, fun.GetType(), lArg, rArg).(*expression.ScalarFunction) + p.EqualConditions[i] = expression.NewFunctionInternal(ctx, fun.FuncName.L, fun.GetType(), lArg, rArg).(*expression.ScalarFunction) } for i, fun := range p.NAEqualConditions { lArg, err := fun.GetArgs()[0].ResolveIndices(lSchema) @@ -107,7 +108,7 @@ func (p *PhysicalHashJoin) ResolveIndicesItself() (err error) { return err } p.RightNAJoinKeys[i] = rArg.(*expression.Column) - p.NAEqualConditions[i] = expression.NewFunctionInternal(fun.GetCtx(), fun.FuncName.L, fun.GetType(), lArg, rArg).(*expression.ScalarFunction) + p.NAEqualConditions[i] = expression.NewFunctionInternal(ctx, fun.FuncName.L, fun.GetType(), lArg, rArg).(*expression.ScalarFunction) } for i, expr := range p.LeftConditions { p.LeftConditions[i], err = expr.ResolveIndices(lSchema) diff --git a/pkg/planner/core/rule_aggregation_push_down.go b/pkg/planner/core/rule_aggregation_push_down.go index a1296d41f1dfd..3b7fca21fe3ce 100644 --- a/pkg/planner/core/rule_aggregation_push_down.go +++ b/pkg/planner/core/rule_aggregation_push_down.go @@ -393,13 +393,13 @@ func (*aggregationPushDownSolver) pushAggCrossUnion(agg *LogicalAggregation, uni newAggFunc := aggFunc.Clone() newArgs := make([]expression.Expression, 0, len(newAggFunc.Args)) for _, arg := range newAggFunc.Args { - newArgs = append(newArgs, expression.ColumnSubstitute(arg, unionSchema, expression.Column2Exprs(unionChild.Schema().Columns))) + newArgs = append(newArgs, expression.ColumnSubstitute(ctx, arg, unionSchema, expression.Column2Exprs(unionChild.Schema().Columns))) } newAggFunc.Args = newArgs newAgg.AggFuncs = append(newAgg.AggFuncs, newAggFunc) } for _, gbyExpr := range agg.GroupByItems { - newExpr := expression.ColumnSubstitute(gbyExpr, unionSchema, expression.Column2Exprs(unionChild.Schema().Columns)) + newExpr := expression.ColumnSubstitute(ctx, gbyExpr, unionSchema, expression.Column2Exprs(unionChild.Schema().Columns)) newAgg.GroupByItems = append(newAgg.GroupByItems, newExpr) // TODO: if there is a duplicated first_row function, we can delete it. firstRow, err := aggregation.NewAggFuncDesc(agg.SCtx(), ast.AggFuncFirstRow, []expression.Expression{gbyExpr}, false) @@ -551,10 +551,11 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan, opt *logicalOptim // push aggregation across projection // TODO: This optimization is not always reasonable. We have not supported pushing projection to kv layer yet, // so we must do this optimization. + ctx := p.SCtx() noSideEffects := true newGbyItems := make([]expression.Expression, 0, len(agg.GroupByItems)) for _, gbyItem := range agg.GroupByItems { - newGbyItems = append(newGbyItems, expression.ColumnSubstitute(gbyItem, proj.schema, proj.Exprs)) + newGbyItems = append(newGbyItems, expression.ColumnSubstitute(ctx, gbyItem, proj.schema, proj.Exprs)) if ExprsHasSideEffects(newGbyItems) { noSideEffects = false break @@ -569,7 +570,7 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan, opt *logicalOptim oldAggFuncsArgs = append(oldAggFuncsArgs, aggFunc.Args) newArgs := make([]expression.Expression, 0, len(aggFunc.Args)) for _, arg := range aggFunc.Args { - newArgs = append(newArgs, expression.ColumnSubstitute(arg, proj.schema, proj.Exprs)) + newArgs = append(newArgs, expression.ColumnSubstitute(ctx, arg, proj.schema, proj.Exprs)) } if ExprsHasSideEffects(newArgs) { noSideEffects = false @@ -581,7 +582,7 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan, opt *logicalOptim oldAggOrderItems = append(oldAggOrderItems, aggFunc.OrderByItems) newOrderByItems := make([]expression.Expression, 0, len(aggFunc.OrderByItems)) for _, oby := range aggFunc.OrderByItems { - newOrderByItems = append(newOrderByItems, expression.ColumnSubstitute(oby.Expr, proj.schema, proj.Exprs)) + newOrderByItems = append(newOrderByItems, expression.ColumnSubstitute(ctx, oby.Expr, proj.schema, proj.Exprs)) } if ExprsHasSideEffects(newOrderByItems) { noSideEffects = false diff --git a/pkg/planner/core/rule_eliminate_projection.go b/pkg/planner/core/rule_eliminate_projection.go index 46b2d761e6095..36d7dbc82fd8d 100644 --- a/pkg/planner/core/rule_eliminate_projection.go +++ b/pkg/planner/core/rule_eliminate_projection.go @@ -209,9 +209,10 @@ func (pe *projectionEliminator) eliminate(p LogicalPlan, replace map[string]*exp // eliminate duplicate projection: projection with child projection if isProj { if child, ok := p.Children()[0].(*LogicalProjection); ok && !ExprsHasSideEffects(child.Exprs) { + ctx := p.SCtx() for i := range proj.Exprs { proj.Exprs[i] = ReplaceColumnOfExpr(proj.Exprs[i], child, child.Schema()) - foldedExpr := expression.FoldConstant(proj.Exprs[i]) + foldedExpr := expression.FoldConstant(ctx, proj.Exprs[i]) // the folded expr should have the same null flag with the original expr, especially for the projection under union, so forcing it here. foldedExpr.GetType().SetFlag((foldedExpr.GetType().GetFlag() & ^mysql.NotNullFlag) | (proj.Exprs[i].GetType().GetFlag() & mysql.NotNullFlag)) proj.Exprs[i] = foldedExpr diff --git a/pkg/planner/core/rule_predicate_push_down.go b/pkg/planner/core/rule_predicate_push_down.go index 2597727ac6d4a..1a64fe1418766 100644 --- a/pkg/planner/core/rule_predicate_push_down.go +++ b/pkg/planner/core/rule_predicate_push_down.go @@ -350,7 +350,7 @@ func (p *LogicalProjection) appendExpr(expr expression.Expression) *expression.C if col, ok := expr.(*expression.Column); ok { return col } - expr = expression.ColumnSubstitute(expr, p.schema, p.Exprs) + expr = expression.ColumnSubstitute(p.SCtx(), expr, p.schema, p.Exprs) p.Exprs = append(p.Exprs, expr) col := &expression.Column{ @@ -481,8 +481,9 @@ func (p *LogicalProjection) PredicatePushDown(predicates []expression.Expression return predicates, p } } + ctx := p.SCtx() for _, cond := range predicates { - substituted, hasFailed, newFilter := expression.ColumnSubstituteImpl(cond, p.Schema(), p.Exprs, true) + substituted, hasFailed, newFilter := expression.ColumnSubstituteImpl(ctx, cond, p.Schema(), p.Exprs, true) if substituted && !hasFailed && !expression.HasGetSetVarFunc(newFilter) { canBePushed = append(canBePushed, newFilter) } else { @@ -525,7 +526,7 @@ func (la *LogicalAggregation) pushDownPredicatesForAggregation(cond expression.E } } if ok { - newFunc := expression.ColumnSubstitute(cond, la.Schema(), exprsOriginal) + newFunc := expression.ColumnSubstitute(la.SCtx(), cond, la.Schema(), exprsOriginal) condsToPush = append(condsToPush, newFunc) } else { ret = append(ret, cond) @@ -635,19 +636,20 @@ func DeriveOtherConditions( deriveLeft bool, deriveRight bool) ( leftCond []expression.Expression, rightCond []expression.Expression) { isOuterSemi := (p.JoinType == LeftOuterSemiJoin) || (p.JoinType == AntiLeftOuterSemiJoin) + ctx := p.SCtx() for _, expr := range p.OtherConditions { if deriveLeft { - leftRelaxedCond := expression.DeriveRelaxedFiltersFromDNF(expr, leftSchema) + leftRelaxedCond := expression.DeriveRelaxedFiltersFromDNF(ctx, expr, leftSchema) if leftRelaxedCond != nil { leftCond = append(leftCond, leftRelaxedCond) } - notNullExpr := deriveNotNullExpr(expr, leftSchema) + notNullExpr := deriveNotNullExpr(ctx, expr, leftSchema) if notNullExpr != nil { leftCond = append(leftCond, notNullExpr) } } if deriveRight { - rightRelaxedCond := expression.DeriveRelaxedFiltersFromDNF(expr, rightSchema) + rightRelaxedCond := expression.DeriveRelaxedFiltersFromDNF(ctx, expr, rightSchema) if rightRelaxedCond != nil { rightCond = append(rightCond, rightRelaxedCond) } @@ -661,7 +663,7 @@ func DeriveOtherConditions( if isOuterSemi { continue } - notNullExpr := deriveNotNullExpr(expr, rightSchema) + notNullExpr := deriveNotNullExpr(ctx, expr, rightSchema) if notNullExpr != nil { rightCond = append(rightCond, notNullExpr) } @@ -673,12 +675,11 @@ func DeriveOtherConditions( // deriveNotNullExpr generates a new expression `not(isnull(col))` given `col1 op col2`, // in which `col` is in specified schema. Caller guarantees that only one of `col1` or // `col2` is in schema. -func deriveNotNullExpr(expr expression.Expression, schema *expression.Schema) expression.Expression { +func deriveNotNullExpr(ctx sessionctx.Context, expr expression.Expression, schema *expression.Schema) expression.Expression { binop, ok := expr.(*expression.ScalarFunction) if !ok || len(binop.GetArgs()) != 2 { return nil } - ctx := binop.GetCtx() arg0, lOK := binop.GetArgs()[0].(*expression.Column) arg1, rOK := binop.GetArgs()[1].(*expression.Column) if !lOK || !rOK { diff --git a/pkg/planner/core/rule_predicate_simplification.go b/pkg/planner/core/rule_predicate_simplification.go index 00f65638423d2..501c78767012e 100644 --- a/pkg/planner/core/rule_predicate_simplification.go +++ b/pkg/planner/core/rule_predicate_simplification.go @@ -81,7 +81,7 @@ func (s *baseLogicalPlan) predicateSimplification(opt *logicalOptimizeOp) Logica // updateInPredicate applies intersection of an in list with <> value. It returns updated In list and a flag for // a special case if an element in the inlist is not removed to keep the list not empty. -func updateInPredicate(inPredicate expression.Expression, notEQPredicate expression.Expression) (expression.Expression, bool) { +func updateInPredicate(ctx sessionctx.Context, inPredicate expression.Expression, notEQPredicate expression.Expression) (expression.Expression, bool) { _, inPredicateType := findPredicateType(inPredicate) _, notEQPredicateType := findPredicateType(notEQPredicate) if inPredicateType != inListPredicate || notEQPredicateType != notEqualPredicate { @@ -97,7 +97,7 @@ func updateInPredicate(inPredicate expression.Expression, notEQPredicate express var lastValue *expression.Constant for _, element := range v.GetArgs() { value, valueOK := element.(*expression.Constant) - redundantValue := valueOK && value.Equal(v.GetCtx(), notEQValue) + redundantValue := valueOK && value.Equal(ctx, notEQValue) if !redundantValue { newValues = append(newValues, element) } @@ -113,7 +113,7 @@ func updateInPredicate(inPredicate expression.Expression, notEQPredicate express newValues = append(newValues, lastValue) specialCase = true } - newPred := expression.NewFunctionInternal(v.GetCtx(), v.FuncName.L, v.RetType, newValues...) + newPred := expression.NewFunctionInternal(ctx, v.FuncName.L, v.RetType, newValues...) return newPred, specialCase } @@ -131,13 +131,13 @@ func applyPredicateSimplification(sctx sessionctx.Context, predicates []expressi jCol, jType := findPredicateType(jthPredicate) if iCol == jCol { if iType == notEqualPredicate && jType == inListPredicate { - predicates[j], specialCase = updateInPredicate(jthPredicate, ithPredicate) + predicates[j], specialCase = updateInPredicate(sctx, jthPredicate, ithPredicate) sctx.GetSessionVars().StmtCtx.SetSkipPlanCache(errors.New("NE/INList simplification is triggered")) if !specialCase { removeValues = append(removeValues, i) } } else if iType == inListPredicate && jType == notEqualPredicate { - predicates[i], specialCase = updateInPredicate(ithPredicate, jthPredicate) + predicates[i], specialCase = updateInPredicate(sctx, ithPredicate, jthPredicate) sctx.GetSessionVars().StmtCtx.SetSkipPlanCache(errors.New("NE/INList simplification is triggered")) if !specialCase { removeValues = append(removeValues, j) diff --git a/pkg/planner/core/rule_topn_push_down.go b/pkg/planner/core/rule_topn_push_down.go index 7f6716cc17948..7574e741ee6ec 100644 --- a/pkg/planner/core/rule_topn_push_down.go +++ b/pkg/planner/core/rule_topn_push_down.go @@ -133,8 +133,9 @@ func (p *LogicalProjection) pushDownTopN(topN *LogicalTopN, opt *logicalOptimize } } if topN != nil { + ctx := p.SCtx() for _, by := range topN.ByItems { - by.Expr = expression.FoldConstant(expression.ColumnSubstitute(by.Expr, p.schema, p.Exprs)) + by.Expr = expression.FoldConstant(ctx, expression.ColumnSubstitute(ctx, by.Expr, p.schema, p.Exprs)) } // remove meaningless constant sort items.