Skip to content

Commit

Permalink
*: let baseFuncDesc.typeInfer return error instead of panic(#1… (#10911)
Browse files Browse the repository at this point in the history
  • Loading branch information
XuHuaiyu authored and zz-jason committed Jun 26, 2019
1 parent 42eced8 commit 0db535a
Show file tree
Hide file tree
Showing 9 changed files with 194 additions and 69 deletions.
10 changes: 10 additions & 0 deletions executor/aggregate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package executor_test

import (
. "github.com/pingcap/check"
"github.com/pingcap/errors"
"github.com/pingcap/parser/terror"
plannercore "github.com/pingcap/tidb/planner/core"
"github.com/pingcap/tidb/util/testkit"
Expand Down Expand Up @@ -338,6 +339,15 @@ func (s *testSuite) TestAggregation(c *C) {
tk.MustExec("insert into t value(0), (-0.9871), (-0.9871)")
tk.MustQuery("select 10 from t group by a").Check(testkit.Rows("10", "10"))
tk.MustQuery("select sum(a) from (select a from t union all select a from t) tmp").Check(testkit.Rows("-3.9484"))
_, err = tk.Exec("select std(a) from t")
c.Assert(errors.Cause(err).Error(), Equals, "unsupported agg function: std")
_, err = tk.Exec("select stddev(a) from t")
c.Assert(errors.Cause(err).Error(), Equals, "unsupported agg function: stddev")
_, err = tk.Exec("select stddev_pop(a) from t")
c.Assert(errors.Cause(err).Error(), Equals, "unsupported agg function: stddev_pop")
_, err = tk.Exec("select std_samp(a) from t")
// TODO: Fix this error message.
c.Assert(errors.Cause(err).Error(), Equals, "[expression:1305]FUNCTION std_samp does not exist")
}

func (s *testSuite) TestStreamAggPushDown(c *C) {
Expand Down
6 changes: 4 additions & 2 deletions executor/executor_required_rows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,8 @@ func (s *testExecSuite) TestStreamAggRequiredRows(c *C) {
childCols := ds.Schema().Columns
schema := expression.NewSchema(childCols...)
groupBy := []expression.Expression{childCols[1]}
aggFunc := aggregation.NewAggFuncDesc(sctx, testCase.aggFunc, []expression.Expression{childCols[0]}, true)
aggFunc, err := aggregation.NewAggFuncDesc(sctx, testCase.aggFunc, []expression.Expression{childCols[0]}, true)
c.Assert(err, IsNil)
aggFuncs := []*aggregation.AggFuncDesc{aggFunc}
exec := buildStreamAggExecutor(sctx, ds, schema, aggFuncs, groupBy)
c.Assert(exec.Open(ctx), IsNil)
Expand Down Expand Up @@ -744,7 +745,8 @@ func (s *testExecSuite) TestHashAggParallelRequiredRows(c *C) {
childCols := ds.Schema().Columns
schema := expression.NewSchema(childCols...)
groupBy := []expression.Expression{childCols[1]}
aggFunc := aggregation.NewAggFuncDesc(sctx, testCase.aggFunc, []expression.Expression{childCols[0]}, hasDistinct)
aggFunc, err := aggregation.NewAggFuncDesc(sctx, testCase.aggFunc, []expression.Expression{childCols[0]}, hasDistinct)
c.Assert(err, IsNil)
aggFuncs := []*aggregation.AggFuncDesc{aggFunc}
exec := buildHashAggExecutor(sctx, ds, schema, aggFuncs, groupBy)
c.Assert(exec.Open(ctx), IsNil)
Expand Down
77 changes: 53 additions & 24 deletions expression/aggregation/aggregation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ func (s *testAggFuncSuit) TestAvg(c *C) {
RetType: types.NewFieldType(mysql.TypeLonglong),
}
ctx := mock.NewContext()
avgFunc := NewAggFuncDesc(s.ctx, ast.AggFuncAvg, []expression.Expression{col}, false).GetAggFunc(ctx)
desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncAvg, []expression.Expression{col}, false)
c.Assert(err, IsNil)
avgFunc := desc.GetAggFunc(ctx)
evalCtx := avgFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx)

result := avgFunc.GetResult(evalCtx)
Expand All @@ -71,12 +73,14 @@ func (s *testAggFuncSuit) TestAvg(c *C) {
result = avgFunc.GetResult(evalCtx)
needed := types.NewDecFromStringForTest("67.000000000000000000000000000000")
c.Assert(result.GetMysqlDecimal().Compare(needed) == 0, IsTrue)
err := avgFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, s.nullRow)
err = avgFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, s.nullRow)
c.Assert(err, IsNil)
result = avgFunc.GetResult(evalCtx)
c.Assert(result.GetMysqlDecimal().Compare(needed) == 0, IsTrue)

distinctAvgFunc := NewAggFuncDesc(s.ctx, ast.AggFuncAvg, []expression.Expression{col}, true).GetAggFunc(ctx)
desc, err = NewAggFuncDesc(s.ctx, ast.AggFuncAvg, []expression.Expression{col}, true)
c.Assert(err, IsNil)
distinctAvgFunc := desc.GetAggFunc(ctx)
evalCtx = distinctAvgFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx)
for _, row := range s.rows {
err := distinctAvgFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row)
Expand Down Expand Up @@ -105,7 +109,8 @@ func (s *testAggFuncSuit) TestAvgFinalMode(c *C) {
Index: 1,
RetType: types.NewFieldType(mysql.TypeNewDecimal),
}
aggFunc := NewAggFuncDesc(s.ctx, ast.AggFuncAvg, []expression.Expression{cntCol, sumCol}, false)
aggFunc, err := NewAggFuncDesc(s.ctx, ast.AggFuncAvg, []expression.Expression{cntCol, sumCol}, false)
c.Assert(err, IsNil)
aggFunc.Mode = FinalMode
avgFunc := aggFunc.GetAggFunc(ctx)
evalCtx := avgFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx)
Expand All @@ -125,7 +130,9 @@ func (s *testAggFuncSuit) TestSum(c *C) {
RetType: types.NewFieldType(mysql.TypeLonglong),
}
ctx := mock.NewContext()
sumFunc := NewAggFuncDesc(s.ctx, ast.AggFuncSum, []expression.Expression{col}, false).GetAggFunc(ctx)
desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncSum, []expression.Expression{col}, false)
c.Assert(err, IsNil)
sumFunc := desc.GetAggFunc(ctx)
evalCtx := sumFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx)

result := sumFunc.GetResult(evalCtx)
Expand All @@ -138,14 +145,16 @@ func (s *testAggFuncSuit) TestSum(c *C) {
result = sumFunc.GetResult(evalCtx)
needed := types.NewDecFromStringForTest("338350")
c.Assert(result.GetMysqlDecimal().Compare(needed) == 0, IsTrue)
err := sumFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, s.nullRow)
err = sumFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, s.nullRow)
c.Assert(err, IsNil)
result = sumFunc.GetResult(evalCtx)
c.Assert(result.GetMysqlDecimal().Compare(needed) == 0, IsTrue)
partialResult := sumFunc.GetPartialResult(evalCtx)
c.Assert(partialResult[0].GetMysqlDecimal().Compare(needed) == 0, IsTrue)

distinctSumFunc := NewAggFuncDesc(s.ctx, ast.AggFuncSum, []expression.Expression{col}, true).GetAggFunc(ctx)
desc, err = NewAggFuncDesc(s.ctx, ast.AggFuncSum, []expression.Expression{col}, true)
c.Assert(err, IsNil)
distinctSumFunc := desc.GetAggFunc(ctx)
evalCtx = distinctSumFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx)
for _, row := range s.rows {
err := distinctSumFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row)
Expand All @@ -162,14 +171,16 @@ func (s *testAggFuncSuit) TestBitAnd(c *C) {
RetType: types.NewFieldType(mysql.TypeLonglong),
}
ctx := mock.NewContext()
bitAndFunc := NewAggFuncDesc(s.ctx, ast.AggFuncBitAnd, []expression.Expression{col}, false).GetAggFunc(ctx)
desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncBitAnd, []expression.Expression{col}, false)
c.Assert(err, IsNil)
bitAndFunc := desc.GetAggFunc(ctx)
evalCtx := bitAndFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx)

result := bitAndFunc.GetResult(evalCtx)
c.Assert(result.GetUint64(), Equals, uint64(math.MaxUint64))

row := chunk.MutRowFromDatums(types.MakeDatums(1)).ToRow()
err := bitAndFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row)
err = bitAndFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row)
c.Assert(err, IsNil)
result = bitAndFunc.GetResult(evalCtx)
c.Assert(result.GetUint64(), Equals, uint64(1))
Expand Down Expand Up @@ -238,14 +249,16 @@ func (s *testAggFuncSuit) TestBitOr(c *C) {
RetType: types.NewFieldType(mysql.TypeLonglong),
}
ctx := mock.NewContext()
bitOrFunc := NewAggFuncDesc(s.ctx, ast.AggFuncBitOr, []expression.Expression{col}, false).GetAggFunc(ctx)
desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncBitOr, []expression.Expression{col}, false)
c.Assert(err, IsNil)
bitOrFunc := desc.GetAggFunc(ctx)
evalCtx := bitOrFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx)

result := bitOrFunc.GetResult(evalCtx)
c.Assert(result.GetUint64(), Equals, uint64(0))

row := chunk.MutRowFromDatums(types.MakeDatums(1)).ToRow()
err := bitOrFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row)
err = bitOrFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row)
c.Assert(err, IsNil)
result = bitOrFunc.GetResult(evalCtx)
c.Assert(result.GetUint64(), Equals, uint64(1))
Expand Down Expand Up @@ -322,14 +335,16 @@ func (s *testAggFuncSuit) TestBitXor(c *C) {
RetType: types.NewFieldType(mysql.TypeLonglong),
}
ctx := mock.NewContext()
bitXorFunc := NewAggFuncDesc(s.ctx, ast.AggFuncBitXor, []expression.Expression{col}, false).GetAggFunc(ctx)
desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncBitXor, []expression.Expression{col}, false)
c.Assert(err, IsNil)
bitXorFunc := desc.GetAggFunc(ctx)
evalCtx := bitXorFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx)

result := bitXorFunc.GetResult(evalCtx)
c.Assert(result.GetUint64(), Equals, uint64(0))

row := chunk.MutRowFromDatums(types.MakeDatums(1)).ToRow()
err := bitXorFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row)
err = bitXorFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row)
c.Assert(err, IsNil)
result = bitXorFunc.GetResult(evalCtx)
c.Assert(result.GetUint64(), Equals, uint64(1))
Expand Down Expand Up @@ -398,7 +413,9 @@ func (s *testAggFuncSuit) TestCount(c *C) {
RetType: types.NewFieldType(mysql.TypeLonglong),
}
ctx := mock.NewContext()
countFunc := NewAggFuncDesc(s.ctx, ast.AggFuncCount, []expression.Expression{col}, false).GetAggFunc(ctx)
desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncCount, []expression.Expression{col}, false)
c.Assert(err, IsNil)
countFunc := desc.GetAggFunc(ctx)
evalCtx := countFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx)

result := countFunc.GetResult(evalCtx)
Expand All @@ -410,14 +427,16 @@ func (s *testAggFuncSuit) TestCount(c *C) {
}
result = countFunc.GetResult(evalCtx)
c.Assert(result.GetInt64(), Equals, int64(5050))
err := countFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, s.nullRow)
err = countFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, s.nullRow)
c.Assert(err, IsNil)
result = countFunc.GetResult(evalCtx)
c.Assert(result.GetInt64(), Equals, int64(5050))
partialResult := countFunc.GetPartialResult(evalCtx)
c.Assert(partialResult[0].GetInt64(), Equals, int64(5050))

distinctCountFunc := NewAggFuncDesc(s.ctx, ast.AggFuncCount, []expression.Expression{col}, true).GetAggFunc(ctx)
desc, err = NewAggFuncDesc(s.ctx, ast.AggFuncCount, []expression.Expression{col}, true)
c.Assert(err, IsNil)
distinctCountFunc := desc.GetAggFunc(ctx)
evalCtx = distinctCountFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx)

for _, row := range s.rows {
Expand All @@ -438,14 +457,16 @@ func (s *testAggFuncSuit) TestConcat(c *C) {
RetType: types.NewFieldType(mysql.TypeVarchar),
}
ctx := mock.NewContext()
concatFunc := NewAggFuncDesc(s.ctx, ast.AggFuncGroupConcat, []expression.Expression{col, sep}, false).GetAggFunc(ctx)
desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncGroupConcat, []expression.Expression{col, sep}, false)
c.Assert(err, IsNil)
concatFunc := desc.GetAggFunc(ctx)
evalCtx := concatFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx)

result := concatFunc.GetResult(evalCtx)
c.Assert(result.IsNull(), IsTrue)

row := chunk.MutRowFromDatums(types.MakeDatums(1, "x"))
err := concatFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row.ToRow())
err = concatFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row.ToRow())
c.Assert(err, IsNil)
result = concatFunc.GetResult(evalCtx)
c.Assert(result.GetString(), Equals, "1")
Expand All @@ -464,7 +485,9 @@ func (s *testAggFuncSuit) TestConcat(c *C) {
partialResult := concatFunc.GetPartialResult(evalCtx)
c.Assert(partialResult[0].GetString(), Equals, "1x2")

distinctConcatFunc := NewAggFuncDesc(s.ctx, ast.AggFuncGroupConcat, []expression.Expression{col, sep}, true).GetAggFunc(ctx)
desc, err = NewAggFuncDesc(s.ctx, ast.AggFuncGroupConcat, []expression.Expression{col, sep}, true)
c.Assert(err, IsNil)
distinctConcatFunc := desc.GetAggFunc(ctx)
evalCtx = distinctConcatFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx)

row.SetDatum(0, types.NewIntDatum(1))
Expand All @@ -487,11 +510,13 @@ func (s *testAggFuncSuit) TestFirstRow(c *C) {
}

ctx := mock.NewContext()
firstRowFunc := NewAggFuncDesc(s.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false).GetAggFunc(ctx)
desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false)
c.Assert(err, IsNil)
firstRowFunc := desc.GetAggFunc(ctx)
evalCtx := firstRowFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx)

row := chunk.MutRowFromDatums(types.MakeDatums(1)).ToRow()
err := firstRowFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row)
err = firstRowFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row)
c.Assert(err, IsNil)
result := firstRowFunc.GetResult(evalCtx)
c.Assert(result.GetUint64(), Equals, uint64(1))
Expand All @@ -512,8 +537,12 @@ func (s *testAggFuncSuit) TestMaxMin(c *C) {
}

ctx := mock.NewContext()
maxFunc := NewAggFuncDesc(s.ctx, ast.AggFuncMax, []expression.Expression{col}, false).GetAggFunc(ctx)
minFunc := NewAggFuncDesc(s.ctx, ast.AggFuncMin, []expression.Expression{col}, false).GetAggFunc(ctx)
desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncMax, []expression.Expression{col}, false)
c.Assert(err, IsNil)
maxFunc := desc.GetAggFunc(ctx)
desc, err = NewAggFuncDesc(s.ctx, ast.AggFuncMin, []expression.Expression{col}, false)
c.Assert(err, IsNil)
minFunc := desc.GetAggFunc(ctx)
maxEvalCtx := maxFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx)
minEvalCtx := minFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx)

Expand All @@ -523,7 +552,7 @@ func (s *testAggFuncSuit) TestMaxMin(c *C) {
c.Assert(result.IsNull(), IsTrue)

row := chunk.MutRowFromDatums(types.MakeDatums(2))
err := maxFunc.Update(maxEvalCtx, s.ctx.GetSessionVars().StmtCtx, row.ToRow())
err = maxFunc.Update(maxEvalCtx, s.ctx.GetSessionVars().StmtCtx, row.ToRow())
c.Assert(err, IsNil)
result = maxFunc.GetResult(maxEvalCtx)
c.Assert(result.GetInt64(), Equals, int64(2))
Expand Down
24 changes: 20 additions & 4 deletions expression/aggregation/bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ func BenchmarkCreateContext(b *testing.B) {
RetType: types.NewFieldType(mysql.TypeLonglong),
}
ctx := mock.NewContext()
fun := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, false).GetAggFunc(ctx)
desc, err := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, false)
if err != nil {
b.Fatal(err)
}
fun := desc.GetAggFunc(ctx)
b.StartTimer()
for i := 0; i < b.N; i++ {
fun.CreateContext(ctx.GetSessionVars().StmtCtx)
Expand All @@ -43,7 +47,11 @@ func BenchmarkResetContext(b *testing.B) {
RetType: types.NewFieldType(mysql.TypeLonglong),
}
ctx := mock.NewContext()
fun := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, false).GetAggFunc(ctx)
desc, err := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, false)
if err != nil {
b.Fatal(err)
}
fun := desc.GetAggFunc(ctx)
evalCtx := fun.CreateContext(ctx.GetSessionVars().StmtCtx)
b.StartTimer()
for i := 0; i < b.N; i++ {
Expand All @@ -58,7 +66,11 @@ func BenchmarkCreateDistinctContext(b *testing.B) {
RetType: types.NewFieldType(mysql.TypeLonglong),
}
ctx := mock.NewContext()
fun := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, true).GetAggFunc(ctx)
desc, err := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, true)
if err != nil {
b.Fatal(err)
}
fun := desc.GetAggFunc(ctx)
b.StartTimer()
for i := 0; i < b.N; i++ {
fun.CreateContext(ctx.GetSessionVars().StmtCtx)
Expand All @@ -72,7 +84,11 @@ func BenchmarkResetDistinctContext(b *testing.B) {
RetType: types.NewFieldType(mysql.TypeLonglong),
}
ctx := mock.NewContext()
fun := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, true).GetAggFunc(ctx)
desc, err := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, true)
if err != nil {
b.Fatal(err)
}
fun := desc.GetAggFunc(ctx)
evalCtx := fun.CreateContext(ctx.GetSessionVars().StmtCtx)
b.StartTimer()
for i := 0; i < b.N; i++ {
Expand Down
12 changes: 7 additions & 5 deletions expression/aggregation/descriptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"strings"

"github.com/cznic/mathutil"
"github.com/pingcap/errors"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/charset"
"github.com/pingcap/parser/model"
Expand All @@ -46,14 +47,14 @@ type AggFuncDesc struct {
}

// NewAggFuncDesc creates an aggregation function signature descriptor.
func NewAggFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression, hasDistinct bool) *AggFuncDesc {
func NewAggFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression, hasDistinct bool) (*AggFuncDesc, error) {
a := &AggFuncDesc{
Name: strings.ToLower(name),
Args: args,
HasDistinct: hasDistinct,
}
a.typeInfer(ctx)
return a
err := a.typeInfer(ctx)
return a, err
}

// Equal checks whether two aggregation function signatures are equal.
Expand Down Expand Up @@ -143,7 +144,7 @@ func (a *AggFuncDesc) String() string {
}

// typeInfer infers the arguments and return types of an aggregation function.
func (a *AggFuncDesc) typeInfer(ctx sessionctx.Context) {
func (a *AggFuncDesc) typeInfer(ctx sessionctx.Context) error {
switch a.Name {
case ast.AggFuncCount:
a.typeInfer4Count(ctx)
Expand All @@ -158,8 +159,9 @@ func (a *AggFuncDesc) typeInfer(ctx sessionctx.Context) {
case ast.AggFuncBitAnd, ast.AggFuncBitOr, ast.AggFuncBitXor:
a.typeInfer4BitFuncs(ctx)
default:
panic("unsupported agg function: " + a.Name)
return errors.Errorf("unsupported agg function: %s", a.Name)
}
return nil
}

// EvalNullValueInOuterJoin gets the null value when the aggregation is upon an outer join,
Expand Down
Loading

0 comments on commit 0db535a

Please sign in to comment.