diff --git a/be/src/exec/cross_join_node.cpp b/be/src/exec/cross_join_node.cpp index 6fe706bd0fbd0..55de2cade1f17 100644 --- a/be/src/exec/cross_join_node.cpp +++ b/be/src/exec/cross_join_node.cpp @@ -73,6 +73,10 @@ Status CrossJoinNode::init(const TPlanNode& tnode, RuntimeState* state) { RETURN_IF_ERROR( Expr::create_expr_trees(_pool, tnode.nestloop_join_node.join_conjuncts, &_join_conjuncts, state)); } + + if (tnode.nestloop_join_node.__isset.interpolate_passthrough) { + _interpolate_passthrough = tnode.nestloop_join_node.interpolate_passthrough; + } if (tnode.nestloop_join_node.__isset.sql_join_conjuncts) { _sql_join_conjuncts = tnode.nestloop_join_node.sql_join_conjuncts; } @@ -619,6 +623,11 @@ std::vector> CrossJoinNode::_decompos left_ops.emplace_back(std::make_shared(context->next_operator_id(), id(), limit())); } + if (_interpolate_passthrough && !context->is_colocate_group()) { + left_ops = context->maybe_interpolate_local_passthrough_exchange(runtime_state(), id(), left_ops, + context->degree_of_parallelism(), true); + } + if constexpr (std::is_same_v) { may_add_chunk_accumulate_operator(left_ops, context, id()); } diff --git a/be/src/exec/cross_join_node.h b/be/src/exec/cross_join_node.h index f5c5f33442d9b..01f7e7e6c6dd2 100644 --- a/be/src/exec/cross_join_node.h +++ b/be/src/exec/cross_join_node.h @@ -127,6 +127,7 @@ class CrossJoinNode final : public ExecNode { std::vector _buf_selective; std::vector _build_runtime_filters; + bool _interpolate_passthrough = false; }; } // namespace starrocks diff --git a/fe/fe-core/src/main/java/com/starrocks/planner/NestLoopJoinNode.java b/fe/fe-core/src/main/java/com/starrocks/planner/NestLoopJoinNode.java index 6e6a29e316468..27d10f593a5e9 100644 --- a/fe/fe-core/src/main/java/com/starrocks/planner/NestLoopJoinNode.java +++ b/fe/fe-core/src/main/java/com/starrocks/planner/NestLoopJoinNode.java @@ -126,6 +126,11 @@ protected void toThrift(TPlanNode msg) { String sqlJoinPredicate = otherJoinConjuncts.stream().map(Expr::toSql).collect(Collectors.joining(",")); msg.nestloop_join_node.setSql_join_conjuncts(sqlJoinPredicate); } + SessionVariable sv = ConnectContext.get().getSessionVariable(); + if (getCanLocalShuffle()) { + msg.nestloop_join_node.setInterpolate_passthrough(sv.isHashJoinInterpolatePassthrough()); + } + if (!buildRuntimeFilters.isEmpty()) { msg.nestloop_join_node.setBuild_runtime_filters( diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/physical/PhysicalJoinOperator.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/physical/PhysicalJoinOperator.java index 433b32762239e..6a082db2cbdf8 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/physical/PhysicalJoinOperator.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/physical/PhysicalJoinOperator.java @@ -33,7 +33,7 @@ public abstract class PhysicalJoinOperator extends PhysicalOperator { protected final JoinOperator joinType; protected final ScalarOperator onPredicate; protected final String joinHint; - protected boolean canLocalShuffle; + protected boolean outputRequireHashPartition = true; protected PhysicalJoinOperator(OperatorType operatorType, JoinOperator joinType, ScalarOperator onPredicate, @@ -131,11 +131,11 @@ public void fillDisableDictOptimizeColumns(ColumnRefSet columnRefSet) { } } - public void setCanLocalShuffle(boolean v) { - canLocalShuffle = v; + public void setOutputRequireHashPartition(boolean v) { + outputRequireHashPartition = v; } - public boolean getCanLocalShuffle() { - return canLocalShuffle; + public boolean getOutputRequireHashPartition() { + return outputRequireHashPartition; } } diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/JoinLocalShuffleRule.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/JoinLocalShuffleRule.java index 7f110114e3ae7..04b0c2d639c5d 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/JoinLocalShuffleRule.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/JoinLocalShuffleRule.java @@ -42,16 +42,22 @@ public Void visit(OptExpression opt, TaskContext context) { @Override public Void visitPhysicalDistribution(OptExpression opt, TaskContext context) { Operator op = opt.getInputs().get(0).getOp(); - // exchange + local agg + join, then this join can use local shuffle. + // 1. exchange + local agg + join, then this join can use local shuffle. if ((op instanceof PhysicalHashAggregateOperator) && ((PhysicalHashAggregateOperator) op).getType().isLocal()) { Operator childOp = opt.getInputs().get(0).getInputs().get(0).getOp(); if (childOp instanceof PhysicalJoinOperator) { PhysicalJoinOperator joinOperator = (PhysicalJoinOperator) childOp; - joinOperator.setCanLocalShuffle(true); + joinOperator.setOutputRequireHashPartition(false); } } + // 2. exchange + join + if (op instanceof PhysicalJoinOperator) { + PhysicalJoinOperator joinOperator = (PhysicalJoinOperator) op; + joinOperator.setOutputRequireHashPartition(false); + } + for (OptExpression input : opt.getInputs()) { input.getOp().accept(this, input, context); } diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/plan/PlanFragmentBuilder.java b/fe/fe-core/src/main/java/com/starrocks/sql/plan/PlanFragmentBuilder.java index e5b7480d789b6..6173f938be68d 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/plan/PlanFragmentBuilder.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/plan/PlanFragmentBuilder.java @@ -2730,7 +2730,7 @@ private PlanFragment visitPhysicalJoin(PlanFragment leftFragment, PlanFragment r throw new StarRocksPlannerException("unknown join operator: " + node, INTERNAL_ERROR); } - if (node.getCanLocalShuffle()) { + if (!node.getOutputRequireHashPartition()) { joinNode.setCanLocalShuffle(true); } diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/JoinLocalShuffleTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/JoinLocalShuffleTest.java index 51376a10c39bb..b1f7aceac7a15 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/plan/JoinLocalShuffleTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/JoinLocalShuffleTest.java @@ -18,6 +18,7 @@ import com.starrocks.qe.SessionVariable; import org.junit.AfterClass; import org.junit.BeforeClass; +import org.junit.Test; public class JoinLocalShuffleTest extends PlanTestBase { @@ -49,4 +50,13 @@ public void joinWithAgg() throws Exception { } sv.setNewPlanerAggStage(0); } + + @Test + public void joinUnderExchange() throws Exception { + SessionVariable sv = connectContext.getSessionVariable(); + sv.setInterpolatePassthrough(true); + String sql = "select l.* from t0 l join [shuffle] t1 on upper(v1) = v5 join [shuffle] t2 on lower(v1) = v9"; + String plan = getVerboseExplain(sql); + assertContains(plan, "can local shuffle: true"); + } } diff --git a/gensrc/thrift/PlanNodes.thrift b/gensrc/thrift/PlanNodes.thrift index 754242809ed98..ee81f434b7eed 100644 --- a/gensrc/thrift/PlanNodes.thrift +++ b/gensrc/thrift/PlanNodes.thrift @@ -706,6 +706,7 @@ struct TNestLoopJoinNode { 2: optional list build_runtime_filters; 3: optional list join_conjuncts 4: optional string sql_join_conjuncts + 5: optional bool interpolate_passthrough = false } enum TAggregationOp {