diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandler.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandler.scala index 16d4ac739f9..baedd0edd00 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandler.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandler.scala @@ -20,7 +20,7 @@ import java.io.File import java.lang.management.ManagementFactory import java.util.concurrent.atomic.AtomicLong -import ai.rapids.cudf.{NvtxColor, NvtxRange, Rmm, RmmEventHandler} +import ai.rapids.cudf.{Cuda, NvtxColor, NvtxRange, Rmm, RmmEventHandler} import com.sun.management.HotSpotDiagnosticMXBean import org.apache.spark.internal.Logging @@ -30,53 +30,129 @@ import org.apache.spark.sql.rapids.execution.TrampolineUtil * RMM event handler to trigger spilling from the device memory store. * @param store device memory store that will be triggered to spill * @param oomDumpDir local directory to create heap dumps on GPU OOM + * @param isGdsSpillEnabled true if GDS is enabled for device->disk spill + * @param maxFailedOOMRetries maximum number of retries for OOMs after + * depleting the device store */ class DeviceMemoryEventHandler( store: RapidsDeviceMemoryStore, oomDumpDir: Option[String], - isGdsSpillEnabled: Boolean) extends RmmEventHandler with Logging { + isGdsSpillEnabled: Boolean, + maxFailedOOMRetries: Int) extends RmmEventHandler with Logging with Arm { // Flag that ensures we dump stack traces once and not for every allocation // failure. The assumption is that unhandled allocations will be fatal // to the process at this stage, so we only need this once before we exit. private var dumpStackTracesOnFailureToHandleOOM = true + // Thread local used to the number of times a sync has been attempted, when + // handling OOMs after depleting the device store. + private val oomRetryState = + ThreadLocal.withInitial[OOMRetryState](() => new OOMRetryState) + + /** + * A small helper class that helps keep track of retry counts as we trigger + * synchronizes on a depleted store. + */ + class OOMRetryState { + private var synchronizeAttempts = 0 + private var retryCountLastSynced = 0 + + def getRetriesSoFar: Int = synchronizeAttempts + + private def reset(): Unit = { + synchronizeAttempts = 0 + retryCountLastSynced = 0 + } + + /** + * If we have synchronized less times than `maxFailedOOMRetries` we allow + * this retry to proceed, and track the `retryCount` provided by cuDF. If we + * are above our counter, we reset our state. + */ + def shouldTrySynchronizing(retryCount: Int): Boolean = { + if (synchronizeAttempts < maxFailedOOMRetries) { + retryCountLastSynced = retryCount + synchronizeAttempts += 1 + true + } else { + reset() + false + } + } + + /** + * We reset our counters if `storeSize` is non-zero (as we only track when the store is + * depleted), or `retryCount` is less than or equal to what we had previously + * recorded in `shouldTrySynchronizing`. We do this to detect that the new failures + * are for a separate allocation (we need to give this new attempt a new set of + * retries.) + * + * For example, if an allocation fails and we deplete the store, `retryCountLastSynced` + * is set to the last `retryCount` sent to us by cuDF as we keep allowing retries + * from cuDF. If we succeed, cuDF resets `retryCount`, and so the new count sent to us + * must be <= than what we saw last, so we can reset our tracking. + */ + def resetIfNeeded(retryCount: Int, storeSize: Long): Unit = { + if (storeSize != 0 || retryCount <= retryCountLastSynced) { + reset() + } + } + } + /** * Handles RMM allocation failures by spilling buffers from device memory. * @param allocSize the byte amount that RMM failed to allocate + * @param retryCount the number of times this allocation has been retried after failure * @return true if allocation should be reattempted or false if it should fail */ - override def onAllocFailure(allocSize: Long): Boolean = { + override def onAllocFailure(allocSize: Long, retryCount: Int): Boolean = { try { - val nvtx = new NvtxRange("onAllocFailure", NvtxColor.RED) - try { + withResource(new NvtxRange("onAllocFailure", NvtxColor.RED)) { _ => val storeSize = store.currentSize + val attemptMsg = if (retryCount > 0) { + s"Attempt ${retryCount}. " + } else { + "First attempt. " + } + + val retryState = oomRetryState.get() + retryState.resetIfNeeded(retryCount, storeSize) + logInfo(s"Device allocation of $allocSize bytes failed, device store has " + - s"$storeSize bytes. Total RMM allocated is ${Rmm.getTotalBytesAllocated} bytes.") + s"$storeSize bytes. $attemptMsg" + + s"Total RMM allocated is ${Rmm.getTotalBytesAllocated} bytes. ") if (storeSize == 0) { - logWarning(s"Device store exhausted, unable to allocate $allocSize bytes. " + + if (retryState.shouldTrySynchronizing(retryCount)) { + Cuda.deviceSynchronize() + logWarning(s"[RETRY ${retryState.getRetriesSoFar}] " + + s"Retrying allocation of $allocSize after a synchronize. " + + s"Total RMM allocated is ${Rmm.getTotalBytesAllocated} bytes.") + true + } else { + logWarning(s"Device store exhausted, unable to allocate $allocSize bytes. " + s"Total RMM allocated is ${Rmm.getTotalBytesAllocated} bytes.") - synchronized { - if (dumpStackTracesOnFailureToHandleOOM) { - dumpStackTracesOnFailureToHandleOOM = false - GpuSemaphore.dumpActiveStackTracesToLog() + synchronized { + if (dumpStackTracesOnFailureToHandleOOM) { + dumpStackTracesOnFailureToHandleOOM = false + GpuSemaphore.dumpActiveStackTracesToLog() + } } + oomDumpDir.foreach(heapDump) + false } - oomDumpDir.foreach(heapDump) - return false - } - val targetSize = Math.max(storeSize - allocSize, 0) - logDebug(s"Targeting device store size of $targetSize bytes") - val amountSpilled = store.synchronousSpill(targetSize) - logInfo(s"Spilled $amountSpilled bytes from the device store") - if (isGdsSpillEnabled) { - TrampolineUtil.incTaskMetricsDiskBytesSpilled(amountSpilled) } else { - TrampolineUtil.incTaskMetricsMemoryBytesSpilled(amountSpilled) + val targetSize = Math.max(storeSize - allocSize, 0) + logDebug(s"Targeting device store size of $targetSize bytes") + val amountSpilled = store.synchronousSpill(targetSize) + logInfo(s"Spilled $amountSpilled bytes from the device store") + if (isGdsSpillEnabled) { + TrampolineUtil.incTaskMetricsDiskBytesSpilled(amountSpilled) + } else { + TrampolineUtil.incTaskMetricsMemoryBytesSpilled(amountSpilled) + } + true } - true - } finally { - nvtx.close() } } catch { case t: Throwable => diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala index 39f98e9685f..ffe5178f866 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala @@ -202,7 +202,10 @@ object RapidsBufferCatalog extends Logging with Arm { logInfo("Installing GPU memory handler for spill") memoryEventHandler = new DeviceMemoryEventHandler( - deviceStorage, rapidsConf.gpuOomDumpDir, rapidsConf.isGdsSpillEnabled) + deviceStorage, + rapidsConf.gpuOomDumpDir, + rapidsConf.isGdsSpillEnabled, + rapidsConf.gpuOomMaxRetries) Rmm.setEventHandler(memoryEventHandler) _shouldUnspill = rapidsConf.isUnspillEnabled diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index cfefdc83125..5b664fa4801 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -339,6 +339,16 @@ object RapidsConf { .stringConf .createOptional + val GPU_OOM_MAX_RETRIES = + conf("spark.rapids.memory.gpu.oomMaxRetries") + .doc("The number of times that an OOM will be re-attempted after the device store " + + "can't spill anymore. In practice, we can use Cuda.deviceSynchronize to allow temporary " + + "state in the allocator and in the various streams to catch up, in hopes we can satisfy " + + "an allocation which was failing due to the interim state of memory.") + .internal() + .integerConf + .createWithDefault(2) + private val RMM_ALLOC_MAX_FRACTION_KEY = "spark.rapids.memory.gpu.maxAllocFraction" private val RMM_ALLOC_MIN_FRACTION_KEY = "spark.rapids.memory.gpu.minAllocFraction" private val RMM_ALLOC_RESERVE_KEY = "spark.rapids.memory.gpu.reserve" @@ -1773,6 +1783,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val gpuOomDumpDir: Option[String] = get(GPU_OOM_DUMP_DIR) + lazy val gpuOomMaxRetries: Int = get(GPU_OOM_MAX_RETRIES) + lazy val isUvmEnabled: Boolean = get(UVM_ENABLED) lazy val isPooledMemEnabled: Boolean = get(POOLED_MEM) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandlerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandlerSuite.scala new file mode 100644 index 00000000000..f49ae74421a --- /dev/null +++ b/tests/src/test/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandlerSuite.scala @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito.when +import org.scalatest.FunSuite +import org.scalatest.mockito.MockitoSugar + +class DeviceMemoryEventHandlerSuite extends FunSuite with MockitoSugar { + + test("a failed allocation should be retried if we spilled enough") { + val mockStore = mock[RapidsDeviceMemoryStore] + when(mockStore.currentSize).thenReturn(1024) + when(mockStore.synchronousSpill(any())).thenReturn(1024) + val handler = new DeviceMemoryEventHandler(mockStore, None, false, 2) + assertResult(true)(handler.onAllocFailure(1024, 0)) + } + + test("when we deplete the store, retry up to max failed OOM retries") { + val mockStore = mock[RapidsDeviceMemoryStore] + when(mockStore.currentSize).thenReturn(0) + when(mockStore.synchronousSpill(any())).thenReturn(0) + val handler = new DeviceMemoryEventHandler(mockStore, None, false, 2) + assertResult(true)(handler.onAllocFailure(1024, 0)) // sync + assertResult(true)(handler.onAllocFailure(1024, 1)) // sync 2 + assertResult(false)(handler.onAllocFailure(1024, 2)) // cuDF would OOM here + } + + test("we reset our OOM state after a successful retry") { + val mockStore = mock[RapidsDeviceMemoryStore] + when(mockStore.currentSize).thenReturn(0) + when(mockStore.synchronousSpill(any())).thenReturn(0) + val handler = new DeviceMemoryEventHandler(mockStore, None, false, 2) + // with this call we sync, and we mark our attempts at 1, we store 0 as the last count + assertResult(true)(handler.onAllocFailure(1024, 0)) + // this retryCount is still 0, we should be back at 1 for attempts + assertResult(true)(handler.onAllocFailure(1024, 0)) + assertResult(true)(handler.onAllocFailure(1024, 1)) + assertResult(false)(handler.onAllocFailure(1024, 2)) // cuDF would OOM here + } +}