diff --git a/executor/aggregate_test.go b/executor/aggregate_test.go index 30586fc16935c..ce973a8a45aac 100644 --- a/executor/aggregate_test.go +++ b/executor/aggregate_test.go @@ -15,6 +15,7 @@ package executor_test import ( . "github.com/pingcap/check" + "github.com/pingcap/errors" "github.com/pingcap/parser/terror" plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/util/testkit" @@ -338,6 +339,15 @@ func (s *testSuite) TestAggregation(c *C) { tk.MustExec("insert into t value(0), (-0.9871), (-0.9871)") tk.MustQuery("select 10 from t group by a").Check(testkit.Rows("10", "10")) tk.MustQuery("select sum(a) from (select a from t union all select a from t) tmp").Check(testkit.Rows("-3.9484")) + _, err = tk.Exec("select std(a) from t") + c.Assert(errors.Cause(err).Error(), Equals, "unsupported agg function: std") + _, err = tk.Exec("select stddev(a) from t") + c.Assert(errors.Cause(err).Error(), Equals, "unsupported agg function: stddev") + _, err = tk.Exec("select stddev_pop(a) from t") + c.Assert(errors.Cause(err).Error(), Equals, "unsupported agg function: stddev_pop") + _, err = tk.Exec("select std_samp(a) from t") + // TODO: Fix this error message. + c.Assert(errors.Cause(err).Error(), Equals, "[expression:1305]FUNCTION std_samp does not exist") } func (s *testSuite) TestStreamAggPushDown(c *C) { diff --git a/executor/executor_required_rows_test.go b/executor/executor_required_rows_test.go index ec1846d754d7d..5121ed4dea2f9 100644 --- a/executor/executor_required_rows_test.go +++ b/executor/executor_required_rows_test.go @@ -685,7 +685,8 @@ func (s *testExecSuite) TestStreamAggRequiredRows(c *C) { childCols := ds.Schema().Columns schema := expression.NewSchema(childCols...) groupBy := []expression.Expression{childCols[1]} - aggFunc := aggregation.NewAggFuncDesc(sctx, testCase.aggFunc, []expression.Expression{childCols[0]}, true) + aggFunc, err := aggregation.NewAggFuncDesc(sctx, testCase.aggFunc, []expression.Expression{childCols[0]}, true) + c.Assert(err, IsNil) aggFuncs := []*aggregation.AggFuncDesc{aggFunc} exec := buildStreamAggExecutor(sctx, ds, schema, aggFuncs, groupBy) c.Assert(exec.Open(ctx), IsNil) @@ -744,7 +745,8 @@ func (s *testExecSuite) TestHashAggParallelRequiredRows(c *C) { childCols := ds.Schema().Columns schema := expression.NewSchema(childCols...) groupBy := []expression.Expression{childCols[1]} - aggFunc := aggregation.NewAggFuncDesc(sctx, testCase.aggFunc, []expression.Expression{childCols[0]}, hasDistinct) + aggFunc, err := aggregation.NewAggFuncDesc(sctx, testCase.aggFunc, []expression.Expression{childCols[0]}, hasDistinct) + c.Assert(err, IsNil) aggFuncs := []*aggregation.AggFuncDesc{aggFunc} exec := buildHashAggExecutor(sctx, ds, schema, aggFuncs, groupBy) c.Assert(exec.Open(ctx), IsNil) diff --git a/expression/aggregation/aggregation_test.go b/expression/aggregation/aggregation_test.go index 0ebbe8a330eef..307382e08f46c 100644 --- a/expression/aggregation/aggregation_test.go +++ b/expression/aggregation/aggregation_test.go @@ -58,7 +58,9 @@ func (s *testAggFuncSuit) TestAvg(c *C) { RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() - avgFunc := NewAggFuncDesc(s.ctx, ast.AggFuncAvg, []expression.Expression{col}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncAvg, []expression.Expression{col}, false) + c.Assert(err, IsNil) + avgFunc := desc.GetAggFunc(ctx) evalCtx := avgFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) result := avgFunc.GetResult(evalCtx) @@ -71,12 +73,14 @@ func (s *testAggFuncSuit) TestAvg(c *C) { result = avgFunc.GetResult(evalCtx) needed := types.NewDecFromStringForTest("67.000000000000000000000000000000") c.Assert(result.GetMysqlDecimal().Compare(needed) == 0, IsTrue) - err := avgFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, s.nullRow) + err = avgFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, s.nullRow) c.Assert(err, IsNil) result = avgFunc.GetResult(evalCtx) c.Assert(result.GetMysqlDecimal().Compare(needed) == 0, IsTrue) - distinctAvgFunc := NewAggFuncDesc(s.ctx, ast.AggFuncAvg, []expression.Expression{col}, true).GetAggFunc(ctx) + desc, err = NewAggFuncDesc(s.ctx, ast.AggFuncAvg, []expression.Expression{col}, true) + c.Assert(err, IsNil) + distinctAvgFunc := desc.GetAggFunc(ctx) evalCtx = distinctAvgFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) for _, row := range s.rows { err := distinctAvgFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) @@ -105,7 +109,8 @@ func (s *testAggFuncSuit) TestAvgFinalMode(c *C) { Index: 1, RetType: types.NewFieldType(mysql.TypeNewDecimal), } - aggFunc := NewAggFuncDesc(s.ctx, ast.AggFuncAvg, []expression.Expression{cntCol, sumCol}, false) + aggFunc, err := NewAggFuncDesc(s.ctx, ast.AggFuncAvg, []expression.Expression{cntCol, sumCol}, false) + c.Assert(err, IsNil) aggFunc.Mode = FinalMode avgFunc := aggFunc.GetAggFunc(ctx) evalCtx := avgFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) @@ -125,7 +130,9 @@ func (s *testAggFuncSuit) TestSum(c *C) { RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() - sumFunc := NewAggFuncDesc(s.ctx, ast.AggFuncSum, []expression.Expression{col}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncSum, []expression.Expression{col}, false) + c.Assert(err, IsNil) + sumFunc := desc.GetAggFunc(ctx) evalCtx := sumFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) result := sumFunc.GetResult(evalCtx) @@ -138,14 +145,16 @@ func (s *testAggFuncSuit) TestSum(c *C) { result = sumFunc.GetResult(evalCtx) needed := types.NewDecFromStringForTest("338350") c.Assert(result.GetMysqlDecimal().Compare(needed) == 0, IsTrue) - err := sumFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, s.nullRow) + err = sumFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, s.nullRow) c.Assert(err, IsNil) result = sumFunc.GetResult(evalCtx) c.Assert(result.GetMysqlDecimal().Compare(needed) == 0, IsTrue) partialResult := sumFunc.GetPartialResult(evalCtx) c.Assert(partialResult[0].GetMysqlDecimal().Compare(needed) == 0, IsTrue) - distinctSumFunc := NewAggFuncDesc(s.ctx, ast.AggFuncSum, []expression.Expression{col}, true).GetAggFunc(ctx) + desc, err = NewAggFuncDesc(s.ctx, ast.AggFuncSum, []expression.Expression{col}, true) + c.Assert(err, IsNil) + distinctSumFunc := desc.GetAggFunc(ctx) evalCtx = distinctSumFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) for _, row := range s.rows { err := distinctSumFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) @@ -162,14 +171,16 @@ func (s *testAggFuncSuit) TestBitAnd(c *C) { RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() - bitAndFunc := NewAggFuncDesc(s.ctx, ast.AggFuncBitAnd, []expression.Expression{col}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncBitAnd, []expression.Expression{col}, false) + c.Assert(err, IsNil) + bitAndFunc := desc.GetAggFunc(ctx) evalCtx := bitAndFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) result := bitAndFunc.GetResult(evalCtx) c.Assert(result.GetUint64(), Equals, uint64(math.MaxUint64)) row := chunk.MutRowFromDatums(types.MakeDatums(1)).ToRow() - err := bitAndFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) + err = bitAndFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) c.Assert(err, IsNil) result = bitAndFunc.GetResult(evalCtx) c.Assert(result.GetUint64(), Equals, uint64(1)) @@ -238,14 +249,16 @@ func (s *testAggFuncSuit) TestBitOr(c *C) { RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() - bitOrFunc := NewAggFuncDesc(s.ctx, ast.AggFuncBitOr, []expression.Expression{col}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncBitOr, []expression.Expression{col}, false) + c.Assert(err, IsNil) + bitOrFunc := desc.GetAggFunc(ctx) evalCtx := bitOrFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) result := bitOrFunc.GetResult(evalCtx) c.Assert(result.GetUint64(), Equals, uint64(0)) row := chunk.MutRowFromDatums(types.MakeDatums(1)).ToRow() - err := bitOrFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) + err = bitOrFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) c.Assert(err, IsNil) result = bitOrFunc.GetResult(evalCtx) c.Assert(result.GetUint64(), Equals, uint64(1)) @@ -322,14 +335,16 @@ func (s *testAggFuncSuit) TestBitXor(c *C) { RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() - bitXorFunc := NewAggFuncDesc(s.ctx, ast.AggFuncBitXor, []expression.Expression{col}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncBitXor, []expression.Expression{col}, false) + c.Assert(err, IsNil) + bitXorFunc := desc.GetAggFunc(ctx) evalCtx := bitXorFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) result := bitXorFunc.GetResult(evalCtx) c.Assert(result.GetUint64(), Equals, uint64(0)) row := chunk.MutRowFromDatums(types.MakeDatums(1)).ToRow() - err := bitXorFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) + err = bitXorFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) c.Assert(err, IsNil) result = bitXorFunc.GetResult(evalCtx) c.Assert(result.GetUint64(), Equals, uint64(1)) @@ -398,7 +413,9 @@ func (s *testAggFuncSuit) TestCount(c *C) { RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() - countFunc := NewAggFuncDesc(s.ctx, ast.AggFuncCount, []expression.Expression{col}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncCount, []expression.Expression{col}, false) + c.Assert(err, IsNil) + countFunc := desc.GetAggFunc(ctx) evalCtx := countFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) result := countFunc.GetResult(evalCtx) @@ -410,14 +427,16 @@ func (s *testAggFuncSuit) TestCount(c *C) { } result = countFunc.GetResult(evalCtx) c.Assert(result.GetInt64(), Equals, int64(5050)) - err := countFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, s.nullRow) + err = countFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, s.nullRow) c.Assert(err, IsNil) result = countFunc.GetResult(evalCtx) c.Assert(result.GetInt64(), Equals, int64(5050)) partialResult := countFunc.GetPartialResult(evalCtx) c.Assert(partialResult[0].GetInt64(), Equals, int64(5050)) - distinctCountFunc := NewAggFuncDesc(s.ctx, ast.AggFuncCount, []expression.Expression{col}, true).GetAggFunc(ctx) + desc, err = NewAggFuncDesc(s.ctx, ast.AggFuncCount, []expression.Expression{col}, true) + c.Assert(err, IsNil) + distinctCountFunc := desc.GetAggFunc(ctx) evalCtx = distinctCountFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) for _, row := range s.rows { @@ -438,14 +457,16 @@ func (s *testAggFuncSuit) TestConcat(c *C) { RetType: types.NewFieldType(mysql.TypeVarchar), } ctx := mock.NewContext() - concatFunc := NewAggFuncDesc(s.ctx, ast.AggFuncGroupConcat, []expression.Expression{col, sep}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncGroupConcat, []expression.Expression{col, sep}, false) + c.Assert(err, IsNil) + concatFunc := desc.GetAggFunc(ctx) evalCtx := concatFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) result := concatFunc.GetResult(evalCtx) c.Assert(result.IsNull(), IsTrue) row := chunk.MutRowFromDatums(types.MakeDatums(1, "x")) - err := concatFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row.ToRow()) + err = concatFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row.ToRow()) c.Assert(err, IsNil) result = concatFunc.GetResult(evalCtx) c.Assert(result.GetString(), Equals, "1") @@ -464,7 +485,9 @@ func (s *testAggFuncSuit) TestConcat(c *C) { partialResult := concatFunc.GetPartialResult(evalCtx) c.Assert(partialResult[0].GetString(), Equals, "1x2") - distinctConcatFunc := NewAggFuncDesc(s.ctx, ast.AggFuncGroupConcat, []expression.Expression{col, sep}, true).GetAggFunc(ctx) + desc, err = NewAggFuncDesc(s.ctx, ast.AggFuncGroupConcat, []expression.Expression{col, sep}, true) + c.Assert(err, IsNil) + distinctConcatFunc := desc.GetAggFunc(ctx) evalCtx = distinctConcatFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) row.SetDatum(0, types.NewIntDatum(1)) @@ -487,11 +510,13 @@ func (s *testAggFuncSuit) TestFirstRow(c *C) { } ctx := mock.NewContext() - firstRowFunc := NewAggFuncDesc(s.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false) + c.Assert(err, IsNil) + firstRowFunc := desc.GetAggFunc(ctx) evalCtx := firstRowFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) row := chunk.MutRowFromDatums(types.MakeDatums(1)).ToRow() - err := firstRowFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) + err = firstRowFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) c.Assert(err, IsNil) result := firstRowFunc.GetResult(evalCtx) c.Assert(result.GetUint64(), Equals, uint64(1)) @@ -512,8 +537,12 @@ func (s *testAggFuncSuit) TestMaxMin(c *C) { } ctx := mock.NewContext() - maxFunc := NewAggFuncDesc(s.ctx, ast.AggFuncMax, []expression.Expression{col}, false).GetAggFunc(ctx) - minFunc := NewAggFuncDesc(s.ctx, ast.AggFuncMin, []expression.Expression{col}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncMax, []expression.Expression{col}, false) + c.Assert(err, IsNil) + maxFunc := desc.GetAggFunc(ctx) + desc, err = NewAggFuncDesc(s.ctx, ast.AggFuncMin, []expression.Expression{col}, false) + c.Assert(err, IsNil) + minFunc := desc.GetAggFunc(ctx) maxEvalCtx := maxFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) minEvalCtx := minFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) @@ -523,7 +552,7 @@ func (s *testAggFuncSuit) TestMaxMin(c *C) { c.Assert(result.IsNull(), IsTrue) row := chunk.MutRowFromDatums(types.MakeDatums(2)) - err := maxFunc.Update(maxEvalCtx, s.ctx.GetSessionVars().StmtCtx, row.ToRow()) + err = maxFunc.Update(maxEvalCtx, s.ctx.GetSessionVars().StmtCtx, row.ToRow()) c.Assert(err, IsNil) result = maxFunc.GetResult(maxEvalCtx) c.Assert(result.GetInt64(), Equals, int64(2)) diff --git a/expression/aggregation/bench_test.go b/expression/aggregation/bench_test.go index e49deebe00da3..c3f72695709cd 100644 --- a/expression/aggregation/bench_test.go +++ b/expression/aggregation/bench_test.go @@ -29,7 +29,11 @@ func BenchmarkCreateContext(b *testing.B) { RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() - fun := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, false) + if err != nil { + b.Fatal(err) + } + fun := desc.GetAggFunc(ctx) b.StartTimer() for i := 0; i < b.N; i++ { fun.CreateContext(ctx.GetSessionVars().StmtCtx) @@ -43,7 +47,11 @@ func BenchmarkResetContext(b *testing.B) { RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() - fun := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, false) + if err != nil { + b.Fatal(err) + } + fun := desc.GetAggFunc(ctx) evalCtx := fun.CreateContext(ctx.GetSessionVars().StmtCtx) b.StartTimer() for i := 0; i < b.N; i++ { @@ -58,7 +66,11 @@ func BenchmarkCreateDistinctContext(b *testing.B) { RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() - fun := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, true).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, true) + if err != nil { + b.Fatal(err) + } + fun := desc.GetAggFunc(ctx) b.StartTimer() for i := 0; i < b.N; i++ { fun.CreateContext(ctx.GetSessionVars().StmtCtx) @@ -72,7 +84,11 @@ func BenchmarkResetDistinctContext(b *testing.B) { RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() - fun := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, true).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, true) + if err != nil { + b.Fatal(err) + } + fun := desc.GetAggFunc(ctx) evalCtx := fun.CreateContext(ctx.GetSessionVars().StmtCtx) b.StartTimer() for i := 0; i < b.N; i++ { diff --git a/expression/aggregation/descriptor.go b/expression/aggregation/descriptor.go index 3487fd389260b..fa3059b275e55 100644 --- a/expression/aggregation/descriptor.go +++ b/expression/aggregation/descriptor.go @@ -21,6 +21,7 @@ import ( "strings" "github.com/cznic/mathutil" + "github.com/pingcap/errors" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/charset" "github.com/pingcap/parser/model" @@ -46,14 +47,14 @@ type AggFuncDesc struct { } // NewAggFuncDesc creates an aggregation function signature descriptor. -func NewAggFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression, hasDistinct bool) *AggFuncDesc { +func NewAggFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression, hasDistinct bool) (*AggFuncDesc, error) { a := &AggFuncDesc{ Name: strings.ToLower(name), Args: args, HasDistinct: hasDistinct, } - a.typeInfer(ctx) - return a + err := a.typeInfer(ctx) + return a, err } // Equal checks whether two aggregation function signatures are equal. @@ -143,7 +144,7 @@ func (a *AggFuncDesc) String() string { } // typeInfer infers the arguments and return types of an aggregation function. -func (a *AggFuncDesc) typeInfer(ctx sessionctx.Context) { +func (a *AggFuncDesc) typeInfer(ctx sessionctx.Context) error { switch a.Name { case ast.AggFuncCount: a.typeInfer4Count(ctx) @@ -158,8 +159,9 @@ func (a *AggFuncDesc) typeInfer(ctx sessionctx.Context) { case ast.AggFuncBitAnd, ast.AggFuncBitOr, ast.AggFuncBitXor: a.typeInfer4BitFuncs(ctx) default: - panic("unsupported agg function: " + a.Name) + return errors.Errorf("unsupported agg function: %s", a.Name) } + return nil } // EvalNullValueInOuterJoin gets the null value when the aggregation is upon an outer join, diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index dbcd954a895d7..7c87d8b3c3c43 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -415,7 +415,11 @@ func (er *expressionRewriter) handleOtherComparableSubq(lexpr, rexpr expression. if useMin { funcName = ast.AggFuncMin } - funcMaxOrMin := aggregation.NewAggFuncDesc(er.ctx, funcName, []expression.Expression{rexpr}, false) + funcMaxOrMin, err := aggregation.NewAggFuncDesc(er.ctx, funcName, []expression.Expression{rexpr}, false) + if err != nil { + er.err = err + return + } // Create a column and append it to the schema of that aggregation. colMaxOrMin := &expression.Column{ @@ -437,7 +441,11 @@ func (er *expressionRewriter) buildQuantifierPlan(plan4Agg *LogicalAggregation, innerIsNull := expression.NewFunctionInternal(er.ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), rexpr) outerIsNull := expression.NewFunctionInternal(er.ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), lexpr) - funcSum := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncSum, []expression.Expression{innerIsNull}, false) + funcSum, err := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncSum, []expression.Expression{innerIsNull}, false) + if err != nil { + er.err = err + return + } colSum := &expression.Column{ ColName: model.NewCIStr("agg_col_sum"), UniqueID: er.ctx.GetSessionVars().AllocPlanColumnID(), @@ -448,7 +456,11 @@ func (er *expressionRewriter) buildQuantifierPlan(plan4Agg *LogicalAggregation, innerHasNull := expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), colSum, expression.Zero) // Build `count(1)` aggregation to check if subquery is empty. - funcCount := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncCount, []expression.Expression{expression.One}, false) + funcCount, err := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncCount, []expression.Expression{expression.One}, false) + if err != nil { + er.err = err + return + } colCount := &expression.Column{ ColName: model.NewCIStr("agg_col_cnt"), UniqueID: er.ctx.GetSessionVars().AllocPlanColumnID(), @@ -509,8 +521,16 @@ func (er *expressionRewriter) buildQuantifierPlan(plan4Agg *LogicalAggregation, // t.id != s.id or count(distinct s.id) > 1 or [any checker]. If there are two different values in s.id , // there must exist a s.id that doesn't equal to t.id. func (er *expressionRewriter) handleNEAny(lexpr, rexpr expression.Expression, np LogicalPlan) { - firstRowFunc := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncFirstRow, []expression.Expression{rexpr}, false) - countFunc := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncCount, []expression.Expression{rexpr}, true) + firstRowFunc, err := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncFirstRow, []expression.Expression{rexpr}, false) + if err != nil { + er.err = err + return + } + countFunc, err := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncCount, []expression.Expression{rexpr}, true) + if err != nil { + er.err = err + return + } plan4Agg := LogicalAggregation{ AggFuncs: []*aggregation.AggFuncDesc{firstRowFunc, countFunc}, }.init(er.ctx) @@ -535,8 +555,16 @@ func (er *expressionRewriter) handleNEAny(lexpr, rexpr expression.Expression, np // handleEQAll handles the case of = all. For example, if the query is t.id = all (select s.id from s), it will be rewrote to // t.id = (select s.id from s having count(distinct s.id) <= 1 and [all checker]). func (er *expressionRewriter) handleEQAll(lexpr, rexpr expression.Expression, np LogicalPlan) { - firstRowFunc := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncFirstRow, []expression.Expression{rexpr}, false) - countFunc := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncCount, []expression.Expression{rexpr}, true) + firstRowFunc, err := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncFirstRow, []expression.Expression{rexpr}, false) + if err != nil { + er.err = err + return + } + countFunc, err := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncCount, []expression.Expression{rexpr}, true) + if err != nil { + er.err = err + return + } plan4Agg := LogicalAggregation{ AggFuncs: []*aggregation.AggFuncDesc{firstRowFunc, countFunc}, }.init(er.ctx) diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index ed26c200a4b8f..976f452a1391f 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -92,7 +92,10 @@ func (b *planBuilder) buildAggregation(p LogicalPlan, aggFuncList []*ast.Aggrega p = np newArgList = append(newArgList, newArg) } - newFunc := aggregation.NewAggFuncDesc(b.ctx, aggFunc.F, newArgList, aggFunc.Distinct) + newFunc, err := aggregation.NewAggFuncDesc(b.ctx, aggFunc.F, newArgList, aggFunc.Distinct) + if err != nil { + return nil, nil, err + } combined := false for j, oldFunc := range plan4Agg.AggFuncs { if oldFunc.Equal(b.ctx, newFunc) { @@ -114,7 +117,10 @@ func (b *planBuilder) buildAggregation(p LogicalPlan, aggFuncList []*ast.Aggrega } } for _, col := range p.Schema().Columns { - newFunc := aggregation.NewAggFuncDesc(b.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false) + newFunc, err := aggregation.NewAggFuncDesc(b.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false) + if err != nil { + return nil, nil, err + } plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, newFunc) newCol, _ := col.Clone().(*expression.Column) newCol.RetType = newFunc.RetTp @@ -649,7 +655,7 @@ func (b *planBuilder) buildProjection(p LogicalPlan, fields []*ast.SelectField, return proj, oldLen, nil } -func (b *planBuilder) buildDistinct(child LogicalPlan, length int) *LogicalAggregation { +func (b *planBuilder) buildDistinct(child LogicalPlan, length int) (*LogicalAggregation, error) { b.optFlag = b.optFlag | flagBuildKeyInfo b.optFlag = b.optFlag | flagAggregationOptimize plan4Agg := LogicalAggregation{ @@ -658,7 +664,10 @@ func (b *planBuilder) buildDistinct(child LogicalPlan, length int) *LogicalAggre }.init(b.ctx) plan4Agg.collectGroupByColumns() for _, col := range child.Schema().Columns { - aggDesc := aggregation.NewAggFuncDesc(b.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false) + aggDesc, err := aggregation.NewAggFuncDesc(b.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false) + if err != nil { + return nil, err + } plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, aggDesc) } plan4Agg.SetChildren(child) @@ -668,7 +677,7 @@ func (b *planBuilder) buildDistinct(child LogicalPlan, length int) *LogicalAggre for i, col := range plan4Agg.schema.Columns { col.RetType = plan4Agg.AggFuncs[i].RetTp } - return plan4Agg + return plan4Agg, nil } // unionJoinFieldType finds the type which can carry the given types in Union. @@ -740,7 +749,10 @@ func (b *planBuilder) buildUnion(union *ast.UnionStmt) (LogicalPlan, error) { unionDistinctPlan := b.buildUnionAll(distinctSelectPlans) if unionDistinctPlan != nil { - unionDistinctPlan = b.buildDistinct(unionDistinctPlan, unionDistinctPlan.Schema().Len()) + unionDistinctPlan, err = b.buildDistinct(unionDistinctPlan, unionDistinctPlan.Schema().Len()) + if err != nil { + return nil, err + } if len(allSelectPlans) > 0 { // Can't change the statements order in order to get the correct column info. allSelectPlans = append([]LogicalPlan{unionDistinctPlan}, allSelectPlans...) @@ -1791,7 +1803,10 @@ func (b *planBuilder) buildSelect(sel *ast.SelectStmt) (p LogicalPlan, err error } if sel.Distinct { - p = b.buildDistinct(p, oldLen) + p, err = b.buildDistinct(p, oldLen) + if err != nil { + return nil, err + } } if sel.OrderBy != nil { diff --git a/planner/core/rule_aggregation_push_down.go b/planner/core/rule_aggregation_push_down.go index 7b26d5454540e..0c3ae08baf207 100644 --- a/planner/core/rule_aggregation_push_down.go +++ b/planner/core/rule_aggregation_push_down.go @@ -187,22 +187,25 @@ func (a *aggregationOptimizer) decompose(ctx sessionctx.Context, aggFunc *aggreg // tryToPushDownAgg tries to push down an aggregate function into a join path. If all aggFuncs are first row, we won't // process it temporarily. If not, We will add additional group by columns and first row functions. We make a new aggregation operator. // If the pushed aggregation is grouped by unique key, it's no need to push it down. -func (a *aggregationOptimizer) tryToPushDownAgg(aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column, join *LogicalJoin, childIdx int) LogicalPlan { +func (a *aggregationOptimizer) tryToPushDownAgg(aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column, join *LogicalJoin, childIdx int) (LogicalPlan, error) { child := join.children[childIdx] if aggregation.IsAllFirstRow(aggFuncs) { - return child + return child, nil } // If the join is multiway-join, we forbid pushing down. if _, ok := join.children[childIdx].(*LogicalJoin); ok { - return child + return child, nil } tmpSchema := expression.NewSchema(gbyCols...) for _, key := range child.Schema().Keys { if tmpSchema.ColumnsIndices(key) != nil { - return child + return child, nil } } - agg := a.makeNewAgg(join.ctx, aggFuncs, gbyCols) + agg, err := a.makeNewAgg(join.ctx, aggFuncs, gbyCols) + if err != nil { + return nil, err + } agg.SetChildren(child) // If agg has no group-by item, it will return a default value, which may cause some bugs. // So here we add a group-by item forcely. @@ -215,10 +218,10 @@ func (a *aggregationOptimizer) tryToPushDownAgg(aggFuncs []*aggregation.AggFuncD var existsDefaultValues bool join.DefaultValues, existsDefaultValues = a.getDefaultValues(agg) if !existsDefaultValues { - return child + return child, nil } } - return agg + return agg, nil } func (a *aggregationOptimizer) getDefaultValues(agg *LogicalAggregation) ([]types.Datum, bool) { @@ -242,7 +245,7 @@ func (a *aggregationOptimizer) checkAnyCountAndSum(aggFuncs []*aggregation.AggFu return false } -func (a *aggregationOptimizer) makeNewAgg(ctx sessionctx.Context, aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column) *LogicalAggregation { +func (a *aggregationOptimizer) makeNewAgg(ctx sessionctx.Context, aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column) (*LogicalAggregation, error) { agg := LogicalAggregation{ GroupByItems: expression.Column2Exprs(gbyCols), groupByCols: gbyCols, @@ -256,7 +259,10 @@ func (a *aggregationOptimizer) makeNewAgg(ctx sessionctx.Context, aggFuncs []*ag newAggFuncDescs = append(newAggFuncDescs, newFuncs...) } for _, gbyCol := range gbyCols { - firstRow := aggregation.NewAggFuncDesc(agg.ctx, ast.AggFuncFirstRow, []expression.Expression{gbyCol}, false) + firstRow, err := aggregation.NewAggFuncDesc(agg.ctx, ast.AggFuncFirstRow, []expression.Expression{gbyCol}, false) + if err != nil { + return nil, err + } newCol, _ := gbyCol.Clone().(*expression.Column) newCol.RetType = firstRow.RetTp newAggFuncDescs = append(newAggFuncDescs, firstRow) @@ -266,7 +272,7 @@ func (a *aggregationOptimizer) makeNewAgg(ctx sessionctx.Context, aggFuncs []*ag agg.SetSchema(schema) // TODO: Add a Projection if any argument of aggregate funcs or group by items are scalar functions. // agg.buildProjectionIfNecessary() - return agg + return agg, nil } // pushAggCrossUnion will try to push the agg down to the union. If the new aggregation's group-by columns doesn't contain unique key. @@ -311,12 +317,11 @@ func (a *aggregationOptimizer) optimize(p LogicalPlan) (LogicalPlan, error) { if !p.context().GetSessionVars().AllowAggPushDown { return p, nil } - a.aggPushDown(p) - return p, nil + return a.aggPushDown(p) } // aggPushDown tries to push down aggregate functions to join paths. -func (a *aggregationOptimizer) aggPushDown(p LogicalPlan) LogicalPlan { +func (a *aggregationOptimizer) aggPushDown(p LogicalPlan) (_ LogicalPlan, err error) { if agg, ok := p.(*LogicalAggregation); ok { proj := a.tryToEliminateAggregation(agg) if proj != nil { @@ -333,12 +338,18 @@ func (a *aggregationOptimizer) aggPushDown(p LogicalPlan) LogicalPlan { if rightInvalid { rChild = join.children[1] } else { - rChild = a.tryToPushDownAgg(rightAggFuncs, rightGbyCols, join, 1) + rChild, err = a.tryToPushDownAgg(rightAggFuncs, rightGbyCols, join, 1) + if err != nil { + return nil, err + } } if leftInvalid { lChild = join.children[0] } else { - lChild = a.tryToPushDownAgg(leftAggFuncs, leftGbyCols, join, 0) + lChild, err = a.tryToPushDownAgg(leftAggFuncs, leftGbyCols, join, 0) + if err != nil { + return nil, err + } } join.SetChildren(lChild, rChild) join.SetSchema(expression.MergeSchema(lChild.Schema(), rChild.Schema())) @@ -367,7 +378,10 @@ func (a *aggregationOptimizer) aggPushDown(p LogicalPlan) LogicalPlan { } else if union, ok1 := child.(*LogicalUnionAll); ok1 { var gbyCols []*expression.Column gbyCols = expression.ExtractColumnsFromExpressions(gbyCols, agg.GroupByItems, nil) - pushedAgg := a.makeNewAgg(agg.ctx, agg.AggFuncs, gbyCols) + pushedAgg, err := a.makeNewAgg(agg.ctx, agg.AggFuncs, gbyCols) + if err != nil { + return nil, err + } newChildren := make([]LogicalPlan, 0, len(union.children)) for _, child := range union.children { newChild := a.pushAggCrossUnion(pushedAgg, union.Schema(), child) @@ -379,11 +393,14 @@ func (a *aggregationOptimizer) aggPushDown(p LogicalPlan) LogicalPlan { } newChildren := make([]LogicalPlan, 0, len(p.Children())) for _, child := range p.Children() { - newChild := a.aggPushDown(child) + newChild, err := a.aggPushDown(child) + if err != nil { + return nil, err + } newChildren = append(newChildren, newChild) } p.SetChildren(newChildren...) - return p + return p, nil } // tryToEliminateAggregation will eliminate aggregation grouped by unique key. diff --git a/planner/core/rule_decorrelate.go b/planner/core/rule_decorrelate.go index 2cd88c5231eda..1bf565ff04528 100644 --- a/planner/core/rule_decorrelate.go +++ b/planner/core/rule_decorrelate.go @@ -185,7 +185,10 @@ func (s *decorrelateSolver) optimize(p LogicalPlan) (LogicalPlan, error) { outerColsInSchema := make([]*expression.Column, 0, outerPlan.Schema().Len()) for i, col := range outerPlan.Schema().Columns { - first := aggregation.NewAggFuncDesc(agg.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false) + first, err := aggregation.NewAggFuncDesc(agg.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false) + if err != nil { + return nil, err + } newAggFuncs = append(newAggFuncs, first) outerCol, _ := outerPlan.Schema().Columns[i].Clone().(*expression.Column) @@ -232,7 +235,10 @@ func (s *decorrelateSolver) optimize(p LogicalPlan) (LogicalPlan, error) { clonedCol := eqCond.GetArgs()[1] // If the join key is not in the aggregation's schema, add first row function. if agg.schema.ColumnIndex(eqCond.GetArgs()[1].(*expression.Column)) == -1 { - newFunc := aggregation.NewAggFuncDesc(apply.ctx, ast.AggFuncFirstRow, []expression.Expression{clonedCol}, false) + newFunc, err := aggregation.NewAggFuncDesc(apply.ctx, ast.AggFuncFirstRow, []expression.Expression{clonedCol}, false) + if err != nil { + return nil, err + } agg.AggFuncs = append(agg.AggFuncs, newFunc) agg.schema.Append(clonedCol.(*expression.Column)) agg.schema.Columns[agg.schema.Len()-1].RetType = newFunc.RetTp