Skip to content

Commit

Permalink
Fix handling of complex aggregate expressions in Pinot passthrough qu…
Browse files Browse the repository at this point in the history
…eries
  • Loading branch information
elonazoulay authored and hashhar committed Mar 19, 2022
1 parent 5a2c9e9 commit 041e630
Show file tree
Hide file tree
Showing 4 changed files with 331 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,24 @@
import org.apache.pinot.common.request.context.OrderByExpressionContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.reduce.PostAggregationHandler;
import org.apache.pinot.core.query.request.context.QueryContext;
import org.apache.pinot.core.query.request.context.utils.BrokerRequestToQueryContextConverter;
import org.apache.pinot.segment.spi.AggregationFunctionType;
import org.apache.pinot.sql.parsers.CalciteSqlCompiler;

import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.Set;

import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.plugin.pinot.PinotColumnHandle.fromNonAggregateColumnHandle;
import static io.trino.plugin.pinot.PinotColumnHandle.getTrinoTypeFromPinotType;
import static io.trino.plugin.pinot.PinotErrorCode.PINOT_EXCEPTION;
import static io.trino.plugin.pinot.PinotErrorCode.PINOT_UNSUPPORTED_COLUMN_TYPE;
import static io.trino.plugin.pinot.query.PinotExpressionRewriter.rewriteExpression;
import static io.trino.plugin.pinot.query.PinotPatterns.WILDCARD;
Expand All @@ -57,12 +62,17 @@
import static java.lang.String.format;
import static java.util.Locale.ENGLISH;
import static java.util.Objects.requireNonNull;
import static org.apache.pinot.segment.spi.AggregationFunctionType.COUNT;
import static org.apache.pinot.segment.spi.AggregationFunctionType.DISTINCTCOUNT;
import static org.apache.pinot.segment.spi.AggregationFunctionType.DISTINCTCOUNTHLL;
import static org.apache.pinot.segment.spi.AggregationFunctionType.getAggregationFunctionType;

public final class DynamicTableBuilder
{
private static final CalciteSqlCompiler REQUEST_COMPILER = new CalciteSqlCompiler();
public static final String OFFLINE_SUFFIX = "_OFFLINE";
public static final String REALTIME_SUFFIX = "_REALTIME";
private static final Set<AggregationFunctionType> NON_NULL_ON_EMPTY_AGGREGATIONS = EnumSet.of(COUNT, DISTINCTCOUNT, DISTINCTCOUNTHLL);

private DynamicTableBuilder()
{
Expand All @@ -84,14 +94,12 @@ public static DynamicTable buildFromPql(PinotMetadata pinotMetadata, SchemaTable
PinotTypeResolver pinotTypeResolver = new PinotTypeResolver(pinotClient, pinotTableName);
List<PinotColumnHandle> selectColumns = ImmutableList.of();

ImmutableMap.Builder<String, Type> aggregateTypesBuilder = ImmutableMap.builder();
Map<String, PinotColumnNameAndTrinoType> aggregateTypes = ImmutableMap.of();
if (queryContext.getAggregationFunctions() != null) {
checkState(queryContext.getAggregationFunctions().length > 0, "Aggregation Functions is empty");
for (AggregationFunction<?, ?> aggregationFunction : queryContext.getAggregationFunctions()) {
aggregateTypesBuilder.put(aggregationFunction.getResultColumnName(), toTrinoType(aggregationFunction.getFinalResultColumnType()));
}
aggregateTypes = getAggregateTypes(schemaTableName, queryContext, columnHandles);
}
Map<String, Type> aggregateTypes = aggregateTypesBuilder.buildOrThrow();

if (queryContext.getSelectExpressions() != null) {
checkState(!queryContext.getSelectExpressions().isEmpty(), "Pinot selections is empty");
selectColumns = getPinotColumns(schemaTableName, queryContext.getSelectExpressions(), queryContext.getAliasList(), columnHandles, pinotTypeResolver, aggregateTypes);
Expand Down Expand Up @@ -150,7 +158,7 @@ private static Type toTrinoType(DataSchema.ColumnDataType columnDataType)
throw new PinotException(PINOT_UNSUPPORTED_COLUMN_TYPE, Optional.empty(), "Unsupported column data type: " + columnDataType);
}

private static List<PinotColumnHandle> getPinotColumns(SchemaTableName schemaTableName, List<ExpressionContext> expressions, List<String> aliases, Map<String, ColumnHandle> columnHandles, PinotTypeResolver pinotTypeResolver, Map<String, Type> aggregateTypes)
private static List<PinotColumnHandle> getPinotColumns(SchemaTableName schemaTableName, List<ExpressionContext> expressions, List<String> aliases, Map<String, ColumnHandle> columnHandles, PinotTypeResolver pinotTypeResolver, Map<String, PinotColumnNameAndTrinoType> aggregateTypes)
{
ImmutableList.Builder<PinotColumnHandle> pinotColumnsBuilder = ImmutableList.builder();
for (int index = 0; index < expressions.size(); index++) {
Expand All @@ -170,22 +178,24 @@ private static List<PinotColumnHandle> getPinotColumns(SchemaTableName schemaTab
return pinotColumnsBuilder.build();
}

private static PinotColumnHandle getPinotColumnHandle(SchemaTableName schemaTableName, ExpressionContext expressionContext, Optional<String> alias, Map<String, ColumnHandle> columnHandles, PinotTypeResolver pinotTypeResolver, Map<String, Type> aggregateTypes)
private static PinotColumnHandle getPinotColumnHandle(SchemaTableName schemaTableName, ExpressionContext expressionContext, Optional<String> alias, Map<String, ColumnHandle> columnHandles, PinotTypeResolver pinotTypeResolver, Map<String, PinotColumnNameAndTrinoType> aggregateTypes)
{
ExpressionContext rewritten = rewriteExpression(schemaTableName, expressionContext, columnHandles);
// If there is no alias, pinot autogenerates the column name:
String columnName = rewritten.toString();
String pinotExpression = formatExpression(schemaTableName, rewritten);
Type trinoType;
boolean isAggregate = isAggregate(rewritten);
boolean isAggregate = hasAggregate(rewritten);
if (isAggregate) {
trinoType = requireNonNull(aggregateTypes.get(columnName.toLowerCase(ENGLISH)), format("Unexpected aggregate expression: '%s'", rewritten));
trinoType = requireNonNull(aggregateTypes.get(columnName).getTrinoType(), format("Unexpected aggregate expression: '%s'", rewritten));
// For aggregation queries, the column name is set by the schema returned from PostAggregationHandler, see getAggregateTypes
columnName = aggregateTypes.get(columnName).getPinotColumnName();
}
else {
trinoType = getTrinoTypeFromPinotType(pinotTypeResolver.resolveExpressionType(rewritten, schemaTableName, columnHandles));
}

return new PinotColumnHandle(alias.orElse(columnName), trinoType, pinotExpression, alias.isPresent(), isAggregate, true, Optional.empty(), Optional.empty());
return new PinotColumnHandle(alias.orElse(columnName), trinoType, pinotExpression, alias.isPresent(), isAggregate, isReturnNullOnEmptyGroup(expressionContext), Optional.empty(), Optional.empty());
}

private static Optional<String> getAlias(List<String> aliases, int index)
Expand All @@ -202,6 +212,81 @@ private static boolean isAggregate(ExpressionContext expressionContext)
return expressionContext.getType() == ExpressionContext.Type.FUNCTION && expressionContext.getFunction().getType() == FunctionContext.Type.AGGREGATION;
}

private static boolean hasAggregate(ExpressionContext expressionContext)
{
switch (expressionContext.getType()) {
case IDENTIFIER:
case LITERAL:
return false;
case FUNCTION:
if (isAggregate(expressionContext)) {
return true;
}
for (ExpressionContext argument : expressionContext.getFunction().getArguments()) {
if (hasAggregate(argument)) {
return true;
}
}
return false;
}
throw new PinotException(PINOT_EXCEPTION, Optional.empty(), format("Unsupported expression type '%s'", expressionContext.getType()));
}

private static Map<String, PinotColumnNameAndTrinoType> getAggregateTypes(SchemaTableName schemaTableName, QueryContext queryContext, Map<String, ColumnHandle> columnHandles)
{
// A mapping from pinot expression to the returned pinot column name and trino type
// Note: the column name is set by the PostAggregationHandler
List<ExpressionContext> aggregateColumnExpressions = queryContext.getSelectExpressions().stream()
.filter(DynamicTableBuilder::hasAggregate)
.collect(toImmutableList());
queryContext = new QueryContext.Builder()
.setAliasList(queryContext.getAliasList())
.setSelectExpressions(aggregateColumnExpressions)
.build();
DataSchema preAggregationSchema = getPreAggregationDataSchema(queryContext);
PostAggregationHandler postAggregationHandler = new PostAggregationHandler(queryContext, preAggregationSchema);
DataSchema postAggregtionSchema = postAggregationHandler.getResultDataSchema();
ImmutableMap.Builder<String, PinotColumnNameAndTrinoType> aggregationTypesBuilder = ImmutableMap.builder();
for (int index = 0; index < postAggregtionSchema.size(); index++) {
aggregationTypesBuilder.put(
// ExpressionContext#toString performs quoting of literals
// Quoting of identifiers is not done to match the corresponding column name in the ResultTable returned from Pinot. Quoting will be done by `DynamicTablePqlExtractor`.
rewriteExpression(schemaTableName,
aggregateColumnExpressions.get(index),
columnHandles).toString(),
new PinotColumnNameAndTrinoType(
postAggregtionSchema.getColumnName(index),
toTrinoType(postAggregtionSchema.getColumnDataType(index))));
}
return aggregationTypesBuilder.buildOrThrow();
}

// Extracted from org.apache.pinot.core.query.reduce.AggregationDataTableReducer
private static DataSchema getPreAggregationDataSchema(QueryContext queryContext)
{
AggregationFunction[] aggregationFunctions = queryContext.getAggregationFunctions();
int numAggregationFunctions = aggregationFunctions.length;
String[] columnNames = new String[numAggregationFunctions];
DataSchema.ColumnDataType[] columnDataTypes = new DataSchema.ColumnDataType[numAggregationFunctions];
for (int i = 0; i < numAggregationFunctions; i++) {
AggregationFunction aggregationFunction = aggregationFunctions[i];
columnNames[i] = aggregationFunction.getResultColumnName();
columnDataTypes[i] = aggregationFunction.getFinalResultColumnType();
}
return new DataSchema(columnNames, columnDataTypes);
}

// To keep consistent behavior with pushed down aggregates, only return non null on an empty group
// if the top level function is in NON_NULL_ON_EMPTY_AGGREGATIONS.
// For all other cases, keep the same behavior as Pinot, since likely the same results are expected.
private static boolean isReturnNullOnEmptyGroup(ExpressionContext expressionContext)
{
if (isAggregate(expressionContext)) {
return !NON_NULL_ON_EMPTY_AGGREGATIONS.contains(getAggregationFunctionType(expressionContext.getFunction().getFunctionName()));
}
return true;
}

private static OptionalLong getOffset(QueryContext queryContext)
{
if (queryContext.getOffset() > 0) {
Expand Down Expand Up @@ -239,4 +324,26 @@ else if (tableName.toUpperCase(ENGLISH).endsWith(REALTIME_SUFFIX)) {
return Optional.empty();
}
}

private static class PinotColumnNameAndTrinoType
{
private final String pinotColumnName;
private final Type trinoType;

public PinotColumnNameAndTrinoType(String pinotColumnName, Type trinoType)
{
this.pinotColumnName = requireNonNull(pinotColumnName, "pinotColumnName is null");
this.trinoType = requireNonNull(trinoType, "trinoType is null");
}

public String getPinotColumnName()
{
return pinotColumnName;
}

public Type getTrinoType()
{
return trinoType;
}
}
}
Loading

0 comments on commit 041e630

Please sign in to comment.