From 6b5eb658ed1da5d787b2897881b9cf145832b632 Mon Sep 17 00:00:00 2001 From: HuaiyuXu <391585975@qq.com> Date: Mon, 24 Jun 2019 13:09:14 +0800 Subject: [PATCH] *: let baseFuncDesc.typeInfer return error instead of panic (#10910) --- cmd/explaintest/r/explain_easy.result | 6 +- executor/aggfuncs/aggfunc_test.go | 9 ++- executor/aggfuncs/window_func_test.go | 3 +- executor/aggregate_test.go | 17 ++++ executor/benchmark_test.go | 5 +- executor/builder.go | 6 +- executor/executor_required_rows_test.go | 6 +- expression/aggregation/agg_to_pb_test.go | 6 +- expression/aggregation/aggregation_test.go | 77 +++++++++++++------ expression/aggregation/base_func.go | 12 +-- expression/aggregation/base_func_test.go | 3 +- expression/aggregation/bench_test.go | 24 +++++- expression/aggregation/descriptor.go | 9 ++- expression/aggregation/window_func.go | 14 ++-- planner/core/expression_rewriter.go | 48 ++++++++++-- planner/core/logical_plan_builder.go | 34 ++++++-- planner/core/rule_aggregation_push_down.go | 53 ++++++++----- planner/core/rule_decorrelate.go | 10 ++- .../core/rule_inject_extra_projection_test.go | 3 +- 19 files changed, 252 insertions(+), 93 deletions(-) diff --git a/cmd/explaintest/r/explain_easy.result b/cmd/explaintest/r/explain_easy.result index 6583b65c47156..ba83ace8edb2b 100644 --- a/cmd/explaintest/r/explain_easy.result +++ b/cmd/explaintest/r/explain_easy.result @@ -456,10 +456,8 @@ TableReader_7 10000.00 root data:TableScan_6 └─TableScan_6 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo explain select distinct t1.a, t1.b from t1 left outer join t2 on t1.a = t2.a; id count task operator info -StreamAgg_18 8000.00 root group by:col_2, col_3, funcs:firstrow(col_0), firstrow(col_1) -└─IndexReader_19 8000.00 root index:StreamAgg_10 - └─StreamAgg_10 8000.00 cop group by:test.t1.a, test.t1.b, funcs:firstrow(test.t1.a), firstrow(test.t1.b) - └─IndexScan_17 10000.00 cop table:t1, index:a, b, range:[NULL,+inf], keep order:true, stats:pseudo +TableReader_9 10000.00 root data:TableScan_8 +└─TableScan_8 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo drop table if exists t; create table t(a int, nb int not null, nc int not null); explain select ifnull(a, 0) from t; diff --git a/executor/aggfuncs/aggfunc_test.go b/executor/aggfuncs/aggfunc_test.go index c61bd792eebf2..7f9d582d5bf81 100644 --- a/executor/aggfuncs/aggfunc_test.go +++ b/executor/aggfuncs/aggfunc_test.go @@ -82,7 +82,8 @@ func (s *testSuite) testMergePartialResult(c *C, p aggTest) { if p.funcName == ast.AggFuncGroupConcat { args = append(args, &expression.Constant{Value: types.NewStringDatum(" "), RetType: types.NewFieldType(mysql.TypeString)}) } - desc := aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, false) + desc, err := aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, false) + c.Assert(err, IsNil) partialDesc, finalDesc := desc.Split([]int{0, 1}) // build partial func for partial phase. @@ -183,7 +184,8 @@ func (s *testSuite) testAggFunc(c *C, p aggTest) { if p.funcName == ast.AggFuncGroupConcat { args = append(args, &expression.Constant{Value: types.NewStringDatum(" "), RetType: types.NewFieldType(mysql.TypeString)}) } - desc := aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, false) + desc, err := aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, false) + c.Assert(err, IsNil) finalFunc := aggfuncs.Build(s.ctx, desc, 0) finalPr := finalFunc.AllocPartialResult() resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{desc.RetTp}, 1) @@ -208,7 +210,8 @@ func (s *testSuite) testAggFunc(c *C, p aggTest) { c.Assert(result, Equals, 0) // test the agg func with distinct - desc = aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, true) + desc, err = aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, true) + c.Assert(err, IsNil) finalFunc = aggfuncs.Build(s.ctx, desc, 0) finalPr = finalFunc.AllocPartialResult() diff --git a/executor/aggfuncs/window_func_test.go b/executor/aggfuncs/window_func_test.go index d6c140d596d40..1108fdc92e019 100644 --- a/executor/aggfuncs/window_func_test.go +++ b/executor/aggfuncs/window_func_test.go @@ -44,7 +44,8 @@ func (s *testSuite) testWindowFunc(c *C, p windowTest) { srcChk.AppendDatum(0, &dt) } - desc := aggregation.NewAggFuncDesc(s.ctx, p.funcName, p.args, false) + desc, err := aggregation.NewAggFuncDesc(s.ctx, p.funcName, p.args, false) + c.Assert(err, IsNil) finalFunc := aggfuncs.BuildWindowFunctions(s.ctx, desc, 0, p.orderByCols) finalPr := finalFunc.AllocPartialResult() resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{desc.RetTp}, 1) diff --git a/executor/aggregate_test.go b/executor/aggregate_test.go index 98bc6d2e701d3..9332b22985aa8 100644 --- a/executor/aggregate_test.go +++ b/executor/aggregate_test.go @@ -15,6 +15,7 @@ package executor_test import ( . "github.com/pingcap/check" + "github.com/pingcap/errors" "github.com/pingcap/parser/terror" plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/util/testkit" @@ -334,6 +335,22 @@ func (s *testSuite1) TestAggregation(c *C) { tk.MustExec("insert into t value(0), (-0.9871), (-0.9871)") tk.MustQuery("select 10 from t group by a").Check(testkit.Rows("10", "10")) tk.MustQuery("select sum(a) from (select a from t union all select a from t) tmp").Check(testkit.Rows("-3.9484")) + _, err = tk.Exec("select std(a) from t") + c.Assert(errors.Cause(err).Error(), Equals, "unsupported agg function: std") + _, err = tk.Exec("select stddev(a) from t") + c.Assert(errors.Cause(err).Error(), Equals, "unsupported agg function: stddev") + _, err = tk.Exec("select stddev_pop(a) from t") + c.Assert(errors.Cause(err).Error(), Equals, "unsupported agg function: stddev_pop") + _, err = tk.Exec("select std_samp(a) from t") + // TODO: Fix this error message. + c.Assert(errors.Cause(err).Error(), Equals, "[expression:1305]FUNCTION std_samp does not exist") + _, err = tk.Exec("select variance(a) from t") + // TODO: Fix this error message. + c.Assert(errors.Cause(err).Error(), Equals, "unsupported agg function: var_pop") + _, err = tk.Exec("select var_pop(a) from t") + c.Assert(errors.Cause(err).Error(), Equals, "unsupported agg function: var_pop") + _, err = tk.Exec("select var_samp(a) from t") + c.Assert(errors.Cause(err).Error(), Equals, "unsupported agg function: var_samp") } func (s *testSuite1) TestStreamAggPushDown(c *C) { diff --git a/executor/benchmark_test.go b/executor/benchmark_test.go index 4d38af79e887f..d58357257882f 100644 --- a/executor/benchmark_test.go +++ b/executor/benchmark_test.go @@ -228,7 +228,10 @@ func buildAggExecutor(b *testing.B, testCase *aggTestCase, child Executor) Execu childCols := testCase.columns() schema := expression.NewSchema(childCols...) groupBy := []expression.Expression{childCols[1]} - aggFunc := aggregation.NewAggFuncDesc(testCase.ctx, testCase.aggFunc, []expression.Expression{childCols[0]}, testCase.hasDistinct) + aggFunc, err := aggregation.NewAggFuncDesc(testCase.ctx, testCase.aggFunc, []expression.Expression{childCols[0]}, testCase.hasDistinct) + if err != nil { + b.Fatal(err) + } aggFuncs := []*aggregation.AggFuncDesc{aggFunc} var aggExec Executor diff --git a/executor/builder.go b/executor/builder.go index 25e340ff361fe..b4e0630507685 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -2152,7 +2152,11 @@ func (b *executorBuilder) buildWindow(v *plannercore.PhysicalWindow) *WindowExec partialResults := make([]aggfuncs.PartialResult, 0, len(v.WindowFuncDescs)) resultColIdx := v.Schema().Len() - len(v.WindowFuncDescs) for _, desc := range v.WindowFuncDescs { - aggDesc := aggregation.NewAggFuncDesc(b.ctx, desc.Name, desc.Args, false) + aggDesc, err := aggregation.NewAggFuncDesc(b.ctx, desc.Name, desc.Args, false) + if err != nil { + b.err = err + return nil + } agg := aggfuncs.BuildWindowFunctions(b.ctx, aggDesc, resultColIdx, orderByCols) windowFuncs = append(windowFuncs, agg) partialResults = append(partialResults, agg.AllocPartialResult()) diff --git a/executor/executor_required_rows_test.go b/executor/executor_required_rows_test.go index 70cf56031e79e..645c346f7f1dc 100644 --- a/executor/executor_required_rows_test.go +++ b/executor/executor_required_rows_test.go @@ -659,7 +659,8 @@ func (s *testExecSuite) TestStreamAggRequiredRows(c *C) { childCols := ds.Schema().Columns schema := expression.NewSchema(childCols...) groupBy := []expression.Expression{childCols[1]} - aggFunc := aggregation.NewAggFuncDesc(sctx, testCase.aggFunc, []expression.Expression{childCols[0]}, true) + aggFunc, err := aggregation.NewAggFuncDesc(sctx, testCase.aggFunc, []expression.Expression{childCols[0]}, true) + c.Assert(err, IsNil) aggFuncs := []*aggregation.AggFuncDesc{aggFunc} exec := buildStreamAggExecutor(sctx, ds, schema, aggFuncs, groupBy) c.Assert(exec.Open(ctx), IsNil) @@ -718,7 +719,8 @@ func (s *testExecSuite) TestHashAggParallelRequiredRows(c *C) { childCols := ds.Schema().Columns schema := expression.NewSchema(childCols...) groupBy := []expression.Expression{childCols[1]} - aggFunc := aggregation.NewAggFuncDesc(sctx, testCase.aggFunc, []expression.Expression{childCols[0]}, hasDistinct) + aggFunc, err := aggregation.NewAggFuncDesc(sctx, testCase.aggFunc, []expression.Expression{childCols[0]}, hasDistinct) + c.Assert(err, IsNil) aggFuncs := []*aggregation.AggFuncDesc{aggFunc} exec := buildHashAggExecutor(sctx, ds, schema, aggFuncs, groupBy) c.Assert(exec.Open(ctx), IsNil) diff --git a/expression/aggregation/agg_to_pb_test.go b/expression/aggregation/agg_to_pb_test.go index 7e80f3d2c7a90..115bb6bb81ba3 100644 --- a/expression/aggregation/agg_to_pb_test.go +++ b/expression/aggregation/agg_to_pb_test.go @@ -76,7 +76,8 @@ func (s *testEvaluatorSuite) TestAggFunc2Pb(c *C) { } for _, funcName := range funcNames { args := []expression.Expression{dg.genColumn(mysql.TypeDouble, 1)} - aggFunc := NewAggFuncDesc(s.ctx, funcName, args, true) + aggFunc, err := NewAggFuncDesc(s.ctx, funcName, args, true) + c.Assert(err, IsNil) pbExpr := AggFuncToPBExpr(sc, client, aggFunc) js, err := json.Marshal(pbExpr) c.Assert(err, IsNil) @@ -94,7 +95,8 @@ func (s *testEvaluatorSuite) TestAggFunc2Pb(c *C) { } for i, funcName := range funcNames { args := []expression.Expression{dg.genColumn(mysql.TypeDouble, 1)} - aggFunc := NewAggFuncDesc(s.ctx, funcName, args, false) + aggFunc, err := NewAggFuncDesc(s.ctx, funcName, args, false) + c.Assert(err, IsNil) aggFunc.RetTp = funcTypes[i] pbExpr := AggFuncToPBExpr(sc, client, aggFunc) js, err := json.Marshal(pbExpr) diff --git a/expression/aggregation/aggregation_test.go b/expression/aggregation/aggregation_test.go index 0ebbe8a330eef..307382e08f46c 100644 --- a/expression/aggregation/aggregation_test.go +++ b/expression/aggregation/aggregation_test.go @@ -58,7 +58,9 @@ func (s *testAggFuncSuit) TestAvg(c *C) { RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() - avgFunc := NewAggFuncDesc(s.ctx, ast.AggFuncAvg, []expression.Expression{col}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncAvg, []expression.Expression{col}, false) + c.Assert(err, IsNil) + avgFunc := desc.GetAggFunc(ctx) evalCtx := avgFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) result := avgFunc.GetResult(evalCtx) @@ -71,12 +73,14 @@ func (s *testAggFuncSuit) TestAvg(c *C) { result = avgFunc.GetResult(evalCtx) needed := types.NewDecFromStringForTest("67.000000000000000000000000000000") c.Assert(result.GetMysqlDecimal().Compare(needed) == 0, IsTrue) - err := avgFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, s.nullRow) + err = avgFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, s.nullRow) c.Assert(err, IsNil) result = avgFunc.GetResult(evalCtx) c.Assert(result.GetMysqlDecimal().Compare(needed) == 0, IsTrue) - distinctAvgFunc := NewAggFuncDesc(s.ctx, ast.AggFuncAvg, []expression.Expression{col}, true).GetAggFunc(ctx) + desc, err = NewAggFuncDesc(s.ctx, ast.AggFuncAvg, []expression.Expression{col}, true) + c.Assert(err, IsNil) + distinctAvgFunc := desc.GetAggFunc(ctx) evalCtx = distinctAvgFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) for _, row := range s.rows { err := distinctAvgFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) @@ -105,7 +109,8 @@ func (s *testAggFuncSuit) TestAvgFinalMode(c *C) { Index: 1, RetType: types.NewFieldType(mysql.TypeNewDecimal), } - aggFunc := NewAggFuncDesc(s.ctx, ast.AggFuncAvg, []expression.Expression{cntCol, sumCol}, false) + aggFunc, err := NewAggFuncDesc(s.ctx, ast.AggFuncAvg, []expression.Expression{cntCol, sumCol}, false) + c.Assert(err, IsNil) aggFunc.Mode = FinalMode avgFunc := aggFunc.GetAggFunc(ctx) evalCtx := avgFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) @@ -125,7 +130,9 @@ func (s *testAggFuncSuit) TestSum(c *C) { RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() - sumFunc := NewAggFuncDesc(s.ctx, ast.AggFuncSum, []expression.Expression{col}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncSum, []expression.Expression{col}, false) + c.Assert(err, IsNil) + sumFunc := desc.GetAggFunc(ctx) evalCtx := sumFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) result := sumFunc.GetResult(evalCtx) @@ -138,14 +145,16 @@ func (s *testAggFuncSuit) TestSum(c *C) { result = sumFunc.GetResult(evalCtx) needed := types.NewDecFromStringForTest("338350") c.Assert(result.GetMysqlDecimal().Compare(needed) == 0, IsTrue) - err := sumFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, s.nullRow) + err = sumFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, s.nullRow) c.Assert(err, IsNil) result = sumFunc.GetResult(evalCtx) c.Assert(result.GetMysqlDecimal().Compare(needed) == 0, IsTrue) partialResult := sumFunc.GetPartialResult(evalCtx) c.Assert(partialResult[0].GetMysqlDecimal().Compare(needed) == 0, IsTrue) - distinctSumFunc := NewAggFuncDesc(s.ctx, ast.AggFuncSum, []expression.Expression{col}, true).GetAggFunc(ctx) + desc, err = NewAggFuncDesc(s.ctx, ast.AggFuncSum, []expression.Expression{col}, true) + c.Assert(err, IsNil) + distinctSumFunc := desc.GetAggFunc(ctx) evalCtx = distinctSumFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) for _, row := range s.rows { err := distinctSumFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) @@ -162,14 +171,16 @@ func (s *testAggFuncSuit) TestBitAnd(c *C) { RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() - bitAndFunc := NewAggFuncDesc(s.ctx, ast.AggFuncBitAnd, []expression.Expression{col}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncBitAnd, []expression.Expression{col}, false) + c.Assert(err, IsNil) + bitAndFunc := desc.GetAggFunc(ctx) evalCtx := bitAndFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) result := bitAndFunc.GetResult(evalCtx) c.Assert(result.GetUint64(), Equals, uint64(math.MaxUint64)) row := chunk.MutRowFromDatums(types.MakeDatums(1)).ToRow() - err := bitAndFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) + err = bitAndFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) c.Assert(err, IsNil) result = bitAndFunc.GetResult(evalCtx) c.Assert(result.GetUint64(), Equals, uint64(1)) @@ -238,14 +249,16 @@ func (s *testAggFuncSuit) TestBitOr(c *C) { RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() - bitOrFunc := NewAggFuncDesc(s.ctx, ast.AggFuncBitOr, []expression.Expression{col}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncBitOr, []expression.Expression{col}, false) + c.Assert(err, IsNil) + bitOrFunc := desc.GetAggFunc(ctx) evalCtx := bitOrFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) result := bitOrFunc.GetResult(evalCtx) c.Assert(result.GetUint64(), Equals, uint64(0)) row := chunk.MutRowFromDatums(types.MakeDatums(1)).ToRow() - err := bitOrFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) + err = bitOrFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) c.Assert(err, IsNil) result = bitOrFunc.GetResult(evalCtx) c.Assert(result.GetUint64(), Equals, uint64(1)) @@ -322,14 +335,16 @@ func (s *testAggFuncSuit) TestBitXor(c *C) { RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() - bitXorFunc := NewAggFuncDesc(s.ctx, ast.AggFuncBitXor, []expression.Expression{col}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncBitXor, []expression.Expression{col}, false) + c.Assert(err, IsNil) + bitXorFunc := desc.GetAggFunc(ctx) evalCtx := bitXorFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) result := bitXorFunc.GetResult(evalCtx) c.Assert(result.GetUint64(), Equals, uint64(0)) row := chunk.MutRowFromDatums(types.MakeDatums(1)).ToRow() - err := bitXorFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) + err = bitXorFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) c.Assert(err, IsNil) result = bitXorFunc.GetResult(evalCtx) c.Assert(result.GetUint64(), Equals, uint64(1)) @@ -398,7 +413,9 @@ func (s *testAggFuncSuit) TestCount(c *C) { RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() - countFunc := NewAggFuncDesc(s.ctx, ast.AggFuncCount, []expression.Expression{col}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncCount, []expression.Expression{col}, false) + c.Assert(err, IsNil) + countFunc := desc.GetAggFunc(ctx) evalCtx := countFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) result := countFunc.GetResult(evalCtx) @@ -410,14 +427,16 @@ func (s *testAggFuncSuit) TestCount(c *C) { } result = countFunc.GetResult(evalCtx) c.Assert(result.GetInt64(), Equals, int64(5050)) - err := countFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, s.nullRow) + err = countFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, s.nullRow) c.Assert(err, IsNil) result = countFunc.GetResult(evalCtx) c.Assert(result.GetInt64(), Equals, int64(5050)) partialResult := countFunc.GetPartialResult(evalCtx) c.Assert(partialResult[0].GetInt64(), Equals, int64(5050)) - distinctCountFunc := NewAggFuncDesc(s.ctx, ast.AggFuncCount, []expression.Expression{col}, true).GetAggFunc(ctx) + desc, err = NewAggFuncDesc(s.ctx, ast.AggFuncCount, []expression.Expression{col}, true) + c.Assert(err, IsNil) + distinctCountFunc := desc.GetAggFunc(ctx) evalCtx = distinctCountFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) for _, row := range s.rows { @@ -438,14 +457,16 @@ func (s *testAggFuncSuit) TestConcat(c *C) { RetType: types.NewFieldType(mysql.TypeVarchar), } ctx := mock.NewContext() - concatFunc := NewAggFuncDesc(s.ctx, ast.AggFuncGroupConcat, []expression.Expression{col, sep}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncGroupConcat, []expression.Expression{col, sep}, false) + c.Assert(err, IsNil) + concatFunc := desc.GetAggFunc(ctx) evalCtx := concatFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) result := concatFunc.GetResult(evalCtx) c.Assert(result.IsNull(), IsTrue) row := chunk.MutRowFromDatums(types.MakeDatums(1, "x")) - err := concatFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row.ToRow()) + err = concatFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row.ToRow()) c.Assert(err, IsNil) result = concatFunc.GetResult(evalCtx) c.Assert(result.GetString(), Equals, "1") @@ -464,7 +485,9 @@ func (s *testAggFuncSuit) TestConcat(c *C) { partialResult := concatFunc.GetPartialResult(evalCtx) c.Assert(partialResult[0].GetString(), Equals, "1x2") - distinctConcatFunc := NewAggFuncDesc(s.ctx, ast.AggFuncGroupConcat, []expression.Expression{col, sep}, true).GetAggFunc(ctx) + desc, err = NewAggFuncDesc(s.ctx, ast.AggFuncGroupConcat, []expression.Expression{col, sep}, true) + c.Assert(err, IsNil) + distinctConcatFunc := desc.GetAggFunc(ctx) evalCtx = distinctConcatFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) row.SetDatum(0, types.NewIntDatum(1)) @@ -487,11 +510,13 @@ func (s *testAggFuncSuit) TestFirstRow(c *C) { } ctx := mock.NewContext() - firstRowFunc := NewAggFuncDesc(s.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false) + c.Assert(err, IsNil) + firstRowFunc := desc.GetAggFunc(ctx) evalCtx := firstRowFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) row := chunk.MutRowFromDatums(types.MakeDatums(1)).ToRow() - err := firstRowFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) + err = firstRowFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) c.Assert(err, IsNil) result := firstRowFunc.GetResult(evalCtx) c.Assert(result.GetUint64(), Equals, uint64(1)) @@ -512,8 +537,12 @@ func (s *testAggFuncSuit) TestMaxMin(c *C) { } ctx := mock.NewContext() - maxFunc := NewAggFuncDesc(s.ctx, ast.AggFuncMax, []expression.Expression{col}, false).GetAggFunc(ctx) - minFunc := NewAggFuncDesc(s.ctx, ast.AggFuncMin, []expression.Expression{col}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncMax, []expression.Expression{col}, false) + c.Assert(err, IsNil) + maxFunc := desc.GetAggFunc(ctx) + desc, err = NewAggFuncDesc(s.ctx, ast.AggFuncMin, []expression.Expression{col}, false) + c.Assert(err, IsNil) + minFunc := desc.GetAggFunc(ctx) maxEvalCtx := maxFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) minEvalCtx := minFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) @@ -523,7 +552,7 @@ func (s *testAggFuncSuit) TestMaxMin(c *C) { c.Assert(result.IsNull(), IsTrue) row := chunk.MutRowFromDatums(types.MakeDatums(2)) - err := maxFunc.Update(maxEvalCtx, s.ctx.GetSessionVars().StmtCtx, row.ToRow()) + err = maxFunc.Update(maxEvalCtx, s.ctx.GetSessionVars().StmtCtx, row.ToRow()) c.Assert(err, IsNil) result = maxFunc.GetResult(maxEvalCtx) c.Assert(result.GetInt64(), Equals, int64(2)) diff --git a/expression/aggregation/base_func.go b/expression/aggregation/base_func.go index ba0e716853476..1a4a971c02294 100644 --- a/expression/aggregation/base_func.go +++ b/expression/aggregation/base_func.go @@ -19,6 +19,7 @@ import ( "strings" "github.com/cznic/mathutil" + "github.com/pingcap/errors" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/charset" "github.com/pingcap/parser/mysql" @@ -37,10 +38,10 @@ type baseFuncDesc struct { RetTp *types.FieldType } -func newBaseFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression) baseFuncDesc { +func newBaseFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression) (baseFuncDesc, error) { b := baseFuncDesc{Name: strings.ToLower(name), Args: args} - b.typeInfer(ctx) - return b + err := b.typeInfer(ctx) + return b, err } func (a *baseFuncDesc) equal(ctx sessionctx.Context, other *baseFuncDesc) bool { @@ -81,7 +82,7 @@ func (a *baseFuncDesc) String() string { } // typeInfer infers the arguments and return types of an function. -func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) { +func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) error { switch a.Name { case ast.AggFuncCount: a.typeInfer4Count(ctx) @@ -107,8 +108,9 @@ func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) { case ast.WindowFuncLead, ast.WindowFuncLag: a.typeInfer4LeadLag(ctx) default: - panic("unsupported agg function: " + a.Name) + return errors.Errorf("unsupported agg function: %s", a.Name) } + return nil } func (a *baseFuncDesc) typeInfer4Count(ctx sessionctx.Context) { diff --git a/expression/aggregation/base_func_test.go b/expression/aggregation/base_func_test.go index 3002400c6ce85..ba7fd757fdbaa 100644 --- a/expression/aggregation/base_func_test.go +++ b/expression/aggregation/base_func_test.go @@ -25,7 +25,8 @@ func (s *testBaseFuncSuite) TestClone(c *check.C) { UniqueID: 0, RetType: types.NewFieldType(mysql.TypeLonglong), } - desc := newBaseFuncDesc(s.ctx, ast.AggFuncFirstRow, []expression.Expression{col}) + desc, err := newBaseFuncDesc(s.ctx, ast.AggFuncFirstRow, []expression.Expression{col}) + c.Assert(err, check.IsNil) cloned := desc.clone() c.Assert(desc.equal(s.ctx, cloned), check.IsTrue) diff --git a/expression/aggregation/bench_test.go b/expression/aggregation/bench_test.go index e49deebe00da3..c3f72695709cd 100644 --- a/expression/aggregation/bench_test.go +++ b/expression/aggregation/bench_test.go @@ -29,7 +29,11 @@ func BenchmarkCreateContext(b *testing.B) { RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() - fun := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, false) + if err != nil { + b.Fatal(err) + } + fun := desc.GetAggFunc(ctx) b.StartTimer() for i := 0; i < b.N; i++ { fun.CreateContext(ctx.GetSessionVars().StmtCtx) @@ -43,7 +47,11 @@ func BenchmarkResetContext(b *testing.B) { RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() - fun := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, false) + if err != nil { + b.Fatal(err) + } + fun := desc.GetAggFunc(ctx) evalCtx := fun.CreateContext(ctx.GetSessionVars().StmtCtx) b.StartTimer() for i := 0; i < b.N; i++ { @@ -58,7 +66,11 @@ func BenchmarkCreateDistinctContext(b *testing.B) { RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() - fun := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, true).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, true) + if err != nil { + b.Fatal(err) + } + fun := desc.GetAggFunc(ctx) b.StartTimer() for i := 0; i < b.N; i++ { fun.CreateContext(ctx.GetSessionVars().StmtCtx) @@ -72,7 +84,11 @@ func BenchmarkResetDistinctContext(b *testing.B) { RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() - fun := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, true).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, true) + if err != nil { + b.Fatal(err) + } + fun := desc.GetAggFunc(ctx) evalCtx := fun.CreateContext(ctx.GetSessionVars().StmtCtx) b.StartTimer() for i := 0; i < b.N; i++ { diff --git a/expression/aggregation/descriptor.go b/expression/aggregation/descriptor.go index d8e544171d8cd..66f5f3346c805 100644 --- a/expression/aggregation/descriptor.go +++ b/expression/aggregation/descriptor.go @@ -37,9 +37,12 @@ type AggFuncDesc struct { } // NewAggFuncDesc creates an aggregation function signature descriptor. -func NewAggFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression, hasDistinct bool) *AggFuncDesc { - b := newBaseFuncDesc(ctx, name, args) - return &AggFuncDesc{baseFuncDesc: b, HasDistinct: hasDistinct} +func NewAggFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression, hasDistinct bool) (*AggFuncDesc, error) { + b, err := newBaseFuncDesc(ctx, name, args) + if err != nil { + return nil, err + } + return &AggFuncDesc{baseFuncDesc: b, HasDistinct: hasDistinct}, nil } // Equal checks whether two aggregation function signatures are equal. diff --git a/expression/aggregation/window_func.go b/expression/aggregation/window_func.go index 28ccfed44e98d..8f963480dde16 100644 --- a/expression/aggregation/window_func.go +++ b/expression/aggregation/window_func.go @@ -27,19 +27,19 @@ type WindowFuncDesc struct { } // NewWindowFuncDesc creates a window function signature descriptor. -func NewWindowFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression) *WindowFuncDesc { +func NewWindowFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression) (*WindowFuncDesc, error) { switch strings.ToLower(name) { case ast.WindowFuncNthValue: val, isNull, ok := expression.GetUint64FromConstant(args[1]) // nth_value does not allow `0`, but allows `null`. if !ok || (val == 0 && !isNull) { - return nil + return nil, nil } case ast.WindowFuncNtile: val, isNull, ok := expression.GetUint64FromConstant(args[0]) // ntile does not allow `0`, but allows `null`. if !ok || (val == 0 && !isNull) { - return nil + return nil, nil } case ast.WindowFuncLead, ast.WindowFuncLag: if len(args) < 2 { @@ -47,10 +47,14 @@ func NewWindowFuncDesc(ctx sessionctx.Context, name string, args []expression.Ex } _, isNull, ok := expression.GetUint64FromConstant(args[1]) if !ok || isNull { - return nil + return nil, nil } } - return &WindowFuncDesc{newBaseFuncDesc(ctx, name, args)} + base, err := newBaseFuncDesc(ctx, name, args) + if err != nil { + return nil, err + } + return &WindowFuncDesc{base}, nil } // noFrameWindowFuncs is the functions that operate on the entire partition, diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 91ea45fca7a26..efff3275d9f75 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -445,7 +445,11 @@ func (er *expressionRewriter) handleOtherComparableSubq(lexpr, rexpr expression. if useMin { funcName = ast.AggFuncMin } - funcMaxOrMin := aggregation.NewAggFuncDesc(er.ctx, funcName, []expression.Expression{rexpr}, false) + funcMaxOrMin, err := aggregation.NewAggFuncDesc(er.ctx, funcName, []expression.Expression{rexpr}, false) + if err != nil { + er.err = err + return + } // Create a column and append it to the schema of that aggregation. colMaxOrMin := &expression.Column{ @@ -467,7 +471,11 @@ func (er *expressionRewriter) buildQuantifierPlan(plan4Agg *LogicalAggregation, innerIsNull := expression.NewFunctionInternal(er.ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), rexpr) outerIsNull := expression.NewFunctionInternal(er.ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), lexpr) - funcSum := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncSum, []expression.Expression{innerIsNull}, false) + funcSum, err := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncSum, []expression.Expression{innerIsNull}, false) + if err != nil { + er.err = err + return + } colSum := &expression.Column{ ColName: model.NewCIStr("agg_col_sum"), UniqueID: er.ctx.GetSessionVars().AllocPlanColumnID(), @@ -478,7 +486,11 @@ func (er *expressionRewriter) buildQuantifierPlan(plan4Agg *LogicalAggregation, innerHasNull := expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), colSum, expression.Zero) // Build `count(1)` aggregation to check if subquery is empty. - funcCount := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncCount, []expression.Expression{expression.One}, false) + funcCount, err := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncCount, []expression.Expression{expression.One}, false) + if err != nil { + er.err = err + return + } colCount := &expression.Column{ ColName: model.NewCIStr("agg_col_cnt"), UniqueID: er.ctx.GetSessionVars().AllocPlanColumnID(), @@ -539,8 +551,16 @@ func (er *expressionRewriter) buildQuantifierPlan(plan4Agg *LogicalAggregation, // t.id != s.id or count(distinct s.id) > 1 or [any checker]. If there are two different values in s.id , // there must exist a s.id that doesn't equal to t.id. func (er *expressionRewriter) handleNEAny(lexpr, rexpr expression.Expression, np LogicalPlan) { - firstRowFunc := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncFirstRow, []expression.Expression{rexpr}, false) - countFunc := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncCount, []expression.Expression{rexpr}, true) + firstRowFunc, err := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncFirstRow, []expression.Expression{rexpr}, false) + if err != nil { + er.err = err + return + } + countFunc, err := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncCount, []expression.Expression{rexpr}, true) + if err != nil { + er.err = err + return + } plan4Agg := LogicalAggregation{ AggFuncs: []*aggregation.AggFuncDesc{firstRowFunc, countFunc}, }.Init(er.ctx) @@ -565,8 +585,16 @@ func (er *expressionRewriter) handleNEAny(lexpr, rexpr expression.Expression, np // handleEQAll handles the case of = all. For example, if the query is t.id = all (select s.id from s), it will be rewrote to // t.id = (select s.id from s having count(distinct s.id) <= 1 and [all checker]). func (er *expressionRewriter) handleEQAll(lexpr, rexpr expression.Expression, np LogicalPlan) { - firstRowFunc := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncFirstRow, []expression.Expression{rexpr}, false) - countFunc := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncCount, []expression.Expression{rexpr}, true) + firstRowFunc, err := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncFirstRow, []expression.Expression{rexpr}, false) + if err != nil { + er.err = err + return + } + countFunc, err := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncCount, []expression.Expression{rexpr}, true) + if err != nil { + er.err = err + return + } plan4Agg := LogicalAggregation{ AggFuncs: []*aggregation.AggFuncDesc{firstRowFunc, countFunc}, }.Init(er.ctx) @@ -712,7 +740,11 @@ func (er *expressionRewriter) handleInSubquery(v *ast.PatternInExpr) (ast.Node, er.b.optFlag |= flagEliminateProjection er.b.optFlag |= flagJoinReOrder // Build distinct for the inner query. - agg := er.b.buildDistinct(np, np.Schema().Len()) + agg, err := er.b.buildDistinct(np, np.Schema().Len()) + if err != nil { + er.err = err + return v, true + } for _, col := range agg.schema.Columns { col.IsReferenced = true } diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 66dff13966e35..2c7bf7c5572e1 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -96,7 +96,10 @@ func (b *PlanBuilder) buildAggregation(p LogicalPlan, aggFuncList []*ast.Aggrega p = np newArgList = append(newArgList, newArg) } - newFunc := aggregation.NewAggFuncDesc(b.ctx, aggFunc.F, newArgList, aggFunc.Distinct) + newFunc, err := aggregation.NewAggFuncDesc(b.ctx, aggFunc.F, newArgList, aggFunc.Distinct) + if err != nil { + return nil, nil, err + } combined := false for j, oldFunc := range plan4Agg.AggFuncs { if oldFunc.Equal(b.ctx, newFunc) { @@ -118,7 +121,10 @@ func (b *PlanBuilder) buildAggregation(p LogicalPlan, aggFuncList []*ast.Aggrega } } for _, col := range p.Schema().Columns { - newFunc := aggregation.NewAggFuncDesc(b.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false) + newFunc, err := aggregation.NewAggFuncDesc(b.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false) + if err != nil { + return nil, nil, err + } plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, newFunc) newCol, _ := col.Clone().(*expression.Column) newCol.RetType = newFunc.RetTp @@ -753,7 +759,7 @@ func (b *PlanBuilder) buildProjection(p LogicalPlan, fields []*ast.SelectField, return proj, oldLen, nil } -func (b *PlanBuilder) buildDistinct(child LogicalPlan, length int) *LogicalAggregation { +func (b *PlanBuilder) buildDistinct(child LogicalPlan, length int) (*LogicalAggregation, error) { b.optFlag = b.optFlag | flagBuildKeyInfo b.optFlag = b.optFlag | flagPushDownAgg plan4Agg := LogicalAggregation{ @@ -762,7 +768,10 @@ func (b *PlanBuilder) buildDistinct(child LogicalPlan, length int) *LogicalAggre }.Init(b.ctx) plan4Agg.collectGroupByColumns() for _, col := range child.Schema().Columns { - aggDesc := aggregation.NewAggFuncDesc(b.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false) + aggDesc, err := aggregation.NewAggFuncDesc(b.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false) + if err != nil { + return nil, err + } plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, aggDesc) } plan4Agg.SetChildren(child) @@ -772,7 +781,7 @@ func (b *PlanBuilder) buildDistinct(child LogicalPlan, length int) *LogicalAggre for i, col := range plan4Agg.schema.Columns { col.RetType = plan4Agg.AggFuncs[i].RetTp } - return plan4Agg + return plan4Agg, nil } // unionJoinFieldType finds the type which can carry the given types in Union. @@ -844,7 +853,10 @@ func (b *PlanBuilder) buildUnion(union *ast.UnionStmt) (LogicalPlan, error) { unionDistinctPlan := b.buildUnionAll(distinctSelectPlans) if unionDistinctPlan != nil { - unionDistinctPlan = b.buildDistinct(unionDistinctPlan, unionDistinctPlan.Schema().Len()) + unionDistinctPlan, err = b.buildDistinct(unionDistinctPlan, unionDistinctPlan.Schema().Len()) + if err != nil { + return nil, err + } if len(allSelectPlans) > 0 { // Can't change the statements order in order to get the correct column info. allSelectPlans = append([]LogicalPlan{unionDistinctPlan}, allSelectPlans...) @@ -2082,7 +2094,10 @@ func (b *PlanBuilder) buildSelect(sel *ast.SelectStmt) (p LogicalPlan, err error } if sel.Distinct { - p = b.buildDistinct(p, oldLen) + p, err = b.buildDistinct(p, oldLen) + if err != nil { + return nil, err + } } if sel.OrderBy != nil { @@ -3128,7 +3143,10 @@ func (b *PlanBuilder) buildWindowFunctions(p LogicalPlan, groupedFuncs map[*ast. descs := make([]*aggregation.WindowFuncDesc, 0, len(funcs)) preArgs := 0 for _, windowFunc := range funcs { - desc := aggregation.NewWindowFuncDesc(b.ctx, windowFunc.F, args[preArgs:preArgs+len(windowFunc.Args)]) + desc, err := aggregation.NewWindowFuncDesc(b.ctx, windowFunc.F, args[preArgs:preArgs+len(windowFunc.Args)]) + if err != nil { + return nil, nil, err + } if desc == nil { return nil, nil, ErrWrongArguments.GenWithStackByArgs(windowFunc.F) } diff --git a/planner/core/rule_aggregation_push_down.go b/planner/core/rule_aggregation_push_down.go index bca18b6305b1e..7164ad3b0c189 100644 --- a/planner/core/rule_aggregation_push_down.go +++ b/planner/core/rule_aggregation_push_down.go @@ -188,22 +188,25 @@ func (a *aggregationPushDownSolver) decompose(ctx sessionctx.Context, aggFunc *a // tryToPushDownAgg tries to push down an aggregate function into a join path. If all aggFuncs are first row, we won't // process it temporarily. If not, We will add additional group by columns and first row functions. We make a new aggregation operator. // If the pushed aggregation is grouped by unique key, it's no need to push it down. -func (a *aggregationPushDownSolver) tryToPushDownAgg(aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column, join *LogicalJoin, childIdx int) LogicalPlan { +func (a *aggregationPushDownSolver) tryToPushDownAgg(aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column, join *LogicalJoin, childIdx int) (_ LogicalPlan, err error) { child := join.children[childIdx] if aggregation.IsAllFirstRow(aggFuncs) { - return child + return child, nil } // If the join is multiway-join, we forbid pushing down. if _, ok := join.children[childIdx].(*LogicalJoin); ok { - return child + return child, nil } tmpSchema := expression.NewSchema(gbyCols...) for _, key := range child.Schema().Keys { if tmpSchema.ColumnsIndices(key) != nil { - return child + return child, nil } } - agg := a.makeNewAgg(join.ctx, aggFuncs, gbyCols) + agg, err := a.makeNewAgg(join.ctx, aggFuncs, gbyCols) + if err != nil { + return nil, err + } agg.SetChildren(child) // If agg has no group-by item, it will return a default value, which may cause some bugs. // So here we add a group-by item forcely. @@ -216,10 +219,10 @@ func (a *aggregationPushDownSolver) tryToPushDownAgg(aggFuncs []*aggregation.Agg var existsDefaultValues bool join.DefaultValues, existsDefaultValues = a.getDefaultValues(agg) if !existsDefaultValues { - return child + return child, nil } } - return agg + return agg, nil } func (a *aggregationPushDownSolver) getDefaultValues(agg *LogicalAggregation) ([]types.Datum, bool) { @@ -243,7 +246,7 @@ func (a *aggregationPushDownSolver) checkAnyCountAndSum(aggFuncs []*aggregation. return false } -func (a *aggregationPushDownSolver) makeNewAgg(ctx sessionctx.Context, aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column) *LogicalAggregation { +func (a *aggregationPushDownSolver) makeNewAgg(ctx sessionctx.Context, aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column) (*LogicalAggregation, error) { agg := LogicalAggregation{ GroupByItems: expression.Column2Exprs(gbyCols), groupByCols: gbyCols, @@ -257,7 +260,10 @@ func (a *aggregationPushDownSolver) makeNewAgg(ctx sessionctx.Context, aggFuncs newAggFuncDescs = append(newAggFuncDescs, newFuncs...) } for _, gbyCol := range gbyCols { - firstRow := aggregation.NewAggFuncDesc(agg.ctx, ast.AggFuncFirstRow, []expression.Expression{gbyCol}, false) + firstRow, err := aggregation.NewAggFuncDesc(agg.ctx, ast.AggFuncFirstRow, []expression.Expression{gbyCol}, false) + if err != nil { + return nil, err + } newCol, _ := gbyCol.Clone().(*expression.Column) newCol.RetType = firstRow.RetTp newAggFuncDescs = append(newAggFuncDescs, firstRow) @@ -267,7 +273,7 @@ func (a *aggregationPushDownSolver) makeNewAgg(ctx sessionctx.Context, aggFuncs agg.SetSchema(schema) // TODO: Add a Projection if any argument of aggregate funcs or group by items are scalar functions. // agg.buildProjectionIfNecessary() - return agg + return agg, nil } // pushAggCrossUnion will try to push the agg down to the union. If the new aggregation's group-by columns doesn't contain unique key. @@ -312,12 +318,11 @@ func (a *aggregationPushDownSolver) optimize(p LogicalPlan) (LogicalPlan, error) if !p.context().GetSessionVars().AllowAggPushDown { return p, nil } - a.aggPushDown(p) - return p, nil + return a.aggPushDown(p) } // aggPushDown tries to push down aggregate functions to join paths. -func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan) LogicalPlan { +func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan) (_ LogicalPlan, err error) { if agg, ok := p.(*LogicalAggregation); ok { proj := a.tryToEliminateAggregation(agg) if proj != nil { @@ -334,12 +339,18 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan) LogicalPlan { if rightInvalid { rChild = join.children[1] } else { - rChild = a.tryToPushDownAgg(rightAggFuncs, rightGbyCols, join, 1) + rChild, err = a.tryToPushDownAgg(rightAggFuncs, rightGbyCols, join, 1) + if err != nil { + return nil, err + } } if leftInvalid { lChild = join.children[0] } else { - lChild = a.tryToPushDownAgg(leftAggFuncs, leftGbyCols, join, 0) + lChild, err = a.tryToPushDownAgg(leftAggFuncs, leftGbyCols, join, 0) + if err != nil { + return nil, err + } } join.SetChildren(lChild, rChild) join.SetSchema(expression.MergeSchema(lChild.Schema(), rChild.Schema())) @@ -368,7 +379,10 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan) LogicalPlan { } else if union, ok1 := child.(*LogicalUnionAll); ok1 { var gbyCols []*expression.Column gbyCols = expression.ExtractColumnsFromExpressions(gbyCols, agg.GroupByItems, nil) - pushedAgg := a.makeNewAgg(agg.ctx, agg.AggFuncs, gbyCols) + pushedAgg, err := a.makeNewAgg(agg.ctx, agg.AggFuncs, gbyCols) + if err != nil { + return nil, err + } newChildren := make([]LogicalPlan, 0, len(union.children)) for _, child := range union.children { newChild := a.pushAggCrossUnion(pushedAgg, union.Schema(), child) @@ -381,9 +395,12 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan) LogicalPlan { } newChildren := make([]LogicalPlan, 0, len(p.Children())) for _, child := range p.Children() { - newChild := a.aggPushDown(child) + newChild, err := a.aggPushDown(child) + if err != nil { + return nil, err + } newChildren = append(newChildren, newChild) } p.SetChildren(newChildren...) - return p + return p, nil } diff --git a/planner/core/rule_decorrelate.go b/planner/core/rule_decorrelate.go index fb9e1c54fe43b..4b127755582ce 100644 --- a/planner/core/rule_decorrelate.go +++ b/planner/core/rule_decorrelate.go @@ -154,7 +154,10 @@ func (s *decorrelateSolver) optimize(p LogicalPlan) (LogicalPlan, error) { outerColsInSchema := make([]*expression.Column, 0, outerPlan.Schema().Len()) for i, col := range outerPlan.Schema().Columns { - first := aggregation.NewAggFuncDesc(agg.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false) + first, err := aggregation.NewAggFuncDesc(agg.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false) + if err != nil { + return nil, err + } newAggFuncs = append(newAggFuncs, first) outerCol, _ := outerPlan.Schema().Columns[i].Clone().(*expression.Column) @@ -201,7 +204,10 @@ func (s *decorrelateSolver) optimize(p LogicalPlan) (LogicalPlan, error) { clonedCol := eqCond.GetArgs()[1] // If the join key is not in the aggregation's schema, add first row function. if agg.schema.ColumnIndex(eqCond.GetArgs()[1].(*expression.Column)) == -1 { - newFunc := aggregation.NewAggFuncDesc(apply.ctx, ast.AggFuncFirstRow, []expression.Expression{clonedCol}, false) + newFunc, err := aggregation.NewAggFuncDesc(apply.ctx, ast.AggFuncFirstRow, []expression.Expression{clonedCol}, false) + if err != nil { + return nil, err + } agg.AggFuncs = append(agg.AggFuncs, newFunc) agg.schema.Append(clonedCol.(*expression.Column)) agg.schema.Columns[agg.schema.Len()-1].RetType = newFunc.RetTp diff --git a/planner/core/rule_inject_extra_projection_test.go b/planner/core/rule_inject_extra_projection_test.go index 66e842837f86d..6b1f44e0e53e6 100644 --- a/planner/core/rule_inject_extra_projection_test.go +++ b/planner/core/rule_inject_extra_projection_test.go @@ -41,9 +41,10 @@ func (s *testInjectProjSuite) TestWrapCastForAggFuncs(c *C) { for _, mode := range modes { for _, retType := range retTypes { sctx := mock.NewContext() - aggFunc := aggregation.NewAggFuncDesc(sctx, name, + aggFunc, err := aggregation.NewAggFuncDesc(sctx, name, []expression.Expression{&expression.Constant{Value: types.Datum{}, RetType: types.NewFieldType(retType)}}, hasDistinct) + c.Assert(err, IsNil) aggFunc.Mode = mode aggFuncs = append(aggFuncs, aggFunc) }