From 467ae6d8f2c51a49272da62565bfb01391a3a3a7 Mon Sep 17 00:00:00 2001 From: Rob Marrowstone Date: Tue, 20 Aug 2024 14:30:55 -0700 Subject: [PATCH] Make work for SEQUENCE file --- .../io/trino/hive/formats/SparseRowType.java | 21 +++++++----- .../encodings/text/StructEncoding.java | 32 +++++++++++++++++-- .../text/TextColumnEncodingFactory.java | 16 ++++++++++ 3 files changed, 59 insertions(+), 10 deletions(-) diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/SparseRowType.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/SparseRowType.java index 55b19a88b530..9232cf9f2331 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/SparseRowType.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/SparseRowType.java @@ -33,9 +33,9 @@ public class SparseRowType extends RowType { private final List sparseFields; - private final int[] offsets; + private final List offsets; - private SparseRowType(List sparseFields, List denseFields, int[] offsets) + private SparseRowType(List sparseFields, List denseFields, List offsets) { super(makeSignature(denseFields), denseFields); this.sparseFields = sparseFields; @@ -49,21 +49,21 @@ public static SparseRowType from(List fields, boolean[] mask) { checkArgument(fields.size() == mask.length); - int[] offsets = new int[fields.size()]; + ImmutableList.Builder offsets = ImmutableList.builder(); ImmutableList.Builder denseFields = ImmutableList.builder(); int offset = 0; for (int i = 0; i < mask.length; i++) { if (mask[i]) { denseFields.add(fields.get(i)); - offsets[i] = offset++; + offsets.add(offset++); } else { - offsets[i] = -1; + offsets.add(-1); } } - return new SparseRowType(ImmutableList.copyOf(fields), denseFields.build(), offsets); + return new SparseRowType(ImmutableList.copyOf(fields), denseFields.build(), offsets.build()); } public static SparseRowType initial(List fields, Integer activeField) @@ -83,8 +83,13 @@ public List getSparseFields() */ public Integer getOffset(int sparsePosition) { - return offsets[sparsePosition] >= 0 - ? offsets[sparsePosition] + return offsets.get(sparsePosition) >= 0 + ? offsets.get(sparsePosition) : null; } + + public List getOffsets() + { + return offsets.stream().map(i -> i >= 0 ? i : null).toList(); + } } diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/StructEncoding.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/StructEncoding.java index fb78ce553b7a..166225a646fd 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/StructEncoding.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/StructEncoding.java @@ -23,6 +23,7 @@ import io.trino.spi.type.RowType; import java.util.List; +import java.util.stream.IntStream; public class StructEncoding extends BlockEncoding @@ -31,6 +32,7 @@ public class StructEncoding private final byte separator; private final boolean lastColumnTakesRest; private final List structFields; + private final List fieldOffsets; public StructEncoding( RowType rowType, @@ -45,6 +47,26 @@ public StructEncoding( this.separator = separator; this.lastColumnTakesRest = lastColumnTakesRest; this.structFields = structFields; + this.fieldOffsets = IntStream.range(0, structFields.size()) + .boxed() + .toList(); + } + + public StructEncoding( + RowType rowType, + Slice nullSequence, + byte separator, + Byte escapeByte, + boolean lastColumnTakesRest, + List structFields, + List fieldOffsets) + { + super(rowType, nullSequence, escapeByte); + this.rowType = rowType; + this.separator = separator; + this.lastColumnTakesRest = lastColumnTakesRest; + this.structFields = structFields; + this.fieldOffsets = fieldOffsets; } @Override @@ -80,7 +102,10 @@ public void decodeValueInto(BlockBuilder builder, Slice slice, int offset, int l while (currentOffset < end) { byte currentByte = slice.getByte(currentOffset); if (currentByte == separator) { - decodeElementValueInto(fieldIndex, fieldBuilders.get(fieldIndex), slice, elementOffset, currentOffset - elementOffset); + Integer fieldOffset = fieldOffsets.get(fieldIndex); + if (fieldOffset != null) { + decodeElementValueInto(fieldIndex, fieldBuilders.get(fieldOffset), slice, elementOffset, currentOffset - elementOffset); + } elementOffset = currentOffset + 1; fieldIndex++; if (lastColumnTakesRest && fieldIndex == structFields.size() - 1) { @@ -98,7 +123,10 @@ else if (isEscapeByte(currentByte)) { } currentOffset++; } - decodeElementValueInto(fieldIndex, fieldBuilders.get(fieldIndex), slice, elementOffset, end - elementOffset); + Integer fieldOffset = fieldOffsets.get(fieldIndex); + if (fieldOffset != null) { + decodeElementValueInto(fieldIndex, fieldBuilders.get(fieldOffset), slice, elementOffset, end - elementOffset); + } fieldIndex++; // missing fields are null diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/TextColumnEncodingFactory.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/TextColumnEncodingFactory.java index 24dbb1e33f4f..2dc0212c392e 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/TextColumnEncodingFactory.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/encodings/text/TextColumnEncodingFactory.java @@ -13,6 +13,7 @@ */ package io.trino.hive.formats.encodings.text; +import io.trino.hive.formats.SparseRowType; import io.trino.hive.formats.encodings.ColumnEncodingFactory; import io.trino.spi.TrinoException; import io.trino.spi.type.ArrayType; @@ -136,6 +137,21 @@ private TextColumnEncoding getEncoding(Type type, int depth) keyEncoding, valueEncoding); } + if (type instanceof SparseRowType sparseRowType) { + List fieldEncodings = sparseRowType.getSparseFields().stream() + .map(RowType.Field::getType) + .map(fieldType -> getEncoding(fieldType, depth + 1)) + .collect(toImmutableList()); + List fieldOffsets = sparseRowType.getOffsets(); + return new StructEncoding( + sparseRowType, + textEncodingOptions.getNullSequence(), + getSeparator(depth + 1), + textEncodingOptions.getEscapeByte(), + textEncodingOptions.isLastColumnTakesRest(), + fieldEncodings, + fieldOffsets); + } if (type instanceof RowType rowType) { List fieldEncodings = rowType.getTypeParameters().stream() .map(fieldType -> getEncoding(fieldType, depth + 1))