Skip to content

Commit

Permalink
[chore](Nereids): optimize to handle enforcer in MergeGroup() (apache…
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwener authored and airborne12 committed Aug 21, 2023
1 parent f5283ad commit 5952b05
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 39 deletions.
36 changes: 21 additions & 15 deletions fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.nereids.util.TreeStringUtils;
import org.apache.doris.nereids.util.Utils;
Expand Down Expand Up @@ -59,6 +58,8 @@ public class Group {

private final List<GroupExpression> logicalExpressions = Lists.newArrayList();
private final List<GroupExpression> physicalExpressions = Lists.newArrayList();
private final List<GroupExpression> enforcers = Lists.newArrayList();

private LogicalProperties logicalProperties;

// Map of cost lower bounds
Expand Down Expand Up @@ -210,6 +211,15 @@ public GroupExpression getBestPlan(PhysicalProperties properties) {
return null;
}

public void addEnforcer(GroupExpression enforcer) {
enforcer.setOwnerGroup(this);
enforcers.add(enforcer);
}

public List<GroupExpression> getEnforcers() {
return enforcers;
}

/**
* Set or update lowestCostPlans: properties --> Pair.of(cost, expression)
*/
Expand Down Expand Up @@ -308,12 +318,12 @@ public void removeParentPhysicalExpressions() {
public void mergeTo(Group target) {
// move parentExpressions Ownership
parentExpressions.keySet().forEach(parent -> target.addParentExpression(parent));
// PhysicalEnforcer isn't in groupExpressions, so mergeGroup() can't replace its children.
// So we need to manually replace the children of PhysicalEnforcer in here.
// TODO: SortEnforcer?
parentExpressions.keySet().stream().filter(ge -> ge.getPlan() instanceof PhysicalDistribute)
.forEach(ge -> ge.children().set(0, target));
parentExpressions.clear();

// move enforcers Ownership
enforcers.forEach(ge -> ge.children().set(0, target));
// TODO: dedup?
enforcers.forEach(enforcer -> target.addEnforcer(enforcer));
enforcers.clear();

// move LogicalExpression PhysicalExpression Ownership
Map<GroupExpression, GroupExpression> logicalSet = target.getLogicalExpressions().stream()
Expand Down Expand Up @@ -345,15 +355,7 @@ public void mergeTo(Group target) {
physicalExpressions.clear();

// Above we already replaceBestPlanGroupExpr, but we still need to moveLowestCostPlansOwnership.
// Because PhysicalEnforcer don't exist in physicalExpressions, so above `replaceBestPlanGroupExpr` can't
// move PhysicalEnforcer in lowestCostPlans. Following code can move PhysicalEnforcer in lowestCostPlans.
lowestCostPlans.forEach((physicalProperties, costAndGroupExpr) -> {
GroupExpression bestGroupExpression = costAndGroupExpr.second;
if (bestGroupExpression.getOwnerGroup() == this || bestGroupExpression.getOwnerGroup() == null) {
// move PhysicalEnforcer into target
Preconditions.checkState(bestGroupExpression.getPlan() instanceof PhysicalDistribute);
bestGroupExpression.setOwnerGroup(target);
}
// move lowestCostPlans Ownership
if (!target.lowestCostPlans.containsKey(physicalProperties)) {
target.lowestCostPlans.put(physicalProperties, costAndGroupExpr);
Expand Down Expand Up @@ -425,6 +427,10 @@ public String toString() {
for (GroupExpression physicalExpression : physicalExpressions) {
str.append(" ").append(physicalExpression).append("\n");
}
str.append(" enforcers:\n");
for (GroupExpression enforcer : enforcers) {
str.append(" ").append(enforcer).append("\n");
}
return str.toString();
}

Expand Down
65 changes: 41 additions & 24 deletions fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.qe.ConnectContext;

Expand All @@ -46,10 +45,11 @@
import org.apache.logging.log4j.Logger;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -115,7 +115,7 @@ public int getGroupExpressionsSize() {
public void removePhysicalExpression() {
groupExpressions.entrySet().removeIf(entry -> entry.getValue().getPlan() instanceof PhysicalPlan);

Iterator<Entry<GroupId, Group>> iterator = groups.entrySet().iterator();
Iterator<Map.Entry<GroupId, Group>> iterator = groups.entrySet().iterator();
while (iterator.hasNext()) {
Map.Entry<GroupId, Group> entry = iterator.next();
Group group = entry.getValue();
Expand Down Expand Up @@ -199,6 +199,20 @@ public Pair<Integer, Integer> countGroupJoin(Group group) {
return Pair.of(continuousJoinCount, Math.max(continuousJoinCount, maxJoinCount));
}

/**
* Add plan to Memo.
*/
public CopyInResult copyIn(Plan plan, @Nullable Group target, boolean rewrite, HashMap<Long, Group> planTable) {
CopyInResult result;
if (rewrite) {
result = doRewrite(plan, target);
} else {
result = doCopyIn(skipProject(plan, target), target, planTable);
}
maybeAddStateId(result);
return result;
}

/**
* Add plan to Memo.
*
Expand All @@ -214,7 +228,7 @@ public CopyInResult copyIn(Plan plan, @Nullable Group target, boolean rewrite) {
if (rewrite) {
result = doRewrite(plan, target);
} else {
result = doCopyIn(skipProject(plan, target), target);
result = doCopyIn(skipProject(plan, target), target, null);
}
maybeAddStateId(result);
return result;
Expand Down Expand Up @@ -402,7 +416,7 @@ private CopyInResult doRewrite(Plan plan, @Nullable Group targetGroup) {
* @return a pair, in which the first element is true if a newly generated groupExpression added into memo,
* and the second element is a reference of node in Memo
*/
private CopyInResult doCopyIn(Plan plan, @Nullable Group targetGroup) {
private CopyInResult doCopyIn(Plan plan, @Nullable Group targetGroup, @Nullable HashMap<Long, Group> planTable) {
Preconditions.checkArgument(!(plan instanceof GroupPlan), "plan can not be GroupPlan");
// check logicalproperties, must same output in a Group.
if (targetGroup != null && !plan.getLogicalProperties().equals(targetGroup.getLogicalProperties())) {
Expand All @@ -425,12 +439,12 @@ private CopyInResult doCopyIn(Plan plan, @Nullable Group targetGroup) {
} else if (child.getGroupExpression().isPresent()) {
childrenGroups.add(child.getGroupExpression().get().getOwnerGroup());
} else {
childrenGroups.add(doCopyIn(child, null).correspondingExpression.getOwnerGroup());
childrenGroups.add(doCopyIn(child, null, planTable).correspondingExpression.getOwnerGroup());
}
}
plan = replaceChildrenToGroupPlan(plan, childrenGroups);
GroupExpression newGroupExpression = new GroupExpression(plan, childrenGroups);
return insertGroupExpression(newGroupExpression, targetGroup, plan.getLogicalProperties());
return insertGroupExpression(newGroupExpression, targetGroup, plan.getLogicalProperties(), planTable);
// TODO: need to derive logical property if generate new group. currently we not copy logical plan into
}

Expand Down Expand Up @@ -474,12 +488,12 @@ private void validateRewriteChildGroup(Group childGroup, Group targetGroup) {
* @return a pair, in which the first element is true if a newly generated groupExpression added into memo,
* and the second element is a reference of node in Memo
*/
private CopyInResult insertGroupExpression(
GroupExpression groupExpression, Group target, LogicalProperties logicalProperties) {
private CopyInResult insertGroupExpression(GroupExpression groupExpression, Group target,
LogicalProperties logicalProperties, HashMap<Long, Group> planTable) {
GroupExpression existedGroupExpression = groupExpressions.get(groupExpression);
if (existedGroupExpression != null) {
if (target != null && !target.getGroupId().equals(existedGroupExpression.getOwnerGroup().getGroupId())) {
mergeGroup(existedGroupExpression.getOwnerGroup(), target);
mergeGroup(target, existedGroupExpression.getOwnerGroup(), planTable);
}
// When we create a GroupExpression, we will add it into ParentExpression of childGroup.
// But if it already exists, we should remove it from ParentExpression of childGroup.
Expand All @@ -506,7 +520,7 @@ private CopyInResult insertGroupExpression(
* @param source source group
* @param destination destination group
*/
public void mergeGroup(Group source, Group destination) {
public void mergeGroup(Group source, Group destination, HashMap<Long, Group> planTable) {
if (source.equals(destination)) {
return;
}
Expand All @@ -516,9 +530,9 @@ public void mergeGroup(Group source, Group destination) {
// cycle, we should not merge
return;
}
// PhysicalEnforcer don't exist in memo, so we need skip them.
if (parent.getPlan() instanceof PhysicalDistribute) {
// TODO: SortEnforcer.
Group parentOwnerGroup = parent.getOwnerGroup();
HashSet<GroupExpression> enforcers = new HashSet<>(parentOwnerGroup.getEnforcers());
if (enforcers.contains(parent)) {
continue;
}
needReplaceChild.add(parent);
Expand All @@ -545,25 +559,28 @@ public void mergeGroup(Group source, Group destination) {
reinsertGroupExpr.mergeTo(existGroupExpr);
} else {
// reinsertGroupExpr & existGroupExpr aren't in same group, need to merge their OwnerGroup.
mergeGroup(reinsertGroupExpr.getOwnerGroup(), existGroupExpr.getOwnerGroup());
mergeGroup(reinsertGroupExpr.getOwnerGroup(), existGroupExpr.getOwnerGroup(), planTable);
}
} else {
groupExpressions.put(reinsertGroupExpr, reinsertGroupExpr);
}
}
// replace source with destination in groups of planTable
if (planTable != null) {
planTable.forEach((bitset, group) -> {
if (group.equals(source)) {
planTable.put(bitset, destination);
}
});
}

source.mergeTo(destination);
if (source == root) {
root = destination;
}
groups.remove(source.getGroupId());
}

/**
* Add enforcer expression into the target group.
*/
public void addEnforcerPlan(GroupExpression groupExpression, Group group) {
Preconditions.checkArgument(groupExpression != null);
groupExpression.setOwnerGroup(group);
// Don't add groupExpression into group's physicalExpressions, it will cause dead loop;
}

private CopyInResult rewriteByExistedPlan(Group targetGroup, Plan existedPlan) {
GroupExpression existedLogicalExpression = existedPlan instanceof GroupPlan
? ((GroupPlan) existedPlan).getGroup().getLogicalExpression() // get first logicalGroupExpression
Expand Down

0 comments on commit 5952b05

Please sign in to comment.