Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle repeated predicate pushdown into Hive connector #984

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1686,10 +1686,6 @@ public Optional<ConstraintApplicationResult<ConnectorTableHandle>> applyFilter(C
HiveTableHandle handle = (HiveTableHandle) tableHandle;
checkArgument(!handle.getAnalyzePartitionValues().isPresent() || constraint.getSummary().isAll(), "Analyze should not have a constraint");

if (handle.getPartitions().isPresent()) {
return Optional.empty(); // TODO: optimize multiple calls to applyFilter
electrum marked this conversation as resolved.
Show resolved Hide resolved
}

HivePartitionResult partitionResult = partitionManager.getPartitions(metastore, handle, constraint);
HiveTableHandle newHandle = partitionManager.applyPartitionResult(handle, partitionResult);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.function.Predicate;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Predicates.not;
Expand All @@ -69,7 +70,6 @@
import static io.prestosql.plugin.hive.HiveUtil.parsePartitionValue;
import static io.prestosql.plugin.hive.metastore.MetastoreUtil.toPartitionName;
import static io.prestosql.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.prestosql.spi.connector.Constraint.alwaysTrue;
import static io.prestosql.spi.predicate.TupleDomain.all;
import static io.prestosql.spi.predicate.TupleDomain.none;
import static io.prestosql.spi.type.Chars.padSpaces;
Expand Down Expand Up @@ -119,7 +119,8 @@ public HivePartitionManager(
public HivePartitionResult getPartitions(SemiTransactionalHiveMetastore metastore, ConnectorTableHandle tableHandle, Constraint constraint)
{
HiveTableHandle hiveTableHandle = (HiveTableHandle) tableHandle;
TupleDomain<ColumnHandle> effectivePredicate = constraint.getSummary();
TupleDomain<ColumnHandle> effectivePredicate = constraint.getSummary()
.intersect(hiveTableHandle.getEnforcedConstraint());

SchemaTableName tableName = hiveTableHandle.getSchemaTableName();
Optional<HiveBucketHandle> hiveBucketHandle = hiveTableHandle.getBucketHandle();
Expand All @@ -141,7 +142,7 @@ public HivePartitionResult getPartitions(SemiTransactionalHiveMetastore metastor
ImmutableList.of(new HivePartition(tableName)),
compactEffectivePredicate,
effectivePredicate,
none(),
all(),
hiveBucketHandle,
bucketFilter);
}
Expand All @@ -150,14 +151,22 @@ public HivePartitionResult getPartitions(SemiTransactionalHiveMetastore metastor
.map(column -> typeManager.getType(column.getTypeSignature()))
.collect(toList());

List<String> partitionNames = getFilteredPartitionNames(metastore, tableName, partitionColumns, effectivePredicate);

Iterable<HivePartition> partitionsIterable = () -> partitionNames.stream()
// Apply extra filters which could not be done by getFilteredPartitionNames
.map(partitionName -> parseValuesAndFilterPartition(tableName, partitionName, partitionColumns, partitionTypes, constraint))
.filter(Optional::isPresent)
.map(Optional::get)
.iterator();
Iterable<HivePartition> partitionsIterable;
Predicate<Map<ColumnHandle, NullableValue>> predicate = constraint.predicate().orElse(value -> true);
if (hiveTableHandle.getPartitions().isPresent()) {
partitionsIterable = hiveTableHandle.getPartitions().get().stream()
.filter(partition -> partitionMatches(partitionColumns, effectivePredicate, predicate, partition))
.collect(toImmutableList());
}
else {
List<String> partitionNames = getFilteredPartitionNames(metastore, tableName, partitionColumns, effectivePredicate);
partitionsIterable = () -> partitionNames.stream()
// Apply extra filters which could not be done by getFilteredPartitionNames
.map(partitionName -> parseValuesAndFilterPartition(tableName, partitionName, partitionColumns, partitionTypes, effectivePredicate, predicate))
.filter(Optional::isPresent)
.map(Optional::get)
.iterator();
}

// All partition key domains will be fully evaluated, so we don't need to include those
TupleDomain<ColumnHandle> remainingTupleDomain = TupleDomain.withColumnDomains(Maps.filterKeys(effectivePredicate.getDomains().get(), not(Predicates.in(partitionColumns))));
Expand All @@ -182,11 +191,11 @@ public HivePartitionResult getPartitions(ConnectorTableHandle tableHandle, List<

List<HivePartition> partitionList = partitionValuesList.stream()
.map(partitionValues -> toPartitionName(partitionColumnNames, partitionValues))
.map(partitionName -> parseValuesAndFilterPartition(tableName, partitionName, partitionColumns, partitionColumnTypes, alwaysTrue()))
.map(partitionName -> parseValuesAndFilterPartition(tableName, partitionName, partitionColumns, partitionColumnTypes, TupleDomain.all(), value -> true))
.map(partition -> partition.orElseThrow(() -> new VerifyException("partition must exist")))
.collect(toImmutableList());

return new HivePartitionResult(partitionColumns, partitionList, all(), all(), none(), bucketHandle, Optional.empty());
return new HivePartitionResult(partitionColumns, partitionList, all(), all(), all(), bucketHandle, Optional.empty());
}

public List<HivePartition> getPartitionsAsList(HivePartitionResult partitionResult)
Expand Down Expand Up @@ -252,24 +261,29 @@ private Optional<HivePartition> parseValuesAndFilterPartition(
String partitionId,
List<HiveColumnHandle> partitionColumns,
List<Type> partitionColumnTypes,
Constraint constraint)
TupleDomain<ColumnHandle> constraintSummary,
Predicate<Map<ColumnHandle, NullableValue>> constraint)
{
HivePartition partition = parsePartition(tableName, partitionId, partitionColumns, partitionColumnTypes, timeZone);

Map<ColumnHandle, Domain> domains = constraint.getSummary().getDomains().get();
if (partitionMatches(partitionColumns, constraintSummary, constraint, partition)) {
return Optional.of(partition);
}
return Optional.empty();
}

private boolean partitionMatches(List<HiveColumnHandle> partitionColumns, TupleDomain<ColumnHandle> constraintSummary, Predicate<Map<ColumnHandle, NullableValue>> constraint, HivePartition partition)
{
Map<ColumnHandle, Domain> domains = constraintSummary.getDomains().get();
for (HiveColumnHandle column : partitionColumns) {
NullableValue value = partition.getKeys().get(column);
Domain allowedDomain = domains.get(column);
if (allowedDomain != null && !allowedDomain.includesNullableValue(value.getValue())) {
return Optional.empty();
return false;
}
}

if (constraint.predicate().isPresent() && !constraint.predicate().get().test(partition.getKeys())) {
return Optional.empty();
}

return Optional.of(partition);
return constraint.test(partition.getKeys());
}

private List<String> getFilteredPartitionNames(SemiTransactionalHiveMetastore metastore, SchemaTableName tableName, List<HiveColumnHandle> partitionKeys, TupleDomain<ColumnHandle> effectivePredicate)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Function;
Expand Down Expand Up @@ -2541,6 +2542,113 @@ public void testPredicatePushDownToTableScan()
}
}

@Test
public void testPartitionPruning()
{
// We need the types of the columns to be different from the values that are used to select them in the queries
// below (i.e., `varchar` vs `varchar(1)` so that the planner inserts implicit coercions between filters and
// cause pushdown to be done iteratively)
assertUpdate("" +
"CREATE TABLE test_partition_pruning (v, k) " +
"WITH (partitioned_by = ARRAY['k']) AS (" +
" VALUES (BIGINT '1', VARCHAR 'a'), " +
" (BIGINT '2', VARCHAR 'b'), " +
" (BIGINT '3', VARCHAR 'c'), " +
" (BIGINT '4', VARCHAR 'e'))",
4);

try {
String query = "SELECT * FROM test_partition_pruning WHERE k = 'a'";
assertQuery(query, "VALUES (1, 'a')");
assertConstraints(
query,
ImmutableSet.of(
new ColumnConstraint(
"k",
VARCHAR.getTypeSignature(),
new FormattedDomain(
false,
ImmutableSet.of(
new FormattedRange(
new FormattedMarker(Optional.of("a"), EXACTLY),
new FormattedMarker(Optional.of("a"), EXACTLY)))))));

query = "SELECT * FROM test_partition_pruning WHERE k IN ('a', 'b')";
assertQuery(query, "VALUES (1, 'a'), (2, 'b')");
assertConstraints(
query,
ImmutableSet.of(
new ColumnConstraint(
"k",
VARCHAR.getTypeSignature(),
new FormattedDomain(
false,
ImmutableSet.of(
new FormattedRange(
new FormattedMarker(Optional.of("a"), EXACTLY),
new FormattedMarker(Optional.of("a"), EXACTLY)),
new FormattedRange(
new FormattedMarker(Optional.of("b"), EXACTLY),
new FormattedMarker(Optional.of("b"), EXACTLY)))))));

query = "SELECT * FROM test_partition_pruning WHERE k >= 'b'";
assertQuery(query, "VALUES (2, 'b'), (3, 'c'), (4, 'e')");
assertConstraints(
query,
ImmutableSet.of(
new ColumnConstraint(
"k",
VARCHAR.getTypeSignature(),
new FormattedDomain(
false,
ImmutableSet.of(
new FormattedRange(
new FormattedMarker(Optional.of("b"), EXACTLY),
new FormattedMarker(Optional.of("b"), EXACTLY)),
new FormattedRange(
new FormattedMarker(Optional.of("c"), EXACTLY),
new FormattedMarker(Optional.of("c"), EXACTLY)),
new FormattedRange(
new FormattedMarker(Optional.of("e"), EXACTLY),
new FormattedMarker(Optional.of("e"), EXACTLY)))))));

query = "SELECT * FROM (" +
" SELECT * " +
" FROM test_partition_pruning " +
" WHERE v IN (1, 2, 4) " +
")" +
"WHERE k >= 'b'";

assertQuery(query, "VALUES (2, 'b'), (4, 'e')");
assertConstraints(
"SELECT * FROM (" +
electrum marked this conversation as resolved.
Show resolved Hide resolved
" SELECT * " +
" FROM test_partition_pruning " +
" WHERE v IN (1, 2, 4) " +
") t " +
"WHERE t.k >= 'b'",
ImmutableSet.of(
electrum marked this conversation as resolved.
Show resolved Hide resolved
new ColumnConstraint(
"k",
VARCHAR.getTypeSignature(),
new FormattedDomain(
false,
ImmutableSet.of(
new FormattedRange(
new FormattedMarker(Optional.of("b"), EXACTLY),
new FormattedMarker(Optional.of("b"), EXACTLY)),
new FormattedRange(
new FormattedMarker(Optional.of("c"), EXACTLY),
new FormattedMarker(Optional.of("c"), EXACTLY)),
new FormattedRange(
new FormattedMarker(Optional.of("e"), EXACTLY),
new FormattedMarker(Optional.of("e"), EXACTLY)))))));
}
finally {
assertUpdate("DROP TABLE test_partition_pruning");
}
}

@Test
public void testMismatchedBucketing()
{
Expand Down Expand Up @@ -4006,6 +4114,17 @@ private void assertColumnType(TableMetadata tableMetadata, String columnName, Ty
assertEquals(tableMetadata.getColumn(columnName).getType(), canonicalizeType(expectedType));
}

private void assertConstraints(@Language("SQL") String query, Set<ColumnConstraint> expected)
{
MaterializedResult result = computeActual("EXPLAIN (TYPE IO, FORMAT JSON) " + query);
Set<ColumnConstraint> constraints = jsonCodec(IoPlan.class).fromJson((String) getOnlyElement(result.getOnlyColumnAsSet()))
.getInputTableColumnInfos().stream()
.findFirst().get()
.getColumnConstraints();

assertEquals(constraints, expected);
}

private void verifyPartition(boolean hasPartition, TableMetadata tableMetadata, List<String> partitionKeys)
{
Object partitionByProperty = tableMetadata.getMetadata().getProperties().get(PARTITIONED_BY_PROPERTY);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,6 @@ public PlanOptimizers(
new PushLimitThroughOuterJoin(),
new PushLimitThroughSemiJoin(),
new PushLimitThroughUnion(),
new PushLimitIntoTableScan(metadata),
new PushPredicateIntoTableScan(metadata, typeAnalyzer),
new RemoveTrivialFilters(),
new RemoveRedundantLimit(),
new RemoveRedundantSort(),
Expand All @@ -339,15 +337,6 @@ public PlanOptimizers(
statsCalculator,
estimatedExchangesCostCalculator,
ImmutableSet.of(
new PushSampleIntoTableScan(metadata))),
new IterativeOptimizer(
ruleStats,
statsCalculator,
estimatedExchangesCostCalculator,
// Temporary hack: separate optimizer step to avoid the sample node being replaced by filter before pushing
// it to table scan node
ImmutableSet.of(
new ImplementBernoulliSampleAsFilter(),
new ImplementOffset(),
new ImplementLimitWithTies())),
simplifyOptimizer,
Expand Down Expand Up @@ -408,7 +397,17 @@ public PlanOptimizers(
ruleStats,
statsCalculator,
estimatedExchangesCostCalculator,
ImmutableSet.of(new PushPredicateIntoTableScan(metadata, typeAnalyzer))),
ImmutableSet.of(
new PushLimitIntoTableScan(metadata),
new PushPredicateIntoTableScan(metadata, typeAnalyzer),
new PushSampleIntoTableScan(metadata))),
new IterativeOptimizer(
ruleStats,
statsCalculator,
estimatedExchangesCostCalculator,
// Temporary hack: separate optimizer step to avoid the sample node being replaced by filter before pushing
// it to table scan node
ImmutableSet.of(new ImplementBernoulliSampleAsFilter())),
new PruneUnreferencedOutputs(),
new IterativeOptimizer(
ruleStats,
Expand Down