Skip to content

Commit

Permalink
planner: avoid to use ScalarFunction.GetCtx in some planner codes (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
lcwangchao authored Nov 24, 2023
1 parent c771e8b commit 8243680
Show file tree
Hide file tree
Showing 26 changed files with 309 additions and 273 deletions.
14 changes: 7 additions & 7 deletions pkg/executor/aggfuncs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ func BuildWindowFunctions(ctx sessionctx.Context, windowFuncDesc *aggregation.Ag
case ast.WindowFuncCumeDist:
return buildCumeDist(ordinal, orderByCols)
case ast.WindowFuncNthValue:
return buildNthValue(windowFuncDesc, ordinal)
return buildNthValue(ctx, windowFuncDesc, ordinal)
case ast.WindowFuncNtile:
return buildNtile(windowFuncDesc, ordinal)
return buildNtile(ctx, windowFuncDesc, ordinal)
case ast.WindowFuncPercentRank:
return buildPercentRank(ordinal, orderByCols)
case ast.WindowFuncLead:
Expand Down Expand Up @@ -668,22 +668,22 @@ func buildCumeDist(ordinal int, orderByCols []*expression.Column) AggFunc {
return r
}

func buildNthValue(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
func buildNthValue(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
}
// Already checked when building the function description.
nth, _, _ := expression.GetUint64FromConstant(aggFuncDesc.Args[1])
nth, _, _ := expression.GetUint64FromConstant(ctx, aggFuncDesc.Args[1])
return &nthValue{baseAggFunc: base, tp: aggFuncDesc.RetTp, nth: nth}
}

func buildNtile(aggFuncDes *aggregation.AggFuncDesc, ordinal int) AggFunc {
func buildNtile(ctx sessionctx.Context, aggFuncDes *aggregation.AggFuncDesc, ordinal int) AggFunc {
base := baseAggFunc{
args: aggFuncDes.Args,
ordinal: ordinal,
}
n, _, _ := expression.GetUint64FromConstant(aggFuncDes.Args[0])
n, _, _ := expression.GetUint64FromConstant(ctx, aggFuncDes.Args[0])
return &ntile{baseAggFunc: base, n: n}
}

Expand All @@ -697,7 +697,7 @@ func buildPercentRank(ordinal int, orderByCols []*expression.Column) AggFunc {
func buildLeadLag(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordinal int) baseLeadLag {
offset := uint64(1)
if len(aggFuncDesc.Args) >= 2 {
offset, _, _ = expression.GetUint64FromConstant(aggFuncDesc.Args[1])
offset, _, _ = expression.GetUint64FromConstant(ctx, aggFuncDesc.Args[1])
}
var defaultExpr expression.Expression
defaultExpr = expression.NewNull()
Expand Down
2 changes: 1 addition & 1 deletion pkg/executor/distsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func closeAll(objs ...Closeable) error {
func rebuildIndexRanges(ctx sessionctx.Context, is *plannercore.PhysicalIndexScan, idxCols []*expression.Column, colLens []int) (ranges []*ranger.Range, err error) {
access := make([]expression.Expression, 0, len(is.AccessCondition))
for _, cond := range is.AccessCondition {
newCond, err1 := expression.SubstituteCorCol2Constant(cond)
newCond, err1 := expression.SubstituteCorCol2Constant(ctx, cond)
if err1 != nil {
return nil, err1
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/expression/aggregation/window_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ func NewWindowFuncDesc(ctx sessionctx.Context, name string, args []expression.Ex
if !skipCheckArgs {
switch strings.ToLower(name) {
case ast.WindowFuncNthValue:
val, isNull, ok := expression.GetUint64FromConstant(args[1])
val, isNull, ok := expression.GetUint64FromConstant(ctx, args[1])
// nth_value does not allow `0`, but allows `null`.
if !ok || (val == 0 && !isNull) {
return nil, nil
}
case ast.WindowFuncNtile:
val, isNull, ok := expression.GetUint64FromConstant(args[0])
val, isNull, ok := expression.GetUint64FromConstant(ctx, args[0])
// ntile does not allow `0`, but allows `null`.
if !ok || (val == 0 && !isNull) {
return nil, nil
Expand All @@ -51,7 +51,7 @@ func NewWindowFuncDesc(ctx sessionctx.Context, name string, args []expression.Ex
if len(args) < 2 {
break
}
_, isNull, ok := expression.GetUint64FromConstant(args[1])
_, isNull, ok := expression.GetUint64FromConstant(ctx, args[1])
if !ok || isNull {
return nil, nil
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -2153,7 +2153,7 @@ func BuildCastFunctionWithCheck(ctx sessionctx.Context, expr Expression, tp *typ
// since we may reset the flag of the field type of CastAsJson later which
// would affect the evaluation of it.
if tp.EvalType() != types.ETJson && err == nil {
res = FoldConstant(res)
res = FoldConstant(ctx, res)
}
return res, err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/expression/builtin_convert_charset.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ func BuildToBinaryFunction(ctx sessionctx.Context, expr Expression) (res Express
Function: f,
ctx: ctx,
}
return FoldConstant(res)
return FoldConstant(ctx, res)
}

// BuildFromBinaryFunction builds from_binary function.
Expand All @@ -246,7 +246,7 @@ func BuildFromBinaryFunction(ctx sessionctx.Context, expr Expression, tp *types.
Function: f,
ctx: ctx,
}
return FoldConstant(res)
return FoldConstant(ctx, res)
}

type funcProp int8
Expand Down
55 changes: 28 additions & 27 deletions pkg/expression/constant_fold.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,18 @@ package expression
import (
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/logutil"
"go.uber.org/zap"
)

// specialFoldHandler stores functions for special UDF to constant fold
var specialFoldHandler = map[string]func(*ScalarFunction) (Expression, bool){}
var specialFoldHandler = map[string]func(sessionctx.Context, *ScalarFunction) (Expression, bool){}

func init() {
specialFoldHandler = map[string]func(*ScalarFunction) (Expression, bool){
specialFoldHandler = map[string]func(sessionctx.Context, *ScalarFunction) (Expression, bool){
ast.If: ifFoldHandler,
ast.Ifnull: ifNullFoldHandler,
ast.Case: caseWhenHandler,
Expand All @@ -36,8 +37,8 @@ func init() {
}

// FoldConstant does constant folding optimization on an expression excluding deferred ones.
func FoldConstant(expr Expression) Expression {
e, _ := foldConstant(expr)
func FoldConstant(ctx sessionctx.Context, expr Expression) Expression {
e, _ := foldConstant(ctx, expr)
// keep the original coercibility, charset, collation and repertoire values after folding
e.SetCoercibility(expr.Coercibility())

Expand All @@ -48,11 +49,11 @@ func FoldConstant(expr Expression) Expression {
return e
}

func isNullHandler(expr *ScalarFunction) (Expression, bool) {
func isNullHandler(ctx sessionctx.Context, expr *ScalarFunction) (Expression, bool) {
arg0 := expr.GetArgs()[0]
if constArg, isConst := arg0.(*Constant); isConst {
isDeferredConst := constArg.DeferredExpr != nil || constArg.ParamMarker != nil
value, err := expr.EvalWithInnerCtx(chunk.Row{})
value, err := expr.Eval(ctx, chunk.Row{})
if err != nil {
// Failed to fold this expr to a constant, print the DEBUG log and
// return the original expression to let the error to be evaluated
Expand All @@ -71,11 +72,11 @@ func isNullHandler(expr *ScalarFunction) (Expression, bool) {
return expr, false
}

func ifFoldHandler(expr *ScalarFunction) (Expression, bool) {
func ifFoldHandler(ctx sessionctx.Context, expr *ScalarFunction) (Expression, bool) {
args := expr.GetArgs()
foldedArg0, _ := foldConstant(args[0])
foldedArg0, _ := foldConstant(ctx, args[0])
if constArg, isConst := foldedArg0.(*Constant); isConst {
arg0, isNull0, err := constArg.EvalInt(expr.GetCtx(), chunk.Row{})
arg0, isNull0, err := constArg.EvalInt(ctx, chunk.Row{})
if err != nil {
// Failed to fold this expr to a constant, print the DEBUG log and
// return the original expression to let the error to be evaluated
Expand All @@ -84,35 +85,35 @@ func ifFoldHandler(expr *ScalarFunction) (Expression, bool) {
return expr, false
}
if !isNull0 && arg0 != 0 {
return foldConstant(args[1])
return foldConstant(ctx, args[1])
}
return foldConstant(args[2])
return foldConstant(ctx, args[2])
}
// if the condition is not const, which branch is unknown to run, so directly return.
return expr, false
}

func ifNullFoldHandler(expr *ScalarFunction) (Expression, bool) {
func ifNullFoldHandler(ctx sessionctx.Context, expr *ScalarFunction) (Expression, bool) {
args := expr.GetArgs()
foldedArg0, isDeferred := foldConstant(args[0])
foldedArg0, isDeferred := foldConstant(ctx, args[0])
if constArg, isConst := foldedArg0.(*Constant); isConst {
// Only check constArg.Value here. Because deferred expression is
// evaluated to constArg.Value after foldConstant(args[0]), it's not
// needed to be checked.
if constArg.Value.IsNull() {
return foldConstant(args[1])
return foldConstant(ctx, args[1])
}
return constArg, isDeferred
}
// if the condition is not const, which branch is unknown to run, so directly return.
return expr, false
}

func caseWhenHandler(expr *ScalarFunction) (Expression, bool) {
func caseWhenHandler(ctx sessionctx.Context, expr *ScalarFunction) (Expression, bool) {
args, l := expr.GetArgs(), len(expr.GetArgs())
var isDeferred, isDeferredConst bool
for i := 0; i < l-1; i += 2 {
expr.GetArgs()[i], isDeferred = foldConstant(args[i])
expr.GetArgs()[i], isDeferred = foldConstant(ctx, args[i])
isDeferredConst = isDeferredConst || isDeferred
if _, isConst := expr.GetArgs()[i].(*Constant); !isConst {
// for no-const, here should return directly, because the following branches are unknown to be run or not
Expand All @@ -121,12 +122,12 @@ func caseWhenHandler(expr *ScalarFunction) (Expression, bool) {
// If the condition is const and true, and the previous conditions
// has no expr, then the folded execution body is returned, otherwise
// the arguments of the casewhen are folded and replaced.
val, isNull, err := args[i].EvalInt(expr.GetCtx(), chunk.Row{})
val, isNull, err := args[i].EvalInt(ctx, chunk.Row{})
if err != nil {
return expr, false
}
if val != 0 && !isNull {
foldedExpr, isDeferred := foldConstant(args[i+1])
foldedExpr, isDeferred := foldConstant(ctx, args[i+1])
isDeferredConst = isDeferredConst || isDeferred
if _, isConst := foldedExpr.(*Constant); isConst {
foldedExpr.GetType().SetDecimal(expr.GetType().GetDecimal())
Expand All @@ -139,7 +140,7 @@ func caseWhenHandler(expr *ScalarFunction) (Expression, bool) {
// is false, then the folded else execution body is returned. otherwise
// the execution body of the else are folded and replaced.
if l%2 == 1 {
foldedExpr, isDeferred := foldConstant(args[l-1])
foldedExpr, isDeferred := foldConstant(ctx, args[l-1])
isDeferredConst = isDeferredConst || isDeferred
if _, isConst := foldedExpr.(*Constant); isConst {
foldedExpr.GetType().SetDecimal(expr.GetType().GetDecimal())
Expand All @@ -150,18 +151,18 @@ func caseWhenHandler(expr *ScalarFunction) (Expression, bool) {
return expr, isDeferredConst
}

func foldConstant(expr Expression) (Expression, bool) {
func foldConstant(ctx sessionctx.Context, expr Expression) (Expression, bool) {
switch x := expr.(type) {
case *ScalarFunction:
if _, ok := unFoldableFunctions[x.FuncName.L]; ok {
return expr, false
}
if function := specialFoldHandler[x.FuncName.L]; function != nil && !MaybeOverOptimized4PlanCache(x.GetCtx(), []Expression{expr}) {
return function(x)
if function := specialFoldHandler[x.FuncName.L]; function != nil && !MaybeOverOptimized4PlanCache(ctx, []Expression{expr}) {
return function(ctx, x)
}

args := x.GetArgs()
sc := x.GetCtx().GetSessionVars().StmtCtx
sc := ctx.GetSessionVars().StmtCtx
argIsConst := make([]bool, len(args))
hasNullArg := false
allConstArg := true
Expand Down Expand Up @@ -193,11 +194,11 @@ func foldConstant(expr Expression) (Expression, bool) {
constArgs[i] = NewOne()
}
}
dummyScalarFunc, err := NewFunctionBase(x.GetCtx(), x.FuncName.L, x.GetType(), constArgs...)
dummyScalarFunc, err := NewFunctionBase(ctx, x.FuncName.L, x.GetType(), constArgs...)
if err != nil {
return expr, isDeferredConst
}
value, err := dummyScalarFunc.EvalWithInnerCtx(chunk.Row{})
value, err := dummyScalarFunc.Eval(ctx, chunk.Row{})
if err != nil {
return expr, isDeferredConst
}
Expand All @@ -217,7 +218,7 @@ func foldConstant(expr Expression) (Expression, bool) {
}
return expr, isDeferredConst
}
value, err := x.EvalWithInnerCtx(chunk.Row{})
value, err := x.Eval(ctx, chunk.Row{})
retType := x.RetType.Clone()
if !hasNullArg {
// set right not null flag for constant value
Expand Down Expand Up @@ -245,7 +246,7 @@ func foldConstant(expr Expression) (Expression, bool) {
ParamMarker: x.ParamMarker,
}, true
} else if x.DeferredExpr != nil {
value, err := x.DeferredExpr.EvalWithInnerCtx(chunk.Row{})
value, err := x.DeferredExpr.Eval(ctx, chunk.Row{})
if err != nil {
logutil.BgLogger().Debug("fold expression to constant", zap.String("expression", x.ExplainInfo()), zap.Error(err))
return expr, true
Expand Down
4 changes: 2 additions & 2 deletions pkg/expression/constant_propagation.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ func (s *propConstSolver) propagateConstantEQ() {
}
for i, cond := range s.conditions {
if !visited[i] {
s.conditions[i] = ColumnSubstitute(cond, NewSchema(cols...), cons)
s.conditions[i] = ColumnSubstitute(s.ctx, cond, NewSchema(cols...), cons)
}
}
}
Expand Down Expand Up @@ -470,7 +470,7 @@ func (s *propOuterJoinConstSolver) propagateConstantEQ() {
}
for i, cond := range s.joinConds {
if !visited[i+lenFilters] {
s.joinConds[i] = ColumnSubstitute(cond, NewSchema(cols...), cons)
s.joinConds[i] = ColumnSubstitute(s.ctx, cond, NewSchema(cols...), cons)
}
}
}
Expand Down
Loading

0 comments on commit 8243680

Please sign in to comment.