diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index 3274fdf888e26..e9577af147e7f 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -4657,6 +4657,52 @@ func TestMppJoinDecimal(t *testing.T) { } } +func TestMppJoinExchangeColumnPrune(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("drop table if exists tt") + tk.MustExec("create table t (c1 int, c2 int, c3 int NOT NULL, c4 int NOT NULL, c5 int)") + tk.MustExec("create table tt (b1 int)") + tk.MustExec("analyze table t") + tk.MustExec("analyze table tt") + + // Create virtual tiflash replica info. + dom := domain.GetDomain(tk.Session()) + is := dom.InfoSchema() + db, exists := is.SchemaByName(model.NewCIStr("test")) + require.True(t, exists) + for _, tblInfo := range db.Tables { + if tblInfo.Name.L == "t" || tblInfo.Name.L == "tt" { + tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{ + Count: 1, + Available: true, + } + } + } + + tk.MustExec("set @@tidb_allow_mpp=1;") + tk.MustExec("set @@session.tidb_broadcast_join_threshold_size = 1") + tk.MustExec("set @@session.tidb_broadcast_join_threshold_count = 1") + + var input []string + var output []struct { + SQL string + Plan []string + } + integrationSuiteData := core.GetIntegrationSuiteData() + integrationSuiteData.LoadTestCases(t, &input, &output) + for i, tt := range input { + testdata.OnRecord(func() { + output[i].SQL = tt + output[i].Plan = testdata.ConvertRowsToStrings(tk.MustQuery(tt).Rows()) + }) + res := tk.MustQuery(tt) + res.Check(testkit.Rows(output[i].Plan...)) + } +} + func TestMppAggTopNWithJoin(t *testing.T) { store := testkit.CreateMockStore(t) tk := testkit.NewTestKit(t, store) diff --git a/planner/core/optimizer.go b/planner/core/optimizer.go index 4623af7261104..fe3219fcdaf21 100644 --- a/planner/core/optimizer.go +++ b/planner/core/optimizer.go @@ -295,7 +295,10 @@ func DoOptimize(ctx context.Context, sctx sessionctx.Context, flag uint64, logic if err != nil { return nil, 0, err } - finalPlan := postOptimize(sctx, physical) + finalPlan, err := postOptimize(sctx, physical) + if err != nil { + return nil, 0, err + } if sctx.GetSessionVars().StmtCtx.EnableOptimizerCETrace { refineCETrace(sctx) @@ -372,9 +375,13 @@ func mergeContinuousSelections(p PhysicalPlan) { } } -func postOptimize(sctx sessionctx.Context, plan PhysicalPlan) PhysicalPlan { +func postOptimize(sctx sessionctx.Context, plan PhysicalPlan) (PhysicalPlan, error) { // some cases from update optimize will require avoiding projection elimination. // see comments ahead of call of DoOptimize in function of buildUpdate(). + err := prunePhysicalColumns(sctx, plan) + if err != nil { + return nil, err + } plan = eliminatePhysicalProjection(plan) plan = InjectExtraProjection(plan) mergeContinuousSelections(plan) @@ -383,7 +390,145 @@ func postOptimize(sctx sessionctx.Context, plan PhysicalPlan) PhysicalPlan { handleFineGrainedShuffle(sctx, plan) checkPlanCacheable(sctx, plan) propagateProbeParents(plan, nil) - return plan + return plan, nil +} + +// prunePhysicalColumns currently only work for MPP(HashJoin<-Exchange). +// Here add projection instead of pruning columns directly for safety considerations. +// And projection is cheap here for it saves the network cost and work in memory. +func prunePhysicalColumns(sctx sessionctx.Context, plan PhysicalPlan) error { + if tableReader, ok := plan.(*PhysicalTableReader); ok { + if _, isExchangeSender := tableReader.tablePlan.(*PhysicalExchangeSender); isExchangeSender { + err := prunePhysicalColumnsInternal(sctx, tableReader.tablePlan) + if err != nil { + return err + } + } + } else { + for _, child := range plan.Children() { + return prunePhysicalColumns(sctx, child) + } + } + return nil +} + +func (p *PhysicalHashJoin) extractUsedCols(parentUsedCols []*expression.Column) (leftCols []*expression.Column, rightCols []*expression.Column) { + for _, eqCond := range p.EqualConditions { + parentUsedCols = append(parentUsedCols, expression.ExtractColumns(eqCond)...) + } + for _, neCond := range p.NAEqualConditions { + parentUsedCols = append(parentUsedCols, expression.ExtractColumns(neCond)...) + } + for _, leftCond := range p.LeftConditions { + parentUsedCols = append(parentUsedCols, expression.ExtractColumns(leftCond)...) + } + for _, rightCond := range p.RightConditions { + parentUsedCols = append(parentUsedCols, expression.ExtractColumns(rightCond)...) + } + for _, otherCond := range p.OtherConditions { + parentUsedCols = append(parentUsedCols, expression.ExtractColumns(otherCond)...) + } + lChild := p.children[0] + rChild := p.children[1] + for _, col := range parentUsedCols { + if lChild.Schema().Contains(col) { + leftCols = append(leftCols, col) + } else if rChild.Schema().Contains(col) { + rightCols = append(rightCols, col) + } + } + return leftCols, rightCols +} + +func prunePhysicalColumnForHashJoinChild(sctx sessionctx.Context, hashJoin *PhysicalHashJoin, joinUsedCols []*expression.Column, sender *PhysicalExchangeSender) error { + var err error + joinUsed := expression.GetUsedList(joinUsedCols, sender.Schema()) + hashCols := make([]*expression.Column, len(sender.HashCols)) + for i, mppCol := range sender.HashCols { + hashCols[i] = mppCol.Col + } + hashUsed := expression.GetUsedList(hashCols, sender.Schema()) + + needPrune := false + usedExprs := make([]expression.Expression, len(sender.Schema().Columns)) + prunedSchema := sender.Schema().Clone() + for i := len(joinUsed) - 1; i >= 0; i-- { + usedExprs[i] = sender.Schema().Columns[i] + if !joinUsed[i] && !hashUsed[i] { + needPrune = true + usedExprs = append(usedExprs[:i], usedExprs[i+1:]...) + prunedSchema.Columns = append(prunedSchema.Columns[:i], prunedSchema.Columns[i+1:]...) + } + } + + if needPrune && len(sender.children) > 0 { + ch := sender.children[0] + proj := PhysicalProjection{ + Exprs: usedExprs, + }.Init(sctx, ch.statsInfo(), ch.SelectBlockOffset()) + + proj.SetSchema(prunedSchema) + proj.SetChildren(ch) + sender.children[0] = proj + + // Resolve Indices from bottom to up + err = proj.ResolveIndicesItself() + if err != nil { + return err + } + err = sender.ResolveIndicesItself() + if err != nil { + return err + } + err = hashJoin.ResolveIndicesItself() + if err != nil { + return err + } + } + return err +} + +func prunePhysicalColumnsInternal(sctx sessionctx.Context, plan PhysicalPlan) error { + var err error + switch x := plan.(type) { + case *PhysicalHashJoin: + schemaColumns := x.Schema().Clone().Columns + leftCols, rightCols := x.extractUsedCols(schemaColumns) + matchPattern := false + for i := 0; i <= 1; i++ { + // Pattern: HashJoin <- ExchangeReceiver <- ExchangeSender + matchPattern = false + var exchangeSender *PhysicalExchangeSender + if receiver, ok := x.children[i].(*PhysicalExchangeReceiver); ok { + exchangeSender, matchPattern = receiver.children[0].(*PhysicalExchangeSender) + } + + if matchPattern { + if i == 0 { + err = prunePhysicalColumnForHashJoinChild(sctx, x, leftCols, exchangeSender) + } else { + err = prunePhysicalColumnForHashJoinChild(sctx, x, rightCols, exchangeSender) + } + if err != nil { + return nil + } + } + + /// recursively travel the physical plan + err = prunePhysicalColumnsInternal(sctx, x.children[i]) + if err != nil { + return nil + } + } + default: + for _, child := range x.Children() { + err = prunePhysicalColumnsInternal(sctx, child) + if err != nil { + return err + } + } + } + return nil } // Only for MPP(Window<-[Sort]<-ExchangeReceiver<-ExchangeSender). diff --git a/planner/core/optimizer_test.go b/planner/core/optimizer_test.go index 6cf80c57fa5ec..0f9dc0a4050c4 100644 --- a/planner/core/optimizer_test.go +++ b/planner/core/optimizer_test.go @@ -18,6 +18,8 @@ import ( "reflect" "testing" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/planner/property" "github.com/pingcap/tidb/types" @@ -288,3 +290,104 @@ func TestHandleFineGrainedShuffle(t *testing.T) { hashSender1.children = []PhysicalPlan{tableScan1} start(partWindow, expStreamCount, 3, 0) } + +// Test for core.prunePhysicalColumns() +func TestPrunePhysicalColumns(t *testing.T) { + sctx := MockContext() + col0 := &expression.Column{ + UniqueID: sctx.GetSessionVars().AllocPlanColumnID(), + RetType: types.NewFieldType(mysql.TypeLonglong), + } + col1 := &expression.Column{ + UniqueID: sctx.GetSessionVars().AllocPlanColumnID(), + RetType: types.NewFieldType(mysql.TypeLonglong), + } + col2 := &expression.Column{ + UniqueID: sctx.GetSessionVars().AllocPlanColumnID(), + RetType: types.NewFieldType(mysql.TypeLonglong), + } + col3 := &expression.Column{ + UniqueID: sctx.GetSessionVars().AllocPlanColumnID(), + RetType: types.NewFieldType(mysql.TypeLonglong), + } + + // Join[col2, col3; col2==col3] <- ExchangeReceiver[col0, col1, col2] <- ExchangeSender[col0, col1, col2] <- Selection[col0, col1, col2; col0 < col1] <- TableScan[col0, col1, col2] + // <- ExchangeReceiver1[col3] <- ExchangeSender1[col3] <- TableScan1[col3] + tableReader := &PhysicalTableReader{} + passSender := &PhysicalExchangeSender{ + ExchangeType: tipb.ExchangeType_PassThrough, + } + hashJoin := &PhysicalHashJoin{} + recv := &PhysicalExchangeReceiver{} + recv1 := &PhysicalExchangeReceiver{} + hashSender := &PhysicalExchangeSender{ + ExchangeType: tipb.ExchangeType_Hash, + } + hashSender1 := &PhysicalExchangeSender{ + ExchangeType: tipb.ExchangeType_Hash, + } + tableScan := &PhysicalTableScan{} + tableScan1 := &PhysicalTableScan{} + + tableReader.tablePlan = passSender + passSender.children = []PhysicalPlan{hashJoin} + hashJoin.children = []PhysicalPlan{recv, recv1} + selection := &PhysicalSelection{} + + cond, err := expression.NewFunction(sctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), col2, col3) + require.True(t, err == nil) + sf, isSF := cond.(*expression.ScalarFunction) + require.True(t, isSF) + hashJoin.EqualConditions = append(hashJoin.EqualConditions, sf) + hashJoin.LeftJoinKeys = append(hashJoin.LeftJoinKeys, col2) + hashJoin.RightJoinKeys = append(hashJoin.RightJoinKeys, col3) + hashJoinSchema := make([]*expression.Column, 0) + hashJoinSchema = append(hashJoinSchema, col3) + hashJoin.SetSchema(expression.NewSchema(hashJoinSchema...)) + + selection.SetChildren(tableScan) + hashSender.SetChildren(selection) + var partitionCols = make([]*property.MPPPartitionColumn, 0, 1) + partitionCols = append(partitionCols, &property.MPPPartitionColumn{ + Col: col2, + CollateID: property.GetCollateIDByNameForPartition(col2.GetType().GetCollate()), + }) + hashSender.HashCols = partitionCols + recv.SetChildren(hashSender) + tableScan.Schema().Columns = append(tableScan.Schema().Columns, col0, col1, col2) + + hashSender1.SetChildren(tableScan1) + recv1.SetChildren(hashSender1) + tableScan1.Schema().Columns = append(tableScan1.Schema().Columns, col3) + + prunePhysicalColumns(sctx, tableReader) + + // Optimized Plan: + // Join[col2, col3; col2==col3] <- ExchangeReceiver[col2] <- ExchangeSender[col2;col2] <- Projection[col2] <- Selection[col0, col1, col2; col0 < col1] <- TableScan[col0, col1, col2] + // <- ExchangeReceiver1[col3] <- ExchangeSender1[col3] <- TableScan1[col3] + require.True(t, len(recv.Schema().Columns) == 1) + require.True(t, recv.Schema().Contains(col2)) + require.False(t, recv.Schema().Contains(col0)) + require.False(t, recv.Schema().Contains(col1)) + require.True(t, len(recv.children[0].Children()) == 1) + physicalProj := recv.children[0].Children()[0] + switch x := physicalProj.(type) { + case *PhysicalProjection: + require.True(t, x.Schema().Contains(col2)) + require.False(t, recv.Schema().Contains(col0)) + require.False(t, recv.Schema().Contains(col1)) + // Check PhysicalProj resolved index + require.True(t, len(x.Exprs) == 1) + require.True(t, x.Exprs[0].(*expression.Column).Index == 2) + default: + require.True(t, false) + } + + // Check resolved indices + require.True(t, hashJoin.LeftJoinKeys[0].Index == 0) + require.True(t, hashSender.HashCols[0].Col.Index == 0) + + // Check recv1,no changes + require.True(t, len(recv1.Schema().Columns) == 1) + require.True(t, recv1.Schema().Contains(col3)) +} diff --git a/planner/core/resolve_indices.go b/planner/core/resolve_indices.go index 483d0b9f92299..b602c46a78bd2 100644 --- a/planner/core/resolve_indices.go +++ b/planner/core/resolve_indices.go @@ -21,12 +21,8 @@ import ( "github.com/pingcap/tidb/util/disjointset" ) -// ResolveIndices implements Plan interface. -func (p *PhysicalProjection) ResolveIndices() (err error) { - err = p.physicalSchemaProducer.ResolveIndices() - if err != nil { - return err - } +// ResolveIndicesItself resolve indices for PhysicalPlan itself +func (p *PhysicalProjection) ResolveIndicesItself() (err error) { for i, expr := range p.Exprs { p.Exprs[i], err = expr.ResolveIndices(p.children[0].Schema()) if err != nil { @@ -41,6 +37,15 @@ func (p *PhysicalProjection) ResolveIndices() (err error) { return } +// ResolveIndices implements Plan interface. +func (p *PhysicalProjection) ResolveIndices() (err error) { + err = p.physicalSchemaProducer.ResolveIndices() + if err != nil { + return err + } + return p.ResolveIndicesItself() +} + // refine4NeighbourProj refines the index for p.Exprs whose type is *Column when // there is two neighbouring Projections. // This function is introduced because that different childProj.Expr may refer @@ -74,12 +79,8 @@ func refine4NeighbourProj(p, childProj *PhysicalProjection) { } } -// ResolveIndices implements Plan interface. -func (p *PhysicalHashJoin) ResolveIndices() (err error) { - err = p.physicalSchemaProducer.ResolveIndices() - if err != nil { - return err - } +// ResolveIndicesItself resolve indices for PhyicalPlan itself +func (p *PhysicalHashJoin) ResolveIndicesItself() (err error) { lSchema := p.children[0].Schema() rSchema := p.children[1].Schema() for i, fun := range p.EqualConditions { @@ -129,6 +130,15 @@ func (p *PhysicalHashJoin) ResolveIndices() (err error) { return } +// ResolveIndices implements Plan interface. +func (p *PhysicalHashJoin) ResolveIndices() (err error) { + err = p.physicalSchemaProducer.ResolveIndices() + if err != nil { + return err + } + return p.ResolveIndicesItself() +} + // ResolveIndices implements Plan interface. func (p *PhysicalMergeJoin) ResolveIndices() (err error) { err = p.physicalSchemaProducer.ResolveIndices() @@ -380,12 +390,8 @@ func (p *PhysicalSelection) ResolveIndices() (err error) { return nil } -// ResolveIndices implements Plan interface. -func (p *PhysicalExchangeSender) ResolveIndices() (err error) { - err = p.basePhysicalPlan.ResolveIndices() - if err != nil { - return err - } +// ResolveIndicesItself resolve indices for PhyicalPlan itself +func (p *PhysicalExchangeSender) ResolveIndicesItself() (err error) { for i, col := range p.HashCols { colExpr, err1 := col.Col.ResolveIndices(p.children[0].Schema()) if err1 != nil { @@ -393,7 +399,16 @@ func (p *PhysicalExchangeSender) ResolveIndices() (err error) { } p.HashCols[i].Col, _ = colExpr.(*expression.Column) } - return err + return +} + +// ResolveIndices implements Plan interface. +func (p *PhysicalExchangeSender) ResolveIndices() (err error) { + err = p.basePhysicalPlan.ResolveIndices() + if err != nil { + return err + } + return p.ResolveIndicesItself() } // ResolveIndices implements Plan interface. diff --git a/planner/core/testdata/integration_suite_in.json b/planner/core/testdata/integration_suite_in.json index 3a4e97c4cba91..6855084993514 100644 --- a/planner/core/testdata/integration_suite_in.json +++ b/planner/core/testdata/integration_suite_in.json @@ -811,6 +811,12 @@ "desc format = 'brief' SELECT STRAIGHT_JOIN t1 . col_varchar_64 , t1 . col_char_64_not_null FROM tt AS t1 INNER JOIN( tt AS t2 JOIN tt AS t3 ON(t3 . col_decimal_30_10_key = t2 . col_tinyint)) ON(t3 . col_varchar_64 = t2 . col_varchar_key) WHERE t3 . col_varchar_64 = t1 . col_char_64_not_null GROUP BY 1 , 2" ] }, + { + "name": "TestMppJoinExchangeColumnPrune", + "cases": [ + "desc format = 'brief' select * from tt t1 where exists (select * from t t2 where t1.b1 = t2.c3 and t2.c1 < t2.c2)" + ] + }, { "name": "TestPushDownAggForMPP", "cases": [ diff --git a/planner/core/testdata/integration_suite_out.json b/planner/core/testdata/integration_suite_out.json index f127fec9a974c..6764637cb74b2 100644 --- a/planner/core/testdata/integration_suite_out.json +++ b/planner/core/testdata/integration_suite_out.json @@ -5574,6 +5574,28 @@ } ] }, + { + "Name": "TestMppJoinExchangeColumnPrune", + "Cases": [ + { + "SQL": "desc format = 'brief' select * from tt t1 where exists (select * from t t2 where t1.b1 = t2.c3 and t2.c1 < t2.c2)", + "Plan": [ + "TableReader 7992.00 root data:ExchangeSender", + "└─ExchangeSender 7992.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashJoin 7992.00 mpp[tiflash] semi join, equal:[eq(test.tt.b1, test.t.c3)]", + " ├─ExchangeReceiver(Build) 8000.00 mpp[tiflash] ", + " │ └─ExchangeSender 8000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.c3, collate: binary]", + " │ └─Projection 8000.00 mpp[tiflash] test.t.c3", + " │ └─Selection 8000.00 mpp[tiflash] lt(test.t.c1, test.t.c2)", + " │ └─TableFullScan 10000.00 mpp[tiflash] table:t2 keep order:false, stats:pseudo", + " └─ExchangeReceiver(Probe) 9990.00 mpp[tiflash] ", + " └─ExchangeSender 9990.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.tt.b1, collate: binary]", + " └─Selection 9990.00 mpp[tiflash] not(isnull(test.tt.b1))", + " └─TableFullScan 10000.00 mpp[tiflash] table:t1 keep order:false, stats:pseudo" + ] + } + ] + }, { "Name": "TestPushDownAggForMPP", "Cases": [