Skip to content

Commit

Permalink
Add AggregationNodeBuilder
Browse files Browse the repository at this point in the history
AggregationNode is created in multiple places
base on existing AggregationNode with some
fields changed.
AggregationNode.Builder makes that semantics
explicit plus it makes it easier to add new
fields to AggregationNode.
  • Loading branch information
lukasz-stec authored and sopel39 committed May 27, 2022
1 parent 111c917 commit a262611
Show file tree
Hide file tree
Showing 22 changed files with 225 additions and 206 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ protected PlanNode visitPlan(PlanNode node, RewriteContext<Void> context)
@Override
public PlanNode visitAggregation(AggregationNode node, RewriteContext<Void> context)
{
return new AggregationNode(idAllocator.getNextId(), context.rewrite(node.getSource()), node.getAggregations(), node.getGroupingSets(), node.getPreGroupedSymbols(), node.getStep(), node.getHashSymbol(), node.getGroupIdSymbol());
return AggregationNode.builderFrom(node)
.setId(idAllocator.getNextId())
.setSource(context.rewrite(node.getSource()))
.build();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,12 @@ private PlanNode addGatheringIntermediate(AggregationNode aggregation, PlanNodeI
{
verify(aggregation.getGroupingKeys().isEmpty(), "Should be an un-grouped aggregation");
ExchangeNode gatheringExchange = ExchangeNode.gatheringExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, aggregation);
return new AggregationNode(
idAllocator.getNextId(),
gatheringExchange,
outputsAsInputs(aggregation.getAggregations()),
aggregation.getGroupingSets(),
aggregation.getPreGroupedSymbols(),
AggregationNode.Step.INTERMEDIATE,
aggregation.getHashSymbol(),
aggregation.getGroupIdSymbol());
return AggregationNode.builderFrom(aggregation)
.setId(idAllocator.getNextId())
.setSource(gatheringExchange)
.setAggregations(outputsAsInputs(aggregation.getAggregations()))
.setStep(AggregationNode.Step.INTERMEDIATE)
.build();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,15 +209,9 @@ public Result apply(AggregationNode aggregationNode, Captures captures, Context
}
}
if (anyRewritten) {
return Result.ofPlanNode(new AggregationNode(
aggregationNode.getId(),
aggregationNode.getSource(),
aggregations.buildOrThrow(),
aggregationNode.getGroupingSets(),
aggregationNode.getPreGroupedSymbols(),
aggregationNode.getStep(),
aggregationNode.getHashSymbol(),
aggregationNode.getGroupIdSymbol()));
return Result.ofPlanNode(AggregationNode.builderFrom(aggregationNode)
.setAggregations(aggregations.buildOrThrow())
.build());
}
return Result.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,20 +140,17 @@ else if (mask.isPresent()) {
newAssignments.putIdentities(aggregationNode.getSource().getOutputSymbols());

return Result.ofPlanNode(
new AggregationNode(
context.getIdAllocator().getNextId(),
new FilterNode(
AggregationNode.builderFrom(aggregationNode)
.setId(context.getIdAllocator().getNextId())
.setSource(new FilterNode(
context.getIdAllocator().getNextId(),
new ProjectNode(
context.getIdAllocator().getNextId(),
aggregationNode.getSource(),
newAssignments.build()),
predicate),
aggregations.buildOrThrow(),
aggregationNode.getGroupingSets(),
ImmutableList.of(),
aggregationNode.getStep(),
aggregationNode.getHashSymbol(),
aggregationNode.getGroupIdSymbol()));
predicate))
.setAggregations(aggregations.buildOrThrow())
.setPreGroupedSymbols(ImmutableList.of())
.build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,10 @@ public Result apply(AggregationNode parent, Captures captures, Context context)
}

return Result.ofPlanNode(
new AggregationNode(
parent.getId(),
subPlan,
newAggregations,
parent.getGroupingSets(),
ImmutableList.of(),
parent.getStep(),
parent.getHashSymbol(),
parent.getGroupIdSymbol()));
AggregationNode.builderFrom(parent)
.setSource(subPlan)
.setAggregations(newAggregations)
.setPreGroupedSymbols(ImmutableList.of())
.build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,8 @@ protected Optional<PlanNode> pushDownProjectOff(

// PruneAggregationSourceColumns will subsequently project off any newly unused inputs.
return Optional.of(
new AggregationNode(
aggregationNode.getId(),
aggregationNode.getSource(),
prunedAggregations,
aggregationNode.getGroupingSets(),
aggregationNode.getPreGroupedSymbols(),
aggregationNode.getStep(),
aggregationNode.getHashSymbol(),
aggregationNode.getGroupIdSymbol()));
AggregationNode.builderFrom(aggregationNode)
.setAggregations(prunedAggregations)
.build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,10 @@ public PlanNode visitAggregation(AggregationNode node, Boolean context)
return rewrittenNode;
}

return new AggregationNode(
node.getId(),
rewrittenNode,
node.getAggregations(),
node.getGroupingSets(),
ImmutableList.of(),
node.getStep(),
node.getHashSymbol(),
node.getGroupIdSymbol());
return AggregationNode.builderFrom(node)
.setSource(rewrittenNode)
.setPreGroupedSymbols(ImmutableList.of())
.build();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,8 @@ else if (metadata.getAggregationFunctionMetadata(context.getSession(), aggregati
if (!anyRewritten) {
return Result.empty();
}
return Result.ofPlanNode(new AggregationNode(
node.getId(),
node.getSource(),
aggregations.buildOrThrow(),
node.getGroupingSets(),
node.getPreGroupedSymbols(),
node.getStep(),
node.getHashSymbol(),
node.getGroupIdSymbol()));
return Result.ofPlanNode(AggregationNode.builderFrom(node)
.setAggregations(aggregations.buildOrThrow())
.build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,11 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont
List<Symbol> groupingKeys = join.getCriteria().stream()
.map(join.getType() == JoinNode.Type.RIGHT ? JoinNode.EquiJoinClause::getLeft : JoinNode.EquiJoinClause::getRight)
.collect(toImmutableList());
AggregationNode rewrittenAggregation = new AggregationNode(
aggregation.getId(),
getInnerTable(join),
aggregation.getAggregations(),
singleGroupingSet(groupingKeys),
ImmutableList.of(),
aggregation.getStep(),
aggregation.getHashSymbol(),
aggregation.getGroupIdSymbol());
AggregationNode rewrittenAggregation = AggregationNode.builderFrom(aggregation)
.setSource(getInnerTable(join))
.setGroupingSets(singleGroupingSet(groupingKeys))
.setPreGroupedSymbols(ImmutableList.of())
.build();

JoinNode rewrittenJoin;
if (join.getType() == JoinNode.Type.LEFT) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,15 +212,10 @@ private static Result pushFilter(FilterNode filterNode, AggregationNode aggregat
aggregation.getOrderingScheme(),
Optional.empty());

AggregationNode newAggregationNode = new AggregationNode(
aggregationNode.getId(),
source,
ImmutableMap.of(countSymbol, newAggregation),
aggregationNode.getGroupingSets(),
aggregationNode.getPreGroupedSymbols(),
aggregationNode.getStep(),
aggregationNode.getHashSymbol(),
aggregationNode.getGroupIdSymbol());
AggregationNode newAggregationNode = AggregationNode.builderFrom(aggregationNode)
.setSource(source)
.setAggregations(ImmutableMap.of(countSymbol, newAggregation))
.build();

// Restore identity projection if it is present in the original plan.
PlanNode filterSource = projectNode.map(project -> project.replaceChildren(ImmutableList.of(newAggregationNode))).orElse(newAggregationNode);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,15 +158,11 @@ private AggregationNode replaceAggregationSource(
PlanNode source,
List<Symbol> groupingKeys)
{
return new AggregationNode(
aggregation.getId(),
source,
aggregation.getAggregations(),
singleGroupingSet(groupingKeys),
ImmutableList.of(),
aggregation.getStep(),
aggregation.getHashSymbol(),
aggregation.getGroupIdSymbol());
return AggregationNode.builderFrom(aggregation)
.setSource(source)
.setGroupingSets(singleGroupingSet(groupingKeys))
.setPreGroupedSymbols(ImmutableList.of())
.build();
}

private PlanNode pushPartialToJoin(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,22 +133,17 @@ public Result apply(AggregationNode node, Captures captures, Context context)
partitionCount = getHashPartitionCount(context.getSession());
}
return Result.ofPlanNode(
new AggregationNode(
node.getId(),
new ProjectNode(
AggregationNode.builderFrom(node)
.setSource(new ProjectNode(
context.getIdAllocator().getNextId(),
node.getSource(),
Assignments.builder()
.putIdentities(node.getSource().getOutputSymbols())
.put(partitionCountSymbol, new LongLiteral(Integer.toString(partitionCount)))
.putAll(envelopeAssignments.buildOrThrow())
.build()),
aggregations.buildOrThrow(),
node.getGroupingSets(),
node.getPreGroupedSymbols(),
node.getStep(),
node.getHashSymbol(),
node.getGroupIdSymbol()));
.build()))
.setAggregations(aggregations.buildOrThrow())
.build());
}

private boolean isStEnvelopeFunctionCall(Expression expression, ResolvedFunction stEnvelopeFunction)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,11 @@ public Result apply(AggregationNode parent, Captures captures, Context context)
return Result.empty();
}

return Result.ofPlanNode(new AggregationNode(
parent.getId(),
child,
aggregations,
parent.getGroupingSets(),
ImmutableList.of(),
parent.getStep(),
parent.getHashSymbol(),
parent.getGroupIdSymbol()));
return Result.ofPlanNode(AggregationNode.builderFrom(parent)
.setSource(child)
.setAggregations(aggregations)
.setPreGroupedSymbols(ImmutableList.of())
.build());
}

private boolean isCountOverConstant(Session session, AggregationNode.Aggregation aggregation, Assignments inputs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import static io.trino.sql.planner.plan.AggregationNode.singleAggregation;
import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet;
import static io.trino.sql.planner.plan.Patterns.aggregation;
import static java.util.Collections.emptyList;

/**
* Implements distinct aggregations with similar inputs by transforming plans of the following shape:
Expand Down Expand Up @@ -122,27 +121,25 @@ public Result apply(AggregationNode aggregation, Captures captures, Context cont
.collect(Collectors.toSet());

return Result.ofPlanNode(
new AggregationNode(
aggregation.getId(),
singleAggregation(
context.getIdAllocator().getNextId(),
aggregation.getSource(),
ImmutableMap.of(),
singleGroupingSet(ImmutableList.<Symbol>builder()
.addAll(aggregation.getGroupingKeys())
.addAll(symbols)
.build())),
// remove DISTINCT flag from function calls
aggregation.getAggregations()
.entrySet().stream()
.collect(Collectors.toMap(
Map.Entry::getKey,
e -> removeDistinct(e.getValue()))),
aggregation.getGroupingSets(),
emptyList(),
aggregation.getStep(),
aggregation.getHashSymbol(),
aggregation.getGroupIdSymbol()));
AggregationNode.builderFrom(aggregation)
.setSource(
singleAggregation(
context.getIdAllocator().getNextId(),
aggregation.getSource(),
ImmutableMap.of(),
singleGroupingSet(ImmutableList.<Symbol>builder()
.addAll(aggregation.getGroupingKeys())
.addAll(symbols)
.build())))
.setAggregations(
// remove DISTINCT flag from function calls
aggregation.getAggregations()
.entrySet().stream()
.collect(Collectors.toMap(
Map.Entry::getKey,
e -> removeDistinct(e.getValue()))))
.setPreGroupedSymbols(ImmutableList.of())
.build());
}

private static Aggregation removeDistinct(Aggregation aggregation)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,18 +130,19 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co

// restore aggregation
AggregationNode aggregation = captures.get(AGGREGATION);
aggregation = new AggregationNode(
aggregation.getId(),
join,
aggregation.getAggregations(),
singleGroupingSet(ImmutableList.<Symbol>builder()
.addAll(join.getLeftOutputSymbols())
.addAll(aggregation.getGroupingKeys())
.build()),
ImmutableList.of(),
aggregation.getStep(),
Optional.empty(),
Optional.empty());
aggregation = AggregationNode.builderFrom(aggregation)
.setSource(join)
.setGroupingSets(
singleGroupingSet(ImmutableList.<Symbol>builder()
.addAll(join.getLeftOutputSymbols())
.addAll(aggregation.getGroupingKeys())
.build()))
.setPreGroupedSymbols(
ImmutableList.of())
.setHashSymbol(
Optional.empty())
.setGroupIdSymbol(Optional.empty())
.build();

// restrict outputs
Optional<PlanNode> project = restrictOutputs(context.getIdAllocator(), aggregation, ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,18 +194,16 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co

// restore grouped aggregation
AggregationNode groupedAggregation = captures.get(AGGREGATION);
groupedAggregation = new AggregationNode(
groupedAggregation.getId(),
distinct != null ? distinct : join,
groupedAggregation.getAggregations(),
singleGroupingSet(ImmutableList.<Symbol>builder()
groupedAggregation = AggregationNode.builderFrom(groupedAggregation)
.setSource(distinct != null ? distinct : join)
.setGroupingSets(singleGroupingSet(ImmutableList.<Symbol>builder()
.addAll(join.getLeftOutputSymbols())
.addAll(groupedAggregation.getGroupingKeys())
.build()),
ImmutableList.of(),
groupedAggregation.getStep(),
Optional.empty(),
Optional.empty());
.build()))
.setPreGroupedSymbols(ImmutableList.of())
.setHashSymbol(Optional.empty())
.setGroupIdSymbol(Optional.empty())
.build();

// restrict outputs and apply projection
Set<Symbol> outputSymbols = new HashSet<>(correlatedJoinNode.getOutputSymbols());
Expand Down
Loading

0 comments on commit a262611

Please sign in to comment.