Skip to content

Commit

Permalink
[fix](nereids)need substitute agg function using agg node's output if…
Browse files Browse the repository at this point in the history
… it's in order by key
  • Loading branch information
starocean999 committed Feb 2, 2024
1 parent b039141 commit fc5cd34
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,14 @@ static class Resolver {
private final List<Expression> groupByExpressions;
private final Map<Expression, Slot> substitution = Maps.newHashMap();
private final List<NamedExpression> newOutputSlots = Lists.newArrayList();
private final Map<Slot, Expression> outputSubstitutionMap;

Resolver(Aggregate aggregate) {
outputExpressions = aggregate.getOutputExpressions();
groupByExpressions = aggregate.getGroupByExpressions();
outputSubstitutionMap = outputExpressions.stream().filter(Alias.class::isInstance)
.collect(Collectors.toMap(alias -> alias.toSlot(), alias -> alias.child(0),
(k1, k2) -> k1));
}

public void resolve(Expression expression) {
Expand Down Expand Up @@ -273,7 +277,8 @@ private boolean checkWhetherNestedAggregateFunctionsExist(AggregateFunction aggr
}

private void generateAliasForNewOutputSlots(Expression expression) {
Alias alias = new Alias(expression);
Expression replacedExpr = ExpressionUtils.replace(expression, outputSubstitutionMap);
Alias alias = new Alias(replacedExpr);
newOutputSlots.add(alias);
substitution.put(expression, alias.toSlot());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Abs;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.types.BigIntType;
Expand Down Expand Up @@ -513,6 +514,28 @@ public void testSortAggregateFunction() {
).when(FieldChecker.check("outputExpressions", Lists.newArrayList(a1, countStar))))
).when(FieldChecker.check("orderKeys", ImmutableList.of(new OrderKey(countStar.toSlot(), true, true))))
).when(FieldChecker.check("projects", Lists.newArrayList(a1.toSlot()))));
sql = "SELECT abs(a1) xx, sum(a2) FROM t1 GROUP BY xx ORDER BY MIN(xx)";
a1 = new SlotReference(
new ExprId(1), "a1", TinyIntType.INSTANCE, true,
ImmutableList.of("test_resolve_aggregate_functions", "t1")
);
Alias xx = new Alias(new ExprId(3), new Abs(a1), "xx");
a2 = new SlotReference(
new ExprId(2), "a2", TinyIntType.INSTANCE, true,
ImmutableList.of("test_resolve_aggregate_functions", "t1")
);
sumA2 = new Alias(new ExprId(4), new Sum(a2), "sum(a2)");

Alias minXX = new Alias(new ExprId(5), new Min(xx.toSlot()), "min(xx)");
PlanChecker.from(connectContext).analyze(sql).printlnTree().matches(logicalProject(
logicalSort(logicalProject(logicalAggregate(logicalProject(logicalOlapScan())
.when(FieldChecker.check("projects", Lists.newArrayList(xx, a2, a1))))))
.when(FieldChecker.check("orderKeys",
ImmutableList
.of(new OrderKey(minXX.toSlot(), true, true)))))
.when(FieldChecker.check("projects",
Lists.newArrayList(xx.toSlot(),
sumA2.toSlot()))));
}

@Test
Expand Down

0 comments on commit fc5cd34

Please sign in to comment.