Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner: avoid to use ScalarFunction.GetCtx in some planner codes #48794

Merged
merged 3 commits into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions pkg/executor/aggfuncs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ func BuildWindowFunctions(ctx sessionctx.Context, windowFuncDesc *aggregation.Ag
case ast.WindowFuncCumeDist:
return buildCumeDist(ordinal, orderByCols)
case ast.WindowFuncNthValue:
return buildNthValue(windowFuncDesc, ordinal)
return buildNthValue(ctx, windowFuncDesc, ordinal)
case ast.WindowFuncNtile:
return buildNtile(windowFuncDesc, ordinal)
return buildNtile(ctx, windowFuncDesc, ordinal)
case ast.WindowFuncPercentRank:
return buildPercentRank(ordinal, orderByCols)
case ast.WindowFuncLead:
Expand Down Expand Up @@ -668,22 +668,22 @@ func buildCumeDist(ordinal int, orderByCols []*expression.Column) AggFunc {
return r
}

func buildNthValue(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
func buildNthValue(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
}
// Already checked when building the function description.
nth, _, _ := expression.GetUint64FromConstant(aggFuncDesc.Args[1])
nth, _, _ := expression.GetUint64FromConstant(ctx, aggFuncDesc.Args[1])
return &nthValue{baseAggFunc: base, tp: aggFuncDesc.RetTp, nth: nth}
}

func buildNtile(aggFuncDes *aggregation.AggFuncDesc, ordinal int) AggFunc {
func buildNtile(ctx sessionctx.Context, aggFuncDes *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseAggFunc{
args: aggFuncDes.Args,
ordinal: ordinal,
}
n, _, _ := expression.GetUint64FromConstant(aggFuncDes.Args[0])
n, _, _ := expression.GetUint64FromConstant(ctx, aggFuncDes.Args[0])
return &ntile{baseAggFunc: base, n: n}
}

Expand All @@ -697,7 +697,7 @@ func buildPercentRank(ordinal int, orderByCols []*expression.Column) AggFunc {
func buildLeadLag(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordinal int) baseLeadLag {
offset := uint64(1)
if len(aggFuncDesc.Args) >= 2 {
offset, _, _ = expression.GetUint64FromConstant(aggFuncDesc.Args[1])
offset, _, _ = expression.GetUint64FromConstant(ctx, aggFuncDesc.Args[1])
}
var defaultExpr expression.Expression
defaultExpr = expression.NewNull()
Expand Down
2 changes: 1 addition & 1 deletion pkg/executor/distsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func closeAll(objs ...Closeable) error {
func rebuildIndexRanges(ctx sessionctx.Context, is *plannercore.PhysicalIndexScan, idxCols []*expression.Column, colLens []int) (ranges []*ranger.Range, err error) {
access := make([]expression.Expression, 0, len(is.AccessCondition))
for _, cond := range is.AccessCondition {
newCond, err1 := expression.SubstituteCorCol2Constant(cond)
newCond, err1 := expression.SubstituteCorCol2Constant(ctx, cond)
if err1 != nil {
return nil, err1
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/expression/aggregation/window_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ func NewWindowFuncDesc(ctx sessionctx.Context, name string, args []expression.Ex
if !skipCheckArgs {
switch strings.ToLower(name) {
case ast.WindowFuncNthValue:
val, isNull, ok := expression.GetUint64FromConstant(args[1])
val, isNull, ok := expression.GetUint64FromConstant(ctx, args[1])
// nth_value does not allow `0`, but allows `null`.
if !ok || (val == 0 && !isNull) {
return nil, nil
}
case ast.WindowFuncNtile:
val, isNull, ok := expression.GetUint64FromConstant(args[0])
val, isNull, ok := expression.GetUint64FromConstant(ctx, args[0])
// ntile does not allow `0`, but allows `null`.
if !ok || (val == 0 && !isNull) {
return nil, nil
Expand All @@ -51,7 +51,7 @@ func NewWindowFuncDesc(ctx sessionctx.Context, name string, args []expression.Ex
if len(args) < 2 {
break
}
_, isNull, ok := expression.GetUint64FromConstant(args[1])
_, isNull, ok := expression.GetUint64FromConstant(ctx, args[1])
if !ok || isNull {
return nil, nil
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -2153,7 +2153,7 @@ func BuildCastFunctionWithCheck(ctx sessionctx.Context, expr Expression, tp *typ
// since we may reset the flag of the field type of CastAsJson later which
// would affect the evaluation of it.
if tp.EvalType() != types.ETJson && err == nil {
res = FoldConstant(res)
res = FoldConstant(ctx, res)
}
return res, err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/expression/builtin_convert_charset.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ func BuildToBinaryFunction(ctx sessionctx.Context, expr Expression) (res Express
Function: f,
ctx: ctx,
}
return FoldConstant(res)
return FoldConstant(ctx, res)
}

// BuildFromBinaryFunction builds from_binary function.
Expand All @@ -246,7 +246,7 @@ func BuildFromBinaryFunction(ctx sessionctx.Context, expr Expression, tp *types.
Function: f,
ctx: ctx,
}
return FoldConstant(res)
return FoldConstant(ctx, res)
}

type funcProp int8
Expand Down
55 changes: 28 additions & 27 deletions pkg/expression/constant_fold.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,18 @@ package expression
import (
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/logutil"
"go.uber.org/zap"
)

// specialFoldHandler stores functions for special UDF to constant fold
var specialFoldHandler = map[string]func(*ScalarFunction) (Expression, bool){}
var specialFoldHandler = map[string]func(sessionctx.Context, *ScalarFunction) (Expression, bool){}

func init() {
specialFoldHandler = map[string]func(*ScalarFunction) (Expression, bool){
specialFoldHandler = map[string]func(sessionctx.Context, *ScalarFunction) (Expression, bool){
ast.If: ifFoldHandler,
ast.Ifnull: ifNullFoldHandler,
ast.Case: caseWhenHandler,
Expand All @@ -36,8 +37,8 @@ func init() {
}

// FoldConstant does constant folding optimization on an expression excluding deferred ones.
func FoldConstant(expr Expression) Expression {
e, _ := foldConstant(expr)
func FoldConstant(ctx sessionctx.Context, expr Expression) Expression {
e, _ := foldConstant(ctx, expr)
// keep the original coercibility, charset, collation and repertoire values after folding
e.SetCoercibility(expr.Coercibility())

Expand All @@ -48,11 +49,11 @@ func FoldConstant(expr Expression) Expression {
return e
}

func isNullHandler(expr *ScalarFunction) (Expression, bool) {
func isNullHandler(ctx sessionctx.Context, expr *ScalarFunction) (Expression, bool) {
arg0 := expr.GetArgs()[0]
if constArg, isConst := arg0.(*Constant); isConst {
isDeferredConst := constArg.DeferredExpr != nil || constArg.ParamMarker != nil
value, err := expr.EvalWithInnerCtx(chunk.Row{})
value, err := expr.Eval(ctx, chunk.Row{})
if err != nil {
// Failed to fold this expr to a constant, print the DEBUG log and
// return the original expression to let the error to be evaluated
Expand All @@ -71,11 +72,11 @@ func isNullHandler(expr *ScalarFunction) (Expression, bool) {
return expr, false
}

func ifFoldHandler(expr *ScalarFunction) (Expression, bool) {
func ifFoldHandler(ctx sessionctx.Context, expr *ScalarFunction) (Expression, bool) {
args := expr.GetArgs()
foldedArg0, _ := foldConstant(args[0])
foldedArg0, _ := foldConstant(ctx, args[0])
if constArg, isConst := foldedArg0.(*Constant); isConst {
arg0, isNull0, err := constArg.EvalInt(expr.GetCtx(), chunk.Row{})
arg0, isNull0, err := constArg.EvalInt(ctx, chunk.Row{})
if err != nil {
// Failed to fold this expr to a constant, print the DEBUG log and
// return the original expression to let the error to be evaluated
Expand All @@ -84,35 +85,35 @@ func ifFoldHandler(expr *ScalarFunction) (Expression, bool) {
return expr, false
}
if !isNull0 && arg0 != 0 {
return foldConstant(args[1])
return foldConstant(ctx, args[1])
}
return foldConstant(args[2])
return foldConstant(ctx, args[2])
}
// if the condition is not const, which branch is unknown to run, so directly return.
return expr, false
}

func ifNullFoldHandler(expr *ScalarFunction) (Expression, bool) {
func ifNullFoldHandler(ctx sessionctx.Context, expr *ScalarFunction) (Expression, bool) {
args := expr.GetArgs()
foldedArg0, isDeferred := foldConstant(args[0])
foldedArg0, isDeferred := foldConstant(ctx, args[0])
if constArg, isConst := foldedArg0.(*Constant); isConst {
// Only check constArg.Value here. Because deferred expression is
// evaluated to constArg.Value after foldConstant(args[0]), it's not
// needed to be checked.
if constArg.Value.IsNull() {
return foldConstant(args[1])
return foldConstant(ctx, args[1])
}
return constArg, isDeferred
}
// if the condition is not const, which branch is unknown to run, so directly return.
return expr, false
}

func caseWhenHandler(expr *ScalarFunction) (Expression, bool) {
func caseWhenHandler(ctx sessionctx.Context, expr *ScalarFunction) (Expression, bool) {
args, l := expr.GetArgs(), len(expr.GetArgs())
var isDeferred, isDeferredConst bool
for i := 0; i < l-1; i += 2 {
expr.GetArgs()[i], isDeferred = foldConstant(args[i])
expr.GetArgs()[i], isDeferred = foldConstant(ctx, args[i])
isDeferredConst = isDeferredConst || isDeferred
if _, isConst := expr.GetArgs()[i].(*Constant); !isConst {
// for no-const, here should return directly, because the following branches are unknown to be run or not
Expand All @@ -121,12 +122,12 @@ func caseWhenHandler(expr *ScalarFunction) (Expression, bool) {
// If the condition is const and true, and the previous conditions
// has no expr, then the folded execution body is returned, otherwise
// the arguments of the casewhen are folded and replaced.
val, isNull, err := args[i].EvalInt(expr.GetCtx(), chunk.Row{})
val, isNull, err := args[i].EvalInt(ctx, chunk.Row{})
if err != nil {
return expr, false
}
if val != 0 && !isNull {
foldedExpr, isDeferred := foldConstant(args[i+1])
foldedExpr, isDeferred := foldConstant(ctx, args[i+1])
isDeferredConst = isDeferredConst || isDeferred
if _, isConst := foldedExpr.(*Constant); isConst {
foldedExpr.GetType().SetDecimal(expr.GetType().GetDecimal())
Expand All @@ -139,7 +140,7 @@ func caseWhenHandler(expr *ScalarFunction) (Expression, bool) {
// is false, then the folded else execution body is returned. otherwise
// the execution body of the else are folded and replaced.
if l%2 == 1 {
foldedExpr, isDeferred := foldConstant(args[l-1])
foldedExpr, isDeferred := foldConstant(ctx, args[l-1])
isDeferredConst = isDeferredConst || isDeferred
if _, isConst := foldedExpr.(*Constant); isConst {
foldedExpr.GetType().SetDecimal(expr.GetType().GetDecimal())
Expand All @@ -150,18 +151,18 @@ func caseWhenHandler(expr *ScalarFunction) (Expression, bool) {
return expr, isDeferredConst
}

func foldConstant(expr Expression) (Expression, bool) {
func foldConstant(ctx sessionctx.Context, expr Expression) (Expression, bool) {
switch x := expr.(type) {
case *ScalarFunction:
if _, ok := unFoldableFunctions[x.FuncName.L]; ok {
return expr, false
}
if function := specialFoldHandler[x.FuncName.L]; function != nil && !MaybeOverOptimized4PlanCache(x.GetCtx(), []Expression{expr}) {
return function(x)
if function := specialFoldHandler[x.FuncName.L]; function != nil && !MaybeOverOptimized4PlanCache(ctx, []Expression{expr}) {
return function(ctx, x)
}

args := x.GetArgs()
sc := x.GetCtx().GetSessionVars().StmtCtx
sc := ctx.GetSessionVars().StmtCtx
argIsConst := make([]bool, len(args))
hasNullArg := false
allConstArg := true
Expand Down Expand Up @@ -193,11 +194,11 @@ func foldConstant(expr Expression) (Expression, bool) {
constArgs[i] = NewOne()
}
}
dummyScalarFunc, err := NewFunctionBase(x.GetCtx(), x.FuncName.L, x.GetType(), constArgs...)
dummyScalarFunc, err := NewFunctionBase(ctx, x.FuncName.L, x.GetType(), constArgs...)
if err != nil {
return expr, isDeferredConst
}
value, err := dummyScalarFunc.EvalWithInnerCtx(chunk.Row{})
value, err := dummyScalarFunc.Eval(ctx, chunk.Row{})
if err != nil {
return expr, isDeferredConst
}
Expand All @@ -217,7 +218,7 @@ func foldConstant(expr Expression) (Expression, bool) {
}
return expr, isDeferredConst
}
value, err := x.EvalWithInnerCtx(chunk.Row{})
value, err := x.Eval(ctx, chunk.Row{})
retType := x.RetType.Clone()
if !hasNullArg {
// set right not null flag for constant value
Expand Down Expand Up @@ -245,7 +246,7 @@ func foldConstant(expr Expression) (Expression, bool) {
ParamMarker: x.ParamMarker,
}, true
} else if x.DeferredExpr != nil {
value, err := x.DeferredExpr.EvalWithInnerCtx(chunk.Row{})
value, err := x.DeferredExpr.Eval(ctx, chunk.Row{})
if err != nil {
logutil.BgLogger().Debug("fold expression to constant", zap.String("expression", x.ExplainInfo()), zap.Error(err))
return expr, true
Expand Down
4 changes: 2 additions & 2 deletions pkg/expression/constant_propagation.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ func (s *propConstSolver) propagateConstantEQ() {
}
for i, cond := range s.conditions {
if !visited[i] {
s.conditions[i] = ColumnSubstitute(cond, NewSchema(cols...), cons)
s.conditions[i] = ColumnSubstitute(s.ctx, cond, NewSchema(cols...), cons)
}
}
}
Expand Down Expand Up @@ -470,7 +470,7 @@ func (s *propOuterJoinConstSolver) propagateConstantEQ() {
}
for i, cond := range s.joinConds {
if !visited[i+lenFilters] {
s.joinConds[i] = ColumnSubstitute(cond, NewSchema(cols...), cons)
s.joinConds[i] = ColumnSubstitute(s.ctx, cond, NewSchema(cols...), cons)
}
}
}
Expand Down
Loading