From 151118d8df9f89caefdc13f43ad337e292b3b8e7 Mon Sep 17 00:00:00 2001 From: Yingjie Luan <1275963@gmail.com> Date: Wed, 19 Apr 2023 14:44:03 -0700 Subject: [PATCH] Implement predicate pushdown for ROW sub fields in parquet for hive --- .../parquet/ParquetPageSourceFactory.java | 33 ++-- ...veParquetComplexTypePredicatePushDown.java | 21 ++- .../predicate/TestParquetPredicateUtils.java | 120 ++++++++++++++- ...leFormatComplexTypesPredicatePushDown.java | 141 ++++++++++++++++++ ...stParquetComplexTypePredicatePushDown.java | 49 ------ 5 files changed, 301 insertions(+), 63 deletions(-) create mode 100644 testing/trino-testing/src/main/java/io/trino/testing/BaseTestFileFormatComplexTypesPredicatePushDown.java delete mode 100644 testing/trino-testing/src/main/java/io/trino/testing/BaseTestParquetComplexTypePredicatePushDown.java diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java index e9577865b789..0d977ddb2d37 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/parquet/ParquetPageSourceFactory.java @@ -57,6 +57,7 @@ import org.apache.parquet.io.MessageColumnIO; import org.apache.parquet.schema.GroupType; import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.Type; import org.joda.time.DateTimeZone; import javax.inject.Inject; @@ -72,6 +73,7 @@ import java.util.Set; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; import static io.trino.parquet.ParquetTypeUtils.constructField; @@ -217,9 +219,6 @@ public static ReaderPageSource createPageSource( Optional parquetWriteValidation, int domainCompactionThreshold) { - // Ignore predicates on partial columns for now. - effectivePredicate = effectivePredicate.filter((column, domain) -> column.isBaseColumn()); - MessageType fileSchema; MessageType requestedSchema; MessageColumnIO messageColumn; @@ -434,18 +433,32 @@ public static TupleDomain getParquetTupleDomain( } ColumnDescriptor descriptor; - if (useColumnNames) { - descriptor = descriptorsByPath.get(ImmutableList.of(columnHandle.getName())); + + Optional baseColumnType = getBaseColumnParquetType(columnHandle, fileSchema, useColumnNames); + // Parquet file has fewer column than partition + if (baseColumnType.isEmpty()) { + continue; + } + + if (baseColumnType.get().isPrimitive()) { + descriptor = descriptorsByPath.get(ImmutableList.of(baseColumnType.get().getName())); } else { - Optional parquetField = getBaseColumnParquetType(columnHandle, fileSchema, false); - if (parquetField.isEmpty() || !parquetField.get().isPrimitive()) { - // Parquet file has fewer column than partition - // Or the field is a complex type + if (columnHandle.getHiveColumnProjectionInfo().isEmpty()) { continue; } - descriptor = descriptorsByPath.get(ImmutableList.of(parquetField.get().getName())); + Optional> subfieldTypes = dereferenceSubFieldTypes(baseColumnType.get().asGroupType(), columnHandle.getHiveColumnProjectionInfo().get()); + // failed to look up subfields from the file schema + if (subfieldTypes.isEmpty()) { + continue; + } + + descriptor = descriptorsByPath.get(ImmutableList.builder() + .add(baseColumnType.get().getName()) + .addAll(subfieldTypes.get().stream().map(Type::getName).collect(toImmutableList())) + .build()); } + if (descriptor != null) { predicate.put(descriptor, entry.getValue()); } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestHiveParquetComplexTypePredicatePushDown.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestHiveParquetComplexTypePredicatePushDown.java index 04c1821590ea..5d119938ba42 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestHiveParquetComplexTypePredicatePushDown.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/TestHiveParquetComplexTypePredicatePushDown.java @@ -14,16 +14,31 @@ package io.trino.plugin.hive.parquet; import io.trino.plugin.hive.HiveQueryRunner; -import io.trino.testing.BaseTestParquetComplexTypePredicatePushDown; +import io.trino.testing.BaseTestFileFormatComplexTypesPredicatePushDown; import io.trino.testing.QueryRunner; +import org.testng.annotations.Test; + +import static io.trino.testing.TestingNames.randomNameSuffix; +import static org.assertj.core.api.Assertions.assertThat; public class TestHiveParquetComplexTypePredicatePushDown - extends BaseTestParquetComplexTypePredicatePushDown + extends BaseTestFileFormatComplexTypesPredicatePushDown { @Override protected QueryRunner createQueryRunner() throws Exception { - return HiveQueryRunner.builder().build(); + return HiveQueryRunner.builder() + .addHiveProperty("hive.storage-format", "PARQUET") + .build(); + } + + @Test + public void ensureFormatParquet() + { + String tableName = "test_table_" + randomNameSuffix(); + assertUpdate("CREATE TABLE " + tableName + " (colTest BIGINT)"); + assertThat(((String) computeScalar("SHOW CREATE TABLE " + tableName))).contains("PARQUET"); + assertUpdate("DROP TABLE " + tableName); } } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/predicate/TestParquetPredicateUtils.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/predicate/TestParquetPredicateUtils.java index ba10daa102c1..05806e27aa3b 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/predicate/TestParquetPredicateUtils.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/parquet/predicate/TestParquetPredicateUtils.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import io.trino.plugin.hive.HiveColumnHandle; +import io.trino.plugin.hive.HiveColumnProjectionInfo; import io.trino.plugin.hive.HiveType; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; @@ -122,12 +123,129 @@ public void testParquetTupleDomainStruct(boolean useColumnNames) MessageType fileSchema = new MessageType("hive_schema", new GroupType(OPTIONAL, "my_struct", new PrimitiveType(OPTIONAL, INT32, "a"), - new PrimitiveType(OPTIONAL, INT32, "b"))); + new PrimitiveType(OPTIONAL, INT32, "b"), + new PrimitiveType(OPTIONAL, INT32, "c"))); Map, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, fileSchema); TupleDomain tupleDomain = getParquetTupleDomain(descriptorsByPath, domain, fileSchema, useColumnNames); assertTrue(tupleDomain.isAll()); } + @Test(dataProvider = "useColumnNames") + public void testParquetTupleDomainStructWithPrimitiveColumnPredicate(boolean useColumNames) + { + RowType baseType = rowType( + RowType.field("a", INTEGER), + RowType.field("b", INTEGER), + RowType.field("c", INTEGER)); + + HiveColumnProjectionInfo columnProjectionInfo = new HiveColumnProjectionInfo( + ImmutableList.of(1), + ImmutableList.of("b"), + HiveType.HIVE_INT, + INTEGER); + + HiveColumnHandle projectedColumn = new HiveColumnHandle( + "row_field", + 0, + HiveType.toHiveType(baseType), + baseType, + Optional.of(columnProjectionInfo), + REGULAR, + Optional.empty()); + + Domain predicateDomain = Domain.singleValue(INTEGER, 123L); + TupleDomain tupleDomain = withColumnDomains(ImmutableMap.of(projectedColumn, predicateDomain)); + + MessageType fileSchema = new MessageType("hive_schema", + new GroupType(OPTIONAL, "row_field", + new PrimitiveType(OPTIONAL, INT32, "a"), + new PrimitiveType(OPTIONAL, INT32, "b"), + new PrimitiveType(OPTIONAL, INT32, "c"))); + Map, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, fileSchema); + TupleDomain calculatedTupleDomain = getParquetTupleDomain(descriptorsByPath, tupleDomain, fileSchema, useColumNames); + assertEquals(calculatedTupleDomain.getDomains().get().size(), 1); + ColumnDescriptor selectedColumnDescriptor = descriptorsByPath.get(ImmutableList.of("row_field", "b")); + assertEquals(calculatedTupleDomain.getDomains().get().get(selectedColumnDescriptor), predicateDomain); + } + + @Test(dataProvider = "useColumnNames") + public void testParquetTupleDomainStructWithComplexColumnPredicate(boolean useColumNames) + { + RowType c1Type = rowType( + RowType.field("c1", INTEGER), + RowType.field("c2", INTEGER)); + RowType baseType = rowType( + RowType.field("a", INTEGER), + RowType.field("b", INTEGER), + RowType.field("c", c1Type)); + + HiveColumnProjectionInfo columnProjectionInfo = new HiveColumnProjectionInfo( + ImmutableList.of(2), + ImmutableList.of("C"), + HiveType.toHiveType(c1Type), + c1Type); + + HiveColumnHandle projectedColumn = new HiveColumnHandle( + "row_field", + 0, + HiveType.toHiveType(baseType), + baseType, + Optional.of(columnProjectionInfo), + REGULAR, + Optional.empty()); + + Domain predicateDomain = Domain.onlyNull(c1Type); + TupleDomain tupleDomain = withColumnDomains(ImmutableMap.of(projectedColumn, predicateDomain)); + + MessageType fileSchema = new MessageType("hive_schema", + new GroupType(OPTIONAL, "row_field", + new PrimitiveType(OPTIONAL, INT32, "a"), + new PrimitiveType(OPTIONAL, INT32, "b"), + new GroupType(OPTIONAL, + "c", + new PrimitiveType(OPTIONAL, INT32, "c1"), + new PrimitiveType(OPTIONAL, INT32, "c2")))); + Map, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, fileSchema); + // skip looking up predicates for complex types as Parquet only stores stats for primitives + TupleDomain calculatedTupleDomain = getParquetTupleDomain(descriptorsByPath, tupleDomain, fileSchema, useColumNames); + assertTrue(calculatedTupleDomain.isAll()); + } + + @Test(dataProvider = "useColumnNames") + public void testParquetTupleDomainStructWithMissingPrimitiveColumn(boolean useColumnNames) + { + RowType baseType = rowType( + RowType.field("a", INTEGER), + RowType.field("b", INTEGER), + RowType.field("non_exist", INTEGER)); + + HiveColumnProjectionInfo columnProjectionInfo = new HiveColumnProjectionInfo( + ImmutableList.of(2), + ImmutableList.of("non_exist"), + HiveType.HIVE_INT, + INTEGER); + + HiveColumnHandle projectedColumn = new HiveColumnHandle( + "row_field", + 0, + HiveType.toHiveType(baseType), + baseType, + Optional.of(columnProjectionInfo), + REGULAR, + Optional.empty()); + + Domain predicateDomain = Domain.singleValue(INTEGER, 123L); + TupleDomain tupleDomain = withColumnDomains(ImmutableMap.of(projectedColumn, predicateDomain)); + + MessageType fileSchema = new MessageType("hive_schema", + new GroupType(OPTIONAL, "row_field", + new PrimitiveType(OPTIONAL, INT32, "a"), + new PrimitiveType(OPTIONAL, INT32, "b"))); + Map, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, fileSchema); + TupleDomain calculatedTupleDomain = getParquetTupleDomain(descriptorsByPath, tupleDomain, fileSchema, useColumnNames); + assertTrue(calculatedTupleDomain.isAll()); + } + @Test(dataProvider = "useColumnNames") public void testParquetTupleDomainMap(boolean useColumnNames) { diff --git a/testing/trino-testing/src/main/java/io/trino/testing/BaseTestFileFormatComplexTypesPredicatePushDown.java b/testing/trino-testing/src/main/java/io/trino/testing/BaseTestFileFormatComplexTypesPredicatePushDown.java new file mode 100644 index 000000000000..fe8cfe5ad7df --- /dev/null +++ b/testing/trino-testing/src/main/java/io/trino/testing/BaseTestFileFormatComplexTypesPredicatePushDown.java @@ -0,0 +1,141 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.testing; + +import org.testng.annotations.Test; + +import static io.trino.testing.TestingNames.randomNameSuffix; +import static org.assertj.core.api.Assertions.assertThat; + +public abstract class BaseTestFileFormatComplexTypesPredicatePushDown + extends AbstractTestQueryFramework +{ + @Test + public void testRowTypeOnlyNullsRowGroupPruning() + { + String tableName = "test_primitive_column_nulls_pruning_" + randomNameSuffix(); + assertUpdate("CREATE TABLE " + tableName + " (col BIGINT)"); + assertUpdate("INSERT INTO " + tableName + " SELECT * FROM unnest(repeat(NULL, 4096))", 4096); + assertNoDataRead("SELECT * FROM " + tableName + " WHERE col IS NOT NULL"); + + tableName = "test_nested_column_nulls_pruning_" + randomNameSuffix(); + // Nested column `a` has nulls count of 4096 and contains only nulls + // Nested column `b` also has nulls count of 4096, but it contains non nulls as well + assertUpdate("CREATE TABLE " + tableName + " (col ROW(a BIGINT, b ARRAY(DOUBLE)))"); + assertUpdate("INSERT INTO " + tableName + " SELECT * FROM unnest(transform(repeat(1, 4096), x -> ROW(ROW(NULL, ARRAY [NULL, rand()]))))", 4096); + + assertNoDataRead("SELECT * FROM " + tableName + " WHERE col.a IS NOT NULL"); + + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE col.b IS NOT NULL", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(4096)); + + assertUpdate("DROP TABLE " + tableName); + } + + @Test + public void testRowTypeRowGroupPruning() + { + String tableName = "test_nested_column_pruning_" + randomNameSuffix(); + assertUpdate("CREATE TABLE " + tableName + " (col1Row ROW(a BIGINT, b BIGINT, c ROW(c1 BIGINT, c2 ROW(c21 BIGINT, c22 BIGINT))))"); + assertUpdate("INSERT INTO " + tableName + " SELECT * FROM unnest(transform(SEQUENCE(1, 10000), x -> ROW(ROW(x*2, 100, ROW(x, ROW(x*5, x*6))))))", 10000); + + // no data read since the row dereference predicate is pushed down + assertNoDataRead("SELECT * FROM " + tableName + " WHERE col1Row.a = -1"); + assertNoDataRead("SELECT * FROM " + tableName + " WHERE col1Row.c.c2.c22 = -1"); + assertNoDataRead("SELECT * FROM " + tableName + " WHERE col1Row.a = -1 AND col1ROW.b = -1 AND col1ROW.c.c1 = -1 AND col1Row.c.c2.c22 = -1"); + + // read all since predicate case matches with the data + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE col1Row.b = 100", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(10000)); + + // no predicate push down for matching with ROW type, as file format only stores stats for primitives + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE col1Row.c = ROW(-1, ROW(-1, -1))", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(0)); + + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE col1Row.c = ROW(-1, ROW(-1, -1)) OR col1Row.a = -1 ", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(0)); + + // no data read since the row group get filtered by primitives in the predicate + assertNoDataRead("SELECT * FROM " + tableName + " WHERE col1Row.c = ROW(-1, ROW(-1, -1)) AND col1Row.a = -1 "); + + // no predicate push down for entire ROW, as file format only stores stats for primitives + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE col1Row = ROW(-1, -1, ROW(-1, ROW(-1, -1)))", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(0)); + + assertUpdate("DROP TABLE " + tableName); + } + + @Test + public void testMapTypeRowGroupPruning() + { + String tableName = "test_nested_column_pruning_" + randomNameSuffix(); + assertUpdate("CREATE TABLE " + tableName + " (colMap Map(VARCHAR, BIGINT))"); + assertUpdate("INSERT INTO " + tableName + " SELECT * FROM unnest(transform(SEQUENCE(1, 10000), x -> ROW(MAP(ARRAY['FOO', 'BAR'], ARRAY[100, 200]))))", 10000); + + // no predicate push down for MAP type dereference + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE colMap['FOO'] = -1", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(0)); + + // no predicate push down for entire Map type + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE colMap = MAP(ARRAY['FOO', 'BAR'], ARRAY[-1, -1])", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(0)); + + assertUpdate("DROP TABLE " + tableName); + } + + @Test + public void testArrayTypeRowGroupPruning() + { + String tableName = "test_nested_column_pruning_" + randomNameSuffix(); + assertUpdate("CREATE TABLE " + tableName + " (colArray ARRAY(BIGINT))"); + assertUpdate("INSERT INTO " + tableName + " SELECT * FROM unnest(transform(SEQUENCE(1, 10000), x -> ROW(ARRAY[100, 200])))", 10000); + + // no predicate push down for ARRAY type dereference + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE colArray[1] = -1", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(0)); + + // no predicate push down for entire ARRAY type + assertQueryStats( + getSession(), + "SELECT * FROM " + tableName + " WHERE colArray = ARRAY[-1, -1]", + queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), + results -> assertThat(results.getRowCount()).isEqualTo(0)); + + assertUpdate("DROP TABLE " + tableName); + } +} diff --git a/testing/trino-testing/src/main/java/io/trino/testing/BaseTestParquetComplexTypePredicatePushDown.java b/testing/trino-testing/src/main/java/io/trino/testing/BaseTestParquetComplexTypePredicatePushDown.java deleted file mode 100644 index b4e3c0141b5d..000000000000 --- a/testing/trino-testing/src/main/java/io/trino/testing/BaseTestParquetComplexTypePredicatePushDown.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.testing; - -import org.testng.annotations.Test; - -import static io.trino.testing.TestingNames.randomNameSuffix; -import static org.assertj.core.api.Assertions.assertThat; - -public abstract class BaseTestParquetComplexTypePredicatePushDown - extends AbstractTestQueryFramework -{ - @Test - public void testParquetOnlyNullsRowGroupPruning() - { - String tableName = "test_primitive_column_nulls_pruning_" + randomNameSuffix(); - assertUpdate("CREATE TABLE " + tableName + " (col BIGINT) WITH (format = 'PARQUET')"); - assertUpdate("INSERT INTO " + tableName + " SELECT * FROM unnest(repeat(NULL, 4096))", 4096); - assertNoDataRead("SELECT * FROM " + tableName + " WHERE col IS NOT NULL"); - - tableName = "test_nested_column_nulls_pruning_" + randomNameSuffix(); - // Nested column `a` has nulls count of 4096 and contains only nulls - // Nested column `b` also has nulls count of 4096, but it contains non nulls as well - assertUpdate("CREATE TABLE " + tableName + " (col ROW(a BIGINT, b ARRAY(DOUBLE))) WITH (format = 'PARQUET')"); - assertUpdate("INSERT INTO " + tableName + " SELECT * FROM unnest(transform(repeat(1, 4096), x -> ROW(ROW(NULL, ARRAY [NULL, rand()]))))", 4096); - // TODO replace with assertNoDataRead after nested column predicate pushdown - assertQueryStats( - getSession(), - "SELECT * FROM " + tableName + " WHERE col.a IS NOT NULL", - queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), - results -> assertThat(results.getRowCount()).isEqualTo(0)); - assertQueryStats( - getSession(), - "SELECT * FROM " + tableName + " WHERE col.b IS NOT NULL", - queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0), - results -> assertThat(results.getRowCount()).isEqualTo(4096)); - } -}