Skip to content

Commit

Permalink
Implement predicate push down for parquet nested columns
Browse files Browse the repository at this point in the history
  • Loading branch information
leetcode-1533 authored and yluan committed Mar 26, 2023
1 parent 9c9e951 commit cb71c5d
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import org.apache.parquet.io.MessageColumnIO;
import org.apache.parquet.schema.GroupType;
import org.apache.parquet.schema.MessageType;
import org.apache.parquet.schema.Type;
import org.joda.time.DateTimeZone;

import javax.inject.Inject;
Expand Down Expand Up @@ -217,9 +218,6 @@ public static ReaderPageSource createPageSource(
Optional<ParquetWriteValidation> parquetWriteValidation,
int domainCompactionThreshold)
{
// Ignore predicates on partial columns for now.
effectivePredicate = effectivePredicate.filter((column, domain) -> column.isBaseColumn());

MessageType fileSchema;
MessageType requestedSchema;
MessageColumnIO messageColumn;
Expand Down Expand Up @@ -433,19 +431,30 @@ public static TupleDomain<ColumnDescriptor> getParquetTupleDomain(
continue;
}

ColumnDescriptor descriptor;
if (useColumnNames) {
descriptor = descriptorsByPath.get(ImmutableList.of(columnHandle.getName()));
ColumnDescriptor descriptor = null;

Optional<org.apache.parquet.schema.Type> baseColumnType = getBaseColumnParquetType(columnHandle, fileSchema, useColumnNames);
// failed to look up the column from the file schema
if (baseColumnType.isEmpty()) {
continue;
}
else if (columnHandle.getHiveColumnProjectionInfo().isEmpty() && baseColumnType.get().isPrimitive()) {
descriptor = descriptorsByPath.get(ImmutableList.of(baseColumnType.get().getName()));
}
else {
Optional<org.apache.parquet.schema.Type> parquetField = getBaseColumnParquetType(columnHandle, fileSchema, false);
if (parquetField.isEmpty() || !parquetField.get().isPrimitive()) {
// Parquet file has fewer column than partition
// Or the field is a complex type
else if (columnHandle.getHiveColumnProjectionInfo().isPresent() && !baseColumnType.get().isPrimitive()) {
Optional<List<Type>> subfieldTypes = dereferenceSubFieldTypes(baseColumnType.get().asGroupType(), columnHandle.getHiveColumnProjectionInfo().get());
// failed to look up subfields from the file schema
if (subfieldTypes.isEmpty()) {
continue;
}
descriptor = descriptorsByPath.get(ImmutableList.of(parquetField.get().getName()));

ImmutableList.Builder<String> path = ImmutableList.builder();
path.add(baseColumnType.get().getName());
path.addAll(subfieldTypes.get().stream().map(Type::getName).toList());

descriptor = descriptorsByPath.get(path.build());
}

if (descriptor != null) {
predicate.put(descriptor, entry.getValue());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5289,11 +5289,11 @@ public void testParquetOnlyNullsRowGroupPruning()
// Nested column `b` also has nulls count of 4096, but it contains non nulls as well
assertUpdate("CREATE TABLE " + tableName + " (col ROW(a BIGINT, b ARRAY(DOUBLE))) WITH (format = 'PARQUET')");
assertUpdate("INSERT INTO " + tableName + " SELECT * FROM unnest(transform(repeat(1, 4096), x -> ROW(ROW(NULL, ARRAY [NULL, rand()]))))", 4096);
// TODO replace with assertNoDataRead after nested column predicate pushdown

assertQueryStats(
getSession(),
"SELECT * FROM " + tableName + " WHERE col.a IS NOT NULL",
queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0),
queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isEqualTo(0),
results -> assertThat(results.getRowCount()).isEqualTo(0));
assertQueryStats(
getSession(),
Expand All @@ -5302,6 +5302,32 @@ public void testParquetOnlyNullsRowGroupPruning()
results -> assertThat(results.getRowCount()).isEqualTo(4096));
}

@Test
public void testParquetNestedRowGroupPruning()
{
String tableName = "test_primitive_column_nested_pruning_" + randomNameSuffix();
assertUpdate("CREATE TABLE " + tableName + " (col BIGINT) WITH (format = 'PARQUET')");
assertUpdate("INSERT INTO " + tableName + " SELECT * FROM unnest(repeat(300, 4096))", 4096);
assertNoDataRead("SELECT * FROM " + tableName + " WHERE col != 300");

tableName = "test_nested_column_nulls_pruning_" + randomNameSuffix();
// Nested column `a` has nulls count of 4096 and contains only nulls
// Nested column `b` also has nulls count of 4096, but it contains non nulls as well
assertUpdate("CREATE TABLE " + tableName + " (col ROW(a BIGINT, b BIGINT)) WITH (format = 'PARQUET')");
assertUpdate("INSERT INTO " + tableName + " SELECT * FROM unnest(transform(repeat(1, 4096), x -> ROW(ROW(300, 500))))", 4096);

assertQueryStats(
getSession(),
"SELECT * FROM " + tableName + " WHERE col.a != 300",
queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isEqualTo(0),
results -> assertThat(results.getRowCount()).isEqualTo(0));
assertQueryStats(
getSession(),
"SELECT * FROM " + tableName + " WHERE col.b = 500",
queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0),
results -> assertThat(results.getRowCount()).isEqualTo(4096));
}

private void assertNoDataRead(@Language("SQL") String sql)
{
assertQueryStats(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import io.trino.plugin.hive.HiveColumnHandle;
import io.trino.plugin.hive.HiveColumnProjectionInfo;
import io.trino.plugin.hive.HiveType;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.TupleDomain;
Expand Down Expand Up @@ -122,12 +123,86 @@ public void testParquetTupleDomainStruct(boolean useColumnNames)
MessageType fileSchema = new MessageType("hive_schema",
new GroupType(OPTIONAL, "my_struct",
new PrimitiveType(OPTIONAL, INT32, "a"),
new PrimitiveType(OPTIONAL, INT32, "b")));
new PrimitiveType(OPTIONAL, INT32, "b"),
new PrimitiveType(OPTIONAL, INT32, "c")));
Map<List<String>, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, fileSchema);
TupleDomain<ColumnDescriptor> tupleDomain = getParquetTupleDomain(descriptorsByPath, domain, fileSchema, useColumnNames);
assertTrue(tupleDomain.isAll());
}

@Test(dataProvider = "useColumnNames")
public void testParquetTupleDomainStructNestedColumn(boolean useColumNames)
{
RowType baseType = rowType(
RowType.field("a", INTEGER),
RowType.field("b", INTEGER),
RowType.field("c", INTEGER));

HiveColumnProjectionInfo columnProjectionInfo = new HiveColumnProjectionInfo(
ImmutableList.of(1),
ImmutableList.of("b"),
HiveType.HIVE_INT,
INTEGER);

HiveColumnHandle projectedColumn = new HiveColumnHandle(
"row_field",
0,
HiveType.toHiveType(baseType),
baseType,
Optional.of(columnProjectionInfo),
REGULAR,
Optional.empty());

Domain predicateDomain = Domain.singleValue(INTEGER, 123L);
TupleDomain<HiveColumnHandle> tupleDomain = withColumnDomains(ImmutableMap.of(projectedColumn, predicateDomain));

MessageType fileSchema = new MessageType("hive_schema",
new GroupType(OPTIONAL, "row_field",
new PrimitiveType(OPTIONAL, INT32, "a"),
new PrimitiveType(OPTIONAL, INT32, "b"),
new PrimitiveType(OPTIONAL, INT32, "c")));
Map<List<String>, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, fileSchema);
TupleDomain<ColumnDescriptor> calculatedTupleDomain = getParquetTupleDomain(descriptorsByPath, tupleDomain, fileSchema, useColumNames);
assertEquals(calculatedTupleDomain.getDomains().get().size(), 1);
ColumnDescriptor selectedColumnDescriptor = descriptorsByPath.get(ImmutableList.of("row_field", "b"));
assertEquals(calculatedTupleDomain.getDomains().get().get(selectedColumnDescriptor), (predicateDomain));
}

@Test(dataProvider = "useColumnNames")
public void testParquetTupleDomainStructNestedColumnNonExist(boolean useColumnNames)
{
RowType baseType = rowType(
RowType.field("a", INTEGER),
RowType.field("b", INTEGER),
RowType.field("non_exist", INTEGER));

HiveColumnProjectionInfo columnProjectionInfo = new HiveColumnProjectionInfo(
ImmutableList.of(2),
ImmutableList.of("non_exist"),
HiveType.HIVE_INT,
INTEGER);

HiveColumnHandle projectedColumn = new HiveColumnHandle(
"row_field",
0,
HiveType.toHiveType(baseType),
baseType,
Optional.of(columnProjectionInfo),
REGULAR,
Optional.empty());

Domain predicateDomain = Domain.singleValue(INTEGER, 123L);
TupleDomain<HiveColumnHandle> tupleDomain = withColumnDomains(ImmutableMap.of(projectedColumn, predicateDomain));

MessageType fileSchema = new MessageType("hive_schema",
new GroupType(OPTIONAL, "row_field",
new PrimitiveType(OPTIONAL, INT32, "a"),
new PrimitiveType(OPTIONAL, INT32, "b")));
Map<List<String>, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, fileSchema);
TupleDomain<ColumnDescriptor> calculatedTupleDomain = getParquetTupleDomain(descriptorsByPath, tupleDomain, fileSchema, useColumnNames);
assertTrue(calculatedTupleDomain.isAll());
}

@Test(dataProvider = "useColumnNames")
public void testParquetTupleDomainMap(boolean useColumnNames)
{
Expand Down

0 comments on commit cb71c5d

Please sign in to comment.