Skip to content

Commit

Permalink
Fix Pinot aggregation pushdown for subqueries
Browse files Browse the repository at this point in the history
  • Loading branch information
elonazoulay committed Apr 17, 2023
1 parent f8190ad commit 2defb1b
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,8 @@ public Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggrega
// can be pushed down: there are currently no subqueries in pinot.
// If there is an offset then do not push the aggregation down as the results will not be correct
if (tableHandle.getQuery().isPresent() &&
(!tableHandle.getQuery().get().getAggregateColumns().isEmpty() ||
(!isAggregationPushdownSupported(session, tableHandle.getQuery(), aggregates, assignments) ||
!tableHandle.getQuery().get().getAggregateColumns().isEmpty() ||
tableHandle.getQuery().get().isAggregateInProjections() ||
tableHandle.getQuery().get().getOffset().isPresent())) {
return Optional.empty();
Expand All @@ -368,10 +369,12 @@ public Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggrega
projections.add(new Variable(pinotColumnHandle.getColumnName(), pinotColumnHandle.getDataType()));
resultAssignments.add(new Assignment(pinotColumnHandle.getColumnName(), pinotColumnHandle, pinotColumnHandle.getDataType()));
}

List<PinotColumnHandle> groupingColumns = getOnlyElement(groupingSets).stream()
.map(PinotColumnHandle.class::cast)
.map(PinotMetadata::toNonAggregateColumnHandle)
.collect(toImmutableList());

OptionalLong limitForDynamicTable = OptionalLong.empty();
// Ensure that pinot default limit of 10 rows is not used
// By setting the limit to maxRowsPerBrokerQuery + 1 the connector will
Expand Down Expand Up @@ -421,6 +424,25 @@ public static PinotColumnHandle toNonAggregateColumnHandle(PinotColumnHandle col
return new PinotColumnHandle(columnHandle.getColumnName(), columnHandle.getDataType(), quoteIdentifier(columnHandle.getColumnName()), false, false, true, Optional.empty(), Optional.empty());
}

private boolean isAggregationPushdownSupported(ConnectorSession session, Optional<DynamicTable> dynamicTable, List<AggregateFunction> aggregates, Map<String, ColumnHandle> assignments)
{
if (dynamicTable.isEmpty()) {
return true;
}
List<PinotColumnHandle> groupingColumns = dynamicTable.get().getGroupingColumns();
if (groupingColumns.isEmpty()) {
return true;
}
// Either second pass of applyAggregation or dynamic table exists
if (aggregates.size() != 1) {
return false;
}
AggregateFunction aggregate = getOnlyElement(aggregates);
AggregateFunctionRule.RewriteContext<Void> context = new CountDistinctContext(assignments, session);

return implementCountDistinct.getPattern().matches(aggregate, context);
}

private Optional<AggregateExpression> applyCountDistinct(ConnectorSession session, AggregateFunction aggregate, Map<String, ColumnHandle> assignments, PinotTableHandle tableHandle, Optional<AggregateExpression> rewriteResult)
{
AggregateFunctionRule.RewriteContext<Void> context = new CountDistinctContext(assignments, session);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1864,6 +1864,91 @@ public void testAggregationPushdown()
.isThrownBy(() -> query("SELECT bool_col, COUNT(long_col) FROM \"SELECT bool_col, long_col FROM " + ALL_TYPES_TABLE + " GROUP BY bool_col, long_col\""))
.withRootCauseInstanceOf(RuntimeException.class)
.withMessage("Operation not supported for DISTINCT aggregation function");

// Verify that count(<column name>) is pushed down only when it matches a COUNT(DISTINCT <column name>) query
assertThat(query("""
SELECT COUNT(bool_col) FROM
(SELECT bool_col FROM alltypes GROUP BY bool_col)
"""))
.matches("VALUES (BIGINT '2')")
.isFullyPushedDown();
assertThat(query("""
SELECT bool_col, COUNT(long_col) FROM
(SELECT bool_col, long_col FROM alltypes GROUP BY bool_col, long_col)
GROUP BY bool_col
"""))
.matches("""
VALUES (FALSE, BIGINT '1'),
(TRUE, BIGINT '9')
""")
.isFullyPushedDown();
// Verify that count(1) is not pushed down when the subquery selects distinct values for a single column
assertThat(query("""
SELECT COUNT(1) FROM
(SELECT bool_col FROM alltypes GROUP BY bool_col)
"""))
.matches("VALUES (BIGINT '2')")
.isNotFullyPushedDown(AggregationNode.class);
// Verify that count(*) is not pushed down when the subquery selects distinct values for a single column
assertThat(query("""
SELECT COUNT(*) FROM
(SELECT bool_col FROM alltypes GROUP BY bool_col)
"""))
.matches("VALUES (BIGINT '2')")
.isNotFullyPushedDown(AggregationNode.class);
// Verify that other aggregation types are not pushed down when the subquery selects distinct values for a single column
assertThat(query("""
SELECT SUM(long_col) FROM
(SELECT long_col FROM alltypes GROUP BY long_col)
"""))
.matches("VALUES (BIGINT '-28327352787')")
.isNotFullyPushedDown(AggregationNode.class);
assertThat(query("""
SELECT bool_col, SUM(long_col) FROM
(SELECT bool_col, long_col FROM alltypes GROUP BY bool_col, long_col)
GROUP BY bool_col
"""))
.matches("VALUES (TRUE, BIGINT '-28327352787'), (FALSE, BIGINT '0')")
.isNotFullyPushedDown(ProjectNode.class, AggregationNode.class, ExchangeNode.class, ExchangeNode.class, AggregationNode.class, ProjectNode.class);
assertThat(query("""
SELECT AVG(long_col) FROM
(SELECT long_col FROM alltypes GROUP BY long_col)
"""))
.matches("VALUES (DOUBLE '-2.8327352787E9')")
.isNotFullyPushedDown(AggregationNode.class);
assertThat(query("""
SELECT bool_col, AVG(long_col) FROM
(SELECT bool_col, long_col FROM alltypes GROUP BY bool_col, long_col)
GROUP BY bool_col
"""))
.matches("VALUES (TRUE, DOUBLE '-3.147483643E9'), (FALSE, DOUBLE '0.0')")
.isNotFullyPushedDown(ProjectNode.class, AggregationNode.class, ExchangeNode.class, ExchangeNode.class, AggregationNode.class, ProjectNode.class);
assertThat(query("""
SELECT MIN(long_col) FROM
(SELECT long_col FROM alltypes GROUP BY long_col)
"""))
.matches("VALUES (BIGINT '-3147483647')")
.isNotFullyPushedDown(AggregationNode.class);
assertThat(query("""
SELECT bool_col, MIN(long_col) FROM
(SELECT bool_col, long_col FROM alltypes GROUP BY bool_col, long_col)
GROUP BY bool_col
"""))
.matches("VALUES (TRUE, BIGINT '-3147483647'), (FALSE, BIGINT '0')")
.isNotFullyPushedDown(ProjectNode.class, AggregationNode.class, ExchangeNode.class, ExchangeNode.class, AggregationNode.class, ProjectNode.class);
assertThat(query("""
SELECT MAX(long_col) FROM
(SELECT long_col FROM alltypes GROUP BY long_col)
"""))
.matches("VALUES (BIGINT '0')")
.isNotFullyPushedDown(AggregationNode.class);
assertThat(query("""
SELECT bool_col, MAX(long_col) FROM
(SELECT bool_col, long_col FROM alltypes GROUP BY bool_col, long_col)
GROUP BY bool_col
"""))
.matches("VALUES (TRUE, BIGINT '-3147483639'), (FALSE, BIGINT '0')")
.isNotFullyPushedDown(ProjectNode.class, AggregationNode.class, ExchangeNode.class, ExchangeNode.class, AggregationNode.class, ProjectNode.class);
}

@Test
Expand Down

0 comments on commit 2defb1b

Please sign in to comment.