Skip to content

Commit

Permalink
Prune Non-referenced Fields from Nested RowTypes
Browse files Browse the repository at this point in the history
This set of changes prunes nested RowTypes to only the fields that are
actually dereferenced in the users' projections.

The Parquet implementation already solves for this, but it works on
it's own abstractions so it's not fit for use in the other Hive
formats. I believe this approach could be adopted by the Parquet
PageSource as well, thereby simplifying, but I don't want to bite that
off now.

I believe the approach will work for Avro as well, but the PageSource
isn't plumbing the inferred reader schema down to the type resolver:
it is just passing the selected columns from the writer schema as both
reader and writer.

I added a test that proves it works well for OpenXJson because it
is simple to mock data for it and it supports position-based
deserialization: a JSON Array into a Row.
  • Loading branch information
rmarrowstone committed Aug 20, 2024
1 parent 8edf5b4 commit ddb2f0a
Show file tree
Hide file tree
Showing 9 changed files with 616 additions and 10 deletions.
10 changes: 10 additions & 0 deletions core/trino-spi/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,16 @@
<old>interface io.trino.spi.exchange.ExchangeManagerHandleResolver</old>
<justification>cleanup</justification>
</item>
<item>
<ignore>true</ignore>
<code>java.method.visibilityIncreased</code>
<new>method void io.trino.spi.type.RowType::&lt;init&gt;(io.trino.spi.type.TypeSignature, java.util.List&lt;io.trino.spi.type.RowType.Field&gt;)</new>
</item>
<item>
<ignore>true</ignore>
<code>java.method.visibilityIncreased</code>
<new>method io.trino.spi.type.TypeSignature io.trino.spi.type.RowType::makeSignature(java.util.List&lt;io.trino.spi.type.RowType.Field&gt;)</new>
</item>
</differences>
</revapi.differences>
</analysisConfiguration>
Expand Down
4 changes: 2 additions & 2 deletions core/trino-spi/src/main/java/io/trino/spi/type/RowType.java
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ public class RowType
private final int flatFixedSize;
private final boolean flatVariableWidth;

private RowType(TypeSignature typeSignature, List<Field> originalFields)
protected RowType(TypeSignature typeSignature, List<Field> originalFields)
{
super(typeSignature, SqlRow.class, RowBlock.class);

Expand Down Expand Up @@ -188,7 +188,7 @@ public static Field field(Type type)
return new Field(Optional.empty(), type);
}

private static TypeSignature makeSignature(List<Field> fields)
protected static TypeSignature makeSignature(List<Field> fields)
{
int size = fields.size();
if (size == 0) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.hive.formats;

import com.google.common.collect.ImmutableList;
import io.trino.spi.type.RowType;

import java.util.List;

import static com.google.common.base.Preconditions.checkArgument;

/**
* A SparseRowType is a RowType that conveys which fields are active in a row.
* <p>
* It manages the positional mapping between the "sparse" fields and the
* underlying "dense" RowType. It allows position-based deserializers to know
* the complete schema for a Row while also knowing which fields need to be
* deserialized. Name-based deserializers can simply use the dense fields
* from the underlying RowType.
*/
public class SparseRowType
extends RowType
{
private final List<Field> sparseFields;
private final int[] offsets;

private SparseRowType(List<Field> sparseFields, List<Field> denseFields, int[] offsets)
{
super(makeSignature(denseFields), denseFields);
this.sparseFields = sparseFields;
this.offsets = offsets;
}

/**
* Create a SparseRowType from a list of fields and a mask indicating which fields are active.
*/
public static SparseRowType from(List<Field> fields, boolean[] mask)
{
checkArgument(fields.size() == mask.length);

int[] offsets = new int[fields.size()];
ImmutableList.Builder<Field> denseFields = ImmutableList.builder();

int offset = 0;
for (int i = 0; i < mask.length; i++) {
if (mask[i]) {
denseFields.add(fields.get(i));
offsets[i] = offset++;
}
else {
offsets[i] = -1;
}
}

return new SparseRowType(ImmutableList.copyOf(fields), denseFields.build(), offsets);
}

public static SparseRowType initial(List<Field> fields, Integer activeField)
{
boolean[] mask = new boolean[fields.size()];
mask[activeField] = true;
return SparseRowType.from(fields, mask);
}

public List<Field> getSparseFields()
{
return sparseFields;
}

/**
* Get the offset to the dense field for the sparseField at sparsePosition.
*/
public Integer getOffset(int sparsePosition)
{
return offsets[sparsePosition] >= 0
? offsets[sparsePosition]
: null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.hive.formats.DistinctMapKeys;
import io.trino.hive.formats.SparseRowType;
import io.trino.hive.formats.line.Column;
import io.trino.hive.formats.line.LineBuffer;
import io.trino.hive.formats.line.LineDeserializer;
Expand Down Expand Up @@ -56,6 +57,7 @@
import java.util.Set;
import java.util.function.IntFunction;
import java.util.regex.Pattern;
import java.util.stream.IntStream;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
Expand Down Expand Up @@ -746,6 +748,7 @@ private static class RowDecoder
{
private final List<FieldName> fieldNames;
private final List<Decoder> fieldDecoders;
private final List<Integer> decoderOffsets;
private final boolean dotsInKeyNames;

public RowDecoder(RowType rowType, OpenXJsonOptions options, List<Decoder> fieldDecoders)
Expand All @@ -755,6 +758,16 @@ public RowDecoder(RowType rowType, OpenXJsonOptions options, List<Decoder> field
.map(fieldName -> fieldName.toLowerCase(Locale.ROOT))
.map(originalValue -> new FieldName(originalValue, options))
.collect(toImmutableList());
if (rowType instanceof SparseRowType sparseRowType) {
// build an inverse mapping, from dense fields to sparse fields
decoderOffsets = IntStream.range(0, sparseRowType.getSparseFields().size())
.filter(sparsePos -> sparseRowType.getOffset(sparsePos) != null)
.boxed()
.toList();
}
else {
decoderOffsets = IntStream.range(0, fieldDecoders.size()).boxed().toList();
}
this.fieldDecoders = fieldDecoders;
this.dotsInKeyNames = options.isDotsInFieldNames();
}
Expand Down Expand Up @@ -831,7 +844,8 @@ else if (dotsInKeyNames) {
private void decodeValueFromList(List<?> jsonArray, IntFunction<BlockBuilder> fieldBuilders)
{
for (int i = 0; i < fieldDecoders.size(); i++) {
Object fieldValue = jsonArray.size() > i ? jsonArray.get(i) : null;
int position = decoderOffsets.get(i);
Object fieldValue = jsonArray.size() > position ? jsonArray.get(position) : null;
BlockBuilder blockBuilder = fieldBuilders.apply(i);
if (fieldValue == null) {
blockBuilder.appendNull();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.google.common.collect.ImmutableSet;
import com.google.inject.Inject;
import io.trino.filesystem.Location;
import io.trino.hive.formats.SparseRowType;
import io.trino.metastore.HiveType;
import io.trino.metastore.HiveTypeName;
import io.trino.metastore.type.TypeInfo;
Expand All @@ -28,6 +29,7 @@
import io.trino.plugin.hive.acid.AcidTransaction;
import io.trino.plugin.hive.coercions.CoercionUtils.CoercionContext;
import io.trino.plugin.hive.util.HiveBucketing.BucketingVersion;
import io.trino.plugin.hive.util.HiveTypeTranslator;
import io.trino.spi.TrinoException;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ConnectorPageSource;
Expand All @@ -41,11 +43,14 @@
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.NullableValue;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeManager;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -279,7 +284,8 @@ public static List<Integer> getProjection(ColumnHandle expected, ColumnHandle re
HiveColumnHandle expectedColumn = (HiveColumnHandle) expected;
HiveColumnHandle readColumn = (HiveColumnHandle) read;

checkArgument(expectedColumn.getBaseColumn().equals(readColumn.getBaseColumn()), "reader column is not valid for expected column");
// checkArgument(expectedColumn.getBaseColumn().equals(readColumn.getBaseColumn()), "reader column is not valid for expected column");
checkArgument(expectedColumn.getBaseHiveColumnIndex() == readColumn.getBaseHiveColumnIndex(), "reader column is not valid for expected column");

List<Integer> expectedDereferences = expectedColumn.getHiveColumnProjectionInfo()
.map(HiveColumnProjectionInfo::getDereferenceIndices)
Expand All @@ -292,7 +298,25 @@ public static List<Integer> getProjection(ColumnHandle expected, ColumnHandle re
checkArgument(readerDereferences.size() <= expectedDereferences.size(), "Field returned by the reader should include expected field");
checkArgument(expectedDereferences.subList(0, readerDereferences.size()).equals(readerDereferences), "Field returned by the reader should be a prefix of expected field");

return expectedDereferences.subList(readerDereferences.size(), expectedDereferences.size());
checkArgument(readerDereferences.isEmpty(), "reader dereferences not supported yet");

Type adapted = readColumn.getBaseColumn().getBaseType();
ImmutableList.Builder<Integer> dereferenceBuilder = ImmutableList.builder();
for (int deref : expectedDereferences) {
if (adapted instanceof SparseRowType sparseRowType) {
int found = sparseRowType.getOffset(deref);
dereferenceBuilder.add(found);
adapted = sparseRowType.getFields().get(found).getType();
}
else if (adapted instanceof RowType rowType) {
dereferenceBuilder.add(deref);
adapted = rowType.getFields().get(deref).getType();
}
else {
throw new IllegalArgumentException("Expected RowType!");
}
}
return dereferenceBuilder.build();
}

public static class ColumnMapping
Expand Down Expand Up @@ -641,7 +665,11 @@ public static Optional<ReaderColumns> projectBaseColumns(List<HiveColumnHandle>
}

/**
* Creates a mapping between the input {@code columns} and base columns based on baseHiveColumnIndex or baseColumnName if required.
* Transforms the input columns (the projections from the user) to the corresponding base columns in the
* table definition, and a mapping from the projected to the base columns.
*
* For example, given projections [foo.bar, baz, foo.qux] for a table schema of [baz, foo], the ReaderColumns
* produced would be [foo, baz] with a mapping of [0, 1, 0].
*/
public static Optional<ReaderColumns> projectBaseColumns(List<HiveColumnHandle> columns, boolean useColumnNames)
{
Expand Down Expand Up @@ -785,4 +813,131 @@ public List<DereferenceChain> getOrderedPrefixes()
return prefixes.build();
}
}

/**
* Transforms the input columns (the projections from the user) to the corresponding base columns in the table
* definition, and a mapping from the projected to the base columns.
*
* For example, given projections [foo.bar, baz, foo.qux] for a table schema of [baz, foo], the ReaderColumns
* produced would be [foo, baz] with a mapping of [0, 1, 0].
*
* The base types returned will be "sparse". For example, if foo from the example above also had a field
* named quuz that would be "pruned" from the base type.
*
* Deserializers that use the positions of fields in deserialization can ... todo complete
*/
public static Optional<ReaderColumns> sparseReaderColumns(List<HiveColumnHandle> columns)
{
requireNonNull(columns, "columns is null");

if (columns.stream().allMatch(HiveColumnHandle::isBaseColumn)) {
return Optional.empty();
}

Map<Integer, Integer> ordinalToPosition = new HashMap<>();
List<Type> sparseBaseTypes = new ArrayList<>(columns.size());
List<HiveColumnHandle> baseColumnHandles = new ArrayList<>(columns.size());
List<Integer> toReaderColMapping = new LinkedList<>();

for (HiveColumnHandle column : columns) {
checkArgument(column.isBaseColumn() == column.getHiveColumnProjectionInfo().isEmpty(), "Invalid column projection");

Type baseType = column.getBaseType();
Integer baseColumnOrdinal = column.getBaseHiveColumnIndex();
Integer position = ordinalToPosition.putIfAbsent(baseColumnOrdinal, sparseBaseTypes.size());
if (position == null) {
position = sparseBaseTypes.size();

// we will replace this with a sparse version later
sparseBaseTypes.add(null);
baseColumnHandles.add(column.getBaseColumn());
}
toReaderColMapping.add(position);

// if the expected is a base column, we need to read the whole thing anyway
if (column.isBaseColumn()) {
sparseBaseTypes.set(position, baseType);
}
// if the existing sparse type is the base, then the base "covers" the projection
else if (sparseBaseTypes.get(position) != column.getBaseType()) {
List<Integer> derefIndexes = column.getHiveColumnProjectionInfo().get().getDereferenceIndices();

RowType baseRowType = (RowType) baseType;
SparseRowType candidate = sparseTypeFor(baseRowType, derefIndexes);
RowType mergedType = switch (sparseBaseTypes.get(position)) {
case null -> candidate;
case Type t -> mergeSparseTypes(candidate, t);
};

sparseBaseTypes.set(position, mergedType);
}
}

List<HiveColumnHandle> readerColumns = new ArrayList<>(baseColumnHandles.size());
for (int i = 0; i < baseColumnHandles.size(); i++) {
HiveColumnHandle baseColumnHandle = baseColumnHandles.get(i);
Type sparseType = sparseBaseTypes.get(i);
HiveColumnHandle projectedColumnHandle = new HiveColumnHandle(
baseColumnHandle.getBaseColumnName(),
baseColumnHandle.getBaseHiveColumnIndex(),
HiveTypeTranslator.toHiveType(sparseType),
sparseType,
Optional.empty(),
baseColumnHandle.getColumnType(),
baseColumnHandle.getComment());
readerColumns.add(projectedColumnHandle);
}

return Optional.of(new ReaderColumns(readerColumns, toReaderColMapping));
}

private static SparseRowType sparseTypeFor(RowType original, List<Integer> derefIndexes)
{
if (derefIndexes.isEmpty()) {
throw new IllegalArgumentException("derefIndexes cannot be empty");
}

Integer index = derefIndexes.getFirst();
if (derefIndexes.size() == 1) {
return SparseRowType.initial(original.getFields(), index);
}

List<RowType.Field> fields = new ArrayList<>(original.getFields());
Type sparseChild = sparseTypeFor((RowType) fields.get(index).getType(), derefIndexes.subList(1, derefIndexes.size()));
fields.set(index, new RowType.Field(fields.get(index).getName(), sparseChild));

return SparseRowType.initial(fields, index);
}

private static RowType mergeSparseTypes(
SparseRowType sparseType,
Type other)
{
if (!(other instanceof RowType rowOther)) {
throw new IllegalArgumentException("other must be a RowType");
}

if (!(other instanceof SparseRowType sparseOther)) {
return rowOther;
}

List<RowType.Field> sparseFields = sparseType.getSparseFields();
List<RowType.Field> otherFields = sparseOther.getSparseFields();
checkArgument(sparseFields.size() == otherFields.size(), "Mismatched field sizes");

List<RowType.Field> mergedFields = new ArrayList<>(sparseFields.size());
boolean[] mask = new boolean[sparseFields.size()];
for (int i = 0; i < mask.length; i++) {
mask[i] = sparseType.getOffset(i) != null || sparseOther.getOffset(i) != null;
RowType.Field field = sparseFields.get(i);
if (field.getType() instanceof SparseRowType sparseChild) {
Type mergedType = mergeSparseTypes(sparseChild, otherFields.get(i).getType());
mergedFields.add(new RowType.Field(field.getName(), mergedType));
}
else {
mergedFields.add(field);
}
}
return SparseRowType.from(mergedFields, mask);
}
}
Loading

0 comments on commit ddb2f0a

Please sign in to comment.