diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala index 80bfbf69c7e..8fee4144270 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids import java.io.{File, FileInputStream} import java.util.Optional -import java.util.concurrent.{Callable, ConcurrentHashMap, ExecutionException, Executors, Future, LinkedBlockingQueue} +import java.util.concurrent.{Callable, ConcurrentHashMap, ExecutionException, Executors, Future, LinkedBlockingQueue, TimeUnit} import java.util.concurrent.atomic.{AtomicInteger, AtomicLong} import scala.collection @@ -28,6 +28,7 @@ import scala.collection.mutable.ListBuffer import ai.rapids.cudf.{NvtxColor, NvtxRange} import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm.withResource +import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion import com.nvidia.spark.rapids.format.TableMeta import com.nvidia.spark.rapids.shuffle.{RapidsShuffleRequestHandler, RapidsShuffleServer, RapidsShuffleTransport} @@ -644,40 +645,107 @@ abstract class RapidsShuffleThreadedReaderBase[K, C]( private val futures = new mutable.Queue[Future[Option[BlockState]]]() private val serializerInstance = serializer.newInstance() private val limiter = new BytesInFlightLimiter(maxBytesInFlight) - private val fallbackIter: Iterator[(Any, Any)] = if (numReaderThreads == 1) { - // this is the non-optimized case, where we add metrics to capture the blocked - // time and the deserialization time as part of the shuffle read time. - new Iterator[(Any, Any)]() { - private var currentIter: Iterator[(Any, Any)] = _ - override def hasNext: Boolean = fetcherIterator.hasNext || ( - currentIter != null && currentIter.hasNext) - - override def next(): (Any, Any) = { - val fetchTimeStart = System.nanoTime() - var readBlockedTime = 0L - if (currentIter == null || !currentIter.hasNext) { - val readBlockedStart = System.nanoTime() - val (_, stream) = fetcherIterator.next() - readBlockedTime = System.nanoTime() - readBlockedStart - currentIter = serializerInstance.deserializeStream(stream).asKeyValueIterator + private val fallbackIter: Iterator[(Any, Any)] with AutoCloseable = + if (numReaderThreads == 1) { + // this is the non-optimized case, where we add metrics to capture the blocked + // time and the deserialization time as part of the shuffle read time. + new Iterator[(Any, Any)]() with AutoCloseable { + private var currentIter: Iterator[(Any, Any)] = _ + private var currentStream: AutoCloseable = _ + override def hasNext: Boolean = fetcherIterator.hasNext || ( + currentIter != null && currentIter.hasNext) + + override def close(): Unit = { + if (currentStream != null) { + currentStream.close() + currentStream = null + } + } + + override def next(): (Any, Any) = { + val fetchTimeStart = System.nanoTime() + var readBlockedTime = 0L + if (currentIter == null || !currentIter.hasNext) { + val readBlockedStart = System.nanoTime() + val (_, stream) = fetcherIterator.next() + readBlockedTime = System.nanoTime() - readBlockedStart + // this is stored only to call close on it + currentStream = stream + currentIter = serializerInstance.deserializeStream(stream).asKeyValueIterator + } + val res = currentIter.next() + val fetchTime = System.nanoTime() - fetchTimeStart + deserializationTimeNs.foreach(_ += (fetchTime - readBlockedTime)) + shuffleReadTimeNs.foreach(_ += fetchTime) + res } - val res = currentIter.next() - val fetchTime = System.nanoTime() - fetchTimeStart - deserializationTimeNs.foreach(_ += (fetchTime - readBlockedTime)) - shuffleReadTimeNs.foreach(_ += fetchTime) - res } + } else { + null } - } else { - null - } - // Register a completion handler to close any queued cbs. + // Register a completion handler to close any queued cbs, + // pending iterators, or futures onTaskCompletion(context) { + // remove any materialized batches queued.forEach { case (_, cb:ColumnarBatch) => cb.close() } queued.clear() + + // close any materialized BlockState objects that are holding onto netty buffers or + // file descriptors + pendingIts.safeClose() + pendingIts.clear() + + // we could have futures left that are either done or in flight + // we need to cancel them and then close out any `BlockState` + // objects that were created (to remove netty buffers or file descriptors) + val futuresAndCancellations = futures.map { f => + val didCancel = f.cancel(true) + (f, didCancel) + } + + // if we weren't able to cancel, we are going to make a best attempt at getting the future + // and we are going to close it. The timeout is to prevent an (unlikely) infinite wait. + // If we do timeout then this handler is going to throw. + var failedFuture: Option[Throwable] = None + futuresAndCancellations + .filter { case (_, didCancel) => !didCancel } + .foreach { case (future, _) => + try { + // this could either be a successful future, or it finished with exception + // the case when it will fail with exception is when the underlying stream is closed + // as part of the shutdown process of the task. + future.get(10, TimeUnit.MILLISECONDS) + .foreach(_.close()) + } catch { + case t: Throwable => + // this is going to capture the first exception and not worry about others + // because we probably don't want to spam the UI or log with an exception per + // block we are fetching + if (failedFuture.isEmpty) { + failedFuture = Some(t) + } + } + } + futures.clear() + try { + if (fallbackIter != null) { + fallbackIter.close() + } + } catch { + case t: Throwable => + if (failedFuture.isEmpty) { + failedFuture = Some(t) + } else { + failedFuture.get.addSuppressed(t) + } + } finally { + failedFuture.foreach { e => + throw e + } + } } override def hasNext: Boolean = { @@ -689,9 +757,26 @@ abstract class RapidsShuffleThreadedReaderBase[K, C]( } } - case class BlockState(blockId: BlockId, batchIter: SerializedBatchIterator) - extends Iterator[(Any, Any)] { - private var nextBatchSize = batchIter.tryReadNextHeader().getOrElse(0L) + case class BlockState( + blockId: BlockId, + batchIter: SerializedBatchIterator, + origStream: AutoCloseable) + extends Iterator[(Any, Any)] with AutoCloseable { + + private var nextBatchSize = { + var success = false + try { + val res = batchIter.tryReadNextHeader().getOrElse(0L) + success = true + res + } finally { + if (!success) { + // we tried to read from a stream, but something happened + // lets close it + close() + } + } + } def getNextBatchSize: Long = nextBatchSize @@ -699,8 +784,23 @@ abstract class RapidsShuffleThreadedReaderBase[K, C]( override def next(): (Any, Any) = { val nextBatch = batchIter.next() - nextBatchSize = batchIter.tryReadNextHeader().getOrElse(0L) - nextBatch + var success = false + try { + nextBatchSize = batchIter.tryReadNextHeader().getOrElse(0L) + success = true + nextBatch + } finally { + if (!success) { + // the call to get a next header threw. We need to close `nextBatch`. + nextBatch match { + case (_, cb: ColumnarBatch) => cb.close() + } + } + } + } + + override def close(): Unit = { + origStream.close() // make sure we call this on error } } @@ -723,7 +823,7 @@ abstract class RapidsShuffleThreadedReaderBase[K, C]( waitTime += System.nanoTime() - waitTimeStart // if the future returned a block state, we have more work to do pending match { - case Some(leftOver@BlockState(_, _)) => + case Some(leftOver@BlockState(_, _, _)) => pendingIts.enqueue(leftOver) case _ => // done } @@ -771,19 +871,27 @@ abstract class RapidsShuffleThreadedReaderBase[K, C]( private def deserializeTask(blockState: BlockState): Unit = { val slot = RapidsShuffleInternalManagerBase.getNextReaderSlot futures += RapidsShuffleInternalManagerBase.queueReadTask(slot, () => { - var currentBatchSize = blockState.getNextBatchSize - var didFit = true - while (blockState.hasNext && didFit) { - val batch = blockState.next() - queued.offer(batch) - // peek at the next batch - currentBatchSize = blockState.getNextBatchSize - didFit = limiter.acquire(currentBatchSize) - } - if (!didFit) { - Some(blockState) - } else { - None // no further batches + var success = false + try { + var currentBatchSize = blockState.getNextBatchSize + var didFit = true + while (blockState.hasNext && didFit) { + val batch = blockState.next() + queued.offer(batch) + // peek at the next batch + currentBatchSize = blockState.getNextBatchSize + didFit = limiter.acquire(currentBatchSize) + } + success = true + if (!didFit) { + Some(blockState) + } else { + None // no further batches + } + } finally { + if (!success) { + blockState.close() + } } }) } @@ -830,7 +938,7 @@ abstract class RapidsShuffleThreadedReaderBase[K, C]( val deserStream = serializerInstance.deserializeStream(inputStream) val batchIter = deserStream.asKeyValueIterator.asInstanceOf[SerializedBatchIterator] - val blockState = BlockState(blockId, batchIter) + val blockState = BlockState(blockId, batchIter, inputStream) // get the next known batch size (there could be multiple batches) if (limiter.acquire(blockState.getNextBatchSize)) { // we can fit at least the first batch in this block diff --git a/tests/src/test/spark321/scala/org/apache/spark/sql/rapids/RapidsShuffleThreadedReaderSuite.scala b/tests/src/test/spark321/scala/org/apache/spark/sql/rapids/RapidsShuffleThreadedReaderSuite.scala index b00e268d949..bedab81fc21 100644 --- a/tests/src/test/spark321/scala/org/apache/spark/sql/rapids/RapidsShuffleThreadedReaderSuite.scala +++ b/tests/src/test/spark321/scala/org/apache/spark/sql/rapids/RapidsShuffleThreadedReaderSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -47,6 +47,15 @@ import org.apache.spark.serializer.SerializerManager import org.apache.spark.sql.rapids.shims.RapidsShuffleThreadedReader import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} +class InjectedShuffleErrorInTests extends Exception { +} + +class ErrorInputStream(wrapped: InputStream) extends InputStream { + override def read(): Int = { + throw new InjectedShuffleErrorInTests + } +} + /** * * Code ported over from `BlockStoreShuffleReaderSuite` in Apache Spark. @@ -56,15 +65,23 @@ import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} * We need to define this class ourselves instead of using a spy because the NioManagedBuffer class * is final (final classes cannot be spied on). */ -class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends ManagedBuffer { +class RecordingManagedBuffer( + underlyingBuffer: NioManagedBuffer, + injectError: Boolean) extends ManagedBuffer { var callsToRetain = 0 var callsToRelease = 0 override def size(): Long = underlyingBuffer.size() override def nioByteBuffer(): ByteBuffer = underlyingBuffer.nioByteBuffer() - override def createInputStream(): InputStream = underlyingBuffer.createInputStream() + override def createInputStream(): InputStream = { + val is = underlyingBuffer.createInputStream() + if (injectError) { + new ErrorInputStream(is) + } else { + is + } + } override def convertToNetty(): AnyRef = underlyingBuffer.convertToNetty() - override def retain(): ManagedBuffer = { callsToRetain += 1 underlyingBuffer.retain() @@ -82,110 +99,133 @@ class RapidsShuffleThreadedReaderSuite RapidsShuffleInternalManagerBase.stopThreadPool() } - /** - * This test makes sure that, when data is read from a HashShuffleReader, the underlying - * ManagedBuffers that contain the data are eventually released. - */ - Seq(1, 2).foreach { numReaderThreads => - test(s"read() releases resources on completion - numThreads=$numReaderThreads") { - val testConf = new SparkConf(false) - // this sets the session and the SparkEnv - SparkSessionHolder.withSparkSession(testConf, _ => { - if (numReaderThreads > 1) { - RapidsShuffleInternalManagerBase.startThreadPoolIfNeeded(0, numReaderThreads) - } - - val reduceId = 15 - val shuffleId = 22 - val numMaps = 6 - val keyValuePairsPerMap = 10 - val serializer = new GpuColumnarBatchSerializer(NoopMetric) - - // Make a mock BlockManager that will return RecordingManagedByteBuffers of data, so that we - // can ensure retain() and release() are properly called. - val blockManager = mock(classOf[BlockManager]) - - // Create a buffer with some randomly generated key-value pairs to use as the shuffle data - // from each mappers (all mappers return the same shuffle data). - val byteOutputStream = new ByteArrayOutputStream() - val serializationStream = serializer.newInstance().serializeStream(byteOutputStream) - withResource(GpuColumnVector.emptyBatchFromTypes(Array.empty)) { emptyBatch => - (0 until keyValuePairsPerMap).foreach { i => - serializationStream.writeKey(i) - serializationStream.writeValue(GpuColumnVector.incRefCounts(emptyBatch)) - } + def runShuffleRead(numReaderThreads: Int, injectError: Boolean = false): Unit = { + val testConf = new SparkConf(false) + // this sets the session and the SparkEnv + SparkSessionHolder.withSparkSession(testConf, _ => { + if (numReaderThreads > 1) { + RapidsShuffleInternalManagerBase.startThreadPoolIfNeeded(0, numReaderThreads) + } + + val reduceId = 15 + val shuffleId = 22 + val numMaps = 6 + val keyValuePairsPerMap = 10 + val serializer = new GpuColumnarBatchSerializer(NoopMetric) + + // Make a mock BlockManager that will return RecordingManagedByteBuffers of data, so that we + // can ensure retain() and release() are properly called. + val blockManager = mock(classOf[BlockManager]) + + // Create a buffer with some randomly generated key-value pairs to use as the shuffle data + // from each mappers (all mappers return the same shuffle data). + val byteOutputStream = new ByteArrayOutputStream() + val serializationStream = serializer.newInstance().serializeStream(byteOutputStream) + withResource(GpuColumnVector.emptyBatchFromTypes(Array.empty)) { emptyBatch => + (0 until keyValuePairsPerMap).foreach { i => + serializationStream.writeKey(i) + serializationStream.writeValue(GpuColumnVector.incRefCounts(emptyBatch)) } - - // Setup the mocked BlockManager to return RecordingManagedBuffers. - val localBlockManagerId = BlockManagerId("test-client", "test-client", 1) - when(blockManager.blockManagerId).thenReturn(localBlockManagerId) - val buffers = (0 until numMaps).map { mapId => - // Create a ManagedBuffer with the shuffle data. - val nioBuffer = new NioManagedBuffer(ByteBuffer.wrap(byteOutputStream.toByteArray)) - val managedBuffer = new RecordingManagedBuffer(nioBuffer) - - // Setup the blockManager mock so the buffer gets returned when the shuffle code tries to - // fetch shuffle data. + } + + // Setup the mocked BlockManager to return RecordingManagedBuffers. + val localBlockManagerId = BlockManagerId("test-client", "test-client", 1) + when(blockManager.blockManagerId).thenReturn(localBlockManagerId) + val buffers = (0 until numMaps).map { mapId => + // Create a ManagedBuffer with the shuffle data. + val nioBuffer = new NioManagedBuffer(ByteBuffer.wrap(byteOutputStream.toByteArray)) + val managedBuffer = new RecordingManagedBuffer(nioBuffer, injectError) + + // Setup the blockManager mock so the buffer gets returned when the shuffle code tries to + // fetch shuffle data. + val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) + when(blockManager.getLocalBlockData(meq(shuffleBlockId))).thenReturn(managedBuffer) + managedBuffer + } + + // Make a mocked MapOutputTracker for the shuffle reader to use to determine what + // shuffle data to read. + val mapOutputTracker = mock(classOf[MapOutputTracker]) + when(mapOutputTracker.getMapSizesByExecutorId( + shuffleId, 0, numMaps, reduceId, reduceId + 1)).thenReturn { + // Test a scenario where all data is local, to avoid creating a bunch of additional mocks + // for the code to read data over the network. + val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) - when(blockManager.getLocalBlockData(meq(shuffleBlockId))).thenReturn(managedBuffer) - managedBuffer + (shuffleBlockId, byteOutputStream.size().toLong, mapId) } - - // Make a mocked MapOutputTracker for the shuffle reader to use to determine what - // shuffle data to read. - val mapOutputTracker = mock(classOf[MapOutputTracker]) - when(mapOutputTracker.getMapSizesByExecutorId( - shuffleId, 0, numMaps, reduceId, reduceId + 1)).thenReturn { - // Test a scenario where all data is local, to avoid creating a bunch of additional mocks - // for the code to read data over the network. - val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => - val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) - (shuffleBlockId, byteOutputStream.size().toLong, mapId) + Seq((localBlockManagerId, shuffleBlockIdsAndSizes)).iterator + } + + // Create a mocked shuffle handle to pass into HashShuffleReader. + val shuffleHandle = { + val dependency = mock(classOf[GpuShuffleDependency[Int, Int, Int]]) + when(dependency.serializer).thenReturn(serializer) + when(dependency.aggregator).thenReturn(None) + when(dependency.keyOrdering).thenReturn(None) + new ShuffleHandleWithMetrics[Int, Int, Int]( + shuffleId, Map.empty, dependency) + } + + val serializerManager = new SerializerManager( + serializer, + new SparkConf() + .set(config.SHUFFLE_COMPRESS, false) + .set(config.SHUFFLE_SPILL_COMPRESS, false)) + + val taskContext = TaskContext.empty() + val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() + val shuffleReader = new RapidsShuffleThreadedReader[Int, Int]( + 0, + numMaps, + reduceId, + reduceId + 1, + shuffleHandle, + taskContext, + metrics, + 1024 * 1024, + serializerManager, + blockManager, + mapOutputTracker = mapOutputTracker, + numReaderThreads = numReaderThreads) + + if (injectError) { + var e: Throwable = null + assertThrows[InjectedShuffleErrorInTests] { + try { + shuffleReader.read().length + } catch { + case t: Throwable => + e = t + throw t } - Seq((localBlockManagerId, shuffleBlockIdsAndSizes)).iterator - } - - // Create a mocked shuffle handle to pass into HashShuffleReader. - val shuffleHandle = { - val dependency = mock(classOf[GpuShuffleDependency[Int, Int, Int]]) - when(dependency.serializer).thenReturn(serializer) - when(dependency.aggregator).thenReturn(None) - when(dependency.keyOrdering).thenReturn(None) - new ShuffleHandleWithMetrics[Int, Int, Int]( - shuffleId, Map.empty, dependency) } - - val serializerManager = new SerializerManager( - serializer, - new SparkConf() - .set(config.SHUFFLE_COMPRESS, false) - .set(config.SHUFFLE_SPILL_COMPRESS, false)) - - val taskContext = TaskContext.empty() - val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() - val shuffleReader = new RapidsShuffleThreadedReader[Int, Int]( - 0, - numMaps, - reduceId, - reduceId + 1, - shuffleHandle, - taskContext, - metrics, - 1024 * 1024, - serializerManager, - blockManager, - mapOutputTracker = mapOutputTracker, - numReaderThreads = numReaderThreads) - + taskContext.markTaskCompleted(Some(e)) + } else { assert(shuffleReader.read().length === keyValuePairsPerMap * numMaps) + taskContext.markTaskCompleted(None) + } + + // Calling .length above will have exhausted the iterator; make sure that exhausting the + // iterator caused retain and release to be called on each buffer. + buffers.foreach { buffer => + assert(buffer.callsToRetain === 1) + assert(buffer.callsToRelease === 1) + } + }) + } - // Calling .length above will have exhausted the iterator; make sure that exhausting the - // iterator caused retain and release to be called on each buffer. - buffers.foreach { buffer => - assert(buffer.callsToRetain === 1) - assert(buffer.callsToRelease === 1) - } - }) + /** + * This test makes sure that, when data is read from a HashShuffleReader, the underlying + * ManagedBuffers that contain the data are eventually released. + */ + Seq(1, 2).foreach { numReaderThreads => + test(s"read() releases resources on completion - numThreads=$numReaderThreads") { + runShuffleRead(numReaderThreads) + } + + test(s"read() releases resources on error - numThreads=$numReaderThreads") { + runShuffleRead(numReaderThreads, injectError = true) } } }