From e6025cb844952488e8896bdf796faea1ea1ba4a4 Mon Sep 17 00:00:00 2001 From: HuaiyuXu <391585975@qq.com> Date: Thu, 11 Oct 2018 16:51:06 +0800 Subject: [PATCH] executor: refine the precision for avg (#7860) (#7874) --- executor/aggfuncs/func_avg.go | 25 ++++++++++++++++++++++++- executor/builder.go | 21 +++++++++++++++++++++ expression/aggregation/descriptor.go | 19 ++++++++++++------- expression/typeinfer_test.go | 4 ++-- 4 files changed, 59 insertions(+), 10 deletions(-) diff --git a/executor/aggfuncs/func_avg.go b/executor/aggfuncs/func_avg.go index 139d60845273f..f917c60e1d044 100644 --- a/executor/aggfuncs/func_avg.go +++ b/executor/aggfuncs/func_avg.go @@ -14,6 +14,8 @@ package aggfuncs import ( + "github.com/cznic/mathutil" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" @@ -56,7 +58,19 @@ func (e *baseAvgDecimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Par finalResult := new(types.MyDecimal) err := types.DecimalDiv(&p.sum, decimalCount, finalResult, types.DivFracIncr) if err != nil { - return errors.Trace(err) + return err + } + // Make the decimal be the result of type inferring. + frac := e.args[0].GetType().Decimal + if len(e.args) == 2 { + frac = e.args[1].GetType().Decimal + } + if frac == -1 { + frac = mysql.MaxDecimalScale + } + err = finalResult.Round(finalResult, mathutil.Min(frac, mysql.MaxDecimalScale), types.ModeHalfEven) + if err != nil { + return err } chk.AppendMyDecimal(e.ordinal, finalResult) return nil @@ -195,6 +209,15 @@ func (e *avgOriginal4DistinctDecimal) AppendFinalResult2Chunk(sctx sessionctx.Co if err != nil { return errors.Trace(err) } + // Make the decimal be the result of type inferring. + frac := e.args[0].GetType().Decimal + if frac == -1 { + frac = mysql.MaxDecimalScale + } + err = finalResult.Round(finalResult, mathutil.Min(frac, mysql.MaxDecimalScale), types.ModeHalfEven) + if err != nil { + return err + } chk.AppendMyDecimal(e.ordinal, finalResult) return nil } diff --git a/executor/builder.go b/executor/builder.go index 1f9ef0e0cb96a..b16d39eee04fa 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -887,6 +887,27 @@ func (b *executorBuilder) wrapCastForAggArgs(funcs []*aggregation.AggFuncDesc) { } for i := range f.Args { f.Args[i] = castFunc(b.ctx, f.Args[i]) + if f.Name != ast.AggFuncAvg && f.Name != ast.AggFuncSum { + continue + } + // After wrapping cast on the argument, flen etc. may not the same + // as the type of the aggregation function. The following part set + // the type of the argument exactly as the type of the aggregation + // function. + // Note: If the `Tp` of argument is the same as the `Tp` of the + // aggregation function, it will not wrap cast function on it + // internally. The reason of the special handling for `Column` is + // that the `RetType` of `Column` refers to the `infoschema`, so we + // need to set a new variable for it to avoid modifying the + // definition in `infoschema`. + if col, ok := f.Args[i].(*expression.Column); ok { + col.RetType = types.NewFieldType(col.RetType.Tp) + } + // originTp is used when the the `Tp` of column is TypeFloat32 while + // the type of the aggregation function is TypeFloat64. + originTp := f.Args[i].GetType().Tp + *(f.Args[i].GetType()) = *(f.RetTp) + f.Args[i].GetType().Tp = originTp } } } diff --git a/expression/aggregation/descriptor.go b/expression/aggregation/descriptor.go index 7cbf71b46b5b0..7c47ec85c0d64 100644 --- a/expression/aggregation/descriptor.go +++ b/expression/aggregation/descriptor.go @@ -291,17 +291,21 @@ func (a *AggFuncDesc) typeInfer4Count(ctx sessionctx.Context) { // Because child returns integer or decimal type. func (a *AggFuncDesc) typeInfer4Sum(ctx sessionctx.Context) { switch a.Args[0].GetType().Tp { - case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeNewDecimal: + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: + a.RetTp = types.NewFieldType(mysql.TypeNewDecimal) + a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxDecimalWidth, 0 + case mysql.TypeNewDecimal: a.RetTp = types.NewFieldType(mysql.TypeNewDecimal) a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxDecimalWidth, a.Args[0].GetType().Decimal if a.RetTp.Decimal < 0 || a.RetTp.Decimal > mysql.MaxDecimalScale { a.RetTp.Decimal = mysql.MaxDecimalScale } - // TODO: a.Args[0] = expression.WrapWithCastAsDecimal(ctx, a.Args[0]) - default: + case mysql.TypeDouble, mysql.TypeFloat: a.RetTp = types.NewFieldType(mysql.TypeDouble) a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, a.Args[0].GetType().Decimal - //TODO: a.Args[0] = expression.WrapWithCastAsReal(ctx, a.Args[0]) + default: + a.RetTp = types.NewFieldType(mysql.TypeDouble) + a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, types.UnspecifiedLength } types.SetBinChsClnFlag(a.RetTp) } @@ -318,11 +322,12 @@ func (a *AggFuncDesc) typeInfer4Avg(ctx sessionctx.Context) { a.RetTp.Decimal = mathutil.Min(a.Args[0].GetType().Decimal+types.DivFracIncr, mysql.MaxDecimalScale) } a.RetTp.Flen = mysql.MaxDecimalWidth - // TODO: a.Args[0] = expression.WrapWithCastAsDecimal(ctx, a.Args[0]) - default: + case mysql.TypeDouble, mysql.TypeFloat: a.RetTp = types.NewFieldType(mysql.TypeDouble) a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, a.Args[0].GetType().Decimal - // TODO: a.Args[0] = expression.WrapWithCastAsReal(ctx, a.Args[0]) + default: + a.RetTp = types.NewFieldType(mysql.TypeDouble) + a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, types.UnspecifiedLength } types.SetBinChsClnFlag(a.RetTp) } diff --git a/expression/typeinfer_test.go b/expression/typeinfer_test.go index 142c0c0c9bf41..bb86b140d5caf 100644 --- a/expression/typeinfer_test.go +++ b/expression/typeinfer_test.go @@ -822,14 +822,14 @@ func (s *testInferTypeSuite) createTestCase4Aggregations() []typeInferTestCase { {"sum(c_decimal)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDecimalWidth, 3}, {"sum(1.0)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDecimalWidth, 1}, {"sum(1.2e2)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength}, - {"sum(c_char)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, 0}, + {"sum(c_char)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength}, {"avg(c_int_d)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDecimalWidth, 4}, {"avg(c_float_d)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength}, {"avg(c_double_d)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength}, {"avg(c_decimal)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDecimalWidth, 7}, {"avg(1.0)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDecimalWidth, 5}, {"avg(1.2e2)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength}, - {"avg(c_char)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, 0}, + {"avg(c_char)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength}, {"group_concat(c_int_d)", mysql.TypeVarString, charset.CharsetUTF8, 0, mysql.MaxBlobWidth, 0}, } }