diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index 2cc7ecd51517..1a9221ce3826 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -584,7 +584,6 @@ object RapidsConf { .booleanConf .createWithDefault(true) - // USER FACING SHUFFLE CONFIGS val SHUFFLE_TRANSPORT_ENABLE = conf("spark.rapids.shuffle.transport.enabled") .doc("When set to true, enable the Rapids Shuffle Transport for accelerated shuffle.") .booleanConf diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala index a81d98537ecf..96e8372550a1 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala @@ -16,16 +16,16 @@ package com.nvidia.spark.rapids.shuffle -import java.util.concurrent.LinkedBlockingQueue +import java.util.concurrent.{LinkedBlockingQueue, TimeUnit} import scala.collection.mutable import ai.rapids.cudf.{NvtxColor, NvtxRange} -import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.{GpuSemaphore, RapidsBuffer, RapidsConf, ShuffleReceivedBufferCatalog, ShuffleReceivedBufferId} import org.apache.spark.TaskContext import org.apache.spark.internal.Logging -import org.apache.spark.shuffle.RapidsShuffleFetchFailedException +import org.apache.spark.shuffle.{RapidsShuffleFetchFailedException, RapidsShuffleTimeoutException} import org.apache.spark.sql.rapids.{GpuShuffleEnv, ShuffleMetricsUpdater} import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockBatchId, ShuffleBlockId} @@ -42,6 +42,8 @@ import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockBatchId, S * @param blocksByAddress blocks to fetch * @param metricsUpdater instance of `ShuffleMetricsUpdater` to update the Spark * shuffle metrics + * @param timeoutSeconds a timeout in seconds, that the iterator will wait while polling + * for batches */ class RapidsShuffleIterator( localBlockManagerId: BlockManagerId, @@ -49,7 +51,8 @@ class RapidsShuffleIterator( transport: RapidsShuffleTransport, blocksByAddress: Array[(BlockManagerId, Seq[(BlockId, Long, Int)])], metricsUpdater: ShuffleMetricsUpdater, - catalog: ShuffleReceivedBufferCatalog = GpuShuffleEnv.getReceivedCatalog) + catalog: ShuffleReceivedBufferCatalog = GpuShuffleEnv.getReceivedCatalog, + timeoutSeconds: Long = GpuShuffleEnv.shuffleFetchTimeoutSeconds) extends Iterator[ColumnarBatch] with Logging { @@ -80,6 +83,8 @@ class RapidsShuffleIterator( mapIndex: Int, errorMessage: String) extends ShuffleClientResult + // when batches (or errors) arrive from the transport, the are pushed + // to the `resolvedBatches` queue. private[this] val resolvedBatches = new LinkedBlockingQueue[ShuffleClientResult]() // Used to track requests that are pending where the number of [[ColumnarBatch]] results is @@ -277,6 +282,10 @@ class RapidsShuffleIterator( //TODO: on task completion we currently don't ask clients to stop/clean resources taskContext.foreach(_.addTaskCompletionListener[Unit](_ => receiveBufferCleaner())) + def pollForResult(timeoutSeconds: Long): Option[ShuffleClientResult] = { + Option(resolvedBatches.poll(timeoutSeconds, TimeUnit.SECONDS)) + } + override def next(): ColumnarBatch = { var cb: ColumnarBatch = null var sb: RapidsBuffer = null @@ -306,10 +315,12 @@ class RapidsShuffleIterator( } val blockedStart = System.currentTimeMillis() - val result = resolvedBatches.take() + var result: Option[ShuffleClientResult] = None + + result = pollForResult(timeoutSeconds) val blockedTime = System.currentTimeMillis() - blockedStart result match { - case BufferReceived(bufferId) => + case Some(BufferReceived(bufferId)) => val nvtxRangeAfterGettingBatch = new NvtxRange("RapidsShuffleIterator.gotBatch", NvtxColor.PURPLE) try { @@ -324,8 +335,9 @@ class RapidsShuffleIterator( } catalog.removeBuffer(bufferId) } - case TransferError(blockManagerId, shuffleBlockBatchId, mapIndex, errorMessage) => + case Some(TransferError(blockManagerId, shuffleBlockBatchId, mapIndex, errorMessage)) => taskContext.foreach(GpuSemaphore.releaseIfNecessary) + metricsUpdater.update(blockedTime, 0, 0, 0) val errorMsg = s"Transfer error detected by shuffle iterator, failing task. ${errorMessage}" logError(errorMsg) throw new RapidsShuffleFetchFailedException( @@ -335,6 +347,16 @@ class RapidsShuffleIterator( mapIndex, shuffleBlockBatchId.startReduceId, errorMsg) + case None => + // NOTE: this isn't perfect, since what we really want is the transport to + // bubble this error, but for now we'll make this a fatal exception. + taskContext.foreach(GpuSemaphore.releaseIfNecessary) + metricsUpdater.update(blockedTime, 0, 0, 0) + val errMsg = s"Timed out after ${timeoutSeconds} seconds while waiting for a shuffle batch." + logError(errMsg) + throw new RapidsShuffleTimeoutException(errMsg) + case _ => + throw new IllegalStateException(s"Invalid result type $result") } cb } diff --git a/sql-plugin/src/main/scala/org/apache/spark/shuffle/RapidsShuffleFetchFailedException.scala b/sql-plugin/src/main/scala/org/apache/spark/shuffle/RapidsShuffleExceptions.scala similarity index 92% rename from sql-plugin/src/main/scala/org/apache/spark/shuffle/RapidsShuffleFetchFailedException.scala rename to sql-plugin/src/main/scala/org/apache/spark/shuffle/RapidsShuffleExceptions.scala index 5b6be5cb591f..c07947790113 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/shuffle/RapidsShuffleFetchFailedException.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/shuffle/RapidsShuffleExceptions.scala @@ -28,3 +28,5 @@ class RapidsShuffleFetchFailedException( extends FetchFailedException( bmAddress, shuffleId, mapId, mapIndex, reduceId, message) { } + +class RapidsShuffleTimeoutException(message: String) extends Exception(message) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala index 8ae9f5c0c8b2..33be7af56946 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala @@ -99,4 +99,6 @@ object GpuShuffleEnv extends Logging { def getReceivedCatalog: ShuffleReceivedBufferCatalog = env.getReceivedCatalog def rapidsShuffleCodec: Option[TableCompressionCodec] = env.rapidsShuffleCodec + + def shuffleFetchTimeoutSeconds: Long = env.getShuffleFetchTimeoutSeconds } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala index 93d980838987..cfff086c00b8 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala @@ -21,9 +21,7 @@ import org.mockito.ArgumentCaptor import org.mockito.ArgumentMatchers._ import org.mockito.Mockito._ -import org.apache.spark.TaskContext -import org.apache.spark.shuffle.RapidsShuffleFetchFailedException -import org.apache.spark.sql.rapids.ShuffleMetricsUpdater +import org.apache.spark.shuffle.{RapidsShuffleFetchFailedException, RapidsShuffleTimeoutException} import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper { @@ -32,10 +30,12 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper { val cl = new RapidsShuffleIterator( RapidsShuffleTestHelper.makeMockBlockManager("1", "1"), - null, + mockConf, mockTransport, blocksByAddress, - null) + testMetricsUpdater, + mockCatalog, + 123) when(mockTransaction.getStatus).thenReturn(TransactionStatus.Error) @@ -43,6 +43,10 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper { assert(cl.hasNext) assertThrows[RapidsShuffleFetchFailedException](cl.next()) + + // not invoked, since we never blocked + verify(testMetricsUpdater, times(0)) + .update(any(), any(), any(), any()) } test("a transport error/cancellation raises a fetch failure") { @@ -51,12 +55,14 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper { val blocksByAddress = RapidsShuffleTestHelper.getBlocksByAddress - val cl = new RapidsShuffleIterator( + val cl = spy(new RapidsShuffleIterator( RapidsShuffleTestHelper.makeMockBlockManager("1", "1"), - null, + mockConf, mockTransport, blocksByAddress, - null) + testMetricsUpdater, + mockCatalog, + 123)) val ac = ArgumentCaptor.forClass(classOf[RapidsShuffleFetchHandler]) when(mockTransport.makeClient(any(), any())).thenReturn(client) @@ -69,22 +75,57 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper { assert(cl.hasNext) assertThrows[RapidsShuffleFetchFailedException](cl.next()) + verify(testMetricsUpdater, times(1)) + .update(any(), any(), any(), any()) + assertResult(0)(testMetricsUpdater.totalRemoteBlocksFetched) + assertResult(0)(testMetricsUpdater.totalRemoteBytesRead) + assertResult(0)(testMetricsUpdater.totalRowsFetched) + newMocks() } } - test("a new good batch is queued") { + test("a timeout while waiting for batches raises a fetch failure") { val blocksByAddress = RapidsShuffleTestHelper.getBlocksByAddress - val mockMetrics = mock[ShuffleMetricsUpdater] + val cl = spy(new RapidsShuffleIterator( + RapidsShuffleTestHelper.makeMockBlockManager("1", "1"), + mockConf, + mockTransport, + blocksByAddress, + testMetricsUpdater, + mockCatalog, + 123)) + + val ac = ArgumentCaptor.forClass(classOf[RapidsShuffleFetchHandler]) + when(mockTransport.makeClient(any(), any())).thenReturn(client) + doNothing().when(client).doFetch(any(), ac.capture(), any()) + + // signal a timeout to the iterator + when(cl.pollForResult(any())).thenReturn(None) + + assertThrows[RapidsShuffleTimeoutException](cl.next()) + + verify(testMetricsUpdater, times(1)) + .update(any(), any(), any(), any()) + assertResult(0)(testMetricsUpdater.totalRemoteBlocksFetched) + assertResult(0)(testMetricsUpdater.totalRemoteBytesRead) + assertResult(0)(testMetricsUpdater.totalRowsFetched) + + newMocks() + } + + test("a new good batch is queued") { + val blocksByAddress = RapidsShuffleTestHelper.getBlocksByAddress val cl = new RapidsShuffleIterator( RapidsShuffleTestHelper.makeMockBlockManager("1", "1"), - null, + mockConf, mockTransport, blocksByAddress, - mockMetrics, - mockCatalog) + testMetricsUpdater, + mockCatalog, + 123) when(mockTransaction.getStatus).thenReturn(TransactionStatus.Error) @@ -107,5 +148,10 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper { assert(cl.hasNext) assertResult(cb)(cl.next()) + assertResult(1)(testMetricsUpdater.totalRemoteBlocksFetched) + assertResult(mockBuffer.size)(testMetricsUpdater.totalRemoteBytesRead) + assertResult(10)(testMetricsUpdater.totalRowsFetched) + + newMocks() } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala index 35c836555fc4..6f8af0123de0 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala @@ -19,7 +19,7 @@ package com.nvidia.spark.rapids.shuffle import java.util.concurrent.Executor import ai.rapids.cudf.{ColumnVector, ContiguousTable} -import com.nvidia.spark.rapids.{Arm, GpuColumnVector, MetaUtils, RapidsDeviceMemoryStore, ShuffleMetadata, ShuffleReceivedBufferCatalog} +import com.nvidia.spark.rapids.{Arm, GpuColumnVector, MetaUtils, RapidsConf, RapidsDeviceMemoryStore, ShuffleMetadata, ShuffleReceivedBufferCatalog} import com.nvidia.spark.rapids.format.TableMeta import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.{spy, when} @@ -27,9 +27,26 @@ import org.scalatest.{BeforeAndAfterEach, FunSuite} import org.scalatest.mockito.MockitoSugar import scala.collection.mutable.ArrayBuffer +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.sql.rapids.ShuffleMetricsUpdater import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockBatchId} +class TestShuffleMetricsUpdater extends ShuffleMetricsUpdater { + var totalRemoteBlocksFetched = 0L + var totalRemoteBytesRead = 0L + var totalRowsFetched = 0L + override def update( + fetchWaitTimeInMs: Long, + remoteBlocksFetched: Long, + remoteBytesRead: Long, + rowsFetched: Long): Unit = { + totalRemoteBlocksFetched += remoteBlocksFetched + totalRemoteBytesRead += remoteBytesRead + totalRowsFetched += rowsFetched + } +} + class RapidsShuffleTestHelper extends FunSuite with BeforeAndAfterEach with MockitoSugar @@ -42,6 +59,8 @@ class RapidsShuffleTestHelper extends FunSuite var mockHandler: RapidsShuffleFetchHandler = _ var mockStorage: RapidsDeviceMemoryStore = _ var mockCatalog: ShuffleReceivedBufferCatalog = _ + var mockConf: RapidsConf = _ + var testMetricsUpdater: TestShuffleMetricsUpdater = _ var client: RapidsShuffleClient = _ override def beforeEach(): Unit = { @@ -57,6 +76,9 @@ class RapidsShuffleTestHelper extends FunSuite mockHandler = mock[RapidsShuffleFetchHandler] mockStorage = mock[RapidsDeviceMemoryStore] mockCatalog = mock[ShuffleReceivedBufferCatalog] + mockConf = mock[RapidsConf] + testMetricsUpdater = spy(new TestShuffleMetricsUpdater) + client = spy(new RapidsShuffleClient( 1, mockConnection,