Skip to content

Commit

Permalink
executor: support window function lead and lag (#9672)
Browse files Browse the repository at this point in the history
  • Loading branch information
alivxxx authored Mar 14, 2019
1 parent fdca44c commit 4422a23
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 5 deletions.
29 changes: 29 additions & 0 deletions executor/aggfuncs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ func BuildWindowFunctions(ctx sessionctx.Context, windowFuncDesc *aggregation.Ag
return buildNthValue(windowFuncDesc, ordinal)
case ast.WindowFuncPercentRank:
return buildPercenRank(ordinal, orderByCols)
case ast.WindowFuncLead:
return buildLead(windowFuncDesc, ordinal)
case ast.WindowFuncLag:
return buildLag(windowFuncDesc, ordinal)
default:
return Build(ctx, windowFuncDesc, ordinal)
}
Expand Down Expand Up @@ -395,3 +399,28 @@ func buildPercenRank(ordinal int, orderByCols []*expression.Column) AggFunc {
}
return &percentRank{baseAggFunc: base, rowComparer: buildRowComparer(orderByCols)}
}

func buildLeadLag(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) baseLeadLag {
offset := uint64(1)
if len(aggFuncDesc.Args) >= 2 {
offset, _, _ = expression.GetUint64FromConstant(aggFuncDesc.Args[1])
}
var defaultExpr expression.Expression
defaultExpr = expression.Null
if len(aggFuncDesc.Args) == 3 {
defaultExpr = aggFuncDesc.Args[2]
}
base := baseAggFunc{
args: aggFuncDesc.Args,
ordinal: ordinal,
}
return baseLeadLag{baseAggFunc: base, offset: offset, defaultExpr: defaultExpr, valueEvaluator: buildValueEvaluator(aggFuncDesc.RetTp)}
}

func buildLead(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
return &lead{buildLeadLag(aggFuncDesc, ordinal)}
}

func buildLag(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
return &lag{buildLeadLag(aggFuncDesc, ordinal)}
}
89 changes: 89 additions & 0 deletions executor/aggfuncs/func_lead_lag.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright 2019 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package aggfuncs

import (
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/util/chunk"
)

type baseLeadLag struct {
baseAggFunc
valueEvaluator // TODO: move it to partial result when parallel execution is supported.

defaultExpr expression.Expression
offset uint64
}

type partialResult4LeadLag struct {
rows []chunk.Row
curIdx uint64
}

func (v *baseLeadLag) AllocPartialResult() PartialResult {
return PartialResult(&partialResult4LeadLag{})
}

func (v *baseLeadLag) ResetPartialResult(pr PartialResult) {
p := (*partialResult4LeadLag)(pr)
p.rows = p.rows[:0]
p.curIdx = 0
}

func (v *baseLeadLag) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error {
p := (*partialResult4LeadLag)(pr)
p.rows = append(p.rows, rowsInGroup...)
return nil
}

type lead struct {
baseLeadLag
}

func (v *lead) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4LeadLag)(pr)
var err error
if p.curIdx+v.offset < uint64(len(p.rows)) {
err = v.evaluateRow(sctx, v.args[0], p.rows[p.curIdx+v.offset])
} else {
err = v.evaluateRow(sctx, v.defaultExpr, p.rows[p.curIdx])
}
if err != nil {
return err
}
v.appendResult(chk, v.ordinal)
p.curIdx++
return nil
}

type lag struct {
baseLeadLag
}

func (v *lag) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4LeadLag)(pr)
var err error
if p.curIdx >= v.offset {
err = v.evaluateRow(sctx, v.args[0], p.rows[p.curIdx-v.offset])
} else {
err = v.evaluateRow(sctx, v.defaultExpr, p.rows[p.curIdx])
}
if err != nil {
return err
}
v.appendResult(chk, v.ordinal)
p.curIdx++
return nil
}
9 changes: 9 additions & 0 deletions executor/window_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,13 @@ func (s *testSuite2) TestWindowFunctions(c *C) {
result.Check(testkit.Rows("1 0", "1 0", "2 0.6666666666666666", "2 0.6666666666666666"))
result = tk.MustQuery("select a, b, percent_rank() over(order by a, b) from t")
result.Check(testkit.Rows("1 1 0", "1 2 0.3333333333333333", "2 1 0.6666666666666666", "2 2 1"))

result = tk.MustQuery("select a, lead(a) over (), lag(a) over() from t")
result.Check(testkit.Rows("1 1 <nil>", "1 2 1", "2 2 1", "2 <nil> 2"))
result = tk.MustQuery("select a, lead(a, 0) over(), lag(a, 0) over() from t")
result.Check(testkit.Rows("1 1 1", "1 1 1", "2 2 2", "2 2 2"))
result = tk.MustQuery("select a, lead(a, 1, a) over(), lag(a, 1, a) over() from t")
result.Check(testkit.Rows("1 1 1", "1 2 1", "2 2 1", "2 2 2"))
result = tk.MustQuery("select a, lead(a, 1, 'lead') over(), lag(a, 1, 'lag') over() from t")
result.Check(testkit.Rows("1 1 lag", "1 2 1", "2 2 1", "2 lead 2"))
}
15 changes: 15 additions & 0 deletions expression/aggregation/base_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) {
a.typeInfer4CumeDist()
case ast.WindowFuncPercentRank:
a.typeInfer4PercentRank()
case ast.WindowFuncLead, ast.WindowFuncLag:
a.typeInfer4LeadLag(ctx)
default:
panic("unsupported agg function: " + a.Name)
}
Expand Down Expand Up @@ -207,6 +209,15 @@ func (a *baseFuncDesc) typeInfer4PercentRank() {
a.RetTp.Flag, a.RetTp.Decimal = mysql.MaxRealWidth, mysql.NotFixedDec
}

func (a *baseFuncDesc) typeInfer4LeadLag(ctx sessionctx.Context) {
if len(a.Args) <= 2 {
a.typeInfer4MaxMin(ctx)
} else {
// Merge the type of first and third argument.
a.RetTp = expression.InferType4ControlFuncs(a.Args[0].GetType(), a.Args[2].GetType())
}
}

// GetDefaultValue gets the default value when the function's input is null.
// According to MySQL, default values of the function are listed as follows:
// e.g.
Expand Down Expand Up @@ -265,6 +276,10 @@ func (a *baseFuncDesc) WrapCastForAggArgs(ctx sessionctx.Context) {
panic("should never happen in baseFuncDesc.WrapCastForAggArgs")
}
for i := range a.Args {
// Do not cast the second args of these functions, as they are simply non-negative numbers.
if i == 1 && (a.Name == ast.WindowFuncLead || a.Name == ast.WindowFuncLag || a.Name == ast.WindowFuncNthValue) {
continue
}
a.Args[i] = castFunc(ctx, a.Args[i])
if a.Name != ast.AggFuncAvg && a.Name != ast.AggFuncSum {
continue
Expand Down
11 changes: 10 additions & 1 deletion expression/aggregation/window_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,21 @@ type WindowFuncDesc struct {

// NewWindowFuncDesc creates a window function signature descriptor.
func NewWindowFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression) *WindowFuncDesc {
if strings.ToLower(name) == ast.WindowFuncNthValue {
switch strings.ToLower(name) {
case ast.WindowFuncNthValue:
val, isNull, ok := expression.GetUint64FromConstant(args[1])
// nth_value does not allow `0`, but allows `null`.
if !ok || (val == 0 && !isNull) {
return nil
}
case ast.WindowFuncLead, ast.WindowFuncLag:
if len(args) < 2 {
break
}
_, isNull, ok := expression.GetUint64FromConstant(args[1])
if !ok || isNull {
return nil
}
}
return &WindowFuncDesc{newBaseFuncDesc(ctx, name, args)}
}
Expand Down
8 changes: 4 additions & 4 deletions expression/builtin_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ var (
_ builtinFunc = &builtinIfJSONSig{}
)

// inferType4ControlFuncs infer result type for builtin IF, IFNULL && NULLIF.
func inferType4ControlFuncs(lhs, rhs *types.FieldType) *types.FieldType {
// InferType4ControlFuncs infer result type for builtin IF, IFNULL, NULLIF, LEAD and LAG.
func InferType4ControlFuncs(lhs, rhs *types.FieldType) *types.FieldType {
resultFieldType := &types.FieldType{}
if lhs.Tp == mysql.TypeNull {
*resultFieldType = *rhs
Expand Down Expand Up @@ -470,7 +470,7 @@ func (c *ifFunctionClass) getFunction(ctx sessionctx.Context, args []Expression)
if err = c.verifyArgs(args); err != nil {
return nil, err
}
retTp := inferType4ControlFuncs(args[1].GetType(), args[2].GetType())
retTp := InferType4ControlFuncs(args[1].GetType(), args[2].GetType())
evalTps := retTp.EvalType()
bf := newBaseBuiltinFuncWithTp(ctx, args, evalTps, types.ETInt, evalTps, evalTps)
retTp.Flag |= bf.tp.Flag
Expand Down Expand Up @@ -680,7 +680,7 @@ func (c *ifNullFunctionClass) getFunction(ctx sessionctx.Context, args []Express
return nil, err
}
lhs, rhs := args[0].GetType(), args[1].GetType()
retTp := inferType4ControlFuncs(lhs, rhs)
retTp := InferType4ControlFuncs(lhs, rhs)
retTp.Flag |= (lhs.Flag & mysql.NotNullFlag) | (rhs.Flag & mysql.NotNullFlag)
if lhs.Tp == mysql.TypeNull && rhs.Tp == mysql.TypeNull {
retTp.Tp = mysql.TypeNull
Expand Down

0 comments on commit 4422a23

Please sign in to comment.