Skip to content

Commit

Permalink
Handle repeated predicate pushdown into Hive connector
Browse files Browse the repository at this point in the history
The previous implementation was only considering the first
attempt where a filter is pushed down into the Hive connector.
As a result, for a query like this, the partition filter above
the bottommost filter would be ignored:

    SELECT * FROM (
        SELECT * FROM t WHERE a in (1, 2)
    ) u
    WHERE u.pk = 'b';
  • Loading branch information
martint authored and electrum committed Jun 14, 2019
1 parent 31d663b commit 74b7a79
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1686,10 +1686,6 @@ public Optional<ConstraintApplicationResult<ConnectorTableHandle>> applyFilter(C
HiveTableHandle handle = (HiveTableHandle) tableHandle;
checkArgument(!handle.getAnalyzePartitionValues().isPresent() || constraint.getSummary().isAll(), "Analyze should not have a constraint");

if (handle.getPartitions().isPresent()) {
return Optional.empty(); // TODO: optimize multiple calls to applyFilter
}

HivePartitionResult partitionResult = partitionManager.getPartitions(metastore, handle, constraint);
HiveTableHandle newHandle = partitionManager.applyPartitionResult(handle, partitionResult);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.function.Predicate;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Predicates.not;
Expand All @@ -69,7 +70,6 @@
import static io.prestosql.plugin.hive.HiveUtil.parsePartitionValue;
import static io.prestosql.plugin.hive.metastore.MetastoreUtil.toPartitionName;
import static io.prestosql.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.prestosql.spi.connector.Constraint.alwaysTrue;
import static io.prestosql.spi.predicate.TupleDomain.all;
import static io.prestosql.spi.predicate.TupleDomain.none;
import static io.prestosql.spi.type.Chars.padSpaces;
Expand Down Expand Up @@ -119,7 +119,8 @@ public HivePartitionManager(
public HivePartitionResult getPartitions(SemiTransactionalHiveMetastore metastore, ConnectorTableHandle tableHandle, Constraint constraint)
{
HiveTableHandle hiveTableHandle = (HiveTableHandle) tableHandle;
TupleDomain<ColumnHandle> effectivePredicate = constraint.getSummary();
TupleDomain<ColumnHandle> effectivePredicate = constraint.getSummary()
.intersect(hiveTableHandle.getEnforcedConstraint());

SchemaTableName tableName = hiveTableHandle.getSchemaTableName();
Optional<HiveBucketHandle> hiveBucketHandle = hiveTableHandle.getBucketHandle();
Expand Down Expand Up @@ -150,14 +151,22 @@ public HivePartitionResult getPartitions(SemiTransactionalHiveMetastore metastor
.map(column -> typeManager.getType(column.getTypeSignature()))
.collect(toList());

List<String> partitionNames = getFilteredPartitionNames(metastore, tableName, partitionColumns, effectivePredicate);

Iterable<HivePartition> partitionsIterable = () -> partitionNames.stream()
// Apply extra filters which could not be done by getFilteredPartitionNames
.map(partitionName -> parseValuesAndFilterPartition(tableName, partitionName, partitionColumns, partitionTypes, constraint))
.filter(Optional::isPresent)
.map(Optional::get)
.iterator();
Iterable<HivePartition> partitionsIterable;
Predicate<Map<ColumnHandle, NullableValue>> predicate = constraint.predicate().orElse(value -> true);
if (hiveTableHandle.getPartitions().isPresent()) {
partitionsIterable = hiveTableHandle.getPartitions().get().stream()
.filter(partition -> partitionMatches(partitionColumns, effectivePredicate, predicate, partition))
.collect(toImmutableList());
}
else {
List<String> partitionNames = getFilteredPartitionNames(metastore, tableName, partitionColumns, effectivePredicate);
partitionsIterable = () -> partitionNames.stream()
// Apply extra filters which could not be done by getFilteredPartitionNames
.map(partitionName -> parseValuesAndFilterPartition(tableName, partitionName, partitionColumns, partitionTypes, effectivePredicate, predicate))
.filter(Optional::isPresent)
.map(Optional::get)
.iterator();
}

// All partition key domains will be fully evaluated, so we don't need to include those
TupleDomain<ColumnHandle> remainingTupleDomain = TupleDomain.withColumnDomains(Maps.filterKeys(effectivePredicate.getDomains().get(), not(Predicates.in(partitionColumns))));
Expand All @@ -182,7 +191,7 @@ public HivePartitionResult getPartitions(ConnectorTableHandle tableHandle, List<

List<HivePartition> partitionList = partitionValuesList.stream()
.map(partitionValues -> toPartitionName(partitionColumnNames, partitionValues))
.map(partitionName -> parseValuesAndFilterPartition(tableName, partitionName, partitionColumns, partitionColumnTypes, alwaysTrue()))
.map(partitionName -> parseValuesAndFilterPartition(tableName, partitionName, partitionColumns, partitionColumnTypes, TupleDomain.all(), value -> true))
.map(partition -> partition.orElseThrow(() -> new VerifyException("partition must exist")))
.collect(toImmutableList());

Expand Down Expand Up @@ -252,24 +261,29 @@ private Optional<HivePartition> parseValuesAndFilterPartition(
String partitionId,
List<HiveColumnHandle> partitionColumns,
List<Type> partitionColumnTypes,
Constraint constraint)
TupleDomain<ColumnHandle> constraintSummary,
Predicate<Map<ColumnHandle, NullableValue>> constraint)
{
HivePartition partition = parsePartition(tableName, partitionId, partitionColumns, partitionColumnTypes, timeZone);

Map<ColumnHandle, Domain> domains = constraint.getSummary().getDomains().get();
if (partitionMatches(partitionColumns, constraintSummary, constraint, partition)) {
return Optional.of(partition);
}
return Optional.empty();
}

private boolean partitionMatches(List<HiveColumnHandle> partitionColumns, TupleDomain<ColumnHandle> constraintSummary, Predicate<Map<ColumnHandle, NullableValue>> constraint, HivePartition partition)
{
Map<ColumnHandle, Domain> domains = constraintSummary.getDomains().get();
for (HiveColumnHandle column : partitionColumns) {
NullableValue value = partition.getKeys().get(column);
Domain allowedDomain = domains.get(column);
if (allowedDomain != null && !allowedDomain.includesNullableValue(value.getValue())) {
return Optional.empty();
return false;
}
}

if (constraint.predicate().isPresent() && !constraint.predicate().get().test(partition.getKeys())) {
return Optional.empty();
}

return Optional.of(partition);
return constraint.test(partition.getKeys());
}

private List<String> getFilteredPartitionNames(SemiTransactionalHiveMetastore metastore, SchemaTableName tableName, List<HiveColumnHandle> partitionKeys, TupleDomain<ColumnHandle> effectivePredicate)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Function;
Expand Down Expand Up @@ -2541,6 +2542,98 @@ public void testPredicatePushDownToTableScan()
}
}

@Test
public void testPartitionPruning()
{
assertUpdate("CREATE TABLE test_partition_pruning (v bigint, k varchar) WITH (partitioned_by = array['k'])");
assertUpdate("INSERT INTO test_partition_pruning (v, k) VALUES (1, 'a'), (2, 'b'), (3, 'c'), (4, 'e')", 4);

try {
String query = "SELECT * FROM test_partition_pruning WHERE k = 'a'";
assertQuery(query, "VALUES (1, 'a')");
assertConstraints(
query,
ImmutableSet.of(
new ColumnConstraint(
"k",
VARCHAR.getTypeSignature(),
new FormattedDomain(
false,
ImmutableSet.of(
new FormattedRange(
new FormattedMarker(Optional.of("a"), EXACTLY),
new FormattedMarker(Optional.of("a"), EXACTLY)))))));

query = "SELECT * FROM test_partition_pruning WHERE k IN ('a', 'b')";
assertQuery(query, "VALUES (1, 'a'), (2, 'b')");
assertConstraints(
query,
ImmutableSet.of(
new ColumnConstraint(
"k",
VARCHAR.getTypeSignature(),
new FormattedDomain(
false,
ImmutableSet.of(
new FormattedRange(
new FormattedMarker(Optional.of("a"), EXACTLY),
new FormattedMarker(Optional.of("a"), EXACTLY)),
new FormattedRange(
new FormattedMarker(Optional.of("b"), EXACTLY),
new FormattedMarker(Optional.of("b"), EXACTLY)))))));

query = "SELECT * FROM test_partition_pruning WHERE k >= 'b'";
assertQuery(query, "VALUES (2, 'b'), (3, 'c'), (4, 'e')");
assertConstraints(
query,
ImmutableSet.of(
new ColumnConstraint(
"k",
VARCHAR.getTypeSignature(),
new FormattedDomain(
false,
ImmutableSet.of(
new FormattedRange(
new FormattedMarker(Optional.of("b"), EXACTLY),
new FormattedMarker(Optional.of("b"), EXACTLY)),
new FormattedRange(
new FormattedMarker(Optional.of("c"), EXACTLY),
new FormattedMarker(Optional.of("c"), EXACTLY)),
new FormattedRange(
new FormattedMarker(Optional.of("e"), EXACTLY),
new FormattedMarker(Optional.of("e"), EXACTLY)))))));

query = "SELECT * FROM (" +
" SELECT * " +
" FROM test_partition_pruning " +
" WHERE v IN (1, 2, 4) " +
") t " +
"WHERE t.k >= 'b'";
assertQuery(query, "VALUES (2, 'b'), (4, 'e')");
assertConstraints(
query,
ImmutableSet.of(
new ColumnConstraint(
"k",
VARCHAR.getTypeSignature(),
new FormattedDomain(
false,
ImmutableSet.of(
new FormattedRange(
new FormattedMarker(Optional.of("b"), EXACTLY),
new FormattedMarker(Optional.of("b"), EXACTLY)),
new FormattedRange(
new FormattedMarker(Optional.of("c"), EXACTLY),
new FormattedMarker(Optional.of("c"), EXACTLY)),
new FormattedRange(
new FormattedMarker(Optional.of("e"), EXACTLY),
new FormattedMarker(Optional.of("e"), EXACTLY)))))));
}
finally {
assertUpdate("DROP TABLE test_partition_pruning");
}
}

@Test
public void testMismatchedBucketing()
{
Expand Down Expand Up @@ -4006,6 +4099,17 @@ private void assertColumnType(TableMetadata tableMetadata, String columnName, Ty
assertEquals(tableMetadata.getColumn(columnName).getType(), canonicalizeType(expectedType));
}

private void assertConstraints(@Language("SQL") String query, Set<ColumnConstraint> expected)
{
MaterializedResult result = computeActual("EXPLAIN (TYPE IO, FORMAT JSON) " + query);
Set<ColumnConstraint> constraints = jsonCodec(IoPlan.class).fromJson((String) getOnlyElement(result.getOnlyColumnAsSet()))
.getInputTableColumnInfos().stream()
.findFirst().get()
.getColumnConstraints();

assertEquals(constraints, expected);
}

private void verifyPartition(boolean hasPartition, TableMetadata tableMetadata, List<String> partitionKeys)
{
Object partitionByProperty = tableMetadata.getMetadata().getProperties().get(PARTITIONED_BY_PROPERTY);
Expand Down

0 comments on commit 74b7a79

Please sign in to comment.