Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add in support for OOM retry #7822

Merged
merged 3 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ Name | Description | Default Value | Applicable at
<a name="memory.gpu.pool"></a>spark.rapids.memory.gpu.pool|Select the RMM pooling allocator to use. Valid values are "DEFAULT", "ARENA", "ASYNC", and "NONE". With "DEFAULT", the RMM pool allocator is used; with "ARENA", the RMM arena allocator is used; with "ASYNC", the new CUDA stream-ordered memory allocator in CUDA 11.2+ is used. If set to "NONE", pooling is disabled and RMM just passes through to CUDA memory allocation directly.|ASYNC|Startup
<a name="memory.gpu.pooling.enabled"></a>spark.rapids.memory.gpu.pooling.enabled|Should RMM act as a pooling allocator for GPU memory, or should it just pass through to CUDA memory allocation directly. DEPRECATED: please use spark.rapids.memory.gpu.pool instead.|true|Startup
<a name="memory.gpu.reserve"></a>spark.rapids.memory.gpu.reserve|The amount of GPU memory that should remain unallocated by RMM and left for system use such as memory needed for kernels and kernel launches.|671088640|Startup
<a name="memory.gpu.state.debug"></a>spark.rapids.memory.gpu.state.debug|To better recover from out of memory errors, RMM will track several states for the threads that interact with the GPU. This provides a log of those state transitions to aid in debugging it. STDOUT or STDERR will have the logging go there empty string will disable logging and anything else will be treated as a file to write the logs to.||Startup
<a name="memory.gpu.unspill.enabled"></a>spark.rapids.memory.gpu.unspill.enabled|When a spilled GPU buffer is needed again, should it be unspilled, or only copied back into GPU memory temporarily. Unspilling may be useful for GPU buffers that are needed frequently, for example, broadcast variables; however, it may also increase GPU memory usage|false|Startup
<a name="memory.host.pageablePool.size"></a>spark.rapids.memory.host.pageablePool.size|The size of the pageable memory pool in bytes unless otherwise specified. Use 0 to disable the pool.|1073741824|Startup
<a name="memory.host.spillStorageSize"></a>spark.rapids.memory.host.spillStorageSize|Amount of off-heap host memory to use for buffering spilled GPU data before spilling to local disk. Use -1 to set the amount to the combined size of pinned and pageable memory pools.|-1|Startup
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,6 +28,7 @@ import ai.rapids.cudf.{BaseDeviceMemoryBuffer, MemoryBuffer, NvtxColor, NvtxRang
import com.nvidia.spark.rapids.{Arm, GpuDeviceManager, RapidsConf}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.ThreadFactoryBuilder
import com.nvidia.spark.rapids.jni.RmmSpark
import com.nvidia.spark.rapids.shuffle.{ClientConnection, MemoryRegistrationCallback, MessageType, MetadataTransportBuffer, TransportBuffer, TransportUtils}
import org.openucx.jucx._
import org.openucx.jucx.ucp._
Expand Down Expand Up @@ -105,7 +106,9 @@ class UCX(transport: UCXShuffleTransport, executor: BlockManagerId, rapidsConf:
new ThreadFactoryBuilder()
.setNameFormat("progress-thread-%d")
.setDaemon(true)
.build))
.build,
() => RmmSpark.associateCurrentThreadWithShuffle(),
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved
() => RmmSpark.removeCurrentThreadAssociation()))

// The pending queues are used to enqueue [[PendingReceive]] or [[PendingSend]], from executor
// task threads and [[progressThread]] will hand them to the UcpWorker thread.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.collection.mutable.ArrayBuffer
import ai.rapids.cudf.{BaseDeviceMemoryBuffer, CudaMemoryBuffer, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer}
import com.nvidia.spark.rapids.{GpuDeviceManager, HashedPriorityQueue, RapidsConf}
import com.nvidia.spark.rapids.ThreadFactoryBuilder
import com.nvidia.spark.rapids.jni.RmmSpark
import com.nvidia.spark.rapids.shuffle._
import com.nvidia.spark.rapids.shuffle.{BounceBufferManager, BufferReceiveState, ClientConnection, PendingTransferRequest, RapidsShuffleClient, RapidsShuffleRequestHandler, RapidsShuffleServer, RapidsShuffleTransport, RefCountedDirectByteBuffer}

Expand Down Expand Up @@ -248,7 +249,9 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon
new ThreadFactoryBuilder()
.setNameFormat("shuffle-transport-client-exec-%d")
.setDaemon(true)
.build),
.build,
() => RmmSpark.associateCurrentThreadWithShuffle(),
() => RmmSpark.removeCurrentThreadAssociation()),
// if we can't hand off because we are too busy, block the caller (in UCX's case,
// the progress thread)
new CallerRunsAndLogs())
Expand All @@ -258,7 +261,9 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon
GpuDeviceManager.wrapThreadFactory(new ThreadFactoryBuilder()
.setNameFormat("shuffle-client-copy-thread-%d")
.setDaemon(true)
.build))
.build,
() => RmmSpark.associateCurrentThreadWithShuffle(),
() => RmmSpark.removeCurrentThreadAssociation()))

override def makeClient(blockManagerId: BlockManagerId): RapidsShuffleClient = {
val peerExecutorId = blockManagerId.executorId.toLong
Expand All @@ -280,14 +285,18 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon
GpuDeviceManager.wrapThreadFactory(new ThreadFactoryBuilder()
.setNameFormat(s"shuffle-server-conn-thread-${shuffleServerId.executorId}-%d")
.setDaemon(true)
.build))
.build,
() => RmmSpark.associateCurrentThreadWithShuffle(),
() => RmmSpark.removeCurrentThreadAssociation()))

// This executor handles any task that would block (e.g. wait for spill synchronously due to OOM)
private[this] val serverCopyExecutor = Executors.newSingleThreadExecutor(
GpuDeviceManager.wrapThreadFactory(new ThreadFactoryBuilder()
.setNameFormat(s"shuffle-server-copy-thread-%d")
.setDaemon(true)
.build))
.build,
() => RmmSpark.associateCurrentThreadWithShuffle(),
() => RmmSpark.removeCurrentThreadAssociation()))

// This is used to queue up on the server all the [[BufferSendState]] as the server waits for
// bounce buffers to become available (it is the equivalent of the transport's throttle, minus
Expand All @@ -296,7 +305,9 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon
GpuDeviceManager.wrapThreadFactory(new ThreadFactoryBuilder()
.setNameFormat(s"shuffle-server-bss-thread-%d")
.setDaemon(true)
.build))
.build,
() => RmmSpark.associateCurrentThreadWithShuffle(),
() => RmmSpark.removeCurrentThreadAssociation()))

/**
* Construct a server instance
Expand Down Expand Up @@ -356,7 +367,9 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon
new ThreadFactoryBuilder()
.setNameFormat(s"shuffle-transport-throttle-monitor")
.setDaemon(true)
.build))
.build,
() => RmmSpark.associateCurrentThreadWithShuffle(),
() => RmmSpark.removeCurrentThreadAssociation()))

// helper class to hold transfer requests that have a bounce buffer
// and should be ready to be handled by a `BufferReceiveState`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,15 +356,26 @@ object GpuDeviceManager extends Logging {
}

/** Wrap a thread factory with one that will set the GPU device on each thread created. */
def wrapThreadFactory(factory: ThreadFactory): ThreadFactory = new ThreadFactory() {
def wrapThreadFactory(factory: ThreadFactory,
before: () => Unit = null,
after: () => Unit = null): ThreadFactory = new ThreadFactory() {
private[this] val devId = getDeviceId.getOrElse {
throw new IllegalStateException("Device ID is not set")
}

override def newThread(runnable: Runnable): Thread = {
factory.newThread(() => {
Cuda.setDevice(devId)
runnable.run()
try {
if (before != null) {
before()
}
Comment on lines +370 to +372
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add logging here. An exception from before()/after() might be difficult to contextualize since it in a different thread.

Comment on lines +370 to +372
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider

Suggested change
if (before != null) {
before()
}
Option(before).foreach(_.apply())

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is more functional. I get that. This is not performance critical code, but it is replacing a check and a branch, probably 3 or 4 instructions with calling a static method to create an object that then calls a method on that object with a function that is probably a separate class that had to be created, possibly as a singleton.

I personally prefer the null check, but if for consistency with other code styles we want the functional one liner I am fine with it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am ok with your preference. Performance considerations are irrelevant here. Thanks for considering the suggestion.

I just realized that we probably need neither version of the null check if you make the default parameter value a nop () => () instead of null

runnable.run()
} finally {
if (after != null) {
after()
}
}
})
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.util.concurrent.{ConcurrentHashMap, Semaphore}
import scala.collection.mutable

import ai.rapids.cudf.{NvtxColor, NvtxRange}
import com.nvidia.spark.rapids.jni.RmmSpark
import org.apache.commons.lang3.mutable.MutableInt

import org.apache.spark.TaskContext
Expand Down Expand Up @@ -132,6 +133,7 @@ private final class GpuSemaphore() extends Logging with Arm {
}
logDebug(s"Task $taskAttemptId acquiring GPU with $permits permits")
semaphore.acquire(permits)
RmmSpark.associateCurrentThreadWithTask(taskAttemptId)
if (refs != null) {
refs.count.increment()
} else {
Expand All @@ -142,13 +144,17 @@ private final class GpuSemaphore() extends Logging with Arm {
context.addTaskCompletionListener[Unit](completeTask)
}
GpuDeviceManager.initializeFromTask()
} else {
// Already had the semaphore, but we don't know if the thread is new or not
RmmSpark.associateCurrentThreadWithTask(taskAttemptId)
}
}
}

def releaseIfNecessary(context: TaskContext): Unit = {
val nvtxRange = new NvtxRange("Release GPU", NvtxColor.RED)
try {
RmmSpark.removeCurrentThreadAssociation()
val taskAttemptId = context.taskAttemptId()
val refs = activeTasks.get(taskAttemptId)
if (refs != null && refs.count.getValue > 0) {
Expand All @@ -164,6 +170,7 @@ private final class GpuSemaphore() extends Logging with Arm {

def completeTask(context: TaskContext): Unit = {
val taskAttemptId = context.taskAttemptId()
RmmSpark.taskDone(taskAttemptId)
val refs = activeTasks.remove(taskAttemptId)
if (refs == null) {
throw new IllegalStateException(s"Completion of unknown task $taskAttemptId")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import com.nvidia.spark.rapids.RapidsBufferCatalog.getExistingRapidsBufferAndAcq
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.StorageTier.StorageTier
import com.nvidia.spark.rapids.format.TableMeta
import com.nvidia.spark.rapids.jni.RmmSpark

import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -760,7 +761,19 @@ object RapidsBufferCatalog extends Logging with Arm {
rapidsConf.gpuOomDumpDir,
rapidsConf.isGdsSpillEnabled,
rapidsConf.gpuOomMaxRetries)
Rmm.setEventHandler(memoryEventHandler)

if (rapidsConf.sparkRmmStateEnable) {
val debugLoc = if (rapidsConf.sparkRmmDebugLocation.isEmpty) {
null
} else {
rapidsConf.sparkRmmDebugLocation
}

RmmSpark.setEventHandler(memoryEventHandler, debugLoc)
} else {
logWarning("SparkRMM retry has been disabled")
Rmm.setEventHandler(memoryEventHandler)
}

_shouldUnspill = rapidsConf.isUnspillEnabled
}
Expand Down
23 changes: 23 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,25 @@ object RapidsConf {
.stringConf
.createWithDefault("NONE")

val SPARK_RMM_STATE_DEBUG = conf("spark.rapids.memory.gpu.state.debug")
abellina marked this conversation as resolved.
Show resolved Hide resolved
.doc("To better recover from out of memory errors, RMM will track several states for " +
"the threads that interact with the GPU. This provides a log of those state " +
"transitions to aid in debugging it. STDOUT or STDERR will have the logging go there " +
"empty string will disable logging and anything else will be treated as a file to " +
"write the logs to.")
.startupOnly()
.stringConf
.createWithDefault("")

val SPARK_RMM_STATE_ENABLE = conf("spark.rapids.memory.gpu.state.enable")
.doc("Enabled or disable using the SparkRMM state tracking to improve " +
"OOM response. This includes possibly retrying parts of the processing in " +
"the case of an OOM")
.startupOnly()
.internal()
.booleanConf
.createWithDefault(true)

val GPU_OOM_DUMP_DIR = conf("spark.rapids.memory.gpu.oomDumpDir")
.doc("The path to a local directory where a heap dump will be created if the GPU " +
"encounters an unrecoverable out-of-memory (OOM) error. The filename will be of the " +
Expand Down Expand Up @@ -1959,6 +1978,10 @@ class RapidsConf(conf: Map[String, String]) extends Logging {

lazy val rmmDebugLocation: String = get(RMM_DEBUG)

lazy val sparkRmmDebugLocation: String = get(SPARK_RMM_STATE_DEBUG)

lazy val sparkRmmStateEnable: Boolean = get(SPARK_RMM_STATE_ENABLE)

lazy val gpuOomDumpDir: Option[String] = get(GPU_OOM_DUMP_DIR)

lazy val gpuOomMaxRetries: Int = get(GPU_OOM_MAX_RETRIES)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,6 +20,7 @@ import java.util.concurrent.{Executors, ScheduledExecutorService, TimeUnit}

import scala.collection.mutable.ArrayBuffer

import com.nvidia.spark.rapids.jni.RmmSpark
import org.apache.commons.lang3.mutable.MutableLong

import org.apache.spark.SparkEnv
Expand Down Expand Up @@ -194,7 +195,9 @@ class RapidsShuffleHeartbeatEndpoint(pluginContext: PluginContext, conf: RapidsC
GpuDeviceManager.wrapThreadFactory(new ThreadFactoryBuilder()
.setNameFormat("rapids-shuffle-hb")
.setDaemon(true)
.build()))
.build(),
() => RmmSpark.associateCurrentThreadWithShuffle(),
abellina marked this conversation as resolved.
Show resolved Hide resolved
() => RmmSpark.removeCurrentThreadAssociation()))

private class InitializeShuffleManager(ctx: PluginContext,
shuffleManager: RapidsShuffleInternalManagerBase) extends Runnable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.collection.mutable

import ai.rapids.cudf.{NvtxColor, NvtxRange}
import com.nvidia.spark.rapids.{Arm, GpuSemaphore, NoopMetric, RapidsBuffer, RapidsBufferHandle, RapidsConf, ShuffleReceivedBufferCatalog}
import com.nvidia.spark.rapids.jni.RmmSpark

import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -345,9 +346,14 @@ class RapidsShuffleIterator(

val blockedStart = System.currentTimeMillis()
var result: Option[ShuffleClientResult] = None

result = pollForResult(timeoutSeconds)
RmmSpark.threadCouldBlockOnShuffle()
try {
result = pollForResult(timeoutSeconds)
} finally {
RmmSpark.threadDoneWithShuffle()
}
val blockedTime = System.currentTimeMillis() - blockedStart

result match {
case Some(BufferReceived(handle)) =>
val nvtxRangeAfterGettingBatch = new NvtxRange("RapidsShuffleIterator.gotBatch",
Expand Down
Loading