diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java index e9577865b78..10a72f57a4c 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java @@ -57,6 +57,7 @@ import org.apache.parquet.io.MessageColumnIO; import org.apache.parquet.schema.GroupType; import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.Type; import org.joda.time.DateTimeZone; import javax.inject.Inject; @@ -217,9 +218,6 @@ public static ReaderPageSource createPageSource( Optional parquetWriteValidation, int domainCompactionThreshold) { - // Ignore predicates on partial columns for now. - effectivePredicate = effectivePredicate.filter((column, domain) -> column.isBaseColumn()); - MessageType fileSchema; MessageType requestedSchema; MessageColumnIO messageColumn; @@ -433,19 +431,30 @@ public static TupleDomain getParquetTupleDomain( continue; } - ColumnDescriptor descriptor; - if (useColumnNames) { - descriptor = descriptorsByPath.get(ImmutableList.of(columnHandle.getName())); + ColumnDescriptor descriptor = null; + + Optional baseColumnType = getBaseColumnParquetType(columnHandle, fileSchema, useColumnNames); + // failed to look up the column from the file schema + if (baseColumnType.isEmpty()) { + continue; + } + else if (columnHandle.getHiveColumnProjectionInfo().isEmpty() && baseColumnType.get().isPrimitive()) { + descriptor = descriptorsByPath.get(ImmutableList.of(baseColumnType.get().getName())); } - else { - Optional parquetField = getBaseColumnParquetType(columnHandle, fileSchema, false); - if (parquetField.isEmpty() || !parquetField.get().isPrimitive()) { - // Parquet file has fewer column than partition - // Or the field is a complex type + else if (columnHandle.getHiveColumnProjectionInfo().isPresent() && !baseColumnType.get().isPrimitive()) { + Optional> subfieldTypes = dereferenceSubFieldTypes(baseColumnType.get().asGroupType(), columnHandle.getHiveColumnProjectionInfo().get()); + // failed to look up subfields from the file schema + if (subfieldTypes.isEmpty()) { continue; } - descriptor = descriptorsByPath.get(ImmutableList.of(parquetField.get().getName())); + + ImmutableList.Builder path = ImmutableList.builder(); + path.add(baseColumnType.get().getName()); + path.addAll(subfieldTypes.get().stream().map(Type::getName).toList()); + + descriptor = descriptorsByPath.get(path.build()); } + if (descriptor != null) { predicate.put(descriptor, entry.getValue()); } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java index 967ee918fd9..383cc78a94e 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java @@ -5289,11 +5289,11 @@ public void testParquetOnlyNullsRowGroupPruning() // Nested column `b` also has nulls count of 4096, but it contains non nulls as well assertUpdate("CREATE TABLE " + tableName + " (col ROW(a BIGINT, b ARRAY(DOUBLE))) WITH (format = 'PARQUET')"); assertUpdate("INSERT INTO " + tableName + " SELECT * FROM unnest(transform(repeat(1, 4096), x -> ROW(ROW(NULL, ARRAY [NULL, rand()]))))", 4096); - // TODO replace with assertNoDataRead after nested column predicate pushdown + assertQueryStats( getSession(), "SELECT * FROM " + tableName + " WHERE col.a IS NOT NULL", - queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isEqualTo(0), results -> assertThat(results.getRowCount()).isEqualTo(0)); assertQueryStats( getSession(), @@ -5302,6 +5302,32 @@ public void testParquetOnlyNullsRowGroupPruning() results -> assertThat(results.getRowCount()).isEqualTo(4096)); } + @Test + public void testParquetNestedRowGroupPruning() + { + String tableName = "test_primitive_column_nested_pruning_" + randomNameSuffix(); + assertUpdate("CREATE TABLE " + tableName + " (col BIGINT) WITH (format = 'PARQUET')"); + assertUpdate("INSERT INTO " + tableName + " SELECT * FROM unnest(repeat(300, 4096))", 4096); + assertNoDataRead("SELECT * FROM " + tableName + " WHERE col != 300"); + + tableName = "test_nested_column_nulls_pruning_" + randomNameSuffix(); + // Nested column `a` has nulls count of 4096 and contains only nulls + // Nested column `b` also has nulls count of 4096, but it contains non nulls as well + assertUpdate("CREATE TABLE " + tableName + " (col ROW(a BIGINT, b BIGINT)) WITH (format = 'PARQUET')"); + assertUpdate("INSERT INTO " + tableName + " SELECT * FROM unnest(transform(repeat(1, 4096), x -> ROW(ROW(300, 500))))", 4096); + + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE col.a != 300", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isEqualTo(0), + results -> assertThat(results.getRowCount()).isEqualTo(0)); + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE col.b = 500", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(4096)); + } + private void assertNoDataRead(@Language("SQL") String sql) { assertQueryStats( diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/predicate/TestParquetPredicateUtils.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/predicate/TestParquetPredicateUtils.java index ba10daa102c..d7a2054f1af 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/predicate/TestParquetPredicateUtils.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/predicate/TestParquetPredicateUtils.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import io.trino.plugin.hive.HiveColumnHandle; +import io.trino.plugin.hive.HiveColumnProjectionInfo; import io.trino.plugin.hive.HiveType; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; @@ -122,12 +123,86 @@ public void testParquetTupleDomainStruct(boolean useColumnNames) MessageType fileSchema = new MessageType("hive_schema", new GroupType(OPTIONAL, "my_struct", new PrimitiveType(OPTIONAL, INT32, "a"), - new PrimitiveType(OPTIONAL, INT32, "b"))); + new PrimitiveType(OPTIONAL, INT32, "b"), + new PrimitiveType(OPTIONAL, INT32, "c"))); Map, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, fileSchema); TupleDomain tupleDomain = getParquetTupleDomain(descriptorsByPath, domain, fileSchema, useColumnNames); assertTrue(tupleDomain.isAll()); } + @Test(dataProvider = "useColumnNames") + public void testParquetTupleDomainStructNestedColumn(boolean useColumNames) + { + RowType baseType = rowType( + RowType.field("a", INTEGER), + RowType.field("b", INTEGER), + RowType.field("c", INTEGER)); + + HiveColumnProjectionInfo columnProjectionInfo = new HiveColumnProjectionInfo( + ImmutableList.of(1), + ImmutableList.of("b"), + HiveType.HIVE_INT, + INTEGER); + + HiveColumnHandle projectedColumn = new HiveColumnHandle( + "row_field", + 0, + HiveType.toHiveType(baseType), + baseType, + Optional.of(columnProjectionInfo), + REGULAR, + Optional.empty()); + + Domain predicateDomain = Domain.singleValue(INTEGER, 123L); + TupleDomain tupleDomain = withColumnDomains(ImmutableMap.of(projectedColumn, predicateDomain)); + + MessageType fileSchema = new MessageType("hive_schema", + new GroupType(OPTIONAL, "row_field", + new PrimitiveType(OPTIONAL, INT32, "a"), + new PrimitiveType(OPTIONAL, INT32, "b"), + new PrimitiveType(OPTIONAL, INT32, "c"))); + Map, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, fileSchema); + TupleDomain calculatedTupleDomain = getParquetTupleDomain(descriptorsByPath, tupleDomain, fileSchema, useColumNames); + assertEquals(calculatedTupleDomain.getDomains().get().size(), 1); + ColumnDescriptor selectedColumnDescriptor = descriptorsByPath.get(ImmutableList.of("row_field", "b")); + assertEquals(calculatedTupleDomain.getDomains().get().get(selectedColumnDescriptor), (predicateDomain)); + } + + @Test(dataProvider = "useColumnNames") + public void testParquetTupleDomainStructNestedColumnNonExist(boolean useColumnNames) + { + RowType baseType = rowType( + RowType.field("a", INTEGER), + RowType.field("b", INTEGER), + RowType.field("non_exist", INTEGER)); + + HiveColumnProjectionInfo columnProjectionInfo = new HiveColumnProjectionInfo( + ImmutableList.of(2), + ImmutableList.of("non_exist"), + HiveType.HIVE_INT, + INTEGER); + + HiveColumnHandle projectedColumn = new HiveColumnHandle( + "row_field", + 0, + HiveType.toHiveType(baseType), + baseType, + Optional.of(columnProjectionInfo), + REGULAR, + Optional.empty()); + + Domain predicateDomain = Domain.singleValue(INTEGER, 123L); + TupleDomain tupleDomain = withColumnDomains(ImmutableMap.of(projectedColumn, predicateDomain)); + + MessageType fileSchema = new MessageType("hive_schema", + new GroupType(OPTIONAL, "row_field", + new PrimitiveType(OPTIONAL, INT32, "a"), + new PrimitiveType(OPTIONAL, INT32, "b"))); + Map, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, fileSchema); + TupleDomain calculatedTupleDomain = getParquetTupleDomain(descriptorsByPath, tupleDomain, fileSchema, useColumnNames); + assertTrue(calculatedTupleDomain.isAll()); + } + @Test(dataProvider = "useColumnNames") public void testParquetTupleDomainMap(boolean useColumnNames) {