Skip to content

Commit

Permalink
Apply the dereference pushdown at the physical level on parquet
Browse files Browse the repository at this point in the history
  • Loading branch information
findinpath authored and raunaqmorarka committed Jul 8, 2023
1 parent 570c6f8 commit a2d0c1b
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@

import com.google.common.base.Suppliers;
import com.google.common.base.VerifyException;
import com.google.common.collect.AbstractIterator;
import com.google.common.collect.BiMap;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
Expand Down Expand Up @@ -103,7 +106,9 @@
import org.apache.parquet.hadoop.metadata.ParquetMetadata;
import org.apache.parquet.io.ColumnIO;
import org.apache.parquet.io.MessageColumnIO;
import org.apache.parquet.schema.GroupType;
import org.apache.parquet.schema.MessageType;
import org.apache.parquet.schema.PrimitiveType;
import org.roaringbitmap.longlong.LongBitmapDataProvider;
import org.roaringbitmap.longlong.Roaring64Bitmap;

Expand All @@ -118,7 +123,6 @@
import java.util.Optional;
import java.util.OptionalLong;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Supplier;

import static com.google.common.base.Preconditions.checkState;
Expand Down Expand Up @@ -902,9 +906,7 @@ private static ReaderPageSourceWithRowPositions createParquetPageSource(
}

// Mapping from Iceberg field ID to Parquet fields.
Map<Integer, org.apache.parquet.schema.Type> parquetIdToField = fileSchema.getFields().stream()
.filter(field -> field.getId() != null)
.collect(toImmutableMap(field -> field.getId().intValue(), Function.identity()));
Map<Integer, org.apache.parquet.schema.Type> parquetIdToField = createParquetIdToFieldMapping(fileSchema);

Optional<ReaderColumns> baseColumnProjections = projectBaseColumns(regularColumns);
List<IcebergColumnHandle> readBaseColumns = baseColumnProjections
Expand All @@ -915,7 +917,7 @@ private static ReaderPageSourceWithRowPositions createParquetPageSource(
.map(column -> parquetIdToField.get(column.getId()))
.collect(toList());

MessageType requestedSchema = new MessageType(fileSchema.getName(), parquetFields.stream().filter(Objects::nonNull).collect(toImmutableList()));
MessageType requestedSchema = getMessageType(regularColumns, fileSchema.getName(), parquetIdToField);
Map<List<String>, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, requestedSchema);
TupleDomain<ColumnDescriptor> parquetTupleDomain = getParquetTupleDomain(descriptorsByPath, effectivePredicate);
TupleDomainParquetPredicate parquetPredicate = buildPredicate(requestedSchema, parquetTupleDomain, descriptorsByPath, UTC);
Expand Down Expand Up @@ -1040,6 +1042,45 @@ else if (column.getId() == TRINO_MERGE_PARTITION_DATA) {
}
}

private static Map<Integer, org.apache.parquet.schema.Type> createParquetIdToFieldMapping(MessageType fileSchema)
{
ImmutableMap.Builder<Integer, org.apache.parquet.schema.Type> builder = ImmutableMap.builder();
addParquetIdToFieldMapping(fileSchema, builder);
return builder.buildOrThrow();
}

private static void addParquetIdToFieldMapping(org.apache.parquet.schema.Type type, ImmutableMap.Builder<Integer, org.apache.parquet.schema.Type> builder)
{
if (type.getId() != null) {
builder.put(type.getId().intValue(), type);
}
if (type instanceof PrimitiveType) {
// Nothing else to do
}
else if (type instanceof GroupType groupType) {
for (org.apache.parquet.schema.Type field : groupType.getFields()) {
addParquetIdToFieldMapping(field, builder);
}
}
else {
throw new IllegalStateException("Unsupported field type: " + type);
}
}

private static MessageType getMessageType(List<IcebergColumnHandle> regularColumns, String fileSchemaName, Map<Integer, org.apache.parquet.schema.Type> parquetIdToField)
{
return projectSufficientColumns(regularColumns)
.map(readerColumns -> readerColumns.get().stream().map(IcebergColumnHandle.class::cast).collect(toUnmodifiableList()))
.orElse(regularColumns)
.stream()
.map(column -> getColumnType(column, parquetIdToField))
.filter(Optional::isPresent)
.map(Optional::get)
.map(type -> new MessageType(fileSchemaName, type))
.reduce(MessageType::union)
.orElse(new MessageType(fileSchemaName, ImmutableList.of()));
}

private static ReaderPageSourceWithRowPositions createAvroPageSource(
TrinoInputFile inputFile,
long start,
Expand Down Expand Up @@ -1278,6 +1319,93 @@ public static Optional<ReaderColumns> projectBaseColumns(List<IcebergColumnHandl
return Optional.of(new ReaderColumns(projectedColumns.build(), outputColumnMapping.build()));
}

/**
* Creates a set of sufficient columns for the input projected columns and prepares a mapping between the two.
* For example, if input {@param columns} include columns "a.b" and "a.b.c", then they will be projected
* from a single column "a.b".
*/
private static Optional<ReaderColumns> projectSufficientColumns(List<IcebergColumnHandle> columns)
{
requireNonNull(columns, "columns is null");

if (columns.stream().allMatch(IcebergColumnHandle::isBaseColumn)) {
return Optional.empty();
}

ImmutableBiMap.Builder<DereferenceChain, IcebergColumnHandle> dereferenceChainsBuilder = ImmutableBiMap.builder();

for (IcebergColumnHandle column : columns) {
DereferenceChain dereferenceChain = new DereferenceChain(column.getBaseColumnIdentity(), column.getPath());
dereferenceChainsBuilder.put(dereferenceChain, column);
}

BiMap<DereferenceChain, IcebergColumnHandle> dereferenceChains = dereferenceChainsBuilder.build();

List<ColumnHandle> sufficientColumns = new ArrayList<>();
ImmutableList.Builder<Integer> outputColumnMapping = ImmutableList.builder();

Map<DereferenceChain, Integer> pickedColumns = new HashMap<>();

// Pick a covering column for every column
for (IcebergColumnHandle columnHandle : columns) {
DereferenceChain dereferenceChain = dereferenceChains.inverse().get(columnHandle);
DereferenceChain chosenColumn = null;

// Shortest existing prefix is chosen as the input.
for (DereferenceChain prefix : dereferenceChain.orderedPrefixes()) {
if (dereferenceChains.containsKey(prefix)) {
chosenColumn = prefix;
break;
}
}

checkState(chosenColumn != null, "chosenColumn is null");
int inputBlockIndex;

if (pickedColumns.containsKey(chosenColumn)) {
// Use already picked column
inputBlockIndex = pickedColumns.get(chosenColumn);
}
else {
// Add a new column for the reader
sufficientColumns.add(dereferenceChains.get(chosenColumn));
pickedColumns.put(chosenColumn, sufficientColumns.size() - 1);
inputBlockIndex = sufficientColumns.size() - 1;
}

outputColumnMapping.add(inputBlockIndex);
}

return Optional.of(new ReaderColumns(sufficientColumns, outputColumnMapping.build()));
}

private static Optional<org.apache.parquet.schema.Type> getColumnType(IcebergColumnHandle column, Map<Integer, org.apache.parquet.schema.Type> parquetIdToField)
{
Optional<org.apache.parquet.schema.Type> baseColumnType = Optional.ofNullable(parquetIdToField.get(column.getBaseColumn().getId()));
if (baseColumnType.isEmpty() || column.getPath().isEmpty()) {
return baseColumnType;
}
GroupType baseType = baseColumnType.get().asGroupType();

List<org.apache.parquet.schema.Type> subfieldTypes = column.getPath().stream()
.filter(parquetIdToField::containsKey)
.map(parquetIdToField::get)
.collect(toImmutableList());

// if there is a mismatch between parquet schema and the Iceberg schema the column cannot be dereferenced
if (subfieldTypes.isEmpty()) {
return Optional.empty();
}

// Construct a stripped version of the original column type containing only the selected field and the hierarchy of its parents
org.apache.parquet.schema.Type type = subfieldTypes.get(subfieldTypes.size() - 1);
for (int i = subfieldTypes.size() - 2; i >= 0; --i) {
GroupType groupType = subfieldTypes.get(i).asGroupType();
type = new GroupType(groupType.getRepetition(), groupType.getName(), ImmutableList.of(type));
}
return Optional.of(new GroupType(baseType.getRepetition(), baseType.getName(), ImmutableList.of(type)));
}

private static TupleDomain<ColumnDescriptor> getParquetTupleDomain(Map<List<String>, ColumnDescriptor> descriptorsByPath, TupleDomain<IcebergColumnHandle> effectivePredicate)
{
if (effectivePredicate.isNone()) {
Expand Down Expand Up @@ -1351,4 +1479,57 @@ public Optional<Long> getEndRowPosition()
return endRowPosition;
}
}

private static class DereferenceChain
{
private final ColumnIdentity baseColumnIdentity;
private final List<Integer> path;

public DereferenceChain(ColumnIdentity baseColumnIdentity, List<Integer> path)
{
this.baseColumnIdentity = requireNonNull(baseColumnIdentity, "baseColumnIdentity is null");
this.path = ImmutableList.copyOf(requireNonNull(path, "path is null"));
}

/**
* Get prefixes of this Dereference chain in increasing order of lengths.
*/
public Iterable<DereferenceChain> orderedPrefixes()
{
return () -> new AbstractIterator<>()
{
private int prefixLength;

@Override
public DereferenceChain computeNext()
{
if (prefixLength > path.size()) {
return endOfData();
}
return new DereferenceChain(baseColumnIdentity, path.subList(0, prefixLength++));
}
};
}

@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}

DereferenceChain that = (DereferenceChain) o;
return Objects.equals(baseColumnIdentity, that.baseColumnIdentity) &&
Objects.equals(path, that.path);
}

@Override
public int hashCode()
{
return Objects.hash(baseColumnIdentity, path);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ protected boolean supportsRowGroupStatistics(String typeName)
typeName.equalsIgnoreCase("timestamp(6) with time zone"));
}

@Override
protected boolean supportsPhysicalPushdown()
{
return true;
}

@Test
public void testRowGroupResetDictionary()
{
Expand Down

0 comments on commit a2d0c1b

Please sign in to comment.