From f103c4e7b0ad2b7081b20cfa81b2c15d9daacf66 Mon Sep 17 00:00:00 2001 From: Liangcai Li Date: Mon, 6 May 2024 10:47:25 +0800 Subject: [PATCH] Support serializing tables directly for shuffle write (#37) * Support serializing packed tables directly --------- Signed-off-by: Firestarman --- .../delta/GpuOptimizeWriteExchangeExec.scala | 12 +- .../rapids/PackedTableHostColumnVector.java | 173 +++++++++++ .../spark/rapids/GpuCoalesceBatches.scala | 8 +- .../rapids/GpuColumnarBatchSerializer.scala | 150 +++++----- .../nvidia/spark/rapids/GpuPartitioning.scala | 62 ++-- .../nvidia/spark/rapids/GpuTableSerde.scala | 270 ++++++++++++++++++ .../spark/rapids/GpuTransitionOverrides.scala | 22 +- .../com/nvidia/spark/rapids/RapidsConf.scala | 12 + .../spark/sql/rapids/GpuShuffleEnv.scala | 7 +- .../GpuShuffleExchangeExecBase.scala | 40 +-- 10 files changed, 630 insertions(+), 126 deletions(-) create mode 100644 sql-plugin/src/main/java/com/nvidia/spark/rapids/PackedTableHostColumnVector.java create mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTableSerde.scala diff --git a/delta-lake/common/src/main/databricks/scala/org/apache/spark/sql/rapids/delta/GpuOptimizeWriteExchangeExec.scala b/delta-lake/common/src/main/databricks/scala/org/apache/spark/sql/rapids/delta/GpuOptimizeWriteExchangeExec.scala index 1a9936ea808..b837604b5e8 100644 --- a/delta-lake/common/src/main/databricks/scala/org/apache/spark/sql/rapids/delta/GpuOptimizeWriteExchangeExec.scala +++ b/delta-lake/common/src/main/databricks/scala/org/apache/spark/sql/rapids/delta/GpuOptimizeWriteExchangeExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * This file was derived from OptimizeWriteExchange.scala * in the Delta Lake project at https://github.com/delta-io/delta @@ -97,8 +97,12 @@ case class GpuOptimizeWriteExchangeExec( ) ++ additionalMetrics } - private lazy val serializer: Serializer = - new GpuColumnarBatchSerializer(gpuLongMetric("dataSize")) + private lazy val sparkTypes: Array[DataType] = child.output.map(_.dataType).toArray + + private lazy val serializer: Serializer = new GpuColumnarBatchSerializer( + gpuLongMetric("dataSize"), allMetrics("rapidsShuffleSerializationTime"), + allMetrics("rapidsShuffleDeserializationTime"), + partitioning.serializingOnGPU, sparkTypes) @transient lazy val inputRDD: RDD[ColumnarBatch] = child.executeColumnar() @@ -116,7 +120,7 @@ case class GpuOptimizeWriteExchangeExec( inputRDD, child.output, partitioning, - child.output.map(_.dataType).toArray, + sparkTypes, serializer, useGPUShuffle=partitioning.usesGPUShuffle, useMultiThreadedShuffle=partitioning.usesMultiThreadedShuffle, diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/PackedTableHostColumnVector.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/PackedTableHostColumnVector.java new file mode 100644 index 00000000000..667eba6c853 --- /dev/null +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/PackedTableHostColumnVector.java @@ -0,0 +1,173 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids; + +import ai.rapids.cudf.ContiguousTable; +import ai.rapids.cudf.DeviceMemoryBuffer; +import ai.rapids.cudf.HostMemoryBuffer; +import com.nvidia.spark.rapids.format.TableMeta; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A column vector that tracks a packed (or compressed) table on host. Unlike a normal + * host column vector, the columnar data within cannot be accessed directly. + * This is intended to only be used during shuffle after the data is partitioned and + * before it is serialized. + */ +public final class PackedTableHostColumnVector extends ColumnVector { + + private static final String BAD_ACCESS_MSG = "Column is packed"; + + private final TableMeta tableMeta; + private final HostMemoryBuffer tableBuffer; + + PackedTableHostColumnVector(TableMeta tableMeta, HostMemoryBuffer tableBuffer) { + super(DataTypes.NullType); + long rows = tableMeta.rowCount(); + int batchRows = (int) rows; + if (rows != batchRows) { + throw new IllegalStateException("Cannot support a batch larger that MAX INT rows"); + } + this.tableMeta = tableMeta; + this.tableBuffer = tableBuffer; + } + + private static ColumnarBatch from(TableMeta meta, DeviceMemoryBuffer devBuf) { + HostMemoryBuffer tableBuf; + try(HostMemoryBuffer buf = HostMemoryBuffer.allocate(devBuf.getLength())) { + buf.copyFromDeviceBuffer(devBuf); + buf.incRefCount(); + tableBuf = buf; + } + ColumnVector column = new PackedTableHostColumnVector(meta, tableBuf); + return new ColumnarBatch(new ColumnVector[] { column }, (int) meta.rowCount()); + } + + /** Both the input table and output batch should be closed. */ + public static ColumnarBatch from(CompressedTable table) { + return from(table.meta(), table.buffer()); + } + + /** Both the input table and output batch should be closed. */ + public static ColumnarBatch from(ContiguousTable table) { + return from(MetaUtils.buildTableMeta(0, table), table.getBuffer()); + } + + /** Returns true if this columnar batch uses a packed table on host */ + public static boolean isBatchPackedOnHost(ColumnarBatch batch) { + return batch.numCols() == 1 && batch.column(0) instanceof PackedTableHostColumnVector; + } + + public TableMeta getTableMeta() { + return tableMeta; + } + + public HostMemoryBuffer getTableBuffer() { + return tableBuffer; + } + + @Override + public void close() { + tableBuffer.close(); + } + + @Override + public boolean hasNull() { + throw new IllegalStateException(BAD_ACCESS_MSG); + } + + @Override + public int numNulls() { + throw new IllegalStateException(BAD_ACCESS_MSG); + } + + @Override + public boolean isNullAt(int rowId) { + throw new IllegalStateException(BAD_ACCESS_MSG); + } + + @Override + public boolean getBoolean(int rowId) { + throw new IllegalStateException(BAD_ACCESS_MSG); + } + + @Override + public byte getByte(int rowId) { + throw new IllegalStateException(BAD_ACCESS_MSG); + } + + @Override + public short getShort(int rowId) { + throw new IllegalStateException(BAD_ACCESS_MSG); + } + + @Override + public int getInt(int rowId) { + throw new IllegalStateException(BAD_ACCESS_MSG); + } + + @Override + public long getLong(int rowId) { + throw new IllegalStateException(BAD_ACCESS_MSG); + } + + @Override + public float getFloat(int rowId) { + throw new IllegalStateException(BAD_ACCESS_MSG); + } + + @Override + public double getDouble(int rowId) { + throw new IllegalStateException(BAD_ACCESS_MSG); + } + + @Override + public ColumnarArray getArray(int rowId) { + throw new IllegalStateException(BAD_ACCESS_MSG); + } + + @Override + public ColumnarMap getMap(int rowId) { + throw new IllegalStateException(BAD_ACCESS_MSG); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + throw new IllegalStateException(BAD_ACCESS_MSG); + } + + @Override + public UTF8String getUTF8String(int rowId) { + throw new IllegalStateException(BAD_ACCESS_MSG); + } + + @Override + public byte[] getBinary(int rowId) { + throw new IllegalStateException(BAD_ACCESS_MSG); + } + + @Override + public ColumnVector getChild(int ordinal) { + throw new IllegalStateException(BAD_ACCESS_MSG); + } +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala index e6dc216d7e6..2d3d5dbcbf7 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -462,7 +462,7 @@ abstract class AbstractGpuCoalesceIterator( // If we have reached the cuDF limit once, proactively filter batches // after that first limit is reached. GpuFilter.filterAndClose(cbFromIter, inputFilterTier.get, - NoopMetric, NoopMetric, opTime) + NoopMetric, NoopMetric, NoopMetric) } else { Iterator(cbFromIter) } @@ -499,7 +499,7 @@ abstract class AbstractGpuCoalesceIterator( var filteredBytes = 0L if (hasAnyToConcat) { val filteredDowIter = GpuFilter.filterAndClose(concatAllAndPutOnGPU(), - filterTier, NoopMetric, NoopMetric, opTime) + filterTier, NoopMetric, NoopMetric, NoopMetric) while (filteredDowIter.hasNext) { closeOnExcept(filteredDowIter.next()) { filteredDownCb => filteredNumRows += filteredDownCb.numRows() @@ -512,7 +512,7 @@ abstract class AbstractGpuCoalesceIterator( // filterAndClose takes ownership of CB so we should not close it on a failure // anymore... val filteredCbIter = GpuFilter.filterAndClose(cb.release, filterTier, - NoopMetric, NoopMetric, opTime) + NoopMetric, NoopMetric, NoopMetric) while (filteredCbIter.hasNext) { closeOnExcept(filteredCbIter.next()) { filteredCb => val filteredWouldBeRows = filteredNumRows + filteredCb.numRows() diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuColumnarBatchSerializer.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuColumnarBatchSerializer.scala index 049f3f21bcf..d785a7884fc 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuColumnarBatchSerializer.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuColumnarBatchSerializer.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,16 +25,16 @@ import scala.reflect.ClassTag import ai.rapids.cudf.{HostColumnVector, HostMemoryBuffer, JCudfSerialization, NvtxColor, NvtxRange} import ai.rapids.cudf.JCudfSerialization.SerializedTableHeader import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} -import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer, SerializerInstance} -import org.apache.spark.sql.types.NullType -import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.sql.types.{DataType, NullType} +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector => SparkColumnVector} -class SerializedBatchIterator(dIn: DataInputStream) - extends Iterator[(Int, ColumnarBatch)] { +class SerializedBatchIterator(dIn: DataInputStream, deserTime: GpuMetric +) extends Iterator[(Int, ColumnarBatch)] { private[this] var nextHeader: Option[SerializedTableHeader] = None private[this] var toBeReturned: Option[ColumnarBatch] = None private[this] var streamClosed: Boolean = false @@ -90,14 +90,16 @@ class SerializedBatchIterator(dIn: DataInputStream) } override def hasNext: Boolean = { - tryReadNextHeader() + deserTime.ns(tryReadNextHeader()) nextHeader.isDefined } override def next(): (Int, ColumnarBatch) = { if (toBeReturned.isEmpty) { - tryReadNextHeader() - toBeReturned = tryReadNext() + deserTime.ns { + tryReadNextHeader() + toBeReturned = tryReadNext() + } if (nextHeader.isEmpty || toBeReturned.isEmpty) { throw new NoSuchElementException("Walked off of the end...") } @@ -108,6 +110,7 @@ class SerializedBatchIterator(dIn: DataInputStream) (0, ret) } } + /** * Serializer for serializing `ColumnarBatch`s for use during normal shuffle. * @@ -124,67 +127,35 @@ class SerializedBatchIterator(dIn: DataInputStream) * * @note The RAPIDS shuffle does not use this code. */ -class GpuColumnarBatchSerializer(dataSize: GpuMetric) - extends Serializer with Serializable { +class GpuColumnarBatchSerializer(dataSize: GpuMetric, + serTime: GpuMetric = NoopMetric, + deserTime: GpuMetric = NoopMetric, + isSerializedTable: Boolean = false, + sparkTypes: Array[DataType] = Array.empty) extends Serializer with Serializable { override def newInstance(): SerializerInstance = - new GpuColumnarBatchSerializerInstance(dataSize) + new GpuColumnarBatchSerializerInstance(dataSize, serTime, deserTime, + isSerializedTable, sparkTypes) override def supportsRelocationOfSerializedObjects: Boolean = true } -private class GpuColumnarBatchSerializerInstance(dataSize: GpuMetric) extends SerializerInstance { +private class GpuColumnarBatchSerializerInstance(dataSize: GpuMetric, serTime: GpuMetric, + deserTime: GpuMetric, isSerializedTable: Boolean, sparkTypes: Array[DataType] +) extends SerializerInstance { - override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream { - private[this] val dOut: DataOutputStream = - new DataOutputStream(new BufferedOutputStream(out)) + override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream + with Logging { + private[this] val dOut = new DataOutputStream(new BufferedOutputStream(out)) + private[this] val tableSerializer = new SimpleTableSerializer() + onTaskCompletion(TaskContext.get())(tableSerializer.close()) - override def writeValue[T: ClassTag](value: T): SerializationStream = { - val batch = value.asInstanceOf[ColumnarBatch] - val numColumns = batch.numCols() - val columns: Array[HostColumnVector] = new Array(numColumns) - val toClose = new ArrayBuffer[AutoCloseable]() - try { - var startRow = 0 - val numRows = batch.numRows() - if (batch.numCols() > 0) { - val firstCol = batch.column(0) - if (firstCol.isInstanceOf[SlicedGpuColumnVector]) { - // We don't have control over ColumnarBatch to put in the slice, so we have to do it - // for each column. In this case we are using the first column. - startRow = firstCol.asInstanceOf[SlicedGpuColumnVector].getStart - for (i <- 0 until numColumns) { - columns(i) = batch.column(i).asInstanceOf[SlicedGpuColumnVector].getBase - } - } else { - for (i <- 0 until numColumns) { - batch.column(i) match { - case gpu: GpuColumnVector => - val cpu = gpu.copyToHost() - toClose += cpu - columns(i) = cpu.getBase - case cpu: RapidsHostColumnVector => - columns(i) = cpu.getBase - } - } - } + private lazy val serializeBatch: ColumnarBatch => Unit = if (isSerializedTable) { + serializeGpuBatch + } else { + serializeCpuBatch + } - dataSize += JCudfSerialization.getSerializedSizeInBytes(columns, startRow, numRows) - val range = new NvtxRange("Serialize Batch", NvtxColor.YELLOW) - try { - JCudfSerialization.writeToStream(columns, dOut, startRow, numRows) - } finally { - range.close() - } - } else { - val range = new NvtxRange("Serialize Row Only Batch", NvtxColor.YELLOW) - try { - JCudfSerialization.writeRowsToStream(dOut, numRows) - } finally { - range.close() - } - } - } finally { - toClose.safeClose() - } + override def writeValue[T: ClassTag](value: T): SerializationStream = { + serTime.ns(withResource(value.asInstanceOf[ColumnarBatch])(serializeBatch)) this } @@ -211,16 +182,65 @@ private class GpuColumnarBatchSerializerInstance(dataSize: GpuMetric) extends Se override def close(): Unit = { dOut.close() + tableSerializer.close() + } + + private def serializeCpuBatch(batch: ColumnarBatch): Unit = { + val numRows = batch.numRows() + val numCols = batch.numCols() + if (numCols > 0) { + withResource(new ArrayBuffer[AutoCloseable]()) { toClose => + var startRow = 0 + val cols = closeOnExcept(batch) { _ => + val toHostCol: SparkColumnVector => HostColumnVector = batch.column(0) match { + case sliced: SlicedGpuColumnVector => + // We don't have control over ColumnarBatch to put in the slice, so we have + // to do it for each column. In this case we are using the first column. + startRow = sliced.getStart + col => col.asInstanceOf[SlicedGpuColumnVector].getBase + case _: GpuColumnVector => + col => { + val hCol = col.asInstanceOf[GpuColumnVector].copyToHost() + toClose += hCol + hCol.getBase + } + case _: RapidsHostColumnVector => + col => col.asInstanceOf[RapidsHostColumnVector].getBase + } + (0 until numCols).map(i => toHostCol(batch.column(i))).toArray + } + dataSize += JCudfSerialization.getSerializedSizeInBytes(cols, startRow, numRows) + withResource(new NvtxRange("Serialize Batch", NvtxColor.YELLOW)) { _ => + JCudfSerialization.writeToStream(cols, dOut, startRow, numRows) + } + } + } else { // Rows only batch + withResource(new NvtxRange("Serialize Row Only Batch", NvtxColor.YELLOW)) { _ => + JCudfSerialization.writeRowsToStream(dOut, numRows) + } + } } - } + private def serializeGpuBatch(batch: ColumnarBatch): Unit = { + if (batch.numCols() > 0) { + val packedCol = batch.column(0).asInstanceOf[PackedTableHostColumnVector] + dataSize += tableSerializer.writeToStream(packedCol, dOut) + } else { + dataSize += tableSerializer.writeRowsOnlyToStream(batch.numRows(), dOut) + } + } + } override def deserializeStream(in: InputStream): DeserializationStream = { new DeserializationStream { private[this] val dIn: DataInputStream = new DataInputStream(new BufferedInputStream(in)) override def asKeyValueIterator: Iterator[(Int, ColumnarBatch)] = { - new SerializedBatchIterator(dIn) + if (isSerializedTable) { + new SerializedTableIterator(dIn, sparkTypes, deserTime) + } else { + new SerializedBatchIterator(dIn, deserTime) + } } override def asIterator: Iterator[Any] = { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuPartitioning.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuPartitioning.scala index 0abf0cf7609..e2081eebb8c 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuPartitioning.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuPartitioning.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,11 +35,13 @@ object GpuPartitioning { } trait GpuPartitioning extends Partitioning { - private[this] val (maxCompressionBatchSize, _useGPUShuffle, _useMultiThreadedShuffle) = { + private[this] val (maxCompressionBatchSize, _useGPUShuffle, _useMultiThreadedShuffle, + _serializingOnGPU) = { val rapidsConf = new RapidsConf(SQLConf.get) (rapidsConf.shuffleCompressionMaxBatchMemory, GpuShuffleEnv.useGPUShuffle(rapidsConf), - GpuShuffleEnv.useMultiThreadedShuffle(rapidsConf)) + GpuShuffleEnv.useMultiThreadedShuffle(rapidsConf), + GpuShuffleEnv.serializingOnGpu(rapidsConf)) } final def columnarEval(batch: ColumnarBatch): GpuColumnVector = { @@ -54,6 +56,32 @@ trait GpuPartitioning extends Partitioning { def usePaddingPartition: Boolean = false + def serializingOnGPU: Boolean = _serializingOnGPU + + private lazy val toPackedBatch: ContiguousTable => ColumnarBatch = + if (_serializingOnGPU) { + table => + withResource(new NvtxRange("Table to Host", NvtxColor.BLUE)) { _ => + withResource(table) { _ => + PackedTableHostColumnVector.from(table) + } + } + } else { + GpuPackedTableColumn.from + } + + private lazy val toCompressedBatch: CompressedTable => ColumnarBatch = + if (_serializingOnGPU) { + table => + withResource(new NvtxRange("Table to Host", NvtxColor.BLUE)) { _ => + withResource(table) { _ => + PackedTableHostColumnVector.from(table) + } + } + } else { + GpuCompressedColumnVector.from + } + def sliceBatch(vectors: Array[RapidsHostColumnVector], start: Int, end: Int): ColumnarBatch = { var ret: ColumnarBatch = null val count = end - start @@ -67,7 +95,7 @@ trait GpuPartitioning extends Partitioning { def sliceInternalOnGpuAndClose(numRows: Int, partitionIndexes: Array[Int], partitionColumns: Array[GpuColumnVector]): Array[ColumnarBatch] = { // The first index will always be 0, so we need to skip it. - val batches = if (numRows > 0) { + if (numRows > 0) { val parts = partitionIndexes.slice(1, partitionIndexes.length) closeOnExcept(new ArrayBuffer[ColumnarBatch](numPartitions)) { splits => val contiguousTables = withResource(partitionColumns) { _ => @@ -79,23 +107,24 @@ trait GpuPartitioning extends Partitioning { case Some(codec) => compressSplits(splits, codec, contiguousTables) case None => - // GpuPackedTableColumn takes ownership of the contiguous tables - closeOnExcept(contiguousTables) { cts => - cts.foreach { ct => splits.append(GpuPackedTableColumn.from(ct)) } - } + // ColumnarBatch takes ownership of the contiguous tables + closeOnExcept(contiguousTables)(_.foreach(ct => splits.append(toPackedBatch(ct)))) } + // synchronize our stream to ensure we have caught up with contiguous split // as downstream consumers (RapidsShuffleManager) will add hundreds of buffers // to the spill framework, this makes it so here we synchronize once. Cuda.DEFAULT_STREAM.sync() + + if (_serializingOnGPU) { + // All the data should be on host now for shuffle, leaving GPU for a while. + GpuSemaphore.releaseIfNecessary(TaskContext.get()) + } splits.toArray } } else { Array[ColumnarBatch]() } - - GpuSemaphore.releaseIfNecessary(TaskContext.get()) - batches } private def reslice(batch: ColumnarBatch, numSlices: Int): Seq[ColumnarBatch] = { @@ -188,7 +217,7 @@ trait GpuPartitioning extends Partitioning { def sliceInternalGpuOrCpuAndClose(numRows: Int, partitionIndexes: Array[Int], partitionColumns: Array[GpuColumnVector]): Array[(ColumnarBatch, Int)] = { - val sliceOnGpu = usesGPUShuffle + val sliceOnGpu = usesGPUShuffle || _serializingOnGPU val nvtxRangeKey = if (sliceOnGpu) { "sliceInternalOnGpu" } else { @@ -226,7 +255,7 @@ trait GpuPartitioning extends Partitioning { // add each table either to the batch to be compressed or to the empty batch tracker contiguousTables.zipWithIndex.foreach { case (ct, i) => if (ct.getRowCount == 0) { - emptyBatches.append((GpuPackedTableColumn.from(ct), i)) + emptyBatches.append((toPackedBatch(ct), i)) } else { compressor.addTableToCompress(ct) } @@ -240,18 +269,15 @@ trait GpuPartitioning extends Partitioning { // add any compressed batches that need to appear before the next empty batch val numCompressedToAdd = emptyOutputIndex - outputIndex (0 until numCompressedToAdd).foreach { _ => - val compressedTable = compressedTables(compressedTableIndex) - outputBatches.append(GpuCompressedColumnVector.from(compressedTable)) + outputBatches.append(toCompressedBatch(compressedTables(compressedTableIndex))) compressedTableIndex += 1 } outputBatches.append(emptyBatch) outputIndex = emptyOutputIndex + 1 } - // add any compressed batches that remain after the last empty batch (compressedTableIndex until compressedTables.length).foreach { i => - val ct = compressedTables(i) - outputBatches.append(GpuCompressedColumnVector.from(ct)) + outputBatches.append(toCompressedBatch(compressedTables(i))) } } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTableSerde.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTableSerde.scala new file mode 100644 index 00000000000..b1ef58dbe60 --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTableSerde.scala @@ -0,0 +1,270 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import java.io.{DataInputStream, DataOutputStream, EOFException} +import java.nio.ByteBuffer + +import ai.rapids.cudf.{DeviceMemoryBuffer, HostMemoryBuffer, NvtxColor, NvtxRange} +import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} +import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion +import com.nvidia.spark.rapids.format.TableMeta + +import org.apache.spark.TaskContext +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.vectorized.ColumnarBatch + +private sealed trait TableSerde extends AutoCloseable { + protected val P_MAGIC_NUM: Int = 0x43554447 // "CUDF" + 1 + protected val P_VERSION: Int = 0 + protected val headerLen = 8 // the size in bytes of two Ints for a header + + // buffers for reuse, so it is should be only one instance of this trait per thread. + protected val tmpBuf = new Array[Byte](1024 * 64) // 64k + protected var hostBuffer: HostMemoryBuffer = _ + + protected def getHostBuffer(len: Long): HostMemoryBuffer = { + assert(len >= 0) + if (hostBuffer != null && len <= hostBuffer.getLength) { + hostBuffer.slice(0, len) + } else { // hostBuffer is null or len is larger than the current one + if (hostBuffer != null) { + hostBuffer.close() + } + hostBuffer = HostMemoryBuffer.allocate(len) + hostBuffer.slice(0, len) + } + } + + override def close(): Unit = { + if (hostBuffer != null) { + hostBuffer.close() + hostBuffer = null + } + } +} + +private[rapids] class SimpleTableSerializer extends TableSerde { + private def writeByteBufferToStream(bBuf: ByteBuffer, dOut: DataOutputStream): Unit = { + // Write the buffer size first + val bufLen = bBuf.capacity() + dOut.writeLong(bufLen.toLong) + if (bBuf.hasArray) { + dOut.write(bBuf.array()) + } else { // Probably a direct buffer + var leftLen = bufLen + while (leftLen > 0) { + val copyLen = Math.min(tmpBuf.length, leftLen) + bBuf.get(tmpBuf, 0, copyLen) + dOut.write(tmpBuf, 0, copyLen) + leftLen -= copyLen + } + } + } + + private def writeHostBufferToStream(hBuf: HostMemoryBuffer, dOut: DataOutputStream): Unit = { + // Write the buffer size first + val bufLen = hBuf.getLength + dOut.writeLong(bufLen) + var leftLen = bufLen + var hOffset = 0L + while (leftLen > 0L) { + val copyLen = Math.min(tmpBuf.length, leftLen) + hBuf.getBytes(tmpBuf, 0, hOffset, copyLen) + dOut.write(tmpBuf, 0, copyLen.toInt) + leftLen -= copyLen + hOffset += copyLen + } + } + + private def writeProtocolHeader(dOut: DataOutputStream): Unit = { + dOut.writeInt(P_MAGIC_NUM) + dOut.writeInt(P_VERSION) + } + + def writeRowsOnlyToStream(numRows: Int, dOut: DataOutputStream): Long = { + withResource(new NvtxRange("Serialize Rows Only Table", NvtxColor.RED)) { _ => + val degenBatch = new ColumnarBatch(Array.empty, numRows) + val tableMetaBuf = MetaUtils.buildDegenerateTableMeta(degenBatch).getByteBuffer + // 1) header, 2) metadata for an empty batch + writeProtocolHeader(dOut) + writeByteBufferToStream(tableMetaBuf, dOut) + headerLen + tableMetaBuf.capacity() + } + } + + def writeToStream(hostTbl: PackedTableHostColumnVector, dOut: DataOutputStream): Long = { + withResource(new NvtxRange("Serialize Host Table", NvtxColor.RED)) { _ => + // In the order of 1) header, 2) table metadata, 3) table data on host + val metaBuf = hostTbl.getTableMeta.getByteBuffer + val dataBuf = hostTbl.getTableBuffer + writeProtocolHeader(dOut) + writeByteBufferToStream(metaBuf, dOut) + writeHostBufferToStream(dataBuf, dOut) + headerLen + metaBuf.capacity() + dataBuf.getLength + } + } +} + +private[rapids] class SimpleTableDeserializer(sparkTypes: Array[DataType]) extends TableSerde { + private def readProtocolHeader(dIn: DataInputStream): Unit = { + val magicNum = dIn.readInt() + if (magicNum != P_MAGIC_NUM) { + throw new IllegalStateException(s"Expected magic number $P_MAGIC_NUM for " + + s"table serializer, but got $magicNum") + } + val version = dIn.readInt() + if (version != P_VERSION) { + throw new IllegalStateException(s"Version mismatch: expected $P_VERSION for " + + s"table serializer, but got $version") + } + } + + private def readByteBufferFromStream(dIn: DataInputStream): ByteBuffer = { + val bufLen = dIn.readLong().toInt + val bufArray = new Array[Byte](bufLen) + var readLen = 0 + // A single call to read(bufArray) can not always read the expected length. So + // we do it here ourselves. + do { + val ret = dIn.read(bufArray, readLen, bufLen - readLen) + if (ret < 0) { + throw new EOFException() + } + readLen += ret + } while (readLen < bufLen) + ByteBuffer.wrap(bufArray) + } + + private def readHostBufferFromStream(dIn: DataInputStream): HostMemoryBuffer = { + val bufLen = dIn.readLong() + closeOnExcept(getHostBuffer(bufLen)) { hostBuf => + var leftLen = bufLen + var hOffset = 0L + while (leftLen > 0) { + val copyLen = Math.min(tmpBuf.length, leftLen) + val readLen = dIn.read(tmpBuf, 0, copyLen.toInt) + if (readLen < 0) { + throw new EOFException() + } + hostBuf.setBytes(hOffset, tmpBuf, 0, readLen) + hOffset += readLen + leftLen -= readLen + } + hostBuf + } + } + + def readFromStream(dIn: DataInputStream): ColumnarBatch = { + // IO operation is coming, so leave GPU for a while. + GpuSemaphore.releaseIfNecessary(TaskContext.get()) + // 1) read and check header + readProtocolHeader(dIn) + // 2) read table metadata + val tableMeta = TableMeta.getRootAsTableMeta(readByteBufferFromStream(dIn)) + if (tableMeta.packedMetaAsByteBuffer() == null) { + // no packed metadata, must be a table with zero columns + // Acquiring the GPU even the coming batch is empty, because the downstream + // tasks expect the GPU batch producer to acquire the semaphore and may + // generate GPU data from batches that are empty. + GpuSemaphore.acquireIfNecessary(TaskContext.get()) + new ColumnarBatch(Array.empty, tableMeta.rowCount().toInt) + } else { + // 3) read table data + val hostBuf = withResource(new NvtxRange("Read Host Table", NvtxColor.ORANGE)) { _ => + readHostBufferFromStream(dIn) + } + val data = withResource(hostBuf) { _ => + // Begin to use GPU + GpuSemaphore.acquireIfNecessary(TaskContext.get()) + withResource(new NvtxRange("Table to Device", NvtxColor.YELLOW)) { _ => + closeOnExcept(DeviceMemoryBuffer.allocate(hostBuf.getLength)) { devBuf => + devBuf.copyFromHostBuffer(hostBuf) + devBuf + } + } + } + withResource(new NvtxRange("Deserialize Table", NvtxColor.RED)) { _ => + withResource(data) { _ => + val bufferMeta = tableMeta.bufferMeta() + if (bufferMeta == null || bufferMeta.codecBufferDescrsLength == 0) { + MetaUtils.getBatchFromMeta(data, tableMeta, sparkTypes) + } else { + // Compressed table is not supported by the write side, but ok to + // put it here for the read side. Since compression will be supported later. + GpuCompressedColumnVector.from(data, tableMeta) + } + } + } + } + } +} + +private[rapids] class SerializedTableIterator(dIn: DataInputStream, + sparkTypes: Array[DataType], + deserTime: GpuMetric) extends Iterator[(Int, ColumnarBatch)] { + + private val tableDeserializer = new SimpleTableDeserializer(sparkTypes) + private var closed = false + private var onDeck: Option[SpillableColumnarBatch] = None + Option(TaskContext.get()).foreach { tc => + onTaskCompletion(tc) { + onDeck.foreach(_.close()) + onDeck = None + tableDeserializer.close() + if (!closed) { + dIn.close() + } + } + } + + override def hasNext: Boolean = { + if (onDeck.isEmpty) { + tryReadNextBatch() + } + onDeck.isDefined + } + + override def next(): (Int, ColumnarBatch) = { + if (!hasNext) { + throw new NoSuchElementException() + } + val ret = withResource(onDeck) { _ => + onDeck.get.getColumnarBatch() + } + onDeck = None + (0, ret) + } + + private def tryReadNextBatch(): Unit = { + if (closed) { + return + } + try { + onDeck = deserTime.ns( + Some(SpillableColumnarBatch(tableDeserializer.readFromStream(dIn), + SpillPriorities.ACTIVE_ON_DECK_PRIORITY)) + ) + } catch { + case _: EOFException => // we reach the end + dIn.close() + closed = true + onDeck.foreach(_.close()) + onDeck = None + } + } +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala index dc6b658abd6..2757d95cc7e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala @@ -64,13 +64,16 @@ class GpuTransitionOverrides extends Rule[SparkPlan] { ProjectExec(exprs, c2r) }.getOrElse(c2r) p.withNewChildren(Array(newChild)) + case exec: GpuShuffleExchangeExecBase => + addPostShuffleCoalesce( + exec.withNewChildren(Seq(optimizeGpuPlanTransitions(exec.child)))) case p => p.withNewChildren(p.children.map(optimizeGpuPlanTransitions)) } /** Adds the appropriate coalesce after a shuffle depending on the type of shuffle configured */ private def addPostShuffleCoalesce(plan: SparkPlan): SparkPlan = { - if (GpuShuffleEnv.useGPUShuffle(rapidsConf)) { + if (GpuShuffleEnv.useGPUShuffle(rapidsConf) || GpuShuffleEnv.serializingOnGpu(rapidsConf)) { GpuCoalesceBatches(plan, TargetSize(rapidsConf.gpuTargetBatchSizeBytes)) } else { GpuShuffleCoalesceExec(plan, rapidsConf.gpuTargetBatchSizeBytes) @@ -511,19 +514,6 @@ class GpuTransitionOverrides extends Rule[SparkPlan] { p.withNewChildren(p.children.map(c => insertCoalesce(c, shouldDisable))) } - /** - * Inserts a shuffle coalesce after every shuffle to coalesce the serialized tables - * on the host before copying the data to the GPU. - * @note This should not be used in combination with the RAPIDS shuffle. - */ - private def insertShuffleCoalesce(plan: SparkPlan): SparkPlan = plan match { - case exec: GpuShuffleExchangeExecBase => - // always follow a GPU shuffle with a shuffle coalesce - GpuShuffleCoalesceExec(exec.withNewChildren(exec.children.map(insertShuffleCoalesce)), - rapidsConf.gpuTargetBatchSizeBytes) - case exec => exec.withNewChildren(plan.children.map(insertShuffleCoalesce)) - } - /** * Inserts a transition to be running on the CPU columnar */ @@ -796,10 +786,6 @@ class GpuTransitionOverrides extends Rule[SparkPlan] { } updatedPlan = insertColumnarFromGpu(updatedPlan) updatedPlan = insertCoalesce(updatedPlan) - // only insert shuffle coalesces when using normal shuffle - if (!GpuShuffleEnv.useGPUShuffle(rapidsConf)) { - updatedPlan = insertShuffleCoalesce(updatedPlan) - } if (plan.conf.adaptiveExecutionEnabled) { updatedPlan = optimizeAdaptiveTransitions(updatedPlan, None) } else { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index cc232052c65..cf71ba963a2 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -1819,6 +1819,16 @@ val SHUFFLE_COMPRESSION_LZ4_CHUNK_SIZE = conf("spark.rapids.shuffle.compression. .booleanConf .createWithDefault(false) + val SHUFFLE_WRITER_GPU_SERIALIZING = + conf("spark.rapids.shuffle.serializeOnGpu.enabled") + .doc("When true, the batch serializing for Shuffle will run on GPU. " + + "It requires making sure the shuffle writer currently being used is compatible " + + "with this GPU serializing.") + .internal() + .startupOnly() + .booleanConf + .createWithDefault(false) + // ALLUXIO CONFIGS val ALLUXIO_MASTER = conf("spark.rapids.alluxio.master") .doc("The Alluxio master hostname. If not set, read Alluxio master URL from " + @@ -2838,6 +2848,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val shuffleMultiThreadedReaderThreads: Int = get(SHUFFLE_MULTITHREADED_READER_THREADS) + lazy val isSerializingOnGpu: Boolean = get(SHUFFLE_WRITER_GPU_SERIALIZING) + lazy val shuffleEnablePaddingPartition: Boolean = get(SHUFFLE_ENABLE_PADDING_PARTITION) def isUCXShuffleManagerMode: Boolean = diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala index 1682dd13c22..eff6a8076db 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -141,6 +141,11 @@ object GpuShuffleEnv extends Logging { isRapidsShuffleAvailable(conf) } + def serializingOnGpu(conf: RapidsConf): Boolean = { + // Serializing on GPU for CPU shuffle conflicts with GPU shuffle + conf.isSerializingOnGpu && (!useGPUShuffle(conf)) + } + def getCatalog: ShuffleBufferCatalog = if (env == null) { null } else { diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExecBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExecBase.scala index 5323fc89019..de90a6c025b 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExecBase.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExecBase.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,9 +21,10 @@ import scala.concurrent.Future import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ +import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion import com.nvidia.spark.rapids.shims.{GpuHashPartitioning, GpuRangePartitioning, ShimUnaryExecNode, ShuffleOriginUtil, SparkShimImpl} -import org.apache.spark.{MapOutputStatistics, ShuffleDependency} +import org.apache.spark.{MapOutputStatistics, ShuffleDependency, TaskContext} import org.apache.spark.rapids.shims.GpuShuffleExchangeExec import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer @@ -199,15 +200,17 @@ abstract class GpuShuffleExchangeExecBase( "dataSize" -> createSizeMetric(ESSENTIAL_LEVEL,"data size"), "dataReadSize" -> createSizeMetric(MODERATE_LEVEL, "data read size"), "rapidsShuffleSerializationTime" -> - createNanoTimingMetric(DEBUG_LEVEL,"rs. serialization time"), + createNanoTimingMetric(MODERATE_LEVEL,"rs. serialization time"), "rapidsShuffleDeserializationTime" -> - createNanoTimingMetric(DEBUG_LEVEL,"rs. deserialization time"), + createNanoTimingMetric(MODERATE_LEVEL,"rs. deserialization time"), "rapidsShuffleWriteTime" -> createNanoTimingMetric(ESSENTIAL_LEVEL,"rs. shuffle write time"), "rapidsShuffleCombineTime" -> createNanoTimingMetric(DEBUG_LEVEL,"rs. shuffle combine time"), "rapidsShuffleWriteIoTime" -> createNanoTimingMetric(DEBUG_LEVEL,"rs. shuffle write io time"), + "rapidsShufflePartitionTime" -> + createNanoTimingMetric(MODERATE_LEVEL, "rs. shuffle partition time"), "rapidsShuffleReadTime" -> createNanoTimingMetric(ESSENTIAL_LEVEL,"rs. shuffle read time") ) ++ GpuMetric.wrap(readMetrics) ++ GpuMetric.wrap(writeMetrics) @@ -231,7 +234,10 @@ abstract class GpuShuffleExchangeExecBase( // This value must be lazy because the child's output may not have been resolved // yet in all cases. private lazy val serializer: Serializer = new GpuColumnarBatchSerializer( - gpuLongMetric("dataSize")) + gpuLongMetric("dataSize"), allMetrics("rapidsShuffleSerializationTime"), + allMetrics("rapidsShuffleDeserializationTime"), + gpuOutputPartitioning.serializingOnGPU, sparkTypes, + ) @transient lazy val inputBatchRDD: RDD[ColumnarBatch] = child.executeColumnar() @@ -314,7 +320,8 @@ object GpuShuffleExchangeExecBase { } val partitioner: GpuExpression = getPartitioner(newRdd, outputAttributes, newPartitioning) def getPartitioned: ColumnarBatch => Any = { - batch => partitioner.columnarEvalAny(batch) + val partitionMetric = metrics("rapidsShufflePartitionTime") + batch => partitionMetric.ns(partitioner.columnarEvalAny(batch)) } val rddWithPartitionIds: RDD[Product2[Int, ColumnarBatch]] = { newRdd.mapPartitions { iter => @@ -323,12 +330,17 @@ object GpuShuffleExchangeExecBase { private var partitioned : Array[(ColumnarBatch, Int)] = _ private var at = 0 private val mutablePair = new MutablePair[Int, ColumnarBatch]() - private def partNextBatch(): Unit = { - if (partitioned != null) { - partitioned.map(_._1).safeClose() - partitioned = null - at = 0 + Option(TaskContext.get()).foreach { tc => + onTaskCompletion(tc) { + if (partitioned != null) { + partitioned.drop(at).map(_._1).safeClose() + } } + } + + private def partNextBatch(): Unit = { + partitioned = null + at = 0 if (iter.hasNext) { var batch = iter.next() while (batch.numRows == 0 && iter.hasNext) { @@ -343,7 +355,6 @@ object GpuShuffleExchangeExecBase { metrics(GpuMetric.NUM_OUTPUT_ROWS) += batches._1.numRows() }) metrics(GpuMetric.NUM_OUTPUT_BATCHES) += partitioned.length - at = 0 } else { batch.close() } @@ -359,10 +370,7 @@ object GpuShuffleExchangeExecBase { } override def next(): Product2[Int, ColumnarBatch] = { - if (partitioned == null || at >= partitioned.length) { - partNextBatch() - } - if (partitioned == null || at >= partitioned.length) { + if (!hasNext) { throw new NoSuchElementException("Walked off of the end...") } val tup = partitioned(at)