Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cuda.deviceSynchronize as a last resort if we cannot spill enough #6849

Merged
merged 4 commits into from
Oct 20, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import java.io.File
import java.lang.management.ManagementFactory
import java.util.concurrent.atomic.AtomicLong

import ai.rapids.cudf.{NvtxColor, NvtxRange, Rmm, RmmEventHandler}
import ai.rapids.cudf.{Cuda, NvtxColor, NvtxRange, Rmm, RmmEventHandler}
import com.sun.management.HotSpotDiagnosticMXBean

import org.apache.spark.internal.Logging
Expand All @@ -30,53 +30,79 @@ import org.apache.spark.sql.rapids.execution.TrampolineUtil
* RMM event handler to trigger spilling from the device memory store.
* @param store device memory store that will be triggered to spill
* @param oomDumpDir local directory to create heap dumps on GPU OOM
* @param isGdsSpillEnabled true if GDS is enabled for device->disk spill
* @param maxFailedOOMRetries maximum number of retries for OOMs after
* depleting the device store
*/
class DeviceMemoryEventHandler(
store: RapidsDeviceMemoryStore,
oomDumpDir: Option[String],
isGdsSpillEnabled: Boolean) extends RmmEventHandler with Logging {
isGdsSpillEnabled: Boolean,
maxFailedOOMRetries: Int) extends RmmEventHandler with Logging with Arm {

// Flag that ensures we dump stack traces once and not for every allocation
// failure. The assumption is that unhandled allocations will be fatal
// to the process at this stage, so we only need this once before we exit.
private var dumpStackTracesOnFailureToHandleOOM = true

// Thread local used to the number of times a sync has been attempted, when
// handling OOMs after depleting the device store.
private val synchronizeAttempts = ThreadLocal.withInitial[Int](() => 0)

/**
* Handles RMM allocation failures by spilling buffers from device memory.
* @param allocSize the byte amount that RMM failed to allocate
* @param retryCount the number of times this allocation has been retried after failure
* @return true if allocation should be reattempted or false if it should fail
*/
override def onAllocFailure(allocSize: Long): Boolean = {
override def onAllocFailure(allocSize: Long, retryCount: Int): Boolean = {
try {
val nvtx = new NvtxRange("onAllocFailure", NvtxColor.RED)
try {
withResource(new NvtxRange("onAllocFailure", NvtxColor.RED)) { _ =>
val storeSize = store.currentSize
val attemptMsg = if (retryCount > 0) {
s"Attempt ${retryCount}. "
} else {
"First attempt. "
}
logInfo(s"Device allocation of $allocSize bytes failed, device store has " +
s"$storeSize bytes. Total RMM allocated is ${Rmm.getTotalBytesAllocated} bytes.")
s"$storeSize bytes. $attemptMsg" +
s"Total RMM allocated is ${Rmm.getTotalBytesAllocated} bytes. ")
if (storeSize == 0) {
jlowe marked this conversation as resolved.
Show resolved Hide resolved
logWarning(s"Device store exhausted, unable to allocate $allocSize bytes. " +
var syncAttempt = synchronizeAttempts.get()
if (syncAttempt <= maxFailedOOMRetries) {
syncAttempt = syncAttempt + 1
synchronizeAttempts.set(syncAttempt)
jlowe marked this conversation as resolved.
Show resolved Hide resolved
Cuda.deviceSynchronize()
logWarning(s"[RETRY ${syncAttempt}] " +
s"Retrying allocation of $allocSize after a synchronize. " +
s"Total RMM allocated is ${Rmm.getTotalBytesAllocated} bytes.")
synchronized {
if (dumpStackTracesOnFailureToHandleOOM) {
dumpStackTracesOnFailureToHandleOOM = false
GpuSemaphore.dumpActiveStackTracesToLog()
true
} else {
synchronizeAttempts.set(0)
logWarning(s"Device store exhausted, unable to allocate $allocSize bytes. " +
s"Total RMM allocated is ${Rmm.getTotalBytesAllocated} bytes.")
synchronized {
if (dumpStackTracesOnFailureToHandleOOM) {
dumpStackTracesOnFailureToHandleOOM = false
GpuSemaphore.dumpActiveStackTracesToLog()
}
}
oomDumpDir.foreach(heapDump)
false
}
oomDumpDir.foreach(heapDump)
return false
}
val targetSize = Math.max(storeSize - allocSize, 0)
logDebug(s"Targeting device store size of $targetSize bytes")
val amountSpilled = store.synchronousSpill(targetSize)
logInfo(s"Spilled $amountSpilled bytes from the device store")
if (isGdsSpillEnabled) {
TrampolineUtil.incTaskMetricsDiskBytesSpilled(amountSpilled)
} else {
TrampolineUtil.incTaskMetricsMemoryBytesSpilled(amountSpilled)
synchronizeAttempts.set(0)
val targetSize = Math.max(storeSize - allocSize, 0)
logDebug(s"Targeting device store size of $targetSize bytes")
val amountSpilled = store.synchronousSpill(targetSize)
logInfo(s"Spilled $amountSpilled bytes from the device store")
if (isGdsSpillEnabled) {
TrampolineUtil.incTaskMetricsDiskBytesSpilled(amountSpilled)
} else {
TrampolineUtil.incTaskMetricsMemoryBytesSpilled(amountSpilled)
}
true
}
true
} finally {
nvtx.close()
}
} catch {
case t: Throwable =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,10 @@ object RapidsBufferCatalog extends Logging with Arm {

logInfo("Installing GPU memory handler for spill")
memoryEventHandler = new DeviceMemoryEventHandler(
deviceStorage, rapidsConf.gpuOomDumpDir, rapidsConf.isGdsSpillEnabled)
deviceStorage,
rapidsConf.gpuOomDumpDir,
rapidsConf.isGdsSpillEnabled,
rapidsConf.gpuOomMaxRetries)
Rmm.setEventHandler(memoryEventHandler)

_shouldUnspill = rapidsConf.isUnspillEnabled
Expand Down
12 changes: 12 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,16 @@ object RapidsConf {
.stringConf
.createOptional

val GPU_OOM_MAX_RETRIES =
conf("spark.rapids.memory.gpu.oomMaxRetries")
.doc("The number of times that an OOM will be re-attempted after the device store " +
"can't spill anymore. In practice, we can use Cuda.deviceSynchronize to allow temporary " +
"state in the allocator and in the various streams to catch up, in hopes we can satisfy " +
"an allocation which was failing due to the interim state of memory.")
.internal()
.integerConf
.createWithDefault(2)

private val RMM_ALLOC_MAX_FRACTION_KEY = "spark.rapids.memory.gpu.maxAllocFraction"
private val RMM_ALLOC_MIN_FRACTION_KEY = "spark.rapids.memory.gpu.minAllocFraction"
private val RMM_ALLOC_RESERVE_KEY = "spark.rapids.memory.gpu.reserve"
Expand Down Expand Up @@ -1773,6 +1783,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging {

lazy val gpuOomDumpDir: Option[String] = get(GPU_OOM_DUMP_DIR)

lazy val gpuOomMaxRetries: Int = get(GPU_OOM_MAX_RETRIES)

lazy val isUvmEnabled: Boolean = get(UVM_ENABLED)

lazy val isPooledMemEnabled: Boolean = get(POOLED_MEM)
Expand Down