diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnVector.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnVector.java index c7913cd93e5..a59e03fdec3 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnVector.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnVector.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -79,6 +79,26 @@ public static long getTotalHostMemoryUsed(ColumnarBatch batch) { return sum; } + // The size in bytes of an offset entry is 4 + final static int OFFSET_STEP = 4; + + // The size in bytes of an offset entry is 4, so shift value is 2. + final static int OFFSET_SHIFT_STEP = 2; + + public static long getOffsetBufferSize(int numRows) { + // The size in bytes of an offset entry is 4, so the buffer size is: + // (numRows + 1) * 4. + return ((long)numRows + 1) << OFFSET_SHIFT_STEP; + } + + public static long getValidityBufferSize(int numRows) { + // This is the same as ColumnView.getValidityBufferSize + // number of bytes required = Math.ceil(number of bits / 8) + long actualBytes = ((long) numRows + 7) >> 3; + // padding to the multiplies of the padding boundary(64 bytes) + return ((actualBytes + 63) >> 6) << 6; + } + private final ai.rapids.cudf.HostColumnVector cudfCv; /** diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/SlicedGpuColumnVector.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/SlicedGpuColumnVector.java index 0693fdfe922..811eabda2ff 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/SlicedGpuColumnVector.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/SlicedGpuColumnVector.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -176,11 +176,7 @@ private static long getSizeOf(HostColumnVectorCore cv, int start, int end) { if (end > start) { ai.rapids.cudf.HostMemoryBuffer validity = cv.getValidity(); if (validity != null) { - // This is the same as ColumnView.getValidityBufferSize - // number of bytes required = Math.ceil(number of bits / 8) - long actualBytes = ((long) (end - start) + 7) >> 3; - // padding to the multiplies of the padding boundary(64 bytes) - total += ((actualBytes + 63) >> 6) << 6; + total += RapidsHostColumnVector.getValidityBufferSize(end - start); } ai.rapids.cudf.HostMemoryBuffer off = cv.getOffsets(); if (off != null) { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchUtils.scala index 92588885be0..aef243cd849 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchUtils.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,13 +16,15 @@ package com.nvidia.spark.rapids +import scala.collection.JavaConverters.seqAsJavaListConverter import scala.collection.mutable.ArrayBuffer -import ai.rapids.cudf.Table -import com.nvidia.spark.rapids.Arm.withResource -import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableProducingSeq +import ai.rapids.cudf.{DType, HostColumnVector, HostColumnVectorCore, HostMemoryBuffer, Table} +import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} +import com.nvidia.spark.rapids.RapidsPluginImplicits.{AutoCloseableProducingSeq, AutoCloseableSeq} import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, MapType, StructType} +import org.apache.spark.sql.vectorized.ColumnarBatch /** * Utility class with methods for calculating various metrics about GPU memory usage @@ -212,4 +214,220 @@ object GpuBatchUtils { Option(retBatch) } + + /** + * Only support batches sliced on CPU for shuffle, meaning the internal + * columns are instances of SlicedGpuColumnVector. + */ + def concatShuffleBatchesAndClose(batches: Seq[ColumnarBatch], + totalSize: Option[Long] = None): ColumnarBatch = { + val (nonEmptyCBs, emptyCbs) = batches.partition(_.numRows() > 0) + if (nonEmptyCBs.nonEmpty) { + emptyCbs.safeClose() + if (nonEmptyCBs.length == 1) { + nonEmptyCBs.head + } else { // more than one batch + withResource(nonEmptyCBs) { _ => + concatShuffleBatches(nonEmptyCBs, totalSize) + } + } + } else { + assert(emptyCbs.nonEmpty) + emptyCbs.tail.safeClose() + emptyCbs.head + } + } + + private def concatShuffleBatches(batches: Seq[ColumnarBatch], + totalSize: Option[Long]): ColumnarBatch = { + val numCols = batches.head.numCols() + // all batches should have the same columns number + batches.tail.foreach(b => assert(numCols == b.numCols())) + val sizeSum = totalSize.getOrElse( + batches.map(SlicedGpuColumnVector.getTotalHostMemoryUsed).sum + ) + (numCols << 6) // For the validity padding, numCols * 64 for the worst case + val concatNumRows = batches.map(_.numRows()).sum + // Allocate a single buffer for the merged batch. + val concatHostCols = withResource(HostMemoryBuffer.allocate(sizeSum)) { allBuf => + var outOff = 0L + (0 until numCols).safeMap { idx => + val cols = batches.map(_.column(idx).asInstanceOf[SlicedGpuColumnVector]) + // Concatenate the input sliced columns + val (concatCol, concatLen) = concatSlicedColumns(cols, concatNumRows, allBuf, outOff) + withResource(concatCol) { _ => + outOff += concatLen + // The downstream shuffle writer expects SlicedGpuColumnVectors + new SlicedGpuColumnVector(concatCol, 0, concatNumRows) + } + } + } + new ColumnarBatch(concatHostCols.toArray, concatNumRows) + } + + private def concatSlicedColumns(cols: Seq[SlicedGpuColumnVector], totalRowsNum: Int, + outBuf: HostMemoryBuffer, outOffset: Long): (RapidsHostColumnVector, Long) = { + // All should have the same type + val colSparkType = cols.head.dataType() + assert(cols.tail.forall(_.dataType() == colSparkType), + s"All the column types should be $colSparkType, but got (" + + s"${cols.map(_.dataType()).mkString("; ")})") + val (cudfHostColumn, colLen) = concatSlicedColumns( + cols.map(c => (c.getBase, c.getStart, c.getEnd)), outBuf, outOffset, Some(totalRowsNum)) + (new RapidsHostColumnVector(colSparkType, cudfHostColumn), colLen) + } + + /** (TODO Move concatenating HostColumnVectors to Rapids JNI) */ + private def concatSlicedColumns(cols: Seq[(HostColumnVectorCore, Int, Int)], + outBuf: HostMemoryBuffer, outOffset: Long, + totalRowsNum: Option[Int] = None): (HostColumnVector, Long) = { + val colCudfType = cols.head._1.getType + val concatRowsNum = totalRowsNum.getOrElse(cols.map(c => c._3 - c._2).sum) + var curGlobalPos = outOffset + // 1) Validity buffer. It is required if any has a validity buffer. + val (concatValidityBuf, nullCount) = if (cols.exists(_._1.hasValidityVector)) { + val concatValidityLen = RapidsHostColumnVector.getValidityBufferSize(concatRowsNum) + closeOnExcept(outBuf.slice(curGlobalPos, concatValidityLen)) { destBuf => + curGlobalPos += concatValidityLen + // Set all the bits to "1" by default. + destBuf.setMemory(0, concatValidityLen, 0xff.toByte) + var accNullCnt = 0L + var destRowsNum = 0 + cols.foreach { case (c, sStart, sEnd) => + val validityBuf = c.getValidity + if (validityBuf != null) { + // Has nulls, set it one by one + var rowId = sStart + while (rowId < sEnd) { + if (isNullAt(validityBuf, rowId)) { + setNullAt(destBuf, destRowsNum) + accNullCnt += 1 + } + rowId += 1 + destRowsNum += 1 + } + } else { // no nulls, just update the dest rows number + destRowsNum += (sEnd - sStart) + } + } + assert(destRowsNum == concatRowsNum) + (destBuf, accNullCnt) + } + } else { + (null, 0L) + } + + // 2) Offset buffer. All should has the same type, so only need to check the first one + val concatOffsetBuf = closeOnExcept(concatValidityBuf) { _ => + if (colCudfType.hasOffsets) { + val concatOffsetLen = RapidsHostColumnVector.getOffsetBufferSize(concatRowsNum) + closeOnExcept(outBuf.slice(curGlobalPos, concatOffsetLen)) { destBuf => + curGlobalPos += concatOffsetLen + val offBufStep = RapidsHostColumnVector.OFFSET_STEP + var destPos = 0L + var accOffsetValue = 0 + // Compute offsets. Suppose all should have offset buffers. + // The first one is always 0 + destBuf.setInt(destPos, accOffsetValue) + destPos += offBufStep + cols.foreach { case (c, sStart, sEnd) => + val offBuf = c.getOffsets + val offBufEnd = sEnd << RapidsHostColumnVector.OFFSET_SHIFT_STEP + var curOffBufPos = sStart << RapidsHostColumnVector.OFFSET_SHIFT_STEP + val offsetDiff = accOffsetValue - offBuf.getInt(curOffBufPos) + curOffBufPos += offBufStep + while (curOffBufPos <= offBufEnd) { + destBuf.setInt(destPos, offBuf.getInt(curOffBufPos) + offsetDiff) + destPos += offBufStep + curOffBufPos += offBufStep + } + // The last entry is offset value for the next buffer + accOffsetValue = destBuf.getInt(destPos - offBufStep) + } + assert(destPos == concatOffsetLen) + destBuf + } + } else { + null + } + } + + // 3) data buffer + val nonEmptyDataCols = cols.filter(_._1.getData != null) + val concatDataBuf = closeOnExcept(Seq(concatValidityBuf, concatOffsetBuf)) { _ => + if (nonEmptyDataCols.nonEmpty) { + // String or primitive type + type DataBufFunc = ((HostColumnVectorCore, Int, Int)) => (HostMemoryBuffer, Long, Long) + val getSlicedDataBuf: DataBufFunc = if (DType.STRING.equals(colCudfType)) { + // String type has both data and offset + c => { // c is (column, start, end) + val start = c._1.getStartListOffset(c._2) + (c._1.getData, start, c._1.getEndListOffset(c._3 - 1) - start) + } + } else { // non-nested type + c => { // c is (column, start, end) + val typeSize = colCudfType.getSizeInBytes.toLong + assert(typeSize > 0, s"Non-nested type is expected, but got $colCudfType") + (c._1.getData, c._2 * typeSize, (c._3 - c._2) * typeSize) + } + } + val nonEmptyDataBufs = nonEmptyDataCols.map(getSlicedDataBuf) + val concatDataLen = nonEmptyDataBufs.map(_._3).sum + closeOnExcept(outBuf.slice(curGlobalPos, concatDataLen)) { destBuf => + curGlobalPos += concatDataLen + var destPos = 0L + // Just append the data buffer one by one + nonEmptyDataBufs.foreach { case (srcBuf, srcStart, srcLen) => + destBuf.copyFromHostBuffer(destPos, srcBuf, srcStart, srcLen) + destPos += srcLen + } + destBuf + } + } else { + null + } + } + + // 4) children + val concatNestedHcv = closeOnExcept(Seq(concatValidityBuf, concatOffsetBuf, concatDataBuf)) { + _ => + if (colCudfType.isNestedType) { + // All should have the same children number + val childrenNum = cols.head._1.getNumChildren + assert(childrenNum > 0, "Non empty children is expected") + (0 until childrenNum).safeMap { idx => + val sChildren = cols.map { case (c, start, end) => + val childView = c.getChildColumnView(idx) + if (childView.getType.hasOffsets) { + (childView, c.getStartListOffset(start).toInt, c.getEndListOffset(end -1).toInt) + } else { + (childView, start, end) + } + } + val (childCol, colLen) = concatSlicedColumns(sChildren, outBuf, curGlobalPos) + curGlobalPos += colLen + childCol + }.asInstanceOf[Seq[HostColumnVectorCore]].asJava + } else { + new java.util.ArrayList[HostColumnVectorCore]() + } + } + + val cudfHostColumn = new HostColumnVector( + colCudfType, concatRowsNum, java.util.Optional.of(nullCount), + concatDataBuf, concatValidityBuf, concatOffsetBuf, concatNestedHcv) + (cudfHostColumn, curGlobalPos - outOffset) + } + + private def setNullAt(validBuf: HostMemoryBuffer, rowId: Int): Unit = { + val bucket = rowId >> 3 // = (rowId / 8) + val curByte = validBuf.getByte(bucket) + val bitmask = (~(1 << (rowId & 0x7).toByte)) + validBuf.setByte(bucket, (curByte & bitmask).toByte) + } + + private def isNullAt(validBuf: HostMemoryBuffer, rowId: Int): Boolean = { + val b = validBuf.getByte(rowId >> 3) // = (rowI / 8) + val ret = b & (1 << (rowId & 0x7).toByte) + ret == 0 + } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index f52c3b5f334..27f3cd2c3bc 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -1755,6 +1755,31 @@ val SHUFFLE_COMPRESSION_LZ4_CHUNK_SIZE = conf("spark.rapids.shuffle.compression. .integerConf .createWithDefault(20) + val SHUFFLE_WRITER_COALESCE_ENABLED = conf("spark.rapids.shuffle.writer.coalesce.enabled") + .doc("when false, disable the small batches coalescing for shuffle write that slicing" + + " batches on CPU.") + .internal() + .booleanConf + .createWithDefault(true) + + val SHUFFLE_WRITER_COALESCE_MIN_PARTITION_SIZE = + conf("spark.rapids.shuffle.writer.coalesce.minPartitionSize") + .doc("The minimum partition size for the coalescing shuffle write. Batches" + + " of a partition will be coalesced until the total size goes beyond this size," + + " then push the coalesced partition data down to the shuffle writer for" + + " serialization.") + .internal() + .bytesConf(ByteUnit.BYTE) + .createWithDefault(5 * 1024 * 1024) // 5MB + + val SHUFFLE_WRITER_COALESCE_TOTAL_PARTITIONS_SIZE = + conf("spark.rapids.shuffle.writer.coalesce.totalPartitionsSize") + .doc("The total size for all the tasks to cache the batches for coalescing" + + " when doing the shuffle write") + .internal() + .bytesConf(ByteUnit.BYTE) + .createWithDefault(10 * 1024 * 1024 * 1024) // 10GB + // ALLUXIO CONFIGS val ALLUXIO_MASTER = conf("spark.rapids.alluxio.master") .doc("The Alluxio master hostname. If not set, read Alluxio master URL from " + @@ -2751,6 +2776,14 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val shuffleMultiThreadedReaderThreads: Int = get(SHUFFLE_MULTITHREADED_READER_THREADS) + lazy val isShuffleWriteCoalesceEnabled: Boolean = get(SHUFFLE_WRITER_COALESCE_ENABLED) + + lazy val shuffleWriteCoalesceMinPartSize: Long = + get(SHUFFLE_WRITER_COALESCE_MIN_PARTITION_SIZE) + + lazy val shuffleWriteCoalesceTotalPartsSize: Long = + get(SHUFFLE_WRITER_COALESCE_TOTAL_PARTITIONS_SIZE) + def isUCXShuffleManagerMode: Boolean = RapidsShuffleManagerMode .withName(get(SHUFFLE_MANAGER_MODE)) == RapidsShuffleManagerMode.UCX diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExecBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExecBase.scala index 5323fc89019..97e4bc2d7f7 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExecBase.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExecBase.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,16 +17,21 @@ package org.apache.spark.sql.rapids.execution import scala.collection.AbstractIterator +import scala.collection.mutable import scala.concurrent.Future +import ai.rapids.cudf.NvtxColor import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.shims.{GpuHashPartitioning, GpuRangePartitioning, ShimUnaryExecNode, ShuffleOriginUtil, SparkShimImpl} import org.apache.spark.{MapOutputStatistics, ShuffleDependency} +import org.apache.spark.internal.Logging import org.apache.spark.rapids.shims.GpuShuffleExchangeExec import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, SortOrder} import org.apache.spark.sql.catalyst.plans.physical.RoundRobinPartitioning @@ -209,7 +214,9 @@ abstract class GpuShuffleExchangeExecBase( "rapidsShuffleWriteIoTime" -> createNanoTimingMetric(DEBUG_LEVEL,"rs. shuffle write io time"), "rapidsShuffleReadTime" -> - createNanoTimingMetric(ESSENTIAL_LEVEL,"rs. shuffle read time") + createNanoTimingMetric(ESSENTIAL_LEVEL,"rs. shuffle read time"), + "rapidsShuffleWriteCoalesceTime" -> + createNanoTimingMetric(MODERATE_LEVEL,"rs. shuffle write coalesce time") ) ++ GpuMetric.wrap(readMetrics) ++ GpuMetric.wrap(writeMetrics) // Spark doesn't report totalTime for this operator so we override metrics @@ -273,7 +280,7 @@ abstract class GpuShuffleExchangeExecBase( } } -object GpuShuffleExchangeExecBase { +object GpuShuffleExchangeExecBase extends Logging { def prepareBatchShuffleDependency( rdd: RDD[ColumnarBatch], outputAttributes: Seq[Attribute], @@ -312,76 +319,16 @@ object GpuShuffleExchangeExecBase { } else { rdd } - val partitioner: GpuExpression = getPartitioner(newRdd, outputAttributes, newPartitioning) - def getPartitioned: ColumnarBatch => Any = { - batch => partitioner.columnarEvalAny(batch) - } - val rddWithPartitionIds: RDD[Product2[Int, ColumnarBatch]] = { - newRdd.mapPartitions { iter => - val getParts = getPartitioned - new AbstractIterator[Product2[Int, ColumnarBatch]] { - private var partitioned : Array[(ColumnarBatch, Int)] = _ - private var at = 0 - private val mutablePair = new MutablePair[Int, ColumnarBatch]() - private def partNextBatch(): Unit = { - if (partitioned != null) { - partitioned.map(_._1).safeClose() - partitioned = null - at = 0 - } - if (iter.hasNext) { - var batch = iter.next() - while (batch.numRows == 0 && iter.hasNext) { - batch.close() - batch = iter.next() - } - // Get a non-empty batch or the last batch. So still need to - // check if it is empty for the later case. - if (batch.numRows > 0) { - partitioned = getParts(batch).asInstanceOf[Array[(ColumnarBatch, Int)]] - partitioned.foreach(batches => { - metrics(GpuMetric.NUM_OUTPUT_ROWS) += batches._1.numRows() - }) - metrics(GpuMetric.NUM_OUTPUT_BATCHES) += partitioned.length - at = 0 - } else { - batch.close() - } - } - } - - override def hasNext: Boolean = { - if (partitioned == null || at >= partitioned.length) { - partNextBatch() - } - - partitioned != null && at < partitioned.length - } - - override def next(): Product2[Int, ColumnarBatch] = { - if (partitioned == null || at >= partitioned.length) { - partNextBatch() - } - if (partitioned == null || at >= partitioned.length) { - throw new NoSuchElementException("Walked off of the end...") - } - val tup = partitioned(at) - mutablePair.update(tup._2, tup._1) - at += 1 - mutablePair - } - } - } - } + val partitioner = getPartitioner(newRdd, outputAttributes, newPartitioning) + val f = makeShuffleIteratorFunc(useGPUShuffle, partitioner, metrics) // Now, we manually create a GpuShuffleDependency. Because pairs in rddWithPartitionIds // are in the form of (partitionId, row) and every partitionId is in the expected range // [0, part.numPartitions - 1]. The partitioner of this is a PartitionIdPassthrough. // We do a GPU version because it allows us to know that the data is on the GPU so we can // detect it and do further processing if needed. - val dependency = - new GpuShuffleDependency[Int, ColumnarBatch, ColumnarBatch]( - rddWithPartitionIds, + val dependency = new GpuShuffleDependency[Int, ColumnarBatch, ColumnarBatch]( + newRdd.mapPartitions(f), new BatchPartitionIdPassthrough(newPartitioning.numPartitions), sparkTypes, serializer, @@ -413,4 +360,183 @@ object GpuShuffleExchangeExecBase { case _ => sys.error(s"Exchange not implemented for $newPartitioning") } } + + private def makeShuffleIteratorFunc( + useGPUShuffle: Boolean, + partitioner: GpuExpression, + metrics: Map[String, GpuMetric] + ): Iterator[ColumnarBatch] => Iterator[Product2[Int, ColumnarBatch]] = { + val getPartitioned: ColumnarBatch => Array[(ColumnarBatch, Int)] = { + batch => partitioner.columnarEvalAny(batch).asInstanceOf[Array[(ColumnarBatch, Int)]] + } + val sqlConf = SQLConf.get + val rapidsConf = new RapidsConf(sqlConf) + val coalShuffleWriteEnabled = rapidsConf.isShuffleWriteCoalesceEnabled + // Coalesce write for GPU shuffle is not supported yet + if (!useGPUShuffle && coalShuffleWriteEnabled) { + makeCoalesceShuffleIterator(getPartitioned, metrics, + rapidsConf.shuffleWriteCoalesceMinPartSize, + rapidsConf.shuffleWriteCoalesceTotalPartsSize, + sqlConf.numShufflePartitions) + } else { + makePassThroughShuffleIterator(getPartitioned, metrics) + } + } + + private def makePassThroughShuffleIterator( + getParts: ColumnarBatch => Array[(ColumnarBatch, Int)], + metrics: Map[String, GpuMetric] + ): Iterator[ColumnarBatch] => Iterator[Product2[Int, ColumnarBatch]] = { + iter => new AbstractIterator[Product2[Int, ColumnarBatch]] { + private var partitioned: Array[(ColumnarBatch, Int)] = _ + private var at = 0 + private val mutablePair = new MutablePair[Int, ColumnarBatch]() + + private def partNextBatch(): Unit = { + if (partitioned != null) { + partitioned.map(_._1).safeClose() + partitioned = null + at = 0 + } + if (iter.hasNext) { + var batch = iter.next() + while (batch.numRows == 0 && iter.hasNext) { + batch.close() + batch = iter.next() + } + // Get a non-empty batch or the last batch. So still need to + // check if it is empty for the later case. + if (batch.numRows > 0) { + partitioned = getParts(batch) + metrics(GpuMetric.NUM_OUTPUT_ROWS) += partitioned.map(_._1.numRows()).sum + metrics(GpuMetric.NUM_OUTPUT_BATCHES) += partitioned.length + at = 0 + } else { + batch.close() + } + } + } + + override def hasNext: Boolean = { + if (partitioned == null || at >= partitioned.length) { + partNextBatch() + } + + partitioned != null && at < partitioned.length + } + + override def next(): Product2[Int, ColumnarBatch] = { + if (partitioned == null || at >= partitioned.length) { + partNextBatch() + } + if (partitioned == null || at >= partitioned.length) { + throw new NoSuchElementException("Walked off of the end...") + } + val tup = partitioned(at) + mutablePair.update(tup._2, tup._1) + at += 1 + mutablePair + } + } + } + + private def makeCoalesceShuffleIterator( + getParts: ColumnarBatch => Array[(ColumnarBatch, Int)], + metrics: Map[String, GpuMetric], + confMinPartSize: Long, + confMaxTotalPartSize: Long, + shufflePartsNum: Int + ): Iterator[ColumnarBatch] => Iterator[Product2[Int, ColumnarBatch]] = { + val outputBatchesNum = metrics(GpuMetric.NUM_OUTPUT_BATCHES) + val outputRowsNum = metrics(GpuMetric.NUM_OUTPUT_ROWS) + val concatTime = metrics("rapidsShuffleWriteCoalesceTime") + val parallelism = SparkSession.active.leafNodeDefaultParallelism + val partSizePerCore = confMaxTotalPartSize / (parallelism * shufflePartsNum) + // At least 32KB + val finalMinPartSize = Math.max(Math.min(confMinPartSize, partSizePerCore), 32 * 1024) + logInfo(s"==> Creating CoalesceShuffleIterator(min partition size: " + + s"$finalMinPartSize for shuffle write...") + + iter => new AbstractIterator[Product2[Int, ColumnarBatch]] { + type PartBuffers = mutable.ArrayBuffer[ColumnarBatch] + private val partitioned = new mutable.LinkedHashMap[Int, (Long, PartBuffers)]() + private val readyParts = new mutable.Queue[(Int, Long, PartBuffers)]() + private var retPair: Option[(Int, ColumnarBatch)] = None + ScalableTaskCompletion.onTaskCompletion { + retPair.foreach(_._2.safeClose()) + (partitioned.values.flatMap(_._2) ++ readyParts.flatMap(_._3)).toSeq.safeClose() + } + private var retConsumed = true + + private def partNextBatch(): Unit = { + var batch = iter.next() + while (batch.numRows == 0 && iter.hasNext) { + batch.close() + batch = iter.next() + } + // Get a non-empty batch or the last batch. So still need to + // check if it is empty for the later case. + if (batch.numRows > 0) { + val parts = getParts(batch) + withResource(parts.map(_._1).toSeq) { _ => + parts.foreach { case (cb, partId) => + val (accSize, buf) = partitioned.remove(partId).getOrElse((0L, new PartBuffers)) + // For now only CPU shuffle is supported, so "SlicedGpuColumnVector" is expected. + val totalSize = accSize + SlicedGpuColumnVector.getTotalHostMemoryUsed(cb) + buf += SlicedGpuColumnVector.incRefCount(cb) + if (totalSize >= finalMinPartSize) { + // Part is ready to write + readyParts.enqueue((partId, totalSize, buf)) + } else { + partitioned.put(partId, (totalSize, buf)) + } + } + } + } else { + batch.close() + } + } + + private def determineNextPart(): Option[(Int, ColumnarBatch)] = { + while (readyParts.isEmpty && iter.hasNext) { + partNextBatch() + } + (if (readyParts.nonEmpty) { // Get one from ready buffer + Some(readyParts.dequeue()) + } else { + partitioned.headOption.map { case (partId, (size, cbs)) => + partitioned.remove(partId) + (partId, size, cbs) + } + }).map { case (partId, cbsSize, batches) => + val merged = withResource(new NvtxWithMetrics( + "Shuffle Write Concat", NvtxColor.BLUE, concatTime)) { _ => + GpuBatchUtils.concatShuffleBatchesAndClose(batches.toSeq, Some(cbsSize)) + } + closeOnExcept(merged) { _ => + outputBatchesNum += 1 + outputRowsNum += merged.numRows() + (partId, merged) + } + } + } + + override def hasNext: Boolean = { + if (retConsumed) { + val pair = determineNextPart() + retPair.foreach(_._2.close()) + retPair = pair + retConsumed = false + } + retPair.isDefined + } + + override def next(): Product2[Int, ColumnarBatch] = { + if (!hasNext) throw new NoSuchElementException("Walked off of the end...") + val ret = retPair.get + retConsumed = true + ret + } + } + } }