From 47814ecd331a9eb6e0632e58e63afd7e7e058dee Mon Sep 17 00:00:00 2001 From: stdpain <34912776+stdpain@users.noreply.github.com> Date: Wed, 23 Oct 2024 19:46:50 +0800 Subject: [PATCH 1/2] [Enhancement] support insert local exchange in parttern exchange->join (#52021) Signed-off-by: stdpain (cherry picked from commit 9c1647404aecc95767add8c350543be5268b0dfe) --- be/src/exec/cross_join_node.cpp | 9 +++++++++ be/src/exec/cross_join_node.h | 1 + .../java/com/starrocks/planner/NestLoopJoinNode.java | 5 +++++ .../operator/physical/PhysicalJoinOperator.java | 10 +++++----- .../sql/optimizer/rule/tree/JoinLocalShuffleRule.java | 10 ++++++++-- .../com/starrocks/sql/plan/PlanFragmentBuilder.java | 2 +- .../com/starrocks/sql/plan/JoinLocalShuffleTest.java | 10 ++++++++++ gensrc/thrift/PlanNodes.thrift | 1 + 8 files changed, 40 insertions(+), 8 deletions(-) diff --git a/be/src/exec/cross_join_node.cpp b/be/src/exec/cross_join_node.cpp index 563cab46eafc2..b348016fc2ac4 100644 --- a/be/src/exec/cross_join_node.cpp +++ b/be/src/exec/cross_join_node.cpp @@ -78,6 +78,10 @@ Status CrossJoinNode::init(const TPlanNode& tnode, RuntimeState* state) { RETURN_IF_ERROR( Expr::create_expr_trees(_pool, tnode.nestloop_join_node.join_conjuncts, &_join_conjuncts, state)); } + + if (tnode.nestloop_join_node.__isset.interpolate_passthrough) { + _interpolate_passthrough = tnode.nestloop_join_node.interpolate_passthrough; + } if (tnode.nestloop_join_node.__isset.sql_join_conjuncts) { _sql_join_conjuncts = tnode.nestloop_join_node.sql_join_conjuncts; } @@ -697,6 +701,11 @@ std::vector> CrossJoinNode::_decompos left_ops.emplace_back(std::make_shared(context->next_operator_id(), id(), limit())); } + if (_interpolate_passthrough && !context->is_colocate_group()) { + left_ops = context->maybe_interpolate_local_passthrough_exchange(runtime_state(), id(), left_ops, + context->degree_of_parallelism(), true); + } + if constexpr (std::is_same_v) { may_add_chunk_accumulate_operator(left_ops, context, id()); } diff --git a/be/src/exec/cross_join_node.h b/be/src/exec/cross_join_node.h index 48656f01a1913..cc21e68291a51 100644 --- a/be/src/exec/cross_join_node.h +++ b/be/src/exec/cross_join_node.h @@ -130,6 +130,7 @@ class CrossJoinNode final : public ExecNode { std::vector _buf_selective; std::vector _build_runtime_filters; + bool _interpolate_passthrough = false; }; } // namespace starrocks diff --git a/fe/fe-core/src/main/java/com/starrocks/planner/NestLoopJoinNode.java b/fe/fe-core/src/main/java/com/starrocks/planner/NestLoopJoinNode.java index 09049e8aa4a7d..3e608ee56608b 100644 --- a/fe/fe-core/src/main/java/com/starrocks/planner/NestLoopJoinNode.java +++ b/fe/fe-core/src/main/java/com/starrocks/planner/NestLoopJoinNode.java @@ -123,6 +123,11 @@ protected void toThrift(TPlanNode msg) { String sqlJoinPredicate = otherJoinConjuncts.stream().map(Expr::toSql).collect(Collectors.joining(",")); msg.nestloop_join_node.setSql_join_conjuncts(sqlJoinPredicate); } + SessionVariable sv = ConnectContext.get().getSessionVariable(); + if (getCanLocalShuffle()) { + msg.nestloop_join_node.setInterpolate_passthrough(sv.isHashJoinInterpolatePassthrough()); + } + if (!buildRuntimeFilters.isEmpty()) { msg.nestloop_join_node.setBuild_runtime_filters( diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/physical/PhysicalJoinOperator.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/physical/PhysicalJoinOperator.java index 433b32762239e..6a082db2cbdf8 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/physical/PhysicalJoinOperator.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/physical/PhysicalJoinOperator.java @@ -33,7 +33,7 @@ public abstract class PhysicalJoinOperator extends PhysicalOperator { protected final JoinOperator joinType; protected final ScalarOperator onPredicate; protected final String joinHint; - protected boolean canLocalShuffle; + protected boolean outputRequireHashPartition = true; protected PhysicalJoinOperator(OperatorType operatorType, JoinOperator joinType, ScalarOperator onPredicate, @@ -131,11 +131,11 @@ public void fillDisableDictOptimizeColumns(ColumnRefSet columnRefSet) { } } - public void setCanLocalShuffle(boolean v) { - canLocalShuffle = v; + public void setOutputRequireHashPartition(boolean v) { + outputRequireHashPartition = v; } - public boolean getCanLocalShuffle() { - return canLocalShuffle; + public boolean getOutputRequireHashPartition() { + return outputRequireHashPartition; } } diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/JoinLocalShuffleRule.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/JoinLocalShuffleRule.java index 7f110114e3ae7..04b0c2d639c5d 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/JoinLocalShuffleRule.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/JoinLocalShuffleRule.java @@ -42,16 +42,22 @@ public Void visit(OptExpression opt, TaskContext context) { @Override public Void visitPhysicalDistribution(OptExpression opt, TaskContext context) { Operator op = opt.getInputs().get(0).getOp(); - // exchange + local agg + join, then this join can use local shuffle. + // 1. exchange + local agg + join, then this join can use local shuffle. if ((op instanceof PhysicalHashAggregateOperator) && ((PhysicalHashAggregateOperator) op).getType().isLocal()) { Operator childOp = opt.getInputs().get(0).getInputs().get(0).getOp(); if (childOp instanceof PhysicalJoinOperator) { PhysicalJoinOperator joinOperator = (PhysicalJoinOperator) childOp; - joinOperator.setCanLocalShuffle(true); + joinOperator.setOutputRequireHashPartition(false); } } + // 2. exchange + join + if (op instanceof PhysicalJoinOperator) { + PhysicalJoinOperator joinOperator = (PhysicalJoinOperator) op; + joinOperator.setOutputRequireHashPartition(false); + } + for (OptExpression input : opt.getInputs()) { input.getOp().accept(this, input, context); } diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/plan/PlanFragmentBuilder.java b/fe/fe-core/src/main/java/com/starrocks/sql/plan/PlanFragmentBuilder.java index b520eaaac9182..cb02c75d6d0b3 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/plan/PlanFragmentBuilder.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/plan/PlanFragmentBuilder.java @@ -2463,7 +2463,7 @@ private PlanFragment visitPhysicalJoin(PlanFragment leftFragment, PlanFragment r throw new StarRocksPlannerException("unknown join operator: " + node, INTERNAL_ERROR); } - if (node.getCanLocalShuffle()) { + if (!node.getOutputRequireHashPartition()) { joinNode.setCanLocalShuffle(true); } diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/JoinLocalShuffleTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/JoinLocalShuffleTest.java index 51376a10c39bb..b1f7aceac7a15 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/plan/JoinLocalShuffleTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/JoinLocalShuffleTest.java @@ -18,6 +18,7 @@ import com.starrocks.qe.SessionVariable; import org.junit.AfterClass; import org.junit.BeforeClass; +import org.junit.Test; public class JoinLocalShuffleTest extends PlanTestBase { @@ -49,4 +50,13 @@ public void joinWithAgg() throws Exception { } sv.setNewPlanerAggStage(0); } + + @Test + public void joinUnderExchange() throws Exception { + SessionVariable sv = connectContext.getSessionVariable(); + sv.setInterpolatePassthrough(true); + String sql = "select l.* from t0 l join [shuffle] t1 on upper(v1) = v5 join [shuffle] t2 on lower(v1) = v9"; + String plan = getVerboseExplain(sql); + assertContains(plan, "can local shuffle: true"); + } } diff --git a/gensrc/thrift/PlanNodes.thrift b/gensrc/thrift/PlanNodes.thrift index 6ae628b71185e..72f028e04af4d 100644 --- a/gensrc/thrift/PlanNodes.thrift +++ b/gensrc/thrift/PlanNodes.thrift @@ -682,6 +682,7 @@ struct TNestLoopJoinNode { 2: optional list build_runtime_filters; 3: optional list join_conjuncts 4: optional string sql_join_conjuncts + 5: optional bool interpolate_passthrough = false } enum TAggregationOp { From e0649024182d2006a63e104f7bc70c8cfc7c42c0 Mon Sep 17 00:00:00 2001 From: stdpain <34912776+stdpain@users.noreply.github.com> Date: Thu, 24 Oct 2024 10:20:00 +0800 Subject: [PATCH 2/2] Update cross_join_node.cpp Signed-off-by: stdpain <34912776+stdpain@users.noreply.github.com> --- be/src/exec/cross_join_node.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/be/src/exec/cross_join_node.cpp b/be/src/exec/cross_join_node.cpp index b348016fc2ac4..6773ab69c5e8e 100644 --- a/be/src/exec/cross_join_node.cpp +++ b/be/src/exec/cross_join_node.cpp @@ -701,7 +701,7 @@ std::vector> CrossJoinNode::_decompos left_ops.emplace_back(std::make_shared(context->next_operator_id(), id(), limit())); } - if (_interpolate_passthrough && !context->is_colocate_group()) { + if (_interpolate_passthrough) { left_ops = context->maybe_interpolate_local_passthrough_exchange(runtime_state(), id(), left_ops, context->degree_of_parallelism(), true); }