Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement predicate push down for parquet dereference column #15163

Merged
merged 5 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import io.trino.plugin.hive.AcidInfo;
import io.trino.plugin.hive.FileFormatDataSourceStats;
import io.trino.plugin.hive.HiveColumnHandle;
import io.trino.plugin.hive.HiveColumnProjectionInfo;
import io.trino.plugin.hive.HiveConfig;
import io.trino.plugin.hive.HivePageSourceFactory;
import io.trino.plugin.hive.HiveType;
Expand All @@ -56,6 +57,7 @@
import org.apache.parquet.io.MessageColumnIO;
import org.apache.parquet.schema.GroupType;
import org.apache.parquet.schema.MessageType;
raunaqmorarka marked this conversation as resolved.
Show resolved Hide resolved
import org.apache.parquet.schema.Type;
import org.joda.time.DateTimeZone;

import javax.inject.Inject;
Expand All @@ -71,6 +73,7 @@
import java.util.Set;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext;
import static io.trino.parquet.ParquetTypeUtils.constructField;
Expand Down Expand Up @@ -216,9 +219,6 @@ public static ReaderPageSource createPageSource(
Optional<ParquetWriteValidation> parquetWriteValidation,
int domainCompactionThreshold)
{
// Ignore predicates on partial columns for now.
leetcode-1533 marked this conversation as resolved.
Show resolved Hide resolved
effectivePredicate = effectivePredicate.filter((column, domain) -> column.isBaseColumn());

MessageType fileSchema;
MessageType requestedSchema;
MessageColumnIO messageColumn;
Expand Down Expand Up @@ -331,24 +331,19 @@ public static Optional<MessageType> getParquetMessageType(List<HiveColumnHandle>

public static Optional<org.apache.parquet.schema.Type> getColumnType(HiveColumnHandle column, MessageType messageType, boolean useParquetColumnNames)
{
Optional<org.apache.parquet.schema.Type> columnType = getBaseColumnParquetType(column, messageType, useParquetColumnNames);
leetcode-1533 marked this conversation as resolved.
Show resolved Hide resolved
if (columnType.isEmpty() || column.getHiveColumnProjectionInfo().isEmpty()) {
return columnType;
Optional<org.apache.parquet.schema.Type> baseColumnType = getBaseColumnParquetType(column, messageType, useParquetColumnNames);
if (baseColumnType.isEmpty() || column.getHiveColumnProjectionInfo().isEmpty()) {
return baseColumnType;
}
GroupType baseType = columnType.get().asGroupType();
ImmutableList.Builder<org.apache.parquet.schema.Type> typeBuilder = ImmutableList.builder();
org.apache.parquet.schema.Type parentType = baseType;
GroupType baseType = baseColumnType.get().asGroupType();
Optional<List<org.apache.parquet.schema.Type>> subFieldTypesOptional = dereferenceSubFieldTypes(baseType, column.getHiveColumnProjectionInfo().get());

for (String name : column.getHiveColumnProjectionInfo().get().getDereferenceNames()) {
org.apache.parquet.schema.Type childType = getParquetTypeByName(name, parentType.asGroupType());
if (childType == null) {
return Optional.empty();
}
typeBuilder.add(childType);
parentType = childType;
// if there is a mismatch between parquet schema and the hive schema and the column cannot be dereferenced
if (subFieldTypesOptional.isEmpty()) {
return Optional.empty();
}

List<org.apache.parquet.schema.Type> subfieldTypes = typeBuilder.build();
List<org.apache.parquet.schema.Type> subfieldTypes = subFieldTypesOptional.get();
org.apache.parquet.schema.Type type = subfieldTypes.get(subfieldTypes.size() - 1);
for (int i = subfieldTypes.size() - 2; i >= 0; --i) {
GroupType groupType = subfieldTypes.get(i).asGroupType();
Expand Down Expand Up @@ -437,18 +432,32 @@ public static TupleDomain<ColumnDescriptor> getParquetTupleDomain(
}

raunaqmorarka marked this conversation as resolved.
Show resolved Hide resolved
ColumnDescriptor descriptor;
if (useColumnNames) {
descriptor = descriptorsByPath.get(ImmutableList.of(columnHandle.getName()));

Optional<org.apache.parquet.schema.Type> baseColumnType = getBaseColumnParquetType(columnHandle, fileSchema, useColumnNames);
// Parquet file has fewer column than partition
if (baseColumnType.isEmpty()) {
continue;
}

if (baseColumnType.get().isPrimitive()) {
descriptor = descriptorsByPath.get(ImmutableList.of(baseColumnType.get().getName()));
}
else {
Optional<org.apache.parquet.schema.Type> parquetField = getBaseColumnParquetType(columnHandle, fileSchema, false);
if (parquetField.isEmpty() || !parquetField.get().isPrimitive()) {
// Parquet file has fewer column than partition
// Or the field is a complex type
if (columnHandle.getHiveColumnProjectionInfo().isEmpty()) {
continue;
}
descriptor = descriptorsByPath.get(ImmutableList.of(parquetField.get().getName()));
Optional<List<Type>> subfieldTypes = dereferenceSubFieldTypes(baseColumnType.get().asGroupType(), columnHandle.getHiveColumnProjectionInfo().get());
// failed to look up subfields from the file schema
if (subfieldTypes.isEmpty()) {
continue;
}

descriptor = descriptorsByPath.get(ImmutableList.<String>builder()
.add(baseColumnType.get().getName())
.addAll(subfieldTypes.get().stream().map(Type::getName).collect(toImmutableList()))
.build());
}

if (descriptor != null) {
predicate.put(descriptor, entry.getValue());
}
Expand Down Expand Up @@ -509,4 +518,32 @@ private static Optional<org.apache.parquet.schema.Type> getBaseColumnParquetType

return Optional.empty();
}

/**
* Dereferencing base parquet type based on projection info's dereference names.
* For example, when dereferencing baseType(level1Field0, level1Field1, Level1Field2(Level2Field0, Level2Field1))
* with a projection info's dereferenceNames list as (basetype, Level1Field2, Level2Field1).
* It would return a list of parquet types in the order of (level1Field2, Level2Field1)
*
* @return child fields on each level of dereferencing. Return Optional.empty when failed to do the lookup.
*/
private static Optional<List<org.apache.parquet.schema.Type>> dereferenceSubFieldTypes(GroupType baseType, HiveColumnProjectionInfo projectionInfo)
{
checkArgument(baseType != null, "base type cannot be null when dereferencing");
checkArgument(projectionInfo != null, "hive column projection info cannot be null when doing dereferencing");

ImmutableList.Builder<org.apache.parquet.schema.Type> typeBuilder = ImmutableList.builder();
org.apache.parquet.schema.Type parentType = baseType;

for (String name : projectionInfo.getDereferenceNames()) {
org.apache.parquet.schema.Type childType = getParquetTypeByName(name, parentType.asGroupType());
if (childType == null) {
return Optional.empty();
}
typeBuilder.add(childType);
parentType = childType;
}

return Optional.of(typeBuilder.build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5276,41 +5276,6 @@ private void testParquetDictionaryPredicatePushdown(Session session)
assertNoDataRead("SELECT * FROM " + tableName + " WHERE n = 3");
}

raunaqmorarka marked this conversation as resolved.
Show resolved Hide resolved
@Test
public void testParquetOnlyNullsRowGroupPruning()
{
String tableName = "test_primitive_column_nulls_pruning_" + randomNameSuffix();
assertUpdate("CREATE TABLE " + tableName + " (col BIGINT) WITH (format = 'PARQUET')");
assertUpdate("INSERT INTO " + tableName + " SELECT * FROM unnest(repeat(NULL, 4096))", 4096);
assertNoDataRead("SELECT * FROM " + tableName + " WHERE col IS NOT NULL");

tableName = "test_nested_column_nulls_pruning_" + randomNameSuffix();
// Nested column `a` has nulls count of 4096 and contains only nulls
// Nested column `b` also has nulls count of 4096, but it contains non nulls as well
assertUpdate("CREATE TABLE " + tableName + " (col ROW(a BIGINT, b ARRAY(DOUBLE))) WITH (format = 'PARQUET')");
assertUpdate("INSERT INTO " + tableName + " SELECT * FROM unnest(transform(repeat(1, 4096), x -> ROW(ROW(NULL, ARRAY [NULL, rand()]))))", 4096);
// TODO replace with assertNoDataRead after nested column predicate pushdown
assertQueryStats(
getSession(),
"SELECT * FROM " + tableName + " WHERE col.a IS NOT NULL",
queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0),
results -> assertThat(results.getRowCount()).isEqualTo(0));
assertQueryStats(
getSession(),
"SELECT * FROM " + tableName + " WHERE col.b IS NOT NULL",
queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isGreaterThan(0),
results -> assertThat(results.getRowCount()).isEqualTo(4096));
}

private void assertNoDataRead(@Language("SQL") String sql)
{
assertQueryStats(
getSession(),
sql,
queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isEqualTo(0),
results -> assertThat(results.getRowCount()).isEqualTo(0));
}

private QueryInfo getQueryInfo(DistributedQueryRunner queryRunner, MaterializedResultWithQueryId queryResult)
{
return queryRunner.getCoordinator().getQueryManager().getFullQueryInfo(queryResult.getQueryId());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.hive.parquet;

import io.trino.plugin.hive.HiveQueryRunner;
import io.trino.testing.BaseTestFileFormatComplexTypesPredicatePushDown;
import io.trino.testing.QueryRunner;
import org.testng.annotations.Test;

import static io.trino.testing.TestingNames.randomNameSuffix;
import static org.assertj.core.api.Assertions.assertThat;

public class TestHiveParquetComplexTypePredicatePushDown
extends BaseTestFileFormatComplexTypesPredicatePushDown
{
@Override
protected QueryRunner createQueryRunner()
throws Exception
{
return HiveQueryRunner.builder()
.addHiveProperty("hive.storage-format", "PARQUET")
.build();
}

@Test
public void ensureFormatParquet()
{
String tableName = "test_table_" + randomNameSuffix();
assertUpdate("CREATE TABLE " + tableName + " (colTest BIGINT)");
assertThat(((String) computeScalar("SHOW CREATE TABLE " + tableName))).contains("PARQUET");
assertUpdate("DROP TABLE " + tableName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import io.trino.plugin.hive.HiveColumnHandle;
import io.trino.plugin.hive.HiveColumnProjectionInfo;
import io.trino.plugin.hive.HiveType;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.TupleDomain;
Expand Down Expand Up @@ -122,12 +123,129 @@ public void testParquetTupleDomainStruct(boolean useColumnNames)
MessageType fileSchema = new MessageType("hive_schema",
new GroupType(OPTIONAL, "my_struct",
new PrimitiveType(OPTIONAL, INT32, "a"),
new PrimitiveType(OPTIONAL, INT32, "b")));
new PrimitiveType(OPTIONAL, INT32, "b"),
new PrimitiveType(OPTIONAL, INT32, "c")));
Map<List<String>, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, fileSchema);
TupleDomain<ColumnDescriptor> tupleDomain = getParquetTupleDomain(descriptorsByPath, domain, fileSchema, useColumnNames);
assertTrue(tupleDomain.isAll());
}

@Test(dataProvider = "useColumnNames")
public void testParquetTupleDomainStructWithPrimitiveColumnPredicate(boolean useColumNames)
{
RowType baseType = rowType(
RowType.field("a", INTEGER),
RowType.field("b", INTEGER),
RowType.field("c", INTEGER));
leetcode-1533 marked this conversation as resolved.
Show resolved Hide resolved

HiveColumnProjectionInfo columnProjectionInfo = new HiveColumnProjectionInfo(
ImmutableList.of(1),
ImmutableList.of("b"),
HiveType.HIVE_INT,
INTEGER);

HiveColumnHandle projectedColumn = new HiveColumnHandle(
"row_field",
0,
HiveType.toHiveType(baseType),
baseType,
Optional.of(columnProjectionInfo),
REGULAR,
Optional.empty());

Domain predicateDomain = Domain.singleValue(INTEGER, 123L);
TupleDomain<HiveColumnHandle> tupleDomain = withColumnDomains(ImmutableMap.of(projectedColumn, predicateDomain));

MessageType fileSchema = new MessageType("hive_schema",
new GroupType(OPTIONAL, "row_field",
new PrimitiveType(OPTIONAL, INT32, "a"),
new PrimitiveType(OPTIONAL, INT32, "b"),
new PrimitiveType(OPTIONAL, INT32, "c")));
Map<List<String>, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, fileSchema);
TupleDomain<ColumnDescriptor> calculatedTupleDomain = getParquetTupleDomain(descriptorsByPath, tupleDomain, fileSchema, useColumNames);
assertEquals(calculatedTupleDomain.getDomains().get().size(), 1);
ColumnDescriptor selectedColumnDescriptor = descriptorsByPath.get(ImmutableList.of("row_field", "b"));
assertEquals(calculatedTupleDomain.getDomains().get().get(selectedColumnDescriptor), predicateDomain);
}

@Test(dataProvider = "useColumnNames")
public void testParquetTupleDomainStructWithComplexColumnPredicate(boolean useColumNames)
{
RowType c1Type = rowType(
RowType.field("c1", INTEGER),
RowType.field("c2", INTEGER));
RowType baseType = rowType(
RowType.field("a", INTEGER),
RowType.field("b", INTEGER),
RowType.field("c", c1Type));

HiveColumnProjectionInfo columnProjectionInfo = new HiveColumnProjectionInfo(
ImmutableList.of(2),
ImmutableList.of("C"),
HiveType.toHiveType(c1Type),
c1Type);

HiveColumnHandle projectedColumn = new HiveColumnHandle(
"row_field",
0,
HiveType.toHiveType(baseType),
baseType,
Optional.of(columnProjectionInfo),
REGULAR,
Optional.empty());

Domain predicateDomain = Domain.onlyNull(c1Type);
TupleDomain<HiveColumnHandle> tupleDomain = withColumnDomains(ImmutableMap.of(projectedColumn, predicateDomain));

MessageType fileSchema = new MessageType("hive_schema",
new GroupType(OPTIONAL, "row_field",
new PrimitiveType(OPTIONAL, INT32, "a"),
new PrimitiveType(OPTIONAL, INT32, "b"),
new GroupType(OPTIONAL,
"c",
new PrimitiveType(OPTIONAL, INT32, "c1"),
new PrimitiveType(OPTIONAL, INT32, "c2"))));
Map<List<String>, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, fileSchema);
// skip looking up predicates for complex types as Parquet only stores stats for primitives
TupleDomain<ColumnDescriptor> calculatedTupleDomain = getParquetTupleDomain(descriptorsByPath, tupleDomain, fileSchema, useColumNames);
assertTrue(calculatedTupleDomain.isAll());
}

@Test(dataProvider = "useColumnNames")
public void testParquetTupleDomainStructWithMissingPrimitiveColumn(boolean useColumnNames)
{
RowType baseType = rowType(
RowType.field("a", INTEGER),
RowType.field("b", INTEGER),
RowType.field("non_exist", INTEGER));

HiveColumnProjectionInfo columnProjectionInfo = new HiveColumnProjectionInfo(
ImmutableList.of(2),
ImmutableList.of("non_exist"),
HiveType.HIVE_INT,
INTEGER);

HiveColumnHandle projectedColumn = new HiveColumnHandle(
"row_field",
0,
HiveType.toHiveType(baseType),
baseType,
Optional.of(columnProjectionInfo),
REGULAR,
Optional.empty());

Domain predicateDomain = Domain.singleValue(INTEGER, 123L);
TupleDomain<HiveColumnHandle> tupleDomain = withColumnDomains(ImmutableMap.of(projectedColumn, predicateDomain));

MessageType fileSchema = new MessageType("hive_schema",
new GroupType(OPTIONAL, "row_field",
new PrimitiveType(OPTIONAL, INT32, "a"),
new PrimitiveType(OPTIONAL, INT32, "b")));
Map<List<String>, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, fileSchema);
TupleDomain<ColumnDescriptor> calculatedTupleDomain = getParquetTupleDomain(descriptorsByPath, tupleDomain, fileSchema, useColumnNames);
assertTrue(calculatedTupleDomain.isAll());
}

@Test(dataProvider = "useColumnNames")
public void testParquetTupleDomainMap(boolean useColumnNames)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,15 @@ protected void assertQueryStats(
resultAssertion.accept(resultWithQueryId.getResult());
}

protected void assertNoDataRead(@Language("SQL") String sql)
{
assertQueryStats(
getSession(),
sql,
queryStats -> assertThat(queryStats.getProcessedInputDataSize().toBytes()).isEqualTo(0),
results -> assertThat(results.getRowCount()).isEqualTo(0));
}

protected MaterializedResult computeExpected(@Language("SQL") String sql, List<? extends Type> resultTypes)
{
return h2QueryRunner.execute(getSession(), sql, resultTypes);
Expand Down
Loading