Skip to content

Commit

Permalink
[Enhancement] support insert local exchange in parttern exchange->join (
Browse files Browse the repository at this point in the history
#52021)

Signed-off-by: stdpain <[email protected]>
(cherry picked from commit 9c16474)
  • Loading branch information
stdpain authored and mergify[bot] committed Oct 23, 2024
1 parent b9f53be commit d407413
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 8 deletions.
9 changes: 9 additions & 0 deletions be/src/exec/cross_join_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ Status CrossJoinNode::init(const TPlanNode& tnode, RuntimeState* state) {
RETURN_IF_ERROR(
Expr::create_expr_trees(_pool, tnode.nestloop_join_node.join_conjuncts, &_join_conjuncts, state));
}

if (tnode.nestloop_join_node.__isset.interpolate_passthrough) {
_interpolate_passthrough = tnode.nestloop_join_node.interpolate_passthrough;
}
if (tnode.nestloop_join_node.__isset.sql_join_conjuncts) {
_sql_join_conjuncts = tnode.nestloop_join_node.sql_join_conjuncts;
}
Expand Down Expand Up @@ -619,6 +623,11 @@ std::vector<std::shared_ptr<pipeline::OperatorFactory>> CrossJoinNode::_decompos
left_ops.emplace_back(std::make_shared<LimitOperatorFactory>(context->next_operator_id(), id(), limit()));
}

if (_interpolate_passthrough && !context->is_colocate_group()) {
left_ops = context->maybe_interpolate_local_passthrough_exchange(runtime_state(), id(), left_ops,
context->degree_of_parallelism(), true);
}

if constexpr (std::is_same_v<BuildFactory, SpillableNLJoinBuildOperatorFactory>) {
may_add_chunk_accumulate_operator(left_ops, context, id());
}
Expand Down
1 change: 1 addition & 0 deletions be/src/exec/cross_join_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ class CrossJoinNode final : public ExecNode {
std::vector<uint32_t> _buf_selective;

std::vector<RuntimeFilterBuildDescriptor*> _build_runtime_filters;
bool _interpolate_passthrough = false;
};

} // namespace starrocks
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ protected void toThrift(TPlanNode msg) {
String sqlJoinPredicate = otherJoinConjuncts.stream().map(Expr::toSql).collect(Collectors.joining(","));
msg.nestloop_join_node.setSql_join_conjuncts(sqlJoinPredicate);
}
SessionVariable sv = ConnectContext.get().getSessionVariable();
if (getCanLocalShuffle()) {
msg.nestloop_join_node.setInterpolate_passthrough(sv.isHashJoinInterpolatePassthrough());
}


if (!buildRuntimeFilters.isEmpty()) {
msg.nestloop_join_node.setBuild_runtime_filters(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public abstract class PhysicalJoinOperator extends PhysicalOperator {
protected final JoinOperator joinType;
protected final ScalarOperator onPredicate;
protected final String joinHint;
protected boolean canLocalShuffle;
protected boolean outputRequireHashPartition = true;

protected PhysicalJoinOperator(OperatorType operatorType, JoinOperator joinType,
ScalarOperator onPredicate,
Expand Down Expand Up @@ -131,11 +131,11 @@ public void fillDisableDictOptimizeColumns(ColumnRefSet columnRefSet) {
}
}

public void setCanLocalShuffle(boolean v) {
canLocalShuffle = v;
public void setOutputRequireHashPartition(boolean v) {
outputRequireHashPartition = v;
}

public boolean getCanLocalShuffle() {
return canLocalShuffle;
public boolean getOutputRequireHashPartition() {
return outputRequireHashPartition;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,22 @@ public Void visit(OptExpression opt, TaskContext context) {
@Override
public Void visitPhysicalDistribution(OptExpression opt, TaskContext context) {
Operator op = opt.getInputs().get(0).getOp();
// exchange + local agg + join, then this join can use local shuffle.
// 1. exchange + local agg + join, then this join can use local shuffle.
if ((op instanceof PhysicalHashAggregateOperator) &&
((PhysicalHashAggregateOperator) op).getType().isLocal()) {
Operator childOp = opt.getInputs().get(0).getInputs().get(0).getOp();
if (childOp instanceof PhysicalJoinOperator) {
PhysicalJoinOperator joinOperator = (PhysicalJoinOperator) childOp;
joinOperator.setCanLocalShuffle(true);
joinOperator.setOutputRequireHashPartition(false);
}
}

// 2. exchange + join
if (op instanceof PhysicalJoinOperator) {
PhysicalJoinOperator joinOperator = (PhysicalJoinOperator) op;
joinOperator.setOutputRequireHashPartition(false);
}

for (OptExpression input : opt.getInputs()) {
input.getOp().accept(this, input, context);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2730,7 +2730,7 @@ private PlanFragment visitPhysicalJoin(PlanFragment leftFragment, PlanFragment r
throw new StarRocksPlannerException("unknown join operator: " + node, INTERNAL_ERROR);
}

if (node.getCanLocalShuffle()) {
if (!node.getOutputRequireHashPartition()) {
joinNode.setCanLocalShuffle(true);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.starrocks.qe.SessionVariable;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;

public class JoinLocalShuffleTest extends PlanTestBase {

Expand Down Expand Up @@ -49,4 +50,13 @@ public void joinWithAgg() throws Exception {
}
sv.setNewPlanerAggStage(0);
}

@Test
public void joinUnderExchange() throws Exception {
SessionVariable sv = connectContext.getSessionVariable();
sv.setInterpolatePassthrough(true);
String sql = "select l.* from t0 l join [shuffle] t1 on upper(v1) = v5 join [shuffle] t2 on lower(v1) = v9";
String plan = getVerboseExplain(sql);
assertContains(plan, "can local shuffle: true");
}
}
1 change: 1 addition & 0 deletions gensrc/thrift/PlanNodes.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,7 @@ struct TNestLoopJoinNode {
2: optional list<RuntimeFilter.TRuntimeFilterDescription> build_runtime_filters;
3: optional list<Exprs.TExpr> join_conjuncts
4: optional string sql_join_conjuncts
5: optional bool interpolate_passthrough = false
}

enum TAggregationOp {
Expand Down

0 comments on commit d407413

Please sign in to comment.