diff --git a/executor/aggfuncs/builder.go b/executor/aggfuncs/builder.go index 6d98d946ac42e..7d5c559d26ffc 100644 --- a/executor/aggfuncs/builder.go +++ b/executor/aggfuncs/builder.go @@ -65,6 +65,10 @@ func BuildWindowFunctions(ctx sessionctx.Context, windowFuncDesc *aggregation.Ag return buildRank(ordinal, orderByCols, true) case ast.WindowFuncRowNumber: return buildRowNumber(windowFuncDesc, ordinal) + case ast.WindowFuncFirstValue: + return buildFirstValue(windowFuncDesc, ordinal) + case ast.WindowFuncLastValue: + return buildLastValue(windowFuncDesc, ordinal) default: return Build(ctx, windowFuncDesc, ordinal) } @@ -348,3 +352,19 @@ func buildRank(ordinal int, orderByCols []*expression.Column, isDense bool) AggF } return r } + +func buildFirstValue(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { + base := baseAggFunc{ + args: aggFuncDesc.Args, + ordinal: ordinal, + } + return &firstValue{baseAggFunc: base, tp: aggFuncDesc.RetTp} +} + +func buildLastValue(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { + base := baseAggFunc{ + args: aggFuncDesc.Args, + ordinal: ordinal, + } + return &lastValue{baseAggFunc: base, tp: aggFuncDesc.RetTp} +} diff --git a/executor/aggfuncs/func_value.go b/executor/aggfuncs/func_value.go new file mode 100644 index 0000000000000..18c552e7a1bed --- /dev/null +++ b/executor/aggfuncs/func_value.go @@ -0,0 +1,302 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggfuncs + +import ( + "github.com/pingcap/parser/mysql" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/types/json" + "github.com/pingcap/tidb/util/chunk" +) + +// valueEvaluator is used to evaluate values for `first_value`, `last_value`, `nth_value`, +// `lead` and `lag`. +type valueEvaluator interface { + // evaluateRow evaluates the expression using row and stores the result inside. + evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error + // appendResult appends the result to chunk. + appendResult(chk *chunk.Chunk, colIdx int) +} + +type value4Int struct { + val int64 + isNull bool +} + +func (v *value4Int) evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error { + var err error + v.val, v.isNull, err = expr.EvalInt(ctx, row) + return err +} + +func (v *value4Int) appendResult(chk *chunk.Chunk, colIdx int) { + if v.isNull { + chk.AppendNull(colIdx) + } else { + chk.AppendInt64(colIdx, v.val) + } +} + +type value4Float32 struct { + val float32 + isNull bool +} + +func (v *value4Float32) evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error { + var err error + var val float64 + val, v.isNull, err = expr.EvalReal(ctx, row) + v.val = float32(val) + return err +} + +func (v *value4Float32) appendResult(chk *chunk.Chunk, colIdx int) { + if v.isNull { + chk.AppendNull(colIdx) + } else { + chk.AppendFloat32(colIdx, v.val) + } +} + +type value4Decimal struct { + val *types.MyDecimal + isNull bool +} + +func (v *value4Decimal) evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error { + var err error + v.val, v.isNull, err = expr.EvalDecimal(ctx, row) + return err +} + +func (v *value4Decimal) appendResult(chk *chunk.Chunk, colIdx int) { + if v.isNull { + chk.AppendNull(colIdx) + } else { + chk.AppendMyDecimal(colIdx, v.val) + } +} + +type value4Float64 struct { + val float64 + isNull bool +} + +func (v *value4Float64) evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error { + var err error + v.val, v.isNull, err = expr.EvalReal(ctx, row) + return err +} + +func (v *value4Float64) appendResult(chk *chunk.Chunk, colIdx int) { + if v.isNull { + chk.AppendNull(colIdx) + } else { + chk.AppendFloat64(colIdx, v.val) + } +} + +type value4String struct { + val string + isNull bool +} + +func (v *value4String) evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error { + var err error + v.val, v.isNull, err = expr.EvalString(ctx, row) + return err +} + +func (v *value4String) appendResult(chk *chunk.Chunk, colIdx int) { + if v.isNull { + chk.AppendNull(colIdx) + } else { + chk.AppendString(colIdx, v.val) + } +} + +type value4Time struct { + val types.Time + isNull bool +} + +func (v *value4Time) evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error { + var err error + v.val, v.isNull, err = expr.EvalTime(ctx, row) + return err +} + +func (v *value4Time) appendResult(chk *chunk.Chunk, colIdx int) { + if v.isNull { + chk.AppendNull(colIdx) + } else { + chk.AppendTime(colIdx, v.val) + } +} + +type value4Duration struct { + val types.Duration + isNull bool +} + +func (v *value4Duration) evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error { + var err error + v.val, v.isNull, err = expr.EvalDuration(ctx, row) + return err +} + +func (v *value4Duration) appendResult(chk *chunk.Chunk, colIdx int) { + if v.isNull { + chk.AppendNull(colIdx) + } else { + chk.AppendDuration(colIdx, v.val) + } +} + +type value4JSON struct { + val json.BinaryJSON + isNull bool +} + +func (v *value4JSON) evaluateRow(ctx sessionctx.Context, expr expression.Expression, row chunk.Row) error { + var err error + v.val, v.isNull, err = expr.EvalJSON(ctx, row) + return err +} + +func (v *value4JSON) appendResult(chk *chunk.Chunk, colIdx int) { + if v.isNull { + chk.AppendNull(colIdx) + } else { + chk.AppendJSON(colIdx, v.val) + } +} + +func buildValueEvaluator(tp *types.FieldType) valueEvaluator { + evalType := tp.EvalType() + if tp.Tp == mysql.TypeBit { + evalType = types.ETString + } + switch evalType { + case types.ETInt: + return &value4Int{} + case types.ETReal: + switch tp.Tp { + case mysql.TypeFloat: + return &value4Float32{} + case mysql.TypeDouble: + return &value4Float64{} + } + case types.ETDecimal: + return &value4Decimal{} + case types.ETDatetime, types.ETTimestamp: + return &value4Time{} + case types.ETDuration: + return &value4Duration{} + case types.ETString: + return &value4String{} + case types.ETJson: + return &value4JSON{} + } + return nil +} + +type firstValue struct { + baseAggFunc + + tp *types.FieldType +} + +type partialResult4FirstValue struct { + gotFirstValue bool + evaluator valueEvaluator +} + +func (v *firstValue) AllocPartialResult() PartialResult { + return PartialResult(&partialResult4FirstValue{evaluator: buildValueEvaluator(v.tp)}) +} + +func (v *firstValue) ResetPartialResult(pr PartialResult) { + p := (*partialResult4FirstValue)(pr) + p.gotFirstValue = false +} + +func (v *firstValue) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + p := (*partialResult4FirstValue)(pr) + if p.gotFirstValue { + return nil + } + if len(rowsInGroup) > 0 { + p.gotFirstValue = true + err := p.evaluator.evaluateRow(sctx, v.args[0], rowsInGroup[0]) + if err != nil { + return err + } + } + return nil +} + +func (v *firstValue) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4FirstValue)(pr) + if !p.gotFirstValue { + chk.AppendNull(v.ordinal) + } else { + p.evaluator.appendResult(chk, v.ordinal) + } + return nil +} + +type lastValue struct { + baseAggFunc + + tp *types.FieldType +} + +type partialResult4LastValue struct { + gotLastValue bool + evaluator valueEvaluator +} + +func (v *lastValue) AllocPartialResult() PartialResult { + return PartialResult(&partialResult4LastValue{evaluator: buildValueEvaluator(v.tp)}) +} + +func (v *lastValue) ResetPartialResult(pr PartialResult) { + p := (*partialResult4LastValue)(pr) + p.gotLastValue = false +} + +func (v *lastValue) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + p := (*partialResult4LastValue)(pr) + if len(rowsInGroup) > 0 { + p.gotLastValue = true + err := p.evaluator.evaluateRow(sctx, v.args[0], rowsInGroup[len(rowsInGroup)-1]) + if err != nil { + return err + } + } + return nil +} + +func (v *lastValue) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4LastValue)(pr) + if !p.gotLastValue { + chk.AppendNull(v.ordinal) + } else { + p.evaluator.appendResult(chk, v.ordinal) + } + return nil +} diff --git a/executor/window_test.go b/executor/window_test.go index c4cd3ae812102..4870251fa3b40 100644 --- a/executor/window_test.go +++ b/executor/window_test.go @@ -93,4 +93,14 @@ func (s *testSuite2) TestWindowFunctions(c *C) { result.Check(testkit.Rows("1 2", "1 2", "2 6", "2 6")) result = tk.MustQuery("select a, sum(a) over(order by a, b) from t") result.Check(testkit.Rows("1 1", "1 2", "2 4", "2 6")) + + result = tk.MustQuery("select a, first_value(a) over(), last_value(a) over() from t") + result.Check(testkit.Rows("1 1 2", "1 1 2", "2 1 2", "2 1 2")) + result = tk.MustQuery("select a, first_value(a) over(rows between 1 preceding and 1 following), last_value(a) over(rows between 1 preceding and 1 following) from t") + result.Check(testkit.Rows("1 1 1", "1 1 2", "2 1 2", "2 2 2")) + result = tk.MustQuery("select a, first_value(a) over(rows between 1 following and 1 following), last_value(a) over(rows between 1 following and 1 following) from t") + result.Check(testkit.Rows("1 1 1", "1 2 2", "2 2 2", "2 ")) + result = tk.MustQuery("select a, first_value(rand(0)) over(), last_value(rand(0)) over() from t") + result.Check(testkit.Rows("1 0.9451961492941164 0.05434383959970039", "1 0.9451961492941164 0.05434383959970039", + "2 0.9451961492941164 0.05434383959970039", "2 0.9451961492941164 0.05434383959970039")) } diff --git a/expression/aggregation/base_func.go b/expression/aggregation/base_func.go index 717bdf2b9613f..cfd1f16b74e39 100644 --- a/expression/aggregation/base_func.go +++ b/expression/aggregation/base_func.go @@ -91,7 +91,8 @@ func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) { a.typeInfer4Avg(ctx) case ast.AggFuncGroupConcat: a.typeInfer4GroupConcat(ctx) - case ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncFirstRow: + case ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncFirstRow, + ast.WindowFuncFirstValue, ast.WindowFuncLastValue: a.typeInfer4MaxMin(ctx) case ast.AggFuncBitAnd, ast.AggFuncBitOr, ast.AggFuncBitXor: a.typeInfer4BitFuncs(ctx)