Skip to content

Commit

Permalink
Enables spillable/unspillable state for RapidsBuffer and allow buffer…
Browse files Browse the repository at this point in the history
… sharing (#7572)

* Enables RapidsBuffer sharing in the spill framework

Signed-off-by: Alessandro Bellina <[email protected]>

* Scalastyle

* Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferStore.scala

* If we keep holding onto the batch we wont spill to disk

* Remove synchronize from DeviceMemoryEventHandler

* Move definition of setSpillable to let data members be on top

* Lock the underlying buffer every time we add, to prevent multiple RapidsDeviceMemoryBuffer instances being created pointing at the same underlying

* Add spillableOnAdd so that stores can expose whether they intend buffers to be spillable as soon as they are added, or handled via ref counting

* Retry allocation instead of redundant spilling if multiple threads tried to spill at the same time

* Minor whitespace change

* Rework locking always taking catalog lock first

* Fix typo

* Fix log messages

* Code review comments

* fix imports and build issue

* Fix RapidsHostMemoryStoreSuite and unspill test

* Take care of some of the feedback

* Removes waitForPending from the RapidsBufferStore, address a few more comments, fix bug

* getExistingRapidsBufferAndAcquire private and static

* Require that callbacks are null when adding buffers to the device store

* Dont forget to set the spill store

* Add third argument matcher when mocking synchronousSpill

* In RapidsGdsStoreSuite use the provided catalog

---------

Signed-off-by: Alessandro Bellina <[email protected]>
  • Loading branch information
abellina authored Feb 8, 2023
1 parent 4e44978 commit 690017e
Show file tree
Hide file tree
Showing 17 changed files with 1,064 additions and 622 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,6 +35,7 @@ import org.apache.spark.sql.rapids.execution.TrampolineUtil
* depleting the device store
*/
class DeviceMemoryEventHandler(
catalog: RapidsBufferCatalog,
store: RapidsDeviceMemoryStore,
oomDumpDir: Option[String],
isGdsSpillEnabled: Boolean,
Expand All @@ -52,7 +53,7 @@ class DeviceMemoryEventHandler(

/**
* A small helper class that helps keep track of retry counts as we trigger
* synchronizes on a depleted store.
* synchronizes on a completely spilled store.
*/
class OOMRetryState {
private var synchronizeAttempts = 0
Expand Down Expand Up @@ -82,19 +83,19 @@ class DeviceMemoryEventHandler(
}

/**
* We reset our counters if `storeSize` is non-zero (as we only track when the store is
* depleted), or `retryCount` is less than or equal to what we had previously
* recorded in `shouldTrySynchronizing`. We do this to detect that the new failures
* are for a separate allocation (we need to give this new attempt a new set of
* retries.)
* We reset our counters if `storeSpillableSize` is non-zero (as we only track when all
* spillable buffers are removed from the store), or `retryCount` is less than or equal
* to what we had previously recorded in `shouldTrySynchronizing`.
* We do this to detect that the new failures are for a separate allocation (we need to
* give this new attempt a new set of retries.)
*
* For example, if an allocation fails and we deplete the store, `retryCountLastSynced`
* For example, if an allocation fails and we spill all eligible buffers, `retryCountLastSynced`
* is set to the last `retryCount` sent to us by cuDF as we keep allowing retries
* from cuDF. If we succeed, cuDF resets `retryCount`, and so the new count sent to us
* must be <= than what we saw last, so we can reset our tracking.
*/
def resetIfNeeded(retryCount: Int, storeSize: Long): Unit = {
if (storeSize != 0 || retryCount <= retryCountLastSynced) {
def resetIfNeeded(retryCount: Int, storeSpillableSize: Long): Unit = {
if (storeSpillableSize != 0 || retryCount <= retryCountLastSynced) {
reset()
}
}
Expand All @@ -108,28 +109,30 @@ class DeviceMemoryEventHandler(
*/
override def onAllocFailure(allocSize: Long, retryCount: Int): Boolean = {
// check arguments for good measure
require(allocSize >= 0,
require(allocSize >= 0,
s"onAllocFailure invoked with invalid allocSize $allocSize")

require(retryCount >= 0,
require(retryCount >= 0,
s"onAllocFailure invoked with invalid retryCount $retryCount")

try {
withResource(new NvtxRange("onAllocFailure", NvtxColor.RED)) { _ =>
val storeSize = store.currentSize
val storeSpillableSize = store.currentSpillableSize

val attemptMsg = if (retryCount > 0) {
s"Attempt ${retryCount}. "
} else {
"First attempt. "
}

val retryState = oomRetryState.get()
retryState.resetIfNeeded(retryCount, storeSize)
retryState.resetIfNeeded(retryCount, storeSpillableSize)

logInfo(s"Device allocation of $allocSize bytes failed, device store has " +
s"$storeSize bytes. $attemptMsg" +
s"$storeSize total and $storeSpillableSize spillable bytes. $attemptMsg" +
s"Total RMM allocated is ${Rmm.getTotalBytesAllocated} bytes. ")
if (storeSize == 0) {
if (storeSpillableSize == 0) {
if (retryState.shouldTrySynchronizing(retryCount)) {
Cuda.deviceSynchronize()
logWarning(s"[RETRY ${retryState.getRetriesSoFar}] " +
Expand All @@ -149,14 +152,17 @@ class DeviceMemoryEventHandler(
false
}
} else {
val targetSize = Math.max(storeSize - allocSize, 0)
val targetSize = Math.max(storeSpillableSize - allocSize, 0)
logDebug(s"Targeting device store size of $targetSize bytes")
val amountSpilled = store.synchronousSpill(targetSize)
logInfo(s"Spilled $amountSpilled bytes from the device store")
if (isGdsSpillEnabled) {
TrampolineUtil.incTaskMetricsDiskBytesSpilled(amountSpilled)
} else {
TrampolineUtil.incTaskMetricsMemoryBytesSpilled(amountSpilled)
val maybeAmountSpilled =
catalog.synchronousSpill(store, targetSize, Cuda.DEFAULT_STREAM)
maybeAmountSpilled.foreach { amountSpilled =>
logInfo(s"Spilled $amountSpilled bytes from the device store")
if (isGdsSpillEnabled) {
TrampolineUtil.incTaskMetricsDiskBytesSpilled(amountSpilled)
} else {
TrampolineUtil.incTaskMetricsMemoryBytesSpilled(amountSpilled)
}
}
true
}
Expand Down
Loading

0 comments on commit 690017e

Please sign in to comment.