Skip to content

Commit

Permalink
Implement predicate pushdown for ROW sub fields in parquet for hive
Browse files Browse the repository at this point in the history
  • Loading branch information
leetcode-1533 authored and Yingjie Luan committed Apr 23, 2023
1 parent a1699f4 commit 151118d
Show file tree
Hide file tree
Showing 5 changed files with 301 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import org.apache.parquet.io.MessageColumnIO;
import org.apache.parquet.schema.GroupType;
import org.apache.parquet.schema.MessageType;
import org.apache.parquet.schema.Type;
import org.joda.time.DateTimeZone;

import javax.inject.Inject;
Expand All @@ -72,6 +73,7 @@
import java.util.Set;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext;
import static io.trino.parquet.ParquetTypeUtils.constructField;
Expand Down Expand Up @@ -217,9 +219,6 @@ public static ReaderPageSource createPageSource(
Optional<ParquetWriteValidation> parquetWriteValidation,
int domainCompactionThreshold)
{
// Ignore predicates on partial columns for now.
effectivePredicate = effectivePredicate.filter((column, domain) -> column.isBaseColumn());

MessageType fileSchema;
MessageType requestedSchema;
MessageColumnIO messageColumn;
Expand Down Expand Up @@ -434,18 +433,32 @@ public static TupleDomain<ColumnDescriptor> getParquetTupleDomain(
}

ColumnDescriptor descriptor;
if (useColumnNames) {
descriptor = descriptorsByPath.get(ImmutableList.of(columnHandle.getName()));

Optional<org.apache.parquet.schema.Type> baseColumnType = getBaseColumnParquetType(columnHandle, fileSchema, useColumnNames);
// Parquet file has fewer column than partition
if (baseColumnType.isEmpty()) {
continue;
}

if (baseColumnType.get().isPrimitive()) {
descriptor = descriptorsByPath.get(ImmutableList.of(baseColumnType.get().getName()));
}
else {
Optional<org.apache.parquet.schema.Type> parquetField = getBaseColumnParquetType(columnHandle, fileSchema, false);
if (parquetField.isEmpty() || !parquetField.get().isPrimitive()) {
// Parquet file has fewer column than partition
// Or the field is a complex type
if (columnHandle.getHiveColumnProjectionInfo().isEmpty()) {
continue;
}
descriptor = descriptorsByPath.get(ImmutableList.of(parquetField.get().getName()));
Optional<List<Type>> subfieldTypes = dereferenceSubFieldTypes(baseColumnType.get().asGroupType(), columnHandle.getHiveColumnProjectionInfo().get());
// failed to look up subfields from the file schema
if (subfieldTypes.isEmpty()) {
continue;
}

descriptor = descriptorsByPath.get(ImmutableList.<String>builder()
.add(baseColumnType.get().getName())
.addAll(subfieldTypes.get().stream().map(Type::getName).collect(toImmutableList()))
.build());
}

if (descriptor != null) {
predicate.put(descriptor, entry.getValue());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,31 @@
package io.trino.plugin.hive.parquet;

import io.trino.plugin.hive.HiveQueryRunner;
import io.trino.testing.BaseTestParquetComplexTypePredicatePushDown;
import io.trino.testing.BaseTestFileFormatComplexTypesPredicatePushDown;
import io.trino.testing.QueryRunner;
import org.testng.annotations.Test;

import static io.trino.testing.TestingNames.randomNameSuffix;
import static org.assertj.core.api.Assertions.assertThat;

public class TestHiveParquetComplexTypePredicatePushDown
extends BaseTestParquetComplexTypePredicatePushDown
extends BaseTestFileFormatComplexTypesPredicatePushDown
{
@Override
protected QueryRunner createQueryRunner()
throws Exception
{
return HiveQueryRunner.builder().build();
return HiveQueryRunner.builder()
.addHiveProperty("hive.storage-format", "PARQUET")
.build();
}

@Test
public void ensureFormatParquet()
{
String tableName = "test_table_" + randomNameSuffix();
assertUpdate("CREATE TABLE " + tableName + " (colTest BIGINT)");
assertThat(((String) computeScalar("SHOW CREATE TABLE " + tableName))).contains("PARQUET");
assertUpdate("DROP TABLE " + tableName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import io.trino.plugin.hive.HiveColumnHandle;
import io.trino.plugin.hive.HiveColumnProjectionInfo;
import io.trino.plugin.hive.HiveType;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.TupleDomain;
Expand Down Expand Up @@ -122,12 +123,129 @@ public void testParquetTupleDomainStruct(boolean useColumnNames)
MessageType fileSchema = new MessageType("hive_schema",
new GroupType(OPTIONAL, "my_struct",
new PrimitiveType(OPTIONAL, INT32, "a"),
new PrimitiveType(OPTIONAL, INT32, "b")));
new PrimitiveType(OPTIONAL, INT32, "b"),
new PrimitiveType(OPTIONAL, INT32, "c")));
Map<List<String>, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, fileSchema);
TupleDomain<ColumnDescriptor> tupleDomain = getParquetTupleDomain(descriptorsByPath, domain, fileSchema, useColumnNames);
assertTrue(tupleDomain.isAll());
}

@Test(dataProvider = "useColumnNames")
public void testParquetTupleDomainStructWithPrimitiveColumnPredicate(boolean useColumNames)
{
RowType baseType = rowType(
RowType.field("a", INTEGER),
RowType.field("b", INTEGER),
RowType.field("c", INTEGER));

HiveColumnProjectionInfo columnProjectionInfo = new HiveColumnProjectionInfo(
ImmutableList.of(1),
ImmutableList.of("b"),
HiveType.HIVE_INT,
INTEGER);

HiveColumnHandle projectedColumn = new HiveColumnHandle(
"row_field",
0,
HiveType.toHiveType(baseType),
baseType,
Optional.of(columnProjectionInfo),
REGULAR,
Optional.empty());

Domain predicateDomain = Domain.singleValue(INTEGER, 123L);
TupleDomain<HiveColumnHandle> tupleDomain = withColumnDomains(ImmutableMap.of(projectedColumn, predicateDomain));

MessageType fileSchema = new MessageType("hive_schema",
new GroupType(OPTIONAL, "row_field",
new PrimitiveType(OPTIONAL, INT32, "a"),
new PrimitiveType(OPTIONAL, INT32, "b"),
new PrimitiveType(OPTIONAL, INT32, "c")));
Map<List<String>, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, fileSchema);
TupleDomain<ColumnDescriptor> calculatedTupleDomain = getParquetTupleDomain(descriptorsByPath, tupleDomain, fileSchema, useColumNames);
assertEquals(calculatedTupleDomain.getDomains().get().size(), 1);
ColumnDescriptor selectedColumnDescriptor = descriptorsByPath.get(ImmutableList.of("row_field", "b"));
assertEquals(calculatedTupleDomain.getDomains().get().get(selectedColumnDescriptor), predicateDomain);
}

@Test(dataProvider = "useColumnNames")
public void testParquetTupleDomainStructWithComplexColumnPredicate(boolean useColumNames)
{
RowType c1Type = rowType(
RowType.field("c1", INTEGER),
RowType.field("c2", INTEGER));
RowType baseType = rowType(
RowType.field("a", INTEGER),
RowType.field("b", INTEGER),
RowType.field("c", c1Type));

HiveColumnProjectionInfo columnProjectionInfo = new HiveColumnProjectionInfo(
ImmutableList.of(2),
ImmutableList.of("C"),
HiveType.toHiveType(c1Type),
c1Type);

HiveColumnHandle projectedColumn = new HiveColumnHandle(
"row_field",
0,
HiveType.toHiveType(baseType),
baseType,
Optional.of(columnProjectionInfo),
REGULAR,
Optional.empty());

Domain predicateDomain = Domain.onlyNull(c1Type);
TupleDomain<HiveColumnHandle> tupleDomain = withColumnDomains(ImmutableMap.of(projectedColumn, predicateDomain));

MessageType fileSchema = new MessageType("hive_schema",
new GroupType(OPTIONAL, "row_field",
new PrimitiveType(OPTIONAL, INT32, "a"),
new PrimitiveType(OPTIONAL, INT32, "b"),
new GroupType(OPTIONAL,
"c",
new PrimitiveType(OPTIONAL, INT32, "c1"),
new PrimitiveType(OPTIONAL, INT32, "c2"))));
Map<List<String>, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, fileSchema);
// skip looking up predicates for complex types as Parquet only stores stats for primitives
TupleDomain<ColumnDescriptor> calculatedTupleDomain = getParquetTupleDomain(descriptorsByPath, tupleDomain, fileSchema, useColumNames);
assertTrue(calculatedTupleDomain.isAll());
}

@Test(dataProvider = "useColumnNames")
public void testParquetTupleDomainStructWithMissingPrimitiveColumn(boolean useColumnNames)
{
RowType baseType = rowType(
RowType.field("a", INTEGER),
RowType.field("b", INTEGER),
RowType.field("non_exist", INTEGER));

HiveColumnProjectionInfo columnProjectionInfo = new HiveColumnProjectionInfo(
ImmutableList.of(2),
ImmutableList.of("non_exist"),
HiveType.HIVE_INT,
INTEGER);

HiveColumnHandle projectedColumn = new HiveColumnHandle(
"row_field",
0,
HiveType.toHiveType(baseType),
baseType,
Optional.of(columnProjectionInfo),
REGULAR,
Optional.empty());

Domain predicateDomain = Domain.singleValue(INTEGER, 123L);
TupleDomain<HiveColumnHandle> tupleDomain = withColumnDomains(ImmutableMap.of(projectedColumn, predicateDomain));

MessageType fileSchema = new MessageType("hive_schema",
new GroupType(OPTIONAL, "row_field",
new PrimitiveType(OPTIONAL, INT32, "a"),
new PrimitiveType(OPTIONAL, INT32, "b")));
Map<List<String>, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, fileSchema);
TupleDomain<ColumnDescriptor> calculatedTupleDomain = getParquetTupleDomain(descriptorsByPath, tupleDomain, fileSchema, useColumnNames);
assertTrue(calculatedTupleDomain.isAll());
}

@Test(dataProvider = "useColumnNames")
public void testParquetTupleDomainMap(boolean useColumnNames)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.testing;

import org.testng.annotations.Test;

import static io.trino.testing.TestingNames.randomNameSuffix;
import static org.assertj.core.api.Assertions.assertThat;

public abstract class BaseTestFileFormatComplexTypesPredicatePushDown
extends AbstractTestQueryFramework
{
@Test
public void testRowTypeOnlyNullsRowGroupPruning()
{
String tableName = "test_primitive_column_nulls_pruning_" + randomNameSuffix();
assertUpdate("CREATE TABLE " + tableName + " (col BIGINT)");
assertUpdate("INSERT INTO " + tableName + " SELECT * FROM unnest(repeat(NULL, 4096))", 4096);
assertNoDataRead("SELECT * FROM " + tableName + " WHERE col IS NOT NULL");

tableName = "test_nested_column_nulls_pruning_" + randomNameSuffix();
// Nested column `a` has nulls count of 4096 and contains only nulls
// Nested column `b` also has nulls count of 4096, but it contains non nulls as well
assertUpdate("CREATE TABLE " + tableName + " (col ROW(a BIGINT, b ARRAY(DOUBLE)))");
assertUpdate("INSERT INTO " + tableName + " SELECT * FROM unnest(transform(repeat(1, 4096), x -> ROW(ROW(NULL, ARRAY [NULL, rand()]))))", 4096);

assertNoDataRead("SELECT * FROM " + tableName + " WHERE col.a IS NOT NULL");

assertQueryStats(
getSession(),
"SELECT * FROM " + tableName + " WHERE col.b IS NOT NULL",
queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0),
results -> assertThat(results.getRowCount()).isEqualTo(4096));

assertUpdate("DROP TABLE " + tableName);
}

@Test
public void testRowTypeRowGroupPruning()
{
String tableName = "test_nested_column_pruning_" + randomNameSuffix();
assertUpdate("CREATE TABLE " + tableName + " (col1Row ROW(a BIGINT, b BIGINT, c ROW(c1 BIGINT, c2 ROW(c21 BIGINT, c22 BIGINT))))");
assertUpdate("INSERT INTO " + tableName + " SELECT * FROM unnest(transform(SEQUENCE(1, 10000), x -> ROW(ROW(x*2, 100, ROW(x, ROW(x*5, x*6))))))", 10000);

// no data read since the row dereference predicate is pushed down
assertNoDataRead("SELECT * FROM " + tableName + " WHERE col1Row.a = -1");
assertNoDataRead("SELECT * FROM " + tableName + " WHERE col1Row.c.c2.c22 = -1");
assertNoDataRead("SELECT * FROM " + tableName + " WHERE col1Row.a = -1 AND col1ROW.b = -1 AND col1ROW.c.c1 = -1 AND col1Row.c.c2.c22 = -1");

// read all since predicate case matches with the data
assertQueryStats(
getSession(),
"SELECT * FROM " + tableName + " WHERE col1Row.b = 100",
queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0),
results -> assertThat(results.getRowCount()).isEqualTo(10000));

// no predicate push down for matching with ROW type, as file format only stores stats for primitives
assertQueryStats(
getSession(),
"SELECT * FROM " + tableName + " WHERE col1Row.c = ROW(-1, ROW(-1, -1))",
queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0),
results -> assertThat(results.getRowCount()).isEqualTo(0));

assertQueryStats(
getSession(),
"SELECT * FROM " + tableName + " WHERE col1Row.c = ROW(-1, ROW(-1, -1)) OR col1Row.a = -1 ",
queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0),
results -> assertThat(results.getRowCount()).isEqualTo(0));

// no data read since the row group get filtered by primitives in the predicate
assertNoDataRead("SELECT * FROM " + tableName + " WHERE col1Row.c = ROW(-1, ROW(-1, -1)) AND col1Row.a = -1 ");

// no predicate push down for entire ROW, as file format only stores stats for primitives
assertQueryStats(
getSession(),
"SELECT * FROM " + tableName + " WHERE col1Row = ROW(-1, -1, ROW(-1, ROW(-1, -1)))",
queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0),
results -> assertThat(results.getRowCount()).isEqualTo(0));

assertUpdate("DROP TABLE " + tableName);
}

@Test
public void testMapTypeRowGroupPruning()
{
String tableName = "test_nested_column_pruning_" + randomNameSuffix();
assertUpdate("CREATE TABLE " + tableName + " (colMap Map(VARCHAR, BIGINT))");
assertUpdate("INSERT INTO " + tableName + " SELECT * FROM unnest(transform(SEQUENCE(1, 10000), x -> ROW(MAP(ARRAY['FOO', 'BAR'], ARRAY[100, 200]))))", 10000);

// no predicate push down for MAP type dereference
assertQueryStats(
getSession(),
"SELECT * FROM " + tableName + " WHERE colMap['FOO'] = -1",
queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0),
results -> assertThat(results.getRowCount()).isEqualTo(0));

// no predicate push down for entire Map type
assertQueryStats(
getSession(),
"SELECT * FROM " + tableName + " WHERE colMap = MAP(ARRAY['FOO', 'BAR'], ARRAY[-1, -1])",
queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0),
results -> assertThat(results.getRowCount()).isEqualTo(0));

assertUpdate("DROP TABLE " + tableName);
}

@Test
public void testArrayTypeRowGroupPruning()
{
String tableName = "test_nested_column_pruning_" + randomNameSuffix();
assertUpdate("CREATE TABLE " + tableName + " (colArray ARRAY(BIGINT))");
assertUpdate("INSERT INTO " + tableName + " SELECT * FROM unnest(transform(SEQUENCE(1, 10000), x -> ROW(ARRAY[100, 200])))", 10000);

// no predicate push down for ARRAY type dereference
assertQueryStats(
getSession(),
"SELECT * FROM " + tableName + " WHERE colArray[1] = -1",
queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0),
results -> assertThat(results.getRowCount()).isEqualTo(0));

// no predicate push down for entire ARRAY type
assertQueryStats(
getSession(),
"SELECT * FROM " + tableName + " WHERE colArray = ARRAY[-1, -1]",
queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0),
results -> assertThat(results.getRowCount()).isEqualTo(0));

assertUpdate("DROP TABLE " + tableName);
}
}
Loading

0 comments on commit 151118d

Please sign in to comment.