Skip to content

Commit

Permalink
Cleanup async state when multi-threaded shuffle readers fail (NVIDIA#…
Browse files Browse the repository at this point in the history
…10637)

* Cleanup async state when multi-threaded shuffle readers fail

Signed-off-by: Alessandro Bellina <[email protected]>

---------

Signed-off-by: Alessandro Bellina <[email protected]>
  • Loading branch information
abellina authored Apr 1, 2024
1 parent c28c7fa commit b14b01e
Show file tree
Hide file tree
Showing 2 changed files with 296 additions and 148 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids

import java.io.{File, FileInputStream}
import java.util.Optional
import java.util.concurrent.{Callable, ConcurrentHashMap, ExecutionException, Executors, Future, LinkedBlockingQueue}
import java.util.concurrent.{Callable, ConcurrentHashMap, ExecutionException, Executors, Future, LinkedBlockingQueue, TimeUnit}
import java.util.concurrent.atomic.{AtomicInteger, AtomicLong}

import scala.collection
Expand All @@ -28,6 +28,7 @@ import scala.collection.mutable.ListBuffer
import ai.rapids.cudf.{NvtxColor, NvtxRange}
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion
import com.nvidia.spark.rapids.format.TableMeta
import com.nvidia.spark.rapids.shuffle.{RapidsShuffleRequestHandler, RapidsShuffleServer, RapidsShuffleTransport}
Expand Down Expand Up @@ -644,40 +645,107 @@ abstract class RapidsShuffleThreadedReaderBase[K, C](
private val futures = new mutable.Queue[Future[Option[BlockState]]]()
private val serializerInstance = serializer.newInstance()
private val limiter = new BytesInFlightLimiter(maxBytesInFlight)
private val fallbackIter: Iterator[(Any, Any)] = if (numReaderThreads == 1) {
// this is the non-optimized case, where we add metrics to capture the blocked
// time and the deserialization time as part of the shuffle read time.
new Iterator[(Any, Any)]() {
private var currentIter: Iterator[(Any, Any)] = _
override def hasNext: Boolean = fetcherIterator.hasNext || (
currentIter != null && currentIter.hasNext)

override def next(): (Any, Any) = {
val fetchTimeStart = System.nanoTime()
var readBlockedTime = 0L
if (currentIter == null || !currentIter.hasNext) {
val readBlockedStart = System.nanoTime()
val (_, stream) = fetcherIterator.next()
readBlockedTime = System.nanoTime() - readBlockedStart
currentIter = serializerInstance.deserializeStream(stream).asKeyValueIterator
private val fallbackIter: Iterator[(Any, Any)] with AutoCloseable =
if (numReaderThreads == 1) {
// this is the non-optimized case, where we add metrics to capture the blocked
// time and the deserialization time as part of the shuffle read time.
new Iterator[(Any, Any)]() with AutoCloseable {
private var currentIter: Iterator[(Any, Any)] = _
private var currentStream: AutoCloseable = _
override def hasNext: Boolean = fetcherIterator.hasNext || (
currentIter != null && currentIter.hasNext)

override def close(): Unit = {
if (currentStream != null) {
currentStream.close()
currentStream = null
}
}

override def next(): (Any, Any) = {
val fetchTimeStart = System.nanoTime()
var readBlockedTime = 0L
if (currentIter == null || !currentIter.hasNext) {
val readBlockedStart = System.nanoTime()
val (_, stream) = fetcherIterator.next()
readBlockedTime = System.nanoTime() - readBlockedStart
// this is stored only to call close on it
currentStream = stream
currentIter = serializerInstance.deserializeStream(stream).asKeyValueIterator
}
val res = currentIter.next()
val fetchTime = System.nanoTime() - fetchTimeStart
deserializationTimeNs.foreach(_ += (fetchTime - readBlockedTime))
shuffleReadTimeNs.foreach(_ += fetchTime)
res
}
val res = currentIter.next()
val fetchTime = System.nanoTime() - fetchTimeStart
deserializationTimeNs.foreach(_ += (fetchTime - readBlockedTime))
shuffleReadTimeNs.foreach(_ += fetchTime)
res
}
} else {
null
}
} else {
null
}

// Register a completion handler to close any queued cbs.
// Register a completion handler to close any queued cbs,
// pending iterators, or futures
onTaskCompletion(context) {
// remove any materialized batches
queued.forEach {
case (_, cb:ColumnarBatch) => cb.close()
}
queued.clear()

// close any materialized BlockState objects that are holding onto netty buffers or
// file descriptors
pendingIts.safeClose()
pendingIts.clear()

// we could have futures left that are either done or in flight
// we need to cancel them and then close out any `BlockState`
// objects that were created (to remove netty buffers or file descriptors)
val futuresAndCancellations = futures.map { f =>
val didCancel = f.cancel(true)
(f, didCancel)
}

// if we weren't able to cancel, we are going to make a best attempt at getting the future
// and we are going to close it. The timeout is to prevent an (unlikely) infinite wait.
// If we do timeout then this handler is going to throw.
var failedFuture: Option[Throwable] = None
futuresAndCancellations
.filter { case (_, didCancel) => !didCancel }
.foreach { case (future, _) =>
try {
// this could either be a successful future, or it finished with exception
// the case when it will fail with exception is when the underlying stream is closed
// as part of the shutdown process of the task.
future.get(10, TimeUnit.MILLISECONDS)
.foreach(_.close())
} catch {
case t: Throwable =>
// this is going to capture the first exception and not worry about others
// because we probably don't want to spam the UI or log with an exception per
// block we are fetching
if (failedFuture.isEmpty) {
failedFuture = Some(t)
}
}
}
futures.clear()
try {
if (fallbackIter != null) {
fallbackIter.close()
}
} catch {
case t: Throwable =>
if (failedFuture.isEmpty) {
failedFuture = Some(t)
} else {
failedFuture.get.addSuppressed(t)
}
} finally {
failedFuture.foreach { e =>
throw e
}
}
}

override def hasNext: Boolean = {
Expand All @@ -689,18 +757,50 @@ abstract class RapidsShuffleThreadedReaderBase[K, C](
}
}

case class BlockState(blockId: BlockId, batchIter: SerializedBatchIterator)
extends Iterator[(Any, Any)] {
private var nextBatchSize = batchIter.tryReadNextHeader().getOrElse(0L)
case class BlockState(
blockId: BlockId,
batchIter: SerializedBatchIterator,
origStream: AutoCloseable)
extends Iterator[(Any, Any)] with AutoCloseable {

private var nextBatchSize = {
var success = false
try {
val res = batchIter.tryReadNextHeader().getOrElse(0L)
success = true
res
} finally {
if (!success) {
// we tried to read from a stream, but something happened
// lets close it
close()
}
}
}

def getNextBatchSize: Long = nextBatchSize

override def hasNext: Boolean = batchIter.hasNext

override def next(): (Any, Any) = {
val nextBatch = batchIter.next()
nextBatchSize = batchIter.tryReadNextHeader().getOrElse(0L)
nextBatch
var success = false
try {
nextBatchSize = batchIter.tryReadNextHeader().getOrElse(0L)
success = true
nextBatch
} finally {
if (!success) {
// the call to get a next header threw. We need to close `nextBatch`.
nextBatch match {
case (_, cb: ColumnarBatch) => cb.close()
}
}
}
}

override def close(): Unit = {
origStream.close() // make sure we call this on error
}
}

Expand All @@ -723,7 +823,7 @@ abstract class RapidsShuffleThreadedReaderBase[K, C](
waitTime += System.nanoTime() - waitTimeStart
// if the future returned a block state, we have more work to do
pending match {
case Some(leftOver@BlockState(_, _)) =>
case Some(leftOver@BlockState(_, _, _)) =>
pendingIts.enqueue(leftOver)
case _ => // done
}
Expand Down Expand Up @@ -771,19 +871,27 @@ abstract class RapidsShuffleThreadedReaderBase[K, C](
private def deserializeTask(blockState: BlockState): Unit = {
val slot = RapidsShuffleInternalManagerBase.getNextReaderSlot
futures += RapidsShuffleInternalManagerBase.queueReadTask(slot, () => {
var currentBatchSize = blockState.getNextBatchSize
var didFit = true
while (blockState.hasNext && didFit) {
val batch = blockState.next()
queued.offer(batch)
// peek at the next batch
currentBatchSize = blockState.getNextBatchSize
didFit = limiter.acquire(currentBatchSize)
}
if (!didFit) {
Some(blockState)
} else {
None // no further batches
var success = false
try {
var currentBatchSize = blockState.getNextBatchSize
var didFit = true
while (blockState.hasNext && didFit) {
val batch = blockState.next()
queued.offer(batch)
// peek at the next batch
currentBatchSize = blockState.getNextBatchSize
didFit = limiter.acquire(currentBatchSize)
}
success = true
if (!didFit) {
Some(blockState)
} else {
None // no further batches
}
} finally {
if (!success) {
blockState.close()
}
}
})
}
Expand Down Expand Up @@ -830,7 +938,7 @@ abstract class RapidsShuffleThreadedReaderBase[K, C](

val deserStream = serializerInstance.deserializeStream(inputStream)
val batchIter = deserStream.asKeyValueIterator.asInstanceOf[SerializedBatchIterator]
val blockState = BlockState(blockId, batchIter)
val blockState = BlockState(blockId, batchIter, inputStream)
// get the next known batch size (there could be multiple batches)
if (limiter.acquire(blockState.getNextBatchSize)) {
// we can fit at least the first batch in this block
Expand Down
Loading

0 comments on commit b14b01e

Please sign in to comment.