Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner: support 3 stage aggregation for single scalar distinct agg #37203

Merged
merged 21 commits into from
Sep 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions executor/set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,14 @@ func TestSetVar(t *testing.T) {
tk.MustExec("set global tidb_opt_skew_distinct_agg=1")
tk.MustQuery("select @@global.tidb_opt_skew_distinct_agg").Check(testkit.Rows("1"))

// test for tidb_opt_three_stage_distinct_agg
tk.MustQuery("select @@session.tidb_opt_three_stage_distinct_agg").Check(testkit.Rows("1")) // default value is 1
tk.MustExec("set session tidb_opt_three_stage_distinct_agg=0")
tk.MustQuery("select @@session.tidb_opt_three_stage_distinct_agg").Check(testkit.Rows("0"))
tk.MustQuery("select @@global.tidb_opt_three_stage_distinct_agg").Check(testkit.Rows("1")) // default value is 1
tk.MustExec("set global tidb_opt_three_stage_distinct_agg=0")
tk.MustQuery("select @@global.tidb_opt_three_stage_distinct_agg").Check(testkit.Rows("0"))

// the value of max_allowed_packet should be a multiple of 1024
tk.MustExec("set @@global.max_allowed_packet=16385")
tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|1292|Truncated incorrect max_allowed_packet value: '16385'"))
Expand Down
21 changes: 21 additions & 0 deletions expression/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,18 @@ func (col *CorrelatedColumn) MemoryUsage() (sum int64) {
return sum
}

// RemapColumn remaps columns with provided mapping and returns new expression
func (col *CorrelatedColumn) RemapColumn(m map[int64]*Column) (Expression, error) {
mapped := m[(&col.Column).UniqueID]
if mapped == nil {
return nil, errors.Errorf("Can't remap column for %s", col)
}
return &CorrelatedColumn{
Column: *mapped,
Data: col.Data,
}, nil
}

// Column represents a column.
type Column struct {
RetType *types.FieldType
Expand Down Expand Up @@ -537,6 +549,15 @@ func (col *Column) resolveIndicesByVirtualExpr(schema *Schema) bool {
return false
}

// RemapColumn remaps columns with provided mapping and returns new expression
func (col *Column) RemapColumn(m map[int64]*Column) (Expression, error) {
mapped := m[col.UniqueID]
if mapped == nil {
return nil, errors.Errorf("Can't remap column for %s", col)
}
return mapped, nil
}

// Vectorized returns if this expression supports vectorized evaluation.
func (col *Column) Vectorized() bool {
return true
Expand Down
5 changes: 5 additions & 0 deletions expression/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,11 @@ func (c *Constant) resolveIndicesByVirtualExpr(_ *Schema) bool {
return true
}

// RemapColumn remaps columns with provided mapping and returns new expression
func (c *Constant) RemapColumn(_ map[int64]*Column) (Expression, error) {
return c, nil
}

// Vectorized returns if this expression supports vectorized evaluation.
func (c *Constant) Vectorized() bool {
if c.DeferredExpr != nil {
Expand Down
3 changes: 3 additions & 0 deletions expression/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ type Expression interface {
// resolveIndicesByVirtualExpr is called inside the `ResolveIndicesByVirtualExpr` It will perform on the expression itself.
resolveIndicesByVirtualExpr(schema *Schema) bool

// RemapColumn remaps columns with provided mapping and returns new expression
RemapColumn(map[int64]*Column) (Expression, error)

// ExplainInfo returns operator information to be explained.
ExplainInfo() string

Expand Down
18 changes: 18 additions & 0 deletions expression/scalar_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,24 @@ func (sf *ScalarFunction) resolveIndicesByVirtualExpr(schema *Schema) bool {
return true
}

// RemapColumn remaps columns with provided mapping and returns new expression
func (sf *ScalarFunction) RemapColumn(m map[int64]*Column) (Expression, error) {
newSf, ok := sf.Clone().(*ScalarFunction)
if !ok {
return nil, errors.New("failed to cast to scalar function")
}
for i, arg := range sf.GetArgs() {
newArg, err := arg.RemapColumn(m)
if err != nil {
return nil, err
}
newSf.GetArgs()[i] = newArg
}
// clear hash code
newSf.hashcode = nil
return newSf, nil
}

// GetSingleColumn returns (Col, Desc) when the ScalarFunction is equivalent to (Col, Desc)
// when used as a sort key, otherwise returns (nil, false).
//
Expand Down
1 change: 1 addition & 0 deletions expression/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,7 @@ func (m *MockExpr) ResolveIndices(schema *Schema) (Expression, error)
func (m *MockExpr) resolveIndices(schema *Schema) error { return nil }
func (m *MockExpr) ResolveIndicesByVirtualExpr(schema *Schema) (Expression, bool) { return m, true }
func (m *MockExpr) resolveIndicesByVirtualExpr(schema *Schema) bool { return true }
func (m *MockExpr) RemapColumn(_ map[int64]*Column) (Expression, error) { return m, nil }
func (m *MockExpr) ExplainInfo() string { return "" }
func (m *MockExpr) ExplainNormalizedInfo() string { return "" }
func (m *MockExpr) HashCode(sc *stmtctx.StatementContext) []byte { return nil }
Expand Down
51 changes: 51 additions & 0 deletions planner/core/enforce_mpp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -484,3 +484,54 @@ func TestMPPSkewedGroupDistinctRewrite(t *testing.T) {
require.Equal(t, output[i].Warn, testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings()))
}
}

// Test 3 stage aggregation for single count distinct
func TestMPPSingleDistinct3Stage(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)

// test table
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int, b bigint not null, c bigint, d date, e varchar(20) collate utf8mb4_general_ci)")

// Create virtual tiflash replica info.
dom := domain.GetDomain(tk.Session())
is := dom.InfoSchema()
db, exists := is.SchemaByName(model.NewCIStr("test"))
require.True(t, exists)
for _, tblInfo := range db.Tables {
if tblInfo.Name.L == "t" {
tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{
Count: 1,
Available: true,
}
}
}

var input []string
var output []struct {
SQL string
Plan []string
Warn []string
}
enforceMPPSuiteData := plannercore.GetEnforceMPPSuiteData()
enforceMPPSuiteData.LoadTestCases(t, &input, &output)
for i, tt := range input {
testdata.OnRecord(func() {
output[i].SQL = tt
})
if strings.HasPrefix(tt, "set") || strings.HasPrefix(tt, "UPDATE") {
tk.MustExec(tt)
continue
}
testdata.OnRecord(func() {
output[i].SQL = tt
output[i].Plan = testdata.ConvertRowsToStrings(tk.MustQuery(tt).Rows())
output[i].Warn = testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings())
})
res := tk.MustQuery(tt)
res.Check(testkit.Rows(output[i].Plan...))
require.Equal(t, output[i].Warn, testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings()))
}
}
112 changes: 112 additions & 0 deletions planner/core/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -1436,6 +1436,35 @@ func (p *basePhysicalAgg) newPartialAggregate(copTaskType kv.StoreType, isMPPTas
return partialAgg, finalAgg
}

// canUse3StageDistinctAgg returns true if this agg can use 3 stage for distinct aggregation
func (p *basePhysicalAgg) canUse3StageDistinctAgg() bool {
num := 0
if !p.ctx.GetSessionVars().Enable3StageDistinctAgg || len(p.GroupByItems) > 0 {
return false
}
for _, fun := range p.AggFuncs {
if fun.HasDistinct {
num++
if num > 1 || fun.Name != ast.AggFuncCount {
return false
}
for _, arg := range fun.Args {
// bail out when args are not simple column, see GitHub issue #35417
AilinKid marked this conversation as resolved.
Show resolved Hide resolved
if _, ok := arg.(*expression.Column); !ok {
return false
}
}
} else if len(fun.Args) > 1 {
return false
AilinKid marked this conversation as resolved.
Show resolved Hide resolved
}

if len(fun.OrderByItems) > 0 {
return false
}
}
return num == 1
}

func genFirstRowAggForGroupBy(ctx sessionctx.Context, groupByItems []expression.Expression) ([]*aggregation.AggFuncDesc, error) {
aggFuncs := make([]*aggregation.AggFuncDesc, 0, len(groupByItems))
for _, groupBy := range groupByItems {
Expand Down Expand Up @@ -1642,15 +1671,98 @@ func (p *PhysicalHashAgg) attach2TaskForMpp(tasks ...task) task {
if !mpp.needEnforceExchanger(prop) {
return p.attach2TaskForMpp1Phase(mpp)
}
// we have to check it before the content of p has been modified
canUse3StageAgg := p.canUse3StageDistinctAgg()
proj := p.convertAvgForMPP()
partialAgg, finalAgg := p.newPartialAggregate(kv.TiFlash, true)
if finalAgg == nil {
return invalidTask
}

// generate 3 stage aggregation for single count distinct if applicable.
// select count(distinct a), count(b) from foo
// will generate plan:
// HashAgg sum(#1), sum(#2) -> final agg
// +- Exchange Passthrough
// +- HashAgg count(distinct a) #1, sum(#3) #2 -> middle agg
// +- Exchange HashPartition by a
// +- HashAgg count(b) #3, group by a -> partial agg
// +- TableScan foo
var middleAgg *PhysicalHashAgg = nil
if partialAgg != nil && canUse3StageAgg {
clonedAgg, err := finalAgg.Clone()
if err != nil {
return invalidTask
}
middleAgg = clonedAgg.(*PhysicalHashAgg)
distinctPos := 0
middleSchema := expression.NewSchema()
schemaMap := make(map[int64]*expression.Column, len(middleAgg.AggFuncs))
for i, fun := range middleAgg.AggFuncs {
col := &expression.Column{
UniqueID: p.ctx.GetSessionVars().AllocPlanColumnID(),
RetType: fun.RetTp,
}
if fun.HasDistinct {
distinctPos = i
} else {
fun.Mode = aggregation.Partial2Mode
originalCol := fun.Args[0].(*expression.Column)
schemaMap[originalCol.UniqueID] = col
}
middleSchema.Append(col)
}
middleAgg.schema = middleSchema

finalHashAgg := finalAgg.(*PhysicalHashAgg)
finalAggDescs := make([]*aggregation.AggFuncDesc, 0, len(finalHashAgg.AggFuncs))
for i, fun := range finalHashAgg.AggFuncs {
newArgs := make([]expression.Expression, 0, 1)
if distinctPos == i {
// change count(distinct) to sum()
fun.Name = ast.AggFuncSum
fun.HasDistinct = false
newArgs = append(newArgs, middleSchema.Columns[i])
} else {
for _, arg := range fun.Args {
newCol, err := arg.RemapColumn(schemaMap)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could use ColumnSubstitute?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, great. will do

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turned out ColumnSubstitute is not applicable here, because the 2 schemas are different.

if err != nil {
return invalidTask
}
newArgs = append(newArgs, newCol)
}
}
fun.Args = newArgs
finalAggDescs = append(finalAggDescs, fun)
}
finalHashAgg.AggFuncs = finalAggDescs
}

// partial agg would be null if one scalar agg cannot run in two-phase mode
if partialAgg != nil {
attachPlan2Task(partialAgg, mpp)
}

if middleAgg != nil && canUse3StageAgg {
items := partialAgg.(*PhysicalHashAgg).GroupByItems
partitionCols := make([]*property.MPPPartitionColumn, 0, len(items))
for _, expr := range items {
col, ok := expr.(*expression.Column)
if !ok {
continue
}
partitionCols = append(partitionCols, &property.MPPPartitionColumn{
Col: col,
CollateID: property.GetCollateIDByNameForPartition(col.GetType().GetCollate()),
})
}

prop := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.HashType, MPPPartitionCols: partitionCols}
newMpp := mpp.enforceExchanger(prop)
attachPlan2Task(middleAgg, newMpp)
mpp = newMpp
}

newMpp := mpp.enforceExchanger(prop)
attachPlan2Task(finalAgg, newMpp)
if proj == nil {
Expand Down
19 changes: 19 additions & 0 deletions planner/core/testdata/enforce_mpp_suite_in.json
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,24 @@
"EXPLAIN select a, count(b), avg(distinct c), count(distinct c) from t group by a; -- multi distinct funcs, bail out",
"EXPLAIN select count(b), count(distinct c) from t; -- single distinct func but no group key, bail out"
]
},
{
"name": "TestMPPSingleDistinct3Stage",
"cases": [
"set @@tidb_allow_mpp=1;set @@tidb_enforce_mpp=1;",
"EXPLAIN select count(distinct b) from t;",
"EXPLAIN select count(distinct c) from t;",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the second stage in explain show count(distinct c) instead of group by c?
(json files cannot be commented, so I put the comments here)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See here:

e.g.
explain select count(distinct a), count(b) from foo

  HashAgg sum(#1), sum(#2)                              -> final agg
   +- Exchange Passthrough
       +- HashAgg count(distinct a) #1, sum(#3) #2      -> middle agg
          +- Exchange HashPartition by a
               +- HashAgg count(b) #3, group by a       -> partial agg
                   +- TableScan foo

After the 1st stage agg, we shuffle by the distinct key. So it is ok to compute count(distinct) in 2nd stage.

"EXPLAIN select count(distinct e) from t;",
"EXPLAIN select count(distinct a,b,c,e) from t;",
"EXPLAIN select count(distinct c), count(a), count(*) from t;",
"EXPLAIN select sum(b), count(a), count(*), count(distinct c) from t;",
"EXPLAIN select sum(b+a), count(*), count(distinct c), count(a) from t having count(distinct c) > 2;",
"EXPLAIN select sum(b+a), count(*), count(a) from t having count(distinct c) > 2;",
"EXPLAIN select sum(b+a), max(b), count(distinct c), count(*) from t having count(a) > 2;",
"EXPLAIN select sum(b), count(distinct a, b, e), count(a+b) from t;",
"EXPLAIN select count(distinct b), json_objectagg(d,c) from t;",
"EXPLAIN select count(distinct c+a), count(a) from t;",
"EXPLAIN select sum(b), count(distinct c+a, b, e), count(a+b) from t;"
]
}
]
Loading