From 2aa334863d48617cee38b6c8380140953098ac7a Mon Sep 17 00:00:00 2001 From: Renjie Liu Date: Thu, 7 Nov 2024 03:49:00 +0800 Subject: [PATCH] Introduce kudo writer. (#2559) * Introduce kudo writer Signed-off-by: liurenjie1024 --- .../spark/rapids/jni/Preconditions.java | 16 +- .../jni/kudo/DataOutputStreamWriter.java | 61 ++++ .../spark/rapids/jni/kudo/DataWriter.java | 44 +++ .../spark/rapids/jni/kudo/KudoSerializer.java | 293 ++++++++++++++++++ .../spark/rapids/jni/kudo/KudoTable.java | 52 ++++ .../rapids/jni/kudo/KudoTableHeader.java | 194 ++++++++++++ .../rapids/jni/kudo/KudoTableHeaderCalc.java | 186 +++++++++++ .../spark/rapids/jni/kudo/SliceInfo.java | 50 +++ .../jni/kudo/SlicedBufferSerializer.java | 210 +++++++++++++ .../jni/kudo/SlicedValidityBufferInfo.java | 74 +++++ .../rapids/jni/schema/SchemaVisitor.java | 33 +- .../spark/rapids/jni/KudoSerializerTest.java | 109 +++++++ 12 files changed, 1317 insertions(+), 5 deletions(-) create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/kudo/DataOutputStreamWriter.java create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/kudo/DataWriter.java create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTable.java create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableHeader.java create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableHeaderCalc.java create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/kudo/SliceInfo.java create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedBufferSerializer.java create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedValidityBufferInfo.java create mode 100644 src/test/java/com/nvidia/spark/rapids/jni/KudoSerializerTest.java diff --git a/src/main/java/com/nvidia/spark/rapids/jni/Preconditions.java b/src/main/java/com/nvidia/spark/rapids/jni/Preconditions.java index 67473a2e61..7f956537c9 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/Preconditions.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/Preconditions.java @@ -21,7 +21,7 @@ /** * This class contains utility methods for checking preconditions. */ -class Preconditions { +public class Preconditions { /** * Check if the condition is true, otherwise throw an IllegalStateException with the given message. */ @@ -39,4 +39,18 @@ public static void ensure(boolean condition, Supplier messageSupplier) { throw new IllegalStateException(messageSupplier.get()); } } + + /** + * Check if the value is non-negative, otherwise throw an IllegalArgumentException with the given message. + * @param value the value to check + * @param name the name of the value + * @return the value if it is non-negative + * @throws IllegalArgumentException if the value is negative + */ + public static int ensureNonNegative(int value, String name) { + if (value < 0) { + throw new IllegalArgumentException(name + " must be non-negative, but was " + value); + } + return value; + } } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataOutputStreamWriter.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataOutputStreamWriter.java new file mode 100644 index 0000000000..d93b91bb7a --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataOutputStreamWriter.java @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.HostMemoryBuffer; + +import java.io.DataOutputStream; +import java.io.IOException; + +/** + * Visible for testing + */ +class DataOutputStreamWriter extends DataWriter { + private final byte[] arrayBuffer = new byte[1024 * 128]; + private final DataOutputStream dout; + + public DataOutputStreamWriter(DataOutputStream dout) { + this.dout = dout; + } + + @Override + public void writeInt(int i) throws IOException { + dout.writeInt(i); + } + + @Override + public void copyDataFrom(HostMemoryBuffer src, long srcOffset, long len) throws IOException { + long dataLeft = len; + while (dataLeft > 0) { + int amountToCopy = (int) Math.min(arrayBuffer.length, dataLeft); + src.getBytes(arrayBuffer, 0, srcOffset, amountToCopy); + dout.write(arrayBuffer, 0, amountToCopy); + srcOffset += amountToCopy; + dataLeft -= amountToCopy; + } + } + + @Override + public void flush() throws IOException { + dout.flush(); + } + + @Override + public void write(byte[] arr, int offset, int length) throws IOException { + dout.write(arr, offset, length); + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataWriter.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataWriter.java new file mode 100644 index 0000000000..1f2e8f3dca --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/DataWriter.java @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.HostMemoryBuffer; + +import java.io.IOException; + +/** + * Visible for testing + */ +abstract class DataWriter { + + public abstract void writeInt(int i) throws IOException; + + /** + * Copy data from src starting at srcOffset and going for len bytes. + * + * @param src where to copy from. + * @param srcOffset offset to start at. + * @param len amount to copy. + */ + public abstract void copyDataFrom(HostMemoryBuffer src, long srcOffset, long len) throws IOException; + + public void flush() throws IOException { + // NOOP by default + } + + public abstract void write(byte[] arr, int offset, int length) throws IOException; +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java new file mode 100644 index 0000000000..b0cd4abf84 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java @@ -0,0 +1,293 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.*; +import com.nvidia.spark.rapids.jni.schema.Visitors; + +import java.io.*; +import java.util.Arrays; +import java.util.stream.IntStream; + +import static com.nvidia.spark.rapids.jni.Preconditions.ensure; +import static java.util.Objects.requireNonNull; + +/** + * This class is used to serialize/deserialize a table using the Kudo format. + * + *

Background

+ * + * The Kudo format is a binary format that is optimized for serializing/deserializing a table partition during Spark + * shuffle. The optimizations are based on two key observations: + * + *
    + *
  1. The binary format doesn't need to be self descriptive, since shuffle runtime could provide information such + * as schema, which helped us to reduce header size a lot. + *
  2. + *
  3. In most cases we need to concat several small tables into a larger table during shuffle read time, since + * gpu's vectorized execution engine typically requires larger batch size, which makes write time concatenation + * meaningless. This relaxed the requirement of calculating exact validity buffer and offset buffer at write time, + * which makes write almost a memory copy process, without sacrificing read performance much. + *
  4. + *
+ * + *

Format

+ * + * Similar to {@link JCudfSerialization}, it still consists of two parts: header and body. + * + *

Header

+ * + * Header consists of following fields: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
Field NameSizeComments
Magic Number4ASCII codes for "KUD0"
Offset4Row offset in original table, in big endian format
Number of rows4Number of rows, in big endian format
Length of validity buffer4Length of validity buffer, in big endian format
Length of offset buffer4Length of offset buffer, in big endian format
Length of total body4Length of total body, in big endian format
Number of columns4Number of columns in flattened schema, in big endian format. For details of flattened schema, + * see {@link com.nvidia.spark.rapids.jni.schema.SchemaVisitor} + *
hasValidityBuffer(number of columns + 7) / 8A bit set to indicate whether a column has validity buffer. To test if column + * coli has validity buffer, use the following code: + *
+ * + * int pos = coli / 8;
+ * int bit = coli % 8;
+ * return (hasValidityBuffer[pos] & (1 << bit)) != 0; + *
+ *
+ * + *

Body

+ * + * The body consists of three part: + *
    + *
  1. Validity buffers for every column with validity in depth-first ordering of schema columns. Each buffer of + * each column is 4 bytes padded. + *
  2. + *
  3. Offset buffers for every column with offsets in depth-first ordering of schema columns. Each buffer of each + * column is 4 bytes padded.
  4. + *
  5. Data buffers for every column with data in depth-first ordering of schema columns. Each buffer of each + * column is 4 bytes padded.
  6. + *
+ * + *

Serialization

+ * + * The serialization process writes the header first, then writes the body. There are two optimizations when writing + * validity buffer and offset buffer: + * + *
    + *
  1. For validity buffer, it only copies buffers without calculating an exact validity buffer. For example, when + * we want to serialize rows [3, 9) of the original table, instead of calculating the exact validity buffer, it + * just copies first two bytes of the validity buffer. At read time, the deserializer will know that the true + * validity buffer starts from the fourth bit, since we have recorded the row offset in the header. + *
  2. + *
  3. For offset buffer, it only copies buffers without calculating an exact offset buffer. For example, when we want + * * to serialize rows [3, 9) of the original table, instead of calculating the exact offset values by subtracting + * * first value, it just copies the offset buffer values of rows [3, 9). + * *
  4. + *
+ */ +public class KudoSerializer { + + private static final byte[] PADDING = new byte[64]; + private static final BufferType[] ALL_BUFFER_TYPES = new BufferType[] {BufferType.VALIDITY, BufferType.OFFSET, + BufferType.DATA}; + + static { + Arrays.fill(PADDING, (byte) 0); + } + + private final int flattenedColumnCount; + + public KudoSerializer(Schema schema) { + requireNonNull(schema, "schema is null"); + this.flattenedColumnCount = schema.getFlattenedColumnNames().length; + } + + /** + * Write partition of a table to a stream. + *
+ * + * The caller should ensure that table's schema matches the schema used to create this serializer, otherwise behavior + * is undefined. + * + * @param table table to write + * @param out output stream + * @param rowOffset row offset in original table + * @param numRows number of rows to write + * @return number of bytes written + */ + public long writeToStream(Table table, OutputStream out, int rowOffset, int numRows) { + HostColumnVector[] columns = null; + try { + columns = IntStream.range(0, table.getNumberOfColumns()) + .mapToObj(table::getColumn) + .map(c -> c.copyToHostAsync(Cuda.DEFAULT_STREAM)) + .toArray(HostColumnVector[]::new); + + Cuda.DEFAULT_STREAM.sync(); + return writeToStream(columns, out, rowOffset, numRows); + } finally { + if (columns != null) { + for (HostColumnVector column : columns) { + column.close(); + } + } + } + } + + /** + * Write partition of an array of {@link HostColumnVector} to an output stream. + *
+ * + * The caller should ensure that table's schema matches the schema used to create this serializer, otherwise behavior + * is undefined. + * + * @param columns columns to write + * @param out output stream + * @param rowOffset row offset in original column vector. + * @param numRows number of rows to write + * @return number of bytes written + */ + public long writeToStream(HostColumnVector[] columns, OutputStream out, int rowOffset, int numRows) { + ensure(numRows > 0, () -> "numRows must be > 0, but was " + numRows); + ensure(columns.length > 0, () -> "columns must not be empty, for row count only records " + + "please call writeRowCountToStream"); + + try { + return writeSliced(columns, writerFrom(out), rowOffset, numRows); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + /** + * Write a row count only record to an output stream. + * @param out output stream + * @param numRows number of rows to write + * @return number of bytes written + */ + public static long writeRowCountToStream(OutputStream out, int numRows) { + if (numRows <= 0) { + throw new IllegalArgumentException("Number of rows must be > 0, but was " + numRows); + } + try { + DataWriter writer = writerFrom(out); + KudoTableHeader header = new KudoTableHeader(0, numRows, 0, 0, 0 + , 0, new byte[0]); + header.writeTo(writer); + writer.flush(); + return header.getSerializedSize(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private long writeSliced(HostColumnVector[] columns, DataWriter out, int rowOffset, int numRows) throws Exception { + KudoTableHeaderCalc headerCalc = new KudoTableHeaderCalc(rowOffset, numRows, flattenedColumnCount); + Visitors.visitColumns(columns, headerCalc); + KudoTableHeader header = headerCalc.getHeader(); + header.writeTo(out); + + long bytesWritten = 0; + for (BufferType bufferType : ALL_BUFFER_TYPES) { + SlicedBufferSerializer serializer = new SlicedBufferSerializer(rowOffset, numRows, bufferType, out); + Visitors.visitColumns(columns, serializer); + bytesWritten += serializer.getTotalDataLen(); + } + + if (bytesWritten != header.getTotalDataLen()) { + throw new IllegalStateException("Header total data length: " + header.getTotalDataLen() + + " does not match actual written data length: " + bytesWritten + + ", rowOffset: " + rowOffset + " numRows: " + numRows); + } + + out.flush(); + + return header.getSerializedSize() + bytesWritten; + } + + private static DataWriter writerFrom(OutputStream out) { + if (!(out instanceof DataOutputStream)) { + out = new DataOutputStream(new BufferedOutputStream(out)); + } + return new DataOutputStreamWriter((DataOutputStream) out); + } + + + static long padForHostAlignment(long orig) { + return ((orig + 3) / 4) * 4; + } + + static long padForHostAlignment(DataWriter out, long bytes) throws IOException { + final long paddedBytes = padForHostAlignment(bytes); + if (paddedBytes > bytes) { + out.write(PADDING, 0, (int) (paddedBytes - bytes)); + } + return paddedBytes; + } + + static long padFor64byteAlignment(long orig) { + return ((orig + 63) / 64) * 64; + } + + static int safeLongToNonNegativeInt(long value) { + ensure(value >= 0, () -> "Expected non negative value, but was " + value); + ensure(value <= Integer.MAX_VALUE, () -> "Value is too large to fit in an int"); + return (int) value; + } +} \ No newline at end of file diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTable.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTable.java new file mode 100644 index 0000000000..218cc215b6 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTable.java @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.HostMemoryBuffer; + +public class KudoTable implements AutoCloseable { + private final KudoTableHeader header; + private final HostMemoryBuffer buffer; + + KudoTable(KudoTableHeader header, HostMemoryBuffer buffer) { + this.header = header; + this.buffer = buffer; + } + + public KudoTableHeader getHeader() { + return header; + } + + public HostMemoryBuffer getBuffer() { + return buffer; + } + + @Override + public String toString() { + return "SerializedTable{" + + "header=" + header + + ", buffer=" + buffer + + '}'; + } + + @Override + public void close() throws Exception { + if (buffer != null) { + buffer.close(); + } + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableHeader.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableHeader.java new file mode 100644 index 0000000000..93038568f5 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableHeader.java @@ -0,0 +1,194 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni.kudo; + +import java.io.DataInputStream; +import java.io.EOFException; +import java.io.IOException; +import java.util.Arrays; +import java.util.Optional; + +import static com.nvidia.spark.rapids.jni.Preconditions.ensure; +import static com.nvidia.spark.rapids.jni.Preconditions.ensureNonNegative; +import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.safeLongToNonNegativeInt; +import static java.util.Objects.requireNonNull; + +/** + * Holds the metadata about a serialized table. If this is being read from a stream + * isInitialized will return true if the metadata was read correctly from the stream. + * It will return false if an EOF was encountered at the beginning indicating that + * there was no data to be read. + */ +public final class KudoTableHeader { + /** + * Magic number "KUD0" in ASCII. + */ + private static final int SER_FORMAT_MAGIC_NUMBER = 0x4B554430; + + // The offset in the original table where row starts. For example, if we want to serialize rows [3, 9) of the + // original table, offset would be 3, and numRows would be 6. + private final int offset; + private final int numRows; + private final int validityBufferLen; + private final int offsetBufferLen; + private final int totalDataLen; + private final int numColumns; + // A bit set to indicate if a column has a validity buffer or not. Each column is represented by a single bit. + private final byte[] hasValidityBuffer; + + /** + * Reads the table header from the given input stream. + * @param din input stream + * @return the table header. If an EOFException is encountered at the beginning, returns empty result. + * @throws IOException if an I/O error occurs + */ + public static Optional readFrom(DataInputStream din) throws IOException { + int num; + try { + num= din.readInt(); + if (num != SER_FORMAT_MAGIC_NUMBER) { + throw new IllegalStateException("Kudo format error, expected magic number " + SER_FORMAT_MAGIC_NUMBER + + " found " + num); + } + } catch (EOFException e) { + // If we get an EOF at the very beginning don't treat it as an error because we may + // have finished reading everything... + return Optional.empty(); + } + + int offset = din.readInt(); + int numRows = din.readInt(); + + int validityBufferLen = din.readInt(); + int offsetBufferLen = din.readInt(); + int totalDataLen = din.readInt(); + int numColumns = din.readInt(); + int validityBufferLength = lengthOfHasValidityBuffer(numColumns); + byte[] hasValidityBuffer = new byte[validityBufferLength]; + din.readFully(hasValidityBuffer); + + return Optional.of(new KudoTableHeader(offset, numRows, validityBufferLen, offsetBufferLen, totalDataLen, numColumns, + hasValidityBuffer)); + } + + KudoTableHeader(int offset, int numRows, int validityBufferLen, int offsetBufferLen, + int totalDataLen, int numColumns, byte[] hasValidityBuffer) { + this.offset = ensureNonNegative(offset, "offset"); + this.numRows = ensureNonNegative(numRows, "numRows"); + this.validityBufferLen = ensureNonNegative(validityBufferLen, "validityBufferLen"); + this.offsetBufferLen = ensureNonNegative(offsetBufferLen, "offsetBufferLen"); + this.totalDataLen = ensureNonNegative(totalDataLen, "totalDataLen"); + this.numColumns = ensureNonNegative(numColumns, "numColumns"); + + requireNonNull(hasValidityBuffer, "hasValidityBuffer cannot be null"); + ensure(hasValidityBuffer.length == lengthOfHasValidityBuffer(numColumns), + () -> numColumns + " columns expects hasValidityBuffer with length " + lengthOfHasValidityBuffer(numColumns) + + ", but found " + hasValidityBuffer.length); + this.hasValidityBuffer = hasValidityBuffer; + } + + /** + * Returns the size of a buffer needed to read data into the stream. + */ + public int getTotalDataLen() { + return totalDataLen; + } + + /** + * Returns the number of rows stored in this table. + */ + public int getNumRows() { + return numRows; + } + + public int getOffset() { + return offset; + } + + public boolean hasValidityBuffer(int columnIndex) { + int pos = columnIndex / 8; + int bit = columnIndex % 8; + return (hasValidityBuffer[pos] & (1 << bit)) != 0; + } + + /** + * Get the size of the serialized header. + * + *

+ * It consists of the following fields: + *

    + *
  1. Magic Number
  2. + *
  3. Row Offset
  4. + *
  5. Number of rows
  6. + *
  7. Validity buffer length
  8. + *
  9. Offset buffer length
  10. + *
  11. Total data length
  12. + *
  13. Number of columns
  14. + *
  15. hasValidityBuffer
  16. + *
+ *

+ * For more details of each field, please refer to {@link KudoSerializer}. + *

+ * + * @return the size of the serialized header. + */ + public int getSerializedSize() { + return 7 * Integer.BYTES + hasValidityBuffer.length; + } + + public int getNumColumns() { + return numColumns; + } + + public int getValidityBufferLen() { + return validityBufferLen; + } + + public int getOffsetBufferLen() { + return offsetBufferLen; + } + + public void writeTo(DataWriter dout) throws IOException { + // Now write out the data + dout.writeInt(SER_FORMAT_MAGIC_NUMBER); + + dout.writeInt(offset); + dout.writeInt(numRows); + dout.writeInt(validityBufferLen); + dout.writeInt(offsetBufferLen); + dout.writeInt(totalDataLen); + dout.writeInt(numColumns); + dout.write(hasValidityBuffer, 0, hasValidityBuffer.length); + } + + @Override + public String toString() { + return "SerializedTableHeader{" + + "offset=" + offset + + ", numRows=" + numRows + + ", validityBufferLen=" + validityBufferLen + + ", offsetBufferLen=" + offsetBufferLen + + ", totalDataLen=" + totalDataLen + + ", numColumns=" + numColumns + + ", hasValidityBuffer=" + Arrays.toString(hasValidityBuffer) + + '}'; + } + + private static int lengthOfHasValidityBuffer(int numColumns) { + return (numColumns + 7) / 8; + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableHeaderCalc.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableHeaderCalc.java new file mode 100644 index 0000000000..d9d478ca1a --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableHeaderCalc.java @@ -0,0 +1,186 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.DType; +import ai.rapids.cudf.HostColumnVectorCore; +import com.nvidia.spark.rapids.jni.schema.HostColumnsVisitor; + +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.List; + +import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padForHostAlignment; +import static java.lang.Math.toIntExact; + +/** + * This class visits a list of columns and calculates the serialized table header. + * + *

+ * The columns are visited in post order, and for more details about the visiting process, please refer to + * {@link HostColumnsVisitor}. + *

+ */ +class KudoTableHeaderCalc implements HostColumnsVisitor { + private final SliceInfo root; + private final int numFlattenedCols; + private final byte[] bitset; + private long validityBufferLen; + private long offsetBufferLen; + private long totalDataLen; + private int nextColIdx; + + private Deque sliceInfos = new ArrayDeque<>(); + + KudoTableHeaderCalc(long rowOffset, long numRows, int numFlattenedCols) { + this.root = new SliceInfo(rowOffset, numRows); + this.totalDataLen = 0; + sliceInfos.addLast(this.root); + this.bitset = new byte[(numFlattenedCols + 7) / 8]; + this.numFlattenedCols = numFlattenedCols; + this.nextColIdx = 0; + } + + public KudoTableHeader getHeader() { + return new KudoTableHeader(toIntExact(root.offset), + toIntExact(root.rowCount), + toIntExact(validityBufferLen), + toIntExact(offsetBufferLen), + toIntExact(totalDataLen), + numFlattenedCols, + bitset); + } + + @Override + public Void visitStruct(HostColumnVectorCore col, List children) { + SliceInfo parent = sliceInfos.getLast(); + + long validityBufferLength = 0; + if (col.hasValidityVector()) { + validityBufferLength = padForHostAlignment(parent.getValidityBufferInfo().getBufferLength()); + } + + this.validityBufferLen += validityBufferLength; + + totalDataLen += validityBufferLength; + this.setHasValidity(col.hasValidityVector()); + return null; + } + + @Override + public Void preVisitList(HostColumnVectorCore col) { + SliceInfo parent = sliceInfos.getLast(); + + + long validityBufferLength = 0; + if (col.hasValidityVector() && parent.rowCount > 0) { + validityBufferLength = padForHostAlignment(parent.getValidityBufferInfo().getBufferLength()); + } + + long offsetBufferLength = 0; + if (col.getOffsets() != null && parent.rowCount > 0) { + offsetBufferLength = padForHostAlignment((parent.rowCount + 1) * Integer.BYTES); + } + + this.validityBufferLen += validityBufferLength; + this.offsetBufferLen += offsetBufferLength; + this.totalDataLen += validityBufferLength + offsetBufferLength; + + this.setHasValidity(col.hasValidityVector()); + + SliceInfo current; + + if (col.getOffsets() != null) { + long start = col.getOffsets().getInt(parent.offset * Integer.BYTES); + long end = col.getOffsets().getInt((parent.offset + parent.rowCount) * Integer.BYTES); + long rowCount = end - start; + current = new SliceInfo(start, rowCount); + } else { + current = new SliceInfo(0, 0); + } + + sliceInfos.addLast(current); + return null; + } + + @Override + public Void visitList(HostColumnVectorCore col, Void preVisitResult, Void childResult) { + sliceInfos.removeLast(); + + return null; + } + + + @Override + public Void visit(HostColumnVectorCore col) { + SliceInfo parent = sliceInfos.peekLast(); + long validityBufferLen = dataLenOfValidityBuffer(col, parent); + long offsetBufferLen = dataLenOfOffsetBuffer(col, parent); + long dataBufferLen = dataLenOfDataBuffer(col, parent); + + this.validityBufferLen += validityBufferLen; + this.offsetBufferLen += offsetBufferLen; + this.totalDataLen += validityBufferLen + offsetBufferLen + dataBufferLen; + + this.setHasValidity(col.hasValidityVector()); + + return null; + } + + private void setHasValidity(boolean hasValidityBuffer) { + if (hasValidityBuffer) { + int bytePos = nextColIdx / 8; + int bitPos = nextColIdx % 8; + bitset[bytePos] = (byte) (bitset[bytePos] | (1 << bitPos)); + } + nextColIdx++; + } + + private static long dataLenOfValidityBuffer(HostColumnVectorCore col, SliceInfo info) { + if (col.hasValidityVector() && info.getRowCount() > 0) { + return padForHostAlignment(info.getValidityBufferInfo().getBufferLength()); + } else { + return 0; + } + } + + private static long dataLenOfOffsetBuffer(HostColumnVectorCore col, SliceInfo info) { + if (DType.STRING.equals(col.getType()) && info.getRowCount() > 0) { + return padForHostAlignment((info.rowCount + 1) * Integer.BYTES); + } else { + return 0; + } + } + + private static long dataLenOfDataBuffer(HostColumnVectorCore col, SliceInfo info) { + if (DType.STRING.equals(col.getType())) { + if (col.getOffsets() != null) { + long startByteOffset = col.getOffsets().getInt(info.offset * Integer.BYTES); + long endByteOffset = col.getOffsets().getInt((info.offset + info.rowCount) * Integer.BYTES); + return padForHostAlignment(endByteOffset - startByteOffset); + } else { + return 0; + } + } else { + if (col.getType().getSizeInBytes() > 0) { + return padForHostAlignment(col.getType().getSizeInBytes() * info.rowCount); + } else { + return 0; + } + } + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/SliceInfo.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SliceInfo.java new file mode 100644 index 0000000000..fd2f1df599 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SliceInfo.java @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni.kudo; + +public class SliceInfo { + final long offset; + final long rowCount; + private final SlicedValidityBufferInfo validityBufferInfo; + + SliceInfo(long offset, long rowCount) { + this.offset = offset; + this.rowCount = rowCount; + this.validityBufferInfo = SlicedValidityBufferInfo.calc(offset, rowCount); + } + + SlicedValidityBufferInfo getValidityBufferInfo() { + return validityBufferInfo; + } + + public long getOffset() { + return offset; + } + + public long getRowCount() { + return rowCount; + } + + @Override + public String toString() { + return "SliceInfo{" + + "offset=" + offset + + ", rowCount=" + rowCount + + ", validityBufferInfo=" + validityBufferInfo + + '}'; + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedBufferSerializer.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedBufferSerializer.java new file mode 100644 index 0000000000..6f912f4b61 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedBufferSerializer.java @@ -0,0 +1,210 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni.kudo; + +import ai.rapids.cudf.BufferType; +import ai.rapids.cudf.DType; +import ai.rapids.cudf.HostColumnVectorCore; +import ai.rapids.cudf.HostMemoryBuffer; +import com.nvidia.spark.rapids.jni.schema.HostColumnsVisitor; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.List; + +import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padForHostAlignment; + +/** + * This class visits a list of columns and serialize one of the buffers (validity, offset, or data) into with kudo + * format. + * + *

+ * The host columns are visited in post order, for more details about the visiting process, please refer to + * {@link HostColumnsVisitor}. + *

+ * + *

+ * For more details about the kudo format, please refer to {@link KudoSerializer}. + *

+ */ +class SlicedBufferSerializer implements HostColumnsVisitor { + private final SliceInfo root; + private final BufferType bufferType; + private final DataWriter writer; + + private final Deque sliceInfos = new ArrayDeque<>(); + private long totalDataLen; + + SlicedBufferSerializer(long rowOffset, long numRows, BufferType bufferType, DataWriter writer) { + this.root = new SliceInfo(rowOffset, numRows); + this.bufferType = bufferType; + this.writer = writer; + this.sliceInfos.addLast(root); + this.totalDataLen = 0; + } + + public long getTotalDataLen() { + return totalDataLen; + } + + @Override + public Void visitStruct(HostColumnVectorCore col, List children) { + SliceInfo parent = sliceInfos.peekLast(); + + try { + switch (bufferType) { + case VALIDITY: + totalDataLen += this.copySlicedValidity(col, parent); + return null; + case OFFSET: + case DATA: + return null; + default: + throw new IllegalArgumentException("Unexpected buffer type: " + bufferType); + } + + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public Void preVisitList(HostColumnVectorCore col) { + SliceInfo parent = sliceInfos.getLast(); + + + long bytesCopied = 0; + try { + switch (bufferType) { + case VALIDITY: + bytesCopied = this.copySlicedValidity(col, parent); + break; + case OFFSET: + bytesCopied = this.copySlicedOffset(col, parent); + break; + case DATA: + break; + default: + throw new IllegalArgumentException("Unexpected buffer type: " + bufferType); + } + + } catch (IOException e) { + throw new RuntimeException(e); + } + + SliceInfo current; + if (col.getOffsets() != null) { + long start = col.getOffsets() + .getInt(parent.offset * Integer.BYTES); + long end = col.getOffsets().getInt((parent.offset + parent.rowCount) * Integer.BYTES); + long rowCount = end - start; + + current = new SliceInfo(start, rowCount); + } else { + current = new SliceInfo(0, 0); + } + + sliceInfos.addLast(current); + + totalDataLen += bytesCopied; + return null; + } + + @Override + public Void visitList(HostColumnVectorCore col, Void preVisitResult, Void childResult) { + sliceInfos.removeLast(); + return null; + } + + @Override + public Void visit(HostColumnVectorCore col) { + SliceInfo parent = sliceInfos.getLast(); + try { + switch (bufferType) { + case VALIDITY: + totalDataLen += this.copySlicedValidity(col, parent); + return null; + case OFFSET: + totalDataLen += this.copySlicedOffset(col, parent); + return null; + case DATA: + totalDataLen += this.copySlicedData(col, parent); + return null; + default: + throw new IllegalArgumentException("Unexpected buffer type: " + bufferType); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private long copySlicedValidity(HostColumnVectorCore column, SliceInfo sliceInfo) throws IOException { + if (column.getValidity() != null && sliceInfo.getRowCount() > 0) { + HostMemoryBuffer buff = column.getValidity(); + long len = sliceInfo.getValidityBufferInfo().getBufferLength(); + writer.copyDataFrom(buff, sliceInfo.getValidityBufferInfo().getBufferOffset(), + len); + return padForHostAlignment(writer, len); + } else { + return 0; + } + } + + private long copySlicedOffset(HostColumnVectorCore column, SliceInfo sliceInfo) throws IOException { + if (sliceInfo.rowCount <= 0 || column.getOffsets() == null) { + // Don't copy anything, there are no rows + return 0; + } + long bytesToCopy = (sliceInfo.rowCount + 1) * Integer.BYTES; + long srcOffset = sliceInfo.offset * Integer.BYTES; + HostMemoryBuffer buff = column.getOffsets(); + writer.copyDataFrom(buff, srcOffset, bytesToCopy); + return padForHostAlignment(writer, bytesToCopy); + } + + private long copySlicedData(HostColumnVectorCore column, SliceInfo sliceInfo) throws IOException { + if (sliceInfo.rowCount > 0) { + DType type = column.getType(); + if (type.equals(DType.STRING)) { + long startByteOffset = column.getOffsets().getInt(sliceInfo.offset * Integer.BYTES); + long endByteOffset = column.getOffsets().getInt((sliceInfo.offset + sliceInfo.rowCount) * Integer.BYTES); + long bytesToCopy = endByteOffset - startByteOffset; + if (column.getData() == null) { + if (bytesToCopy != 0) { + throw new IllegalStateException("String column has no data buffer, " + + "but bytes to copy is not zero: " + bytesToCopy); + } + + return 0; + } else { + writer.copyDataFrom(column.getData(), startByteOffset, bytesToCopy); + return padForHostAlignment(writer, bytesToCopy); + } + } else if (type.getSizeInBytes() > 0) { + long bytesToCopy = sliceInfo.rowCount * type.getSizeInBytes(); + long srcOffset = sliceInfo.offset * type.getSizeInBytes(); + writer.copyDataFrom(column.getData(), srcOffset, bytesToCopy); + return padForHostAlignment(writer, bytesToCopy); + } else { + return 0; + } + } else { + return 0; + } + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedValidityBufferInfo.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedValidityBufferInfo.java new file mode 100644 index 0000000000..865971b56c --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedValidityBufferInfo.java @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni.kudo; + +/** + * A simple utility class to hold information about serializing/deserializing sliced validity buffer. + */ +class SlicedValidityBufferInfo { + private final long bufferOffset; + private final long bufferLength; + /// The bit offset within the buffer where the slice starts + private final long beginBit; + private final long endBit; // Exclusive + + SlicedValidityBufferInfo(long bufferOffset, long bufferLength, long beginBit, long endBit) { + this.bufferOffset = bufferOffset; + this.bufferLength = bufferLength; + this.beginBit = beginBit; + this.endBit = endBit; + } + + @Override + public String toString() { + return "SlicedValidityBufferInfo{" + "bufferOffset=" + bufferOffset + ", bufferLength=" + bufferLength + + ", beginBit=" + beginBit + ", endBit=" + endBit + '}'; + } + + public long getBufferOffset() { + return bufferOffset; + } + + public long getBufferLength() { + return bufferLength; + } + + public long getBeginBit() { + return beginBit; + } + + public long getEndBit() { + return endBit; + } + + static SlicedValidityBufferInfo calc(long rowOffset, long numRows) { + if (rowOffset < 0) { + throw new IllegalArgumentException("rowOffset must be >= 0, but was " + rowOffset); + } + if (numRows < 0) { + throw new IllegalArgumentException("numRows must be >= 0, but was " + numRows); + } + long bufferOffset = rowOffset / 8; + long beginBit = rowOffset % 8; + long bufferLength = 0; + if (numRows > 0) { + bufferLength = (rowOffset + numRows - 1) / 8 - bufferOffset + 1; + } + long endBit = beginBit + numRows; + return new SlicedValidityBufferInfo(bufferOffset, bufferLength, beginBit, endBit); + } +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/schema/SchemaVisitor.java b/src/main/java/com/nvidia/spark/rapids/jni/schema/SchemaVisitor.java index f75e9a2b11..4986a5e473 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/schema/SchemaVisitor.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/schema/SchemaVisitor.java @@ -25,14 +25,39 @@ /** * A post order visitor for schemas. * - *

+ *

Flattened Schema

+ * + * A flattened schema is a schema where all fields with nested types are flattened into an array of fields. For example, + * for a schema with following fields: + * + *
    + *
  • A: struct { int a1; long a2}
  • + *
  • B: list { int b1}
  • + *
  • C: string
  • + *
  • D: long
  • + *
+ * + * The flattened schema will be: + * + *
    + *
  • A: struct
  • + *
  • A.a1: int
  • + *
  • A.a2: long
  • + *
  • B: list
  • + *
  • B.b1: int
  • + *
  • C: string
  • + *
  • D: long
  • + *
* - * For example, if our schema consists of three fields A, B, and C with following types: + *

Example

* + *

+ * This visitor visits each filed in the flattened schema in post order. For example, if our schema consists of three + * fields A, B, and C with following fields: *

    *
  • A: struct { int a1; long a2}
  • *
  • B: list { int b1}
  • - *
  • C: string c1
  • + *
  • C: string
  • *
* * The order of visiting will be: @@ -43,7 +68,7 @@ *
  • Previsit list field B
  • *
  • Visit primitive field b1
  • *
  • Visit list field B with results from b1 and previsit result.
  • - *
  • Visit primitive field c1
  • + *
  • Visit primitive field C
  • *
  • Visit top schema with results from fields A, B, and C
  • * * diff --git a/src/test/java/com/nvidia/spark/rapids/jni/KudoSerializerTest.java b/src/test/java/com/nvidia/spark/rapids/jni/KudoSerializerTest.java new file mode 100644 index 0000000000..34f03557d7 --- /dev/null +++ b/src/test/java/com/nvidia/spark/rapids/jni/KudoSerializerTest.java @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni; + +import ai.rapids.cudf.DType; +import ai.rapids.cudf.HostColumnVector; +import ai.rapids.cudf.Schema; +import ai.rapids.cudf.Table; +import com.nvidia.spark.rapids.jni.kudo.KudoSerializer; +import com.nvidia.spark.rapids.jni.kudo.KudoTableHeader; +import org.junit.jupiter.api.Test; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; + +import static org.junit.jupiter.api.Assertions.*; + +public class KudoSerializerTest { + + @Test + public void testRowCountOnly() throws Exception { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + long bytesWritten = KudoSerializer.writeRowCountToStream(out, 5); + assertEquals(28, bytesWritten); + + ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); + KudoTableHeader header = KudoTableHeader.readFrom(new DataInputStream(in)).get(); + + assertEquals(0, header.getNumColumns()); + assertEquals(0, header.getOffset()); + assertEquals(5, header.getNumRows()); + assertEquals(0, header.getValidityBufferLen()); + assertEquals(0, header.getOffsetBufferLen()); + assertEquals(0, header.getTotalDataLen()); + } + + @Test + public void testWriteSimple() throws Exception { + KudoSerializer serializer = new KudoSerializer(buildSimpleTestSchema()); + + try(Table t = buildSimpleTable()) { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + long bytesWritten = serializer.writeToStream(t, out, 0, 4); + assertEquals(189, bytesWritten); + + ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); + + KudoTableHeader header = KudoTableHeader.readFrom(new DataInputStream(in)).get(); + assertEquals(7, header.getNumColumns()); + assertEquals(0, header.getOffset()); + assertEquals(4, header.getNumRows()); + assertEquals(24, header.getValidityBufferLen()); + assertEquals(40, header.getOffsetBufferLen()); + assertEquals(160, header.getTotalDataLen()); + + // First integer column has no validity buffer + assertFalse(header.hasValidityBuffer(0)); + for (int i=1; i<7; i++) { + assertTrue(header.hasValidityBuffer(i)); + } + } + } + + private static Schema buildSimpleTestSchema() { + Schema.Builder builder = Schema.builder(); + + builder.addColumn(DType.INT32, "a"); + builder.addColumn(DType.STRING, "b"); + Schema.Builder listBuilder = builder.addColumn(DType.LIST, "c"); + listBuilder.addColumn(DType.INT32, "c1"); + + Schema.Builder structBuilder = builder.addColumn(DType.STRUCT, "d"); + structBuilder.addColumn(DType.INT8, "d1"); + structBuilder.addColumn(DType.INT64, "d2"); + + return builder.build(); + } + + private static Table buildSimpleTable() { + HostColumnVector.StructType st = new HostColumnVector.StructType( + true, + new HostColumnVector.BasicType(true, DType.INT8), + new HostColumnVector.BasicType(true, DType.INT64) + ); + return new Table.TestBuilder() + .column(1, 2, 3, 4) + .column("1", "12", null, "45") + .column(new Integer[]{1, null, 3}, new Integer[]{4, 5, 6}, null, new Integer[]{7, 8, 9}) + .column(st, new HostColumnVector.StructData((byte)1, 11L), + new HostColumnVector.StructData ((byte)2, null), null, + new HostColumnVector.StructData((byte)3, 33L)) + .build(); + } +}