Skip to content

Commit

Permalink
Add a timeout for RapidsShuffleIterator to prevent jobs to hang infin… (
Browse files Browse the repository at this point in the history
NVIDIA#732)

* Add a timeout for RapidsShuffleIterator to prevent jobs to hang infinitely in cases where the transport doesn't report an error

Signed-off-by: Alessandro Bellina <[email protected]>

* Read the timeout configuration once, and store as a val

Signed-off-by: Alessandro Bellina <[email protected]>

* Use spark.network.timeout as our fetch timeout setting

Signed-off-by: Alessandro Bellina <[email protected]>

* No really, dont change SparkEnv

Signed-off-by: Alessandro Bellina <[email protected]>

* Pass the timeout, to make it easier to mock in tests
  • Loading branch information
abellina authored and JustPlay committed Sep 13, 2020
1 parent a852f32 commit e2adf57
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,6 @@ object RapidsConf {
.booleanConf
.createWithDefault(true)

// USER FACING SHUFFLE CONFIGS
val SHUFFLE_TRANSPORT_ENABLE = conf("spark.rapids.shuffle.transport.enabled")
.doc("When set to true, enable the Rapids Shuffle Transport for accelerated shuffle.")
.booleanConf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@

package com.nvidia.spark.rapids.shuffle

import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.{LinkedBlockingQueue, TimeUnit}

import scala.collection.mutable

import ai.rapids.cudf.{NvtxColor, NvtxRange}
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.{GpuSemaphore, RapidsBuffer, RapidsConf, ShuffleReceivedBufferCatalog, ShuffleReceivedBufferId}

import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.shuffle.RapidsShuffleFetchFailedException
import org.apache.spark.shuffle.{RapidsShuffleFetchFailedException, RapidsShuffleTimeoutException}
import org.apache.spark.sql.rapids.{GpuShuffleEnv, ShuffleMetricsUpdater}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockBatchId, ShuffleBlockId}
Expand All @@ -42,14 +42,17 @@ import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockBatchId, S
* @param blocksByAddress blocks to fetch
* @param metricsUpdater instance of `ShuffleMetricsUpdater` to update the Spark
* shuffle metrics
* @param timeoutSeconds a timeout in seconds, that the iterator will wait while polling
* for batches
*/
class RapidsShuffleIterator(
localBlockManagerId: BlockManagerId,
rapidsConf: RapidsConf,
transport: RapidsShuffleTransport,
blocksByAddress: Array[(BlockManagerId, Seq[(BlockId, Long, Int)])],
metricsUpdater: ShuffleMetricsUpdater,
catalog: ShuffleReceivedBufferCatalog = GpuShuffleEnv.getReceivedCatalog)
catalog: ShuffleReceivedBufferCatalog = GpuShuffleEnv.getReceivedCatalog,
timeoutSeconds: Long = GpuShuffleEnv.shuffleFetchTimeoutSeconds)
extends Iterator[ColumnarBatch]
with Logging {

Expand Down Expand Up @@ -80,6 +83,8 @@ class RapidsShuffleIterator(
mapIndex: Int,
errorMessage: String) extends ShuffleClientResult

// when batches (or errors) arrive from the transport, the are pushed
// to the `resolvedBatches` queue.
private[this] val resolvedBatches = new LinkedBlockingQueue[ShuffleClientResult]()

// Used to track requests that are pending where the number of [[ColumnarBatch]] results is
Expand Down Expand Up @@ -277,6 +282,10 @@ class RapidsShuffleIterator(
//TODO: on task completion we currently don't ask clients to stop/clean resources
taskContext.foreach(_.addTaskCompletionListener[Unit](_ => receiveBufferCleaner()))

def pollForResult(timeoutSeconds: Long): Option[ShuffleClientResult] = {
Option(resolvedBatches.poll(timeoutSeconds, TimeUnit.SECONDS))
}

override def next(): ColumnarBatch = {
var cb: ColumnarBatch = null
var sb: RapidsBuffer = null
Expand Down Expand Up @@ -306,10 +315,12 @@ class RapidsShuffleIterator(
}

val blockedStart = System.currentTimeMillis()
val result = resolvedBatches.take()
var result: Option[ShuffleClientResult] = None

result = pollForResult(timeoutSeconds)
val blockedTime = System.currentTimeMillis() - blockedStart
result match {
case BufferReceived(bufferId) =>
case Some(BufferReceived(bufferId)) =>
val nvtxRangeAfterGettingBatch = new NvtxRange("RapidsShuffleIterator.gotBatch",
NvtxColor.PURPLE)
try {
Expand All @@ -324,8 +335,9 @@ class RapidsShuffleIterator(
}
catalog.removeBuffer(bufferId)
}
case TransferError(blockManagerId, shuffleBlockBatchId, mapIndex, errorMessage) =>
case Some(TransferError(blockManagerId, shuffleBlockBatchId, mapIndex, errorMessage)) =>
taskContext.foreach(GpuSemaphore.releaseIfNecessary)
metricsUpdater.update(blockedTime, 0, 0, 0)
val errorMsg = s"Transfer error detected by shuffle iterator, failing task. ${errorMessage}"
logError(errorMsg)
throw new RapidsShuffleFetchFailedException(
Expand All @@ -335,6 +347,16 @@ class RapidsShuffleIterator(
mapIndex,
shuffleBlockBatchId.startReduceId,
errorMsg)
case None =>
// NOTE: this isn't perfect, since what we really want is the transport to
// bubble this error, but for now we'll make this a fatal exception.
taskContext.foreach(GpuSemaphore.releaseIfNecessary)
metricsUpdater.update(blockedTime, 0, 0, 0)
val errMsg = s"Timed out after ${timeoutSeconds} seconds while waiting for a shuffle batch."
logError(errMsg)
throw new RapidsShuffleTimeoutException(errMsg)
case _ =>
throw new IllegalStateException(s"Invalid result type $result")
}
cb
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,5 @@ class RapidsShuffleFetchFailedException(
extends FetchFailedException(
bmAddress, shuffleId, mapId, mapIndex, reduceId, message) {
}

class RapidsShuffleTimeoutException(message: String) extends Exception(message)
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,6 @@ object GpuShuffleEnv extends Logging {
def getReceivedCatalog: ShuffleReceivedBufferCatalog = env.getReceivedCatalog

def rapidsShuffleCodec: Option[TableCompressionCodec] = env.rapidsShuffleCodec

def shuffleFetchTimeoutSeconds: Long = env.getShuffleFetchTimeoutSeconds
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ import org.mockito.ArgumentCaptor
import org.mockito.ArgumentMatchers._
import org.mockito.Mockito._

import org.apache.spark.TaskContext
import org.apache.spark.shuffle.RapidsShuffleFetchFailedException
import org.apache.spark.sql.rapids.ShuffleMetricsUpdater
import org.apache.spark.shuffle.{RapidsShuffleFetchFailedException, RapidsShuffleTimeoutException}
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}

class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper {
Expand All @@ -32,17 +30,23 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper {

val cl = new RapidsShuffleIterator(
RapidsShuffleTestHelper.makeMockBlockManager("1", "1"),
null,
mockConf,
mockTransport,
blocksByAddress,
null)
testMetricsUpdater,
mockCatalog,
123)

when(mockTransaction.getStatus).thenReturn(TransactionStatus.Error)

when(mockTransport.makeClient(any(), any())).thenThrow(new IllegalStateException("Test"))

assert(cl.hasNext)
assertThrows[RapidsShuffleFetchFailedException](cl.next())

// not invoked, since we never blocked
verify(testMetricsUpdater, times(0))
.update(any(), any(), any(), any())
}

test("a transport error/cancellation raises a fetch failure") {
Expand All @@ -51,12 +55,14 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper {

val blocksByAddress = RapidsShuffleTestHelper.getBlocksByAddress

val cl = new RapidsShuffleIterator(
val cl = spy(new RapidsShuffleIterator(
RapidsShuffleTestHelper.makeMockBlockManager("1", "1"),
null,
mockConf,
mockTransport,
blocksByAddress,
null)
testMetricsUpdater,
mockCatalog,
123))

val ac = ArgumentCaptor.forClass(classOf[RapidsShuffleFetchHandler])
when(mockTransport.makeClient(any(), any())).thenReturn(client)
Expand All @@ -69,22 +75,57 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper {
assert(cl.hasNext)
assertThrows[RapidsShuffleFetchFailedException](cl.next())

verify(testMetricsUpdater, times(1))
.update(any(), any(), any(), any())
assertResult(0)(testMetricsUpdater.totalRemoteBlocksFetched)
assertResult(0)(testMetricsUpdater.totalRemoteBytesRead)
assertResult(0)(testMetricsUpdater.totalRowsFetched)

newMocks()
}
}

test("a new good batch is queued") {
test("a timeout while waiting for batches raises a fetch failure") {
val blocksByAddress = RapidsShuffleTestHelper.getBlocksByAddress

val mockMetrics = mock[ShuffleMetricsUpdater]
val cl = spy(new RapidsShuffleIterator(
RapidsShuffleTestHelper.makeMockBlockManager("1", "1"),
mockConf,
mockTransport,
blocksByAddress,
testMetricsUpdater,
mockCatalog,
123))

val ac = ArgumentCaptor.forClass(classOf[RapidsShuffleFetchHandler])
when(mockTransport.makeClient(any(), any())).thenReturn(client)
doNothing().when(client).doFetch(any(), ac.capture(), any())

// signal a timeout to the iterator
when(cl.pollForResult(any())).thenReturn(None)

assertThrows[RapidsShuffleTimeoutException](cl.next())

verify(testMetricsUpdater, times(1))
.update(any(), any(), any(), any())
assertResult(0)(testMetricsUpdater.totalRemoteBlocksFetched)
assertResult(0)(testMetricsUpdater.totalRemoteBytesRead)
assertResult(0)(testMetricsUpdater.totalRowsFetched)

newMocks()
}

test("a new good batch is queued") {
val blocksByAddress = RapidsShuffleTestHelper.getBlocksByAddress

val cl = new RapidsShuffleIterator(
RapidsShuffleTestHelper.makeMockBlockManager("1", "1"),
null,
mockConf,
mockTransport,
blocksByAddress,
mockMetrics,
mockCatalog)
testMetricsUpdater,
mockCatalog,
123)

when(mockTransaction.getStatus).thenReturn(TransactionStatus.Error)

Expand All @@ -107,5 +148,10 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper {

assert(cl.hasNext)
assertResult(cb)(cl.next())
assertResult(1)(testMetricsUpdater.totalRemoteBlocksFetched)
assertResult(mockBuffer.size)(testMetricsUpdater.totalRemoteBytesRead)
assertResult(10)(testMetricsUpdater.totalRowsFetched)

newMocks()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,34 @@ package com.nvidia.spark.rapids.shuffle
import java.util.concurrent.Executor

import ai.rapids.cudf.{ColumnVector, ContiguousTable}
import com.nvidia.spark.rapids.{Arm, GpuColumnVector, MetaUtils, RapidsDeviceMemoryStore, ShuffleMetadata, ShuffleReceivedBufferCatalog}
import com.nvidia.spark.rapids.{Arm, GpuColumnVector, MetaUtils, RapidsConf, RapidsDeviceMemoryStore, ShuffleMetadata, ShuffleReceivedBufferCatalog}
import com.nvidia.spark.rapids.format.TableMeta
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito.{spy, when}
import org.scalatest.{BeforeAndAfterEach, FunSuite}
import org.scalatest.mockito.MockitoSugar
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.sql.rapids.ShuffleMetricsUpdater
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockBatchId}

class TestShuffleMetricsUpdater extends ShuffleMetricsUpdater {
var totalRemoteBlocksFetched = 0L
var totalRemoteBytesRead = 0L
var totalRowsFetched = 0L
override def update(
fetchWaitTimeInMs: Long,
remoteBlocksFetched: Long,
remoteBytesRead: Long,
rowsFetched: Long): Unit = {
totalRemoteBlocksFetched += remoteBlocksFetched
totalRemoteBytesRead += remoteBytesRead
totalRowsFetched += rowsFetched
}
}

class RapidsShuffleTestHelper extends FunSuite
with BeforeAndAfterEach
with MockitoSugar
Expand All @@ -42,6 +59,8 @@ class RapidsShuffleTestHelper extends FunSuite
var mockHandler: RapidsShuffleFetchHandler = _
var mockStorage: RapidsDeviceMemoryStore = _
var mockCatalog: ShuffleReceivedBufferCatalog = _
var mockConf: RapidsConf = _
var testMetricsUpdater: TestShuffleMetricsUpdater = _
var client: RapidsShuffleClient = _

override def beforeEach(): Unit = {
Expand All @@ -57,6 +76,9 @@ class RapidsShuffleTestHelper extends FunSuite
mockHandler = mock[RapidsShuffleFetchHandler]
mockStorage = mock[RapidsDeviceMemoryStore]
mockCatalog = mock[ShuffleReceivedBufferCatalog]
mockConf = mock[RapidsConf]
testMetricsUpdater = spy(new TestShuffleMetricsUpdater)

client = spy(new RapidsShuffleClient(
1,
mockConnection,
Expand Down

0 comments on commit e2adf57

Please sign in to comment.