diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java index ac7384a9bd90..6317272c1657 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java @@ -33,6 +33,7 @@ import io.trino.plugin.hive.AcidInfo; import io.trino.plugin.hive.FileFormatDataSourceStats; import io.trino.plugin.hive.HiveColumnHandle; +import io.trino.plugin.hive.HiveColumnProjectionInfo; import io.trino.plugin.hive.HiveConfig; import io.trino.plugin.hive.HivePageSourceFactory; import io.trino.plugin.hive.HiveType; @@ -56,6 +57,7 @@ import org.apache.parquet.io.MessageColumnIO; import org.apache.parquet.schema.GroupType; import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.Type; import org.joda.time.DateTimeZone; import javax.inject.Inject; @@ -71,6 +73,7 @@ import java.util.Set; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; import static io.trino.parquet.ParquetTypeUtils.constructField; @@ -216,9 +219,6 @@ public static ReaderPageSource createPageSource( Optional parquetWriteValidation, int domainCompactionThreshold) { - // Ignore predicates on partial columns for now. - effectivePredicate = effectivePredicate.filter((column, domain) -> column.isBaseColumn()); - MessageType fileSchema; MessageType requestedSchema; MessageColumnIO messageColumn; @@ -331,24 +331,19 @@ public static Optional getParquetMessageType(List public static Optional getColumnType(HiveColumnHandle column, MessageType messageType, boolean useParquetColumnNames) { - Optional columnType = getBaseColumnParquetType(column, messageType, useParquetColumnNames); - if (columnType.isEmpty() || column.getHiveColumnProjectionInfo().isEmpty()) { - return columnType; + Optional baseColumnType = getBaseColumnParquetType(column, messageType, useParquetColumnNames); + if (baseColumnType.isEmpty() || column.getHiveColumnProjectionInfo().isEmpty()) { + return baseColumnType; } - GroupType baseType = columnType.get().asGroupType(); - ImmutableList.Builder typeBuilder = ImmutableList.builder(); - org.apache.parquet.schema.Type parentType = baseType; + GroupType baseType = baseColumnType.get().asGroupType(); + Optional> subFieldTypesOptional = dereferenceSubFieldTypes(baseType, column.getHiveColumnProjectionInfo().get()); - for (String name : column.getHiveColumnProjectionInfo().get().getDereferenceNames()) { - org.apache.parquet.schema.Type childType = getParquetTypeByName(name, parentType.asGroupType()); - if (childType == null) { - return Optional.empty(); - } - typeBuilder.add(childType); - parentType = childType; + // if there is a mismatch between parquet schema and the hive schema and the column cannot be dereferenced + if (subFieldTypesOptional.isEmpty()) { + return Optional.empty(); } - List subfieldTypes = typeBuilder.build(); + List subfieldTypes = subFieldTypesOptional.get(); org.apache.parquet.schema.Type type = subfieldTypes.get(subfieldTypes.size() - 1); for (int i = subfieldTypes.size() - 2; i >= 0; --i) { GroupType groupType = subfieldTypes.get(i).asGroupType(); @@ -437,18 +432,32 @@ public static TupleDomain getParquetTupleDomain( } ColumnDescriptor descriptor; - if (useColumnNames) { - descriptor = descriptorsByPath.get(ImmutableList.of(columnHandle.getName())); + + Optional baseColumnType = getBaseColumnParquetType(columnHandle, fileSchema, useColumnNames); + // Parquet file has fewer column than partition + if (baseColumnType.isEmpty()) { + continue; + } + + if (baseColumnType.get().isPrimitive()) { + descriptor = descriptorsByPath.get(ImmutableList.of(baseColumnType.get().getName())); } else { - Optional parquetField = getBaseColumnParquetType(columnHandle, fileSchema, false); - if (parquetField.isEmpty() || !parquetField.get().isPrimitive()) { - // Parquet file has fewer column than partition - // Or the field is a complex type + if (columnHandle.getHiveColumnProjectionInfo().isEmpty()) { continue; } - descriptor = descriptorsByPath.get(ImmutableList.of(parquetField.get().getName())); + Optional> subfieldTypes = dereferenceSubFieldTypes(baseColumnType.get().asGroupType(), columnHandle.getHiveColumnProjectionInfo().get()); + // failed to look up subfields from the file schema + if (subfieldTypes.isEmpty()) { + continue; + } + + descriptor = descriptorsByPath.get(ImmutableList.builder() + .add(baseColumnType.get().getName()) + .addAll(subfieldTypes.get().stream().map(Type::getName).collect(toImmutableList())) + .build()); } + if (descriptor != null) { predicate.put(descriptor, entry.getValue()); } @@ -509,4 +518,32 @@ private static Optional getBaseColumnParquetType return Optional.empty(); } + + /** + * Dereferencing base parquet type based on projection info's dereference names. + * For example, when dereferencing baseType(level1Field0, level1Field1, Level1Field2(Level2Field0, Level2Field1)) + * with a projection info's dereferenceNames list as (basetype, Level1Field2, Level2Field1). + * It would return a list of parquet types in the order of (level1Field2, Level2Field1) + * + * @return child fields on each level of dereferencing. Return Optional.empty when failed to do the lookup. + */ + private static Optional> dereferenceSubFieldTypes(GroupType baseType, HiveColumnProjectionInfo projectionInfo) + { + checkArgument(baseType != null, "base type cannot be null when dereferencing"); + checkArgument(projectionInfo != null, "hive column projection info cannot be null when doing dereferencing"); + + ImmutableList.Builder typeBuilder = ImmutableList.builder(); + org.apache.parquet.schema.Type parentType = baseType; + + for (String name : projectionInfo.getDereferenceNames()) { + org.apache.parquet.schema.Type childType = getParquetTypeByName(name, parentType.asGroupType()); + if (childType == null) { + return Optional.empty(); + } + typeBuilder.add(childType); + parentType = childType; + } + + return Optional.of(typeBuilder.build()); + } } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java index 967ee918fd9d..9a3c9504fbce 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java @@ -5276,41 +5276,6 @@ private void testParquetDictionaryPredicatePushdown(Session session) assertNoDataRead("SELECT * FROM " + tableName + " WHERE n = 3"); } - @Test - public void testParquetOnlyNullsRowGroupPruning() - { - String tableName = "test_primitive_column_nulls_pruning_" + randomNameSuffix(); - assertUpdate("CREATE TABLE " + tableName + " (col BIGINT) WITH (format = 'PARQUET')"); - assertUpdate("INSERT INTO " + tableName + " SELECT * FROM unnest(repeat(NULL, 4096))", 4096); - assertNoDataRead("SELECT * FROM " + tableName + " WHERE col IS NOT NULL"); - - tableName = "test_nested_column_nulls_pruning_" + randomNameSuffix(); - // Nested column `a` has nulls count of 4096 and contains only nulls - // Nested column `b` also has nulls count of 4096, but it contains non nulls as well - assertUpdate("CREATE TABLE " + tableName + " (col ROW(a BIGINT, b ARRAY(DOUBLE))) WITH (format = 'PARQUET')"); - assertUpdate("INSERT INTO " + tableName + " SELECT * FROM unnest(transform(repeat(1, 4096), x -> ROW(ROW(NULL, ARRAY [NULL, rand()]))))", 4096); - // TODO replace with assertNoDataRead after nested column predicate pushdown - assertQueryStats( - getSession(), - "SELECT * FROM " + tableName + " WHERE col.a IS NOT NULL", - queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), - results -> assertThat(results.getRowCount()).isEqualTo(0)); - assertQueryStats( - getSession(), - "SELECT * FROM " + tableName + " WHERE col.b IS NOT NULL", - queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), - results -> assertThat(results.getRowCount()).isEqualTo(4096)); - } - - private void assertNoDataRead(@Language("SQL") String sql) - { - assertQueryStats( - getSession(), - sql, - queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isEqualTo(0), - results -> assertThat(results.getRowCount()).isEqualTo(0)); - } - private QueryInfo getQueryInfo(DistributedQueryRunner queryRunner, MaterializedResultWithQueryId queryResult) { return queryRunner.getCoordinator().getQueryManager().getFullQueryInfo(queryResult.getQueryId()); diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestHiveParquetComplexTypePredicatePushDown.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestHiveParquetComplexTypePredicatePushDown.java new file mode 100644 index 000000000000..5d119938ba42 --- /dev/null +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestHiveParquetComplexTypePredicatePushDown.java @@ -0,0 +1,44 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive.parquet; + +import io.trino.plugin.hive.HiveQueryRunner; +import io.trino.testing.BaseTestFileFormatComplexTypesPredicatePushDown; +import io.trino.testing.QueryRunner; +import org.testng.annotations.Test; + +import static io.trino.testing.TestingNames.randomNameSuffix; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestHiveParquetComplexTypePredicatePushDown + extends BaseTestFileFormatComplexTypesPredicatePushDown +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return HiveQueryRunner.builder() + .addHiveProperty("hive.storage-format", "PARQUET") + .build(); + } + + @Test + public void ensureFormatParquet() + { + String tableName = "test_table_" + randomNameSuffix(); + assertUpdate("CREATE TABLE " + tableName + " (colTest BIGINT)"); + assertThat(((String) computeScalar("SHOW CREATE TABLE " + tableName))).contains("PARQUET"); + assertUpdate("DROP TABLE " + tableName); + } +} diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/predicate/TestParquetPredicateUtils.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/predicate/TestParquetPredicateUtils.java index ba10daa102c1..05806e27aa3b 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/predicate/TestParquetPredicateUtils.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/predicate/TestParquetPredicateUtils.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import io.trino.plugin.hive.HiveColumnHandle; +import io.trino.plugin.hive.HiveColumnProjectionInfo; import io.trino.plugin.hive.HiveType; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; @@ -122,12 +123,129 @@ public void testParquetTupleDomainStruct(boolean useColumnNames) MessageType fileSchema = new MessageType("hive_schema", new GroupType(OPTIONAL, "my_struct", new PrimitiveType(OPTIONAL, INT32, "a"), - new PrimitiveType(OPTIONAL, INT32, "b"))); + new PrimitiveType(OPTIONAL, INT32, "b"), + new PrimitiveType(OPTIONAL, INT32, "c"))); Map, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, fileSchema); TupleDomain tupleDomain = getParquetTupleDomain(descriptorsByPath, domain, fileSchema, useColumnNames); assertTrue(tupleDomain.isAll()); } + @Test(dataProvider = "useColumnNames") + public void testParquetTupleDomainStructWithPrimitiveColumnPredicate(boolean useColumNames) + { + RowType baseType = rowType( + RowType.field("a", INTEGER), + RowType.field("b", INTEGER), + RowType.field("c", INTEGER)); + + HiveColumnProjectionInfo columnProjectionInfo = new HiveColumnProjectionInfo( + ImmutableList.of(1), + ImmutableList.of("b"), + HiveType.HIVE_INT, + INTEGER); + + HiveColumnHandle projectedColumn = new HiveColumnHandle( + "row_field", + 0, + HiveType.toHiveType(baseType), + baseType, + Optional.of(columnProjectionInfo), + REGULAR, + Optional.empty()); + + Domain predicateDomain = Domain.singleValue(INTEGER, 123L); + TupleDomain tupleDomain = withColumnDomains(ImmutableMap.of(projectedColumn, predicateDomain)); + + MessageType fileSchema = new MessageType("hive_schema", + new GroupType(OPTIONAL, "row_field", + new PrimitiveType(OPTIONAL, INT32, "a"), + new PrimitiveType(OPTIONAL, INT32, "b"), + new PrimitiveType(OPTIONAL, INT32, "c"))); + Map, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, fileSchema); + TupleDomain calculatedTupleDomain = getParquetTupleDomain(descriptorsByPath, tupleDomain, fileSchema, useColumNames); + assertEquals(calculatedTupleDomain.getDomains().get().size(), 1); + ColumnDescriptor selectedColumnDescriptor = descriptorsByPath.get(ImmutableList.of("row_field", "b")); + assertEquals(calculatedTupleDomain.getDomains().get().get(selectedColumnDescriptor), predicateDomain); + } + + @Test(dataProvider = "useColumnNames") + public void testParquetTupleDomainStructWithComplexColumnPredicate(boolean useColumNames) + { + RowType c1Type = rowType( + RowType.field("c1", INTEGER), + RowType.field("c2", INTEGER)); + RowType baseType = rowType( + RowType.field("a", INTEGER), + RowType.field("b", INTEGER), + RowType.field("c", c1Type)); + + HiveColumnProjectionInfo columnProjectionInfo = new HiveColumnProjectionInfo( + ImmutableList.of(2), + ImmutableList.of("C"), + HiveType.toHiveType(c1Type), + c1Type); + + HiveColumnHandle projectedColumn = new HiveColumnHandle( + "row_field", + 0, + HiveType.toHiveType(baseType), + baseType, + Optional.of(columnProjectionInfo), + REGULAR, + Optional.empty()); + + Domain predicateDomain = Domain.onlyNull(c1Type); + TupleDomain tupleDomain = withColumnDomains(ImmutableMap.of(projectedColumn, predicateDomain)); + + MessageType fileSchema = new MessageType("hive_schema", + new GroupType(OPTIONAL, "row_field", + new PrimitiveType(OPTIONAL, INT32, "a"), + new PrimitiveType(OPTIONAL, INT32, "b"), + new GroupType(OPTIONAL, + "c", + new PrimitiveType(OPTIONAL, INT32, "c1"), + new PrimitiveType(OPTIONAL, INT32, "c2")))); + Map, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, fileSchema); + // skip looking up predicates for complex types as Parquet only stores stats for primitives + TupleDomain calculatedTupleDomain = getParquetTupleDomain(descriptorsByPath, tupleDomain, fileSchema, useColumNames); + assertTrue(calculatedTupleDomain.isAll()); + } + + @Test(dataProvider = "useColumnNames") + public void testParquetTupleDomainStructWithMissingPrimitiveColumn(boolean useColumnNames) + { + RowType baseType = rowType( + RowType.field("a", INTEGER), + RowType.field("b", INTEGER), + RowType.field("non_exist", INTEGER)); + + HiveColumnProjectionInfo columnProjectionInfo = new HiveColumnProjectionInfo( + ImmutableList.of(2), + ImmutableList.of("non_exist"), + HiveType.HIVE_INT, + INTEGER); + + HiveColumnHandle projectedColumn = new HiveColumnHandle( + "row_field", + 0, + HiveType.toHiveType(baseType), + baseType, + Optional.of(columnProjectionInfo), + REGULAR, + Optional.empty()); + + Domain predicateDomain = Domain.singleValue(INTEGER, 123L); + TupleDomain tupleDomain = withColumnDomains(ImmutableMap.of(projectedColumn, predicateDomain)); + + MessageType fileSchema = new MessageType("hive_schema", + new GroupType(OPTIONAL, "row_field", + new PrimitiveType(OPTIONAL, INT32, "a"), + new PrimitiveType(OPTIONAL, INT32, "b"))); + Map, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, fileSchema); + TupleDomain calculatedTupleDomain = getParquetTupleDomain(descriptorsByPath, tupleDomain, fileSchema, useColumnNames); + assertTrue(calculatedTupleDomain.isAll()); + } + @Test(dataProvider = "useColumnNames") public void testParquetTupleDomainMap(boolean useColumnNames) { diff --git a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestQueryFramework.java b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestQueryFramework.java index aa616c95319b..7d0594d865c0 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestQueryFramework.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestQueryFramework.java @@ -562,6 +562,15 @@ protected void assertQueryStats( resultAssertion.accept(resultWithQueryId.getResult()); } + protected void assertNoDataRead(@Language("SQL") String sql) + { + assertQueryStats( + getSession(), + sql, + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isEqualTo(0), + results -> assertThat(results.getRowCount()).isEqualTo(0)); + } + protected MaterializedResult computeExpected(@Language("SQL") String sql, List resultTypes) { return h2QueryRunner.execute(getSession(), sql, resultTypes); diff --git a/testing/trino-testing/src/main/java/io/trino/testing/BaseTestFileFormatComplexTypesPredicatePushDown.java b/testing/trino-testing/src/main/java/io/trino/testing/BaseTestFileFormatComplexTypesPredicatePushDown.java new file mode 100644 index 000000000000..4df67c41f10c --- /dev/null +++ b/testing/trino-testing/src/main/java/io/trino/testing/BaseTestFileFormatComplexTypesPredicatePushDown.java @@ -0,0 +1,168 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.testing; + +import org.testng.annotations.Test; + +import static io.trino.testing.TestingNames.randomNameSuffix; +import static org.assertj.core.api.Assertions.assertThat; + +public abstract class BaseTestFileFormatComplexTypesPredicatePushDown + extends AbstractTestQueryFramework +{ + @Test + public void testRowTypeOnlyNullsRowGroupPruning() + { + String tableName = "test_primitive_column_nulls_pruning_" + randomNameSuffix(); + assertUpdate("CREATE TABLE " + tableName + " (col BIGINT)"); + assertUpdate("INSERT INTO " + tableName + " SELECT * FROM unnest(repeat(NULL, 4096))", 4096); + assertNoDataRead("SELECT * FROM " + tableName + " WHERE col IS NOT NULL"); + + tableName = "test_nested_column_nulls_pruning_" + randomNameSuffix(); + // Nested column `a` has nulls count of 4096 and contains only nulls + // Nested column `b` also has nulls count of 4096, but it contains non nulls as well + assertUpdate("CREATE TABLE " + tableName + " (col ROW(a BIGINT, b ARRAY(DOUBLE)))"); + assertUpdate("INSERT INTO " + tableName + " SELECT * FROM unnest(transform(repeat(1, 4096), x -> ROW(ROW(NULL, ARRAY [NULL, rand()]))))", 4096); + + assertNoDataRead("SELECT * FROM " + tableName + " WHERE col.a IS NOT NULL"); + + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE col.a IS NULL", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(4096)); + + // no predicate push down for the entire array type + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE col.b IS NOT NULL", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(4096)); + + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE col.b IS NULL", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(0)); + + // no predicate push down for entire ROW + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE col IS NOT NULL", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(4096)); + + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE col IS NULL", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(0)); + + assertUpdate("DROP TABLE " + tableName); + } + + @Test + public void testRowTypeRowGroupPruning() + { + String tableName = "test_nested_column_pruning_" + randomNameSuffix(); + assertUpdate("CREATE TABLE " + tableName + " (col1Row ROW(a BIGINT, b BIGINT, c ROW(c1 BIGINT, c2 ROW(c21 BIGINT, c22 BIGINT))))"); + assertUpdate("INSERT INTO " + tableName + " SELECT * FROM unnest(transform(SEQUENCE(1, 10000), x -> ROW(ROW(x*2, 100, ROW(x, ROW(x*5, x*6))))))", 10000); + + // no data read since the row dereference predicate is pushed down + assertNoDataRead("SELECT * FROM " + tableName + " WHERE col1Row.a = -1"); + assertNoDataRead("SELECT * FROM " + tableName + " WHERE col1Row.a IS NULL"); + assertNoDataRead("SELECT * FROM " + tableName + " WHERE col1Row.c.c2.c22 = -1"); + assertNoDataRead("SELECT * FROM " + tableName + " WHERE col1Row.a = -1 AND col1ROW.b = -1 AND col1ROW.c.c1 = -1 AND col1Row.c.c2.c22 = -1"); + + // read all since predicate case matches with the data + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE col1Row.b = 100", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(10000)); + + // no predicate push down for matching with ROW type, as file format only stores stats for primitives + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE col1Row.c = ROW(-1, ROW(-1, -1))", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(0)); + + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE col1Row.c = ROW(-1, ROW(-1, -1)) OR col1Row.a = -1 ", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(0)); + + // no data read since the row group get filtered by primitives in the predicate + assertNoDataRead("SELECT * FROM " + tableName + " WHERE col1Row.c = ROW(-1, ROW(-1, -1)) AND col1Row.a = -1 "); + + // no predicate push down for entire ROW, as file format only stores stats for primitives + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE col1Row = ROW(-1, -1, ROW(-1, ROW(-1, -1)))", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(0)); + + assertUpdate("DROP TABLE " + tableName); + } + + @Test + public void testMapTypeRowGroupPruning() + { + String tableName = "test_nested_column_pruning_" + randomNameSuffix(); + assertUpdate("CREATE TABLE " + tableName + " (colMap Map(VARCHAR, BIGINT))"); + assertUpdate("INSERT INTO " + tableName + " SELECT * FROM unnest(transform(SEQUENCE(1, 10000), x -> ROW(MAP(ARRAY['FOO', 'BAR'], ARRAY[100, 200]))))", 10000); + + // no predicate push down for MAP type dereference + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE colMap['FOO'] = -1", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(0)); + + // no predicate push down for entire Map type + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE colMap = MAP(ARRAY['FOO', 'BAR'], ARRAY[-1, -1])", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(0)); + + assertUpdate("DROP TABLE " + tableName); + } + + @Test + public void testArrayTypeRowGroupPruning() + { + String tableName = "test_nested_column_pruning_" + randomNameSuffix(); + assertUpdate("CREATE TABLE " + tableName + " (colArray ARRAY(BIGINT))"); + assertUpdate("INSERT INTO " + tableName + " SELECT * FROM unnest(transform(SEQUENCE(1, 10000), x -> ROW(ARRAY[100, 200])))", 10000); + + // no predicate push down for ARRAY type dereference + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE colArray[1] = -1", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(0)); + + // no predicate push down for entire ARRAY type + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE colArray = ARRAY[-1, -1]", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(0)); + + assertUpdate("DROP TABLE " + tableName); + } +}