diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java index a2774f2351f8..35e8a39c09cd 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPageSourceProvider.java @@ -15,6 +15,9 @@ import com.google.common.base.Suppliers; import com.google.common.base.VerifyException; +import com.google.common.collect.AbstractIterator; +import com.google.common.collect.BiMap; +import com.google.common.collect.ImmutableBiMap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -103,7 +106,9 @@ import org.apache.parquet.hadoop.metadata.ParquetMetadata; import org.apache.parquet.io.ColumnIO; import org.apache.parquet.io.MessageColumnIO; +import org.apache.parquet.schema.GroupType; import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.PrimitiveType; import org.roaringbitmap.longlong.LongBitmapDataProvider; import org.roaringbitmap.longlong.Roaring64Bitmap; @@ -118,7 +123,6 @@ import java.util.Optional; import java.util.OptionalLong; import java.util.Set; -import java.util.function.Function; import java.util.function.Supplier; import static com.google.common.base.Preconditions.checkState; @@ -568,21 +572,21 @@ private static ReaderPageSourceWithRowPositions createOrcPageSource( Map effectivePredicateDomains = effectivePredicate.getDomains() .orElseThrow(() -> new IllegalArgumentException("Effective predicate is none")); - Optional columnProjections = projectColumns(columns); + Optional baseColumnProjections = projectBaseColumns(columns); Map>> projectionsByFieldId = columns.stream() .collect(groupingBy( column -> column.getBaseColumnIdentity().getId(), mapping(IcebergColumnHandle::getPath, toUnmodifiableList()))); - List readColumns = columnProjections + List readBaseColumns = baseColumnProjections .map(readerColumns -> (List) readerColumns.get().stream().map(IcebergColumnHandle.class::cast).collect(toImmutableList())) .orElse(columns); - List fileReadColumns = new ArrayList<>(readColumns.size()); - List fileReadTypes = new ArrayList<>(readColumns.size()); - List projectedLayouts = new ArrayList<>(readColumns.size()); - List columnAdaptations = new ArrayList<>(readColumns.size()); + List fileReadColumns = new ArrayList<>(readBaseColumns.size()); + List fileReadTypes = new ArrayList<>(readBaseColumns.size()); + List projectedLayouts = new ArrayList<>(readBaseColumns.size()); + List columnAdaptations = new ArrayList<>(readBaseColumns.size()); - for (IcebergColumnHandle column : readColumns) { + for (IcebergColumnHandle column : readBaseColumns) { verify(column.isBaseColumn(), "Column projections must be based from a root column"); OrcColumn orcColumn = fileColumnsByIcebergId.get(column.getId()); @@ -659,7 +663,7 @@ else if (orcColumn != null) { memoryUsage, INITIAL_BATCH_SIZE, exception -> handleException(orcDataSourceId, exception), - new IdBasedFieldMapperFactory(readColumns)); + new IdBasedFieldMapperFactory(readBaseColumns)); return new ReaderPageSourceWithRowPositions( new ReaderPageSource( @@ -672,7 +676,7 @@ else if (orcColumn != null) { memoryUsage, stats, reader.getCompressionKind()), - columnProjections), + baseColumnProjections), recordReader.getStartRowPosition(), recordReader.getEndRowPosition()); } @@ -902,20 +906,18 @@ private static ReaderPageSourceWithRowPositions createParquetPageSource( } // Mapping from Iceberg field ID to Parquet fields. - Map parquetIdToField = fileSchema.getFields().stream() - .filter(field -> field.getId() != null) - .collect(toImmutableMap(field -> field.getId().intValue(), Function.identity())); + Map parquetIdToField = createParquetIdToFieldMapping(fileSchema); - Optional columnProjections = projectColumns(regularColumns); - List readColumns = columnProjections + Optional baseColumnProjections = projectBaseColumns(regularColumns); + List readBaseColumns = baseColumnProjections .map(readerColumns -> (List) readerColumns.get().stream().map(IcebergColumnHandle.class::cast).collect(toImmutableList())) .orElse(regularColumns); - List parquetFields = readColumns.stream() + List parquetFields = readBaseColumns.stream() .map(column -> parquetIdToField.get(column.getId())) .collect(toList()); - MessageType requestedSchema = new MessageType(fileSchema.getName(), parquetFields.stream().filter(Objects::nonNull).collect(toImmutableList())); + MessageType requestedSchema = getMessageType(regularColumns, fileSchema.getName(), parquetIdToField); Map, ColumnDescriptor> descriptorsByPath = getDescriptors(fileSchema, requestedSchema); TupleDomain parquetTupleDomain = getParquetTupleDomain(descriptorsByPath, effectivePredicate); TupleDomainParquetPredicate parquetPredicate = buildPredicate(requestedSchema, parquetTupleDomain, descriptorsByPath, UTC); @@ -947,8 +949,8 @@ private static ReaderPageSourceWithRowPositions createParquetPageSource( int parquetSourceChannel = 0; ImmutableList.Builder parquetColumnFieldsBuilder = ImmutableList.builder(); - for (int columnIndex = 0; columnIndex < readColumns.size(); columnIndex++) { - IcebergColumnHandle column = readColumns.get(columnIndex); + for (int columnIndex = 0; columnIndex < readBaseColumns.size(); columnIndex++) { + IcebergColumnHandle column = readBaseColumns.get(columnIndex); if (column.isIsDeletedColumn()) { pageSourceBuilder.addConstantColumn(nativeValueToBlock(BOOLEAN, false)); } @@ -1014,7 +1016,7 @@ else if (column.getId() == TRINO_MERGE_PARTITION_DATA) { return new ReaderPageSourceWithRowPositions( new ReaderPageSource( pageSourceBuilder.build(parquetReader), - columnProjections), + baseColumnProjections), startRowPosition, endRowPosition); } @@ -1040,6 +1042,45 @@ else if (column.getId() == TRINO_MERGE_PARTITION_DATA) { } } + private static Map createParquetIdToFieldMapping(MessageType fileSchema) + { + ImmutableMap.Builder builder = ImmutableMap.builder(); + addParquetIdToFieldMapping(fileSchema, builder); + return builder.buildOrThrow(); + } + + private static void addParquetIdToFieldMapping(org.apache.parquet.schema.Type type, ImmutableMap.Builder builder) + { + if (type.getId() != null) { + builder.put(type.getId().intValue(), type); + } + if (type instanceof PrimitiveType) { + // Nothing else to do + } + else if (type instanceof GroupType groupType) { + for (org.apache.parquet.schema.Type field : groupType.getFields()) { + addParquetIdToFieldMapping(field, builder); + } + } + else { + throw new IllegalStateException("Unsupported field type: " + type); + } + } + + private static MessageType getMessageType(List regularColumns, String fileSchemaName, Map parquetIdToField) + { + return projectSufficientColumns(regularColumns) + .map(readerColumns -> readerColumns.get().stream().map(IcebergColumnHandle.class::cast).collect(toUnmodifiableList())) + .orElse(regularColumns) + .stream() + .map(column -> getColumnType(column, parquetIdToField)) + .filter(Optional::isPresent) + .map(Optional::get) + .map(type -> new MessageType(fileSchemaName, type)) + .reduce(MessageType::union) + .orElse(new MessageType(fileSchemaName, ImmutableList.of())); + } + private static ReaderPageSourceWithRowPositions createAvroPageSource( TrinoInputFile inputFile, long start, @@ -1054,16 +1095,16 @@ private static ReaderPageSourceWithRowPositions createAvroPageSource( ConstantPopulatingPageSource.Builder constantPopulatingPageSourceBuilder = ConstantPopulatingPageSource.builder(); int avroSourceChannel = 0; - Optional columnProjections = projectColumns(columns); + Optional baseColumnProjections = projectBaseColumns(columns); - List readColumns = columnProjections + List readBaseColumns = baseColumnProjections .map(readerColumns -> (List) readerColumns.get().stream().map(IcebergColumnHandle.class::cast).collect(toImmutableList())) .orElse(columns); InputFile file = new ForwardingInputFile(inputFile); OptionalLong fileModifiedTime = OptionalLong.empty(); try { - if (readColumns.stream().anyMatch(IcebergColumnHandle::isFileModifiedTimeColumn)) { + if (readBaseColumns.stream().anyMatch(IcebergColumnHandle::isFileModifiedTimeColumn)) { fileModifiedTime = OptionalLong.of(inputFile.lastModified().toEpochMilli()); } } @@ -1087,7 +1128,7 @@ private static ReaderPageSourceWithRowPositions createAvroPageSource( ImmutableList.Builder columnTypes = ImmutableList.builder(); ImmutableList.Builder rowIndexChannels = ImmutableList.builder(); - for (IcebergColumnHandle column : readColumns) { + for (IcebergColumnHandle column : readBaseColumns) { verify(column.isBaseColumn(), "Column projections must be based from a root column"); org.apache.avro.Schema.Field field = fileColumnsByIcebergId.get(column.getId()); @@ -1138,7 +1179,7 @@ else if (field == null) { columnTypes.build(), rowIndexChannels.build(), newSimpleAggregatedMemoryContext())), - columnProjections), + baseColumnProjections), Optional.empty(), Optional.empty()); } @@ -1246,7 +1287,7 @@ public ProjectedLayout getFieldLayout(OrcColumn orcColumn) /** * Creates a mapping between the input {@code columns} and base columns if required. */ - public static Optional projectColumns(List columns) + public static Optional projectBaseColumns(List columns) { requireNonNull(columns, "columns is null"); @@ -1278,6 +1319,93 @@ public static Optional projectColumns(List c return Optional.of(new ReaderColumns(projectedColumns.build(), outputColumnMapping.build())); } + /** + * Creates a set of sufficient columns for the input projected columns and prepares a mapping between the two. + * For example, if input {@param columns} include columns "a.b" and "a.b.c", then they will be projected + * from a single column "a.b". + */ + private static Optional projectSufficientColumns(List columns) + { + requireNonNull(columns, "columns is null"); + + if (columns.stream().allMatch(IcebergColumnHandle::isBaseColumn)) { + return Optional.empty(); + } + + ImmutableBiMap.Builder dereferenceChainsBuilder = ImmutableBiMap.builder(); + + for (IcebergColumnHandle column : columns) { + DereferenceChain dereferenceChain = new DereferenceChain(column.getBaseColumnIdentity(), column.getPath()); + dereferenceChainsBuilder.put(dereferenceChain, column); + } + + BiMap dereferenceChains = dereferenceChainsBuilder.build(); + + List sufficientColumns = new ArrayList<>(); + ImmutableList.Builder outputColumnMapping = ImmutableList.builder(); + + Map pickedColumns = new HashMap<>(); + + // Pick a covering column for every column + for (IcebergColumnHandle columnHandle : columns) { + DereferenceChain dereferenceChain = dereferenceChains.inverse().get(columnHandle); + DereferenceChain chosenColumn = null; + + // Shortest existing prefix is chosen as the input. + for (DereferenceChain prefix : dereferenceChain.orderedPrefixes()) { + if (dereferenceChains.containsKey(prefix)) { + chosenColumn = prefix; + break; + } + } + + checkState(chosenColumn != null, "chosenColumn is null"); + int inputBlockIndex; + + if (pickedColumns.containsKey(chosenColumn)) { + // Use already picked column + inputBlockIndex = pickedColumns.get(chosenColumn); + } + else { + // Add a new column for the reader + sufficientColumns.add(dereferenceChains.get(chosenColumn)); + pickedColumns.put(chosenColumn, sufficientColumns.size() - 1); + inputBlockIndex = sufficientColumns.size() - 1; + } + + outputColumnMapping.add(inputBlockIndex); + } + + return Optional.of(new ReaderColumns(sufficientColumns, outputColumnMapping.build())); + } + + private static Optional getColumnType(IcebergColumnHandle column, Map parquetIdToField) + { + Optional baseColumnType = Optional.ofNullable(parquetIdToField.get(column.getBaseColumn().getId())); + if (baseColumnType.isEmpty() || column.getPath().isEmpty()) { + return baseColumnType; + } + GroupType baseType = baseColumnType.get().asGroupType(); + + List subfieldTypes = column.getPath().stream() + .filter(parquetIdToField::containsKey) + .map(parquetIdToField::get) + .collect(toImmutableList()); + + // if there is a mismatch between parquet schema and the Iceberg schema the column cannot be dereferenced + if (subfieldTypes.isEmpty()) { + return Optional.empty(); + } + + // Construct a stripped version of the original column type containing only the selected field and the hierarchy of its parents + org.apache.parquet.schema.Type type = subfieldTypes.get(subfieldTypes.size() - 1); + for (int i = subfieldTypes.size() - 2; i >= 0; --i) { + GroupType groupType = subfieldTypes.get(i).asGroupType(); + type = new GroupType(groupType.getRepetition(), groupType.getName(), ImmutableList.of(type)); + } + return Optional.of(new GroupType(baseType.getRepetition(), baseType.getName(), ImmutableList.of(type))); + } + private static TupleDomain getParquetTupleDomain(Map, ColumnDescriptor> descriptorsByPath, TupleDomain effectivePredicate) { if (effectivePredicate.isNone()) { @@ -1351,4 +1479,57 @@ public Optional getEndRowPosition() return endRowPosition; } } + + private static class DereferenceChain + { + private final ColumnIdentity baseColumnIdentity; + private final List path; + + public DereferenceChain(ColumnIdentity baseColumnIdentity, List path) + { + this.baseColumnIdentity = requireNonNull(baseColumnIdentity, "baseColumnIdentity is null"); + this.path = ImmutableList.copyOf(requireNonNull(path, "path is null")); + } + + /** + * Get prefixes of this Dereference chain in increasing order of lengths. + */ + public Iterable orderedPrefixes() + { + return () -> new AbstractIterator<>() + { + private int prefixLength; + + @Override + public DereferenceChain computeNext() + { + if (prefixLength > path.size()) { + return endOfData(); + } + return new DereferenceChain(baseColumnIdentity, path.subList(0, prefixLength++)); + } + }; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + DereferenceChain that = (DereferenceChain) o; + return Objects.equals(baseColumnIdentity, that.baseColumnIdentity) && + Objects.equals(path, that.path); + } + + @Override + public int hashCode() + { + return Objects.hash(baseColumnIdentity, path); + } + } } diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergParquetConnectorTest.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergParquetConnectorTest.java index 7764143f5c7d..16633c9f091f 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergParquetConnectorTest.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergParquetConnectorTest.java @@ -49,6 +49,12 @@ protected boolean supportsRowGroupStatistics(String typeName) typeName.equalsIgnoreCase("timestamp(6) with time zone")); } + @Override + protected boolean supportsPhysicalPushdown() + { + return true; + } + @Test public void testRowGroupResetDictionary() {