Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enables spillable/unspillable state for RapidsBuffer and allow buffer sharing #7572

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
ae399bf
Enables RapidsBuffer sharing in the spill framework
abellina Jan 16, 2023
1d4bb47
Scalastyle
abellina Jan 24, 2023
a64a3d3
Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer…
abellina Jan 24, 2023
eb94df7
If we keep holding onto the batch we wont spill to disk
abellina Jan 24, 2023
5f9ac7a
Merge branch 'spill/rapids_buffer_handle_dedup_final' of github.com:a…
abellina Jan 24, 2023
3c5c48c
Remove synchronize from DeviceMemoryEventHandler
abellina Jan 25, 2023
99771ac
Move definition of setSpillable to let data members be on top
abellina Jan 25, 2023
8737302
Lock the underlying buffer every time we add, to prevent multiple Rap…
abellina Jan 25, 2023
e6632d0
Add spillableOnAdd so that stores can expose whether they intend buff…
abellina Jan 25, 2023
51673a1
Retry allocation instead of redundant spilling if multiple threads tr…
abellina Jan 25, 2023
a689ed2
Minor whitespace change
abellina Jan 26, 2023
e25f576
Merge branch 'branch-23.02' of https://github.com/NVIDIA/spark-rapids…
abellina Jan 30, 2023
6ccdbca
Merge branch 'branch-23.04' of https://github.com/NVIDIA/spark-rapids…
abellina Jan 30, 2023
a38ea73
Rework locking always taking catalog lock first
abellina Jan 26, 2023
d762557
Fix typo
abellina Jan 31, 2023
c470791
Fix log messages
abellina Jan 31, 2023
139b09d
Code review comments
abellina Feb 3, 2023
f96b644
fix imports and build issue
abellina Feb 3, 2023
7f8521c
Fix RapidsHostMemoryStoreSuite and unspill test
abellina Feb 6, 2023
bbc7c54
Take care of some of the feedback
abellina Feb 6, 2023
bc2a215
Removes waitForPending from the RapidsBufferStore, address a few more…
abellina Feb 6, 2023
a0b039a
getExistingRapidsBufferAndAcquire private and static
abellina Feb 6, 2023
93c0007
Require that callbacks are null when adding buffers to the device store
abellina Feb 7, 2023
1cf8c04
Dont forget to set the spill store
abellina Feb 7, 2023
00aa5ca
Add third argument matcher when mocking synchronousSpill
abellina Feb 7, 2023
6cca74d
In RapidsGdsStoreSuite use the provided catalog
abellina Feb 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,6 +35,7 @@ import org.apache.spark.sql.rapids.execution.TrampolineUtil
* depleting the device store
*/
class DeviceMemoryEventHandler(
catalog: RapidsBufferCatalog,
store: RapidsDeviceMemoryStore,
oomDumpDir: Option[String],
isGdsSpillEnabled: Boolean,
Expand All @@ -52,7 +53,7 @@ class DeviceMemoryEventHandler(

/**
* A small helper class that helps keep track of retry counts as we trigger
* synchronizes on a depleted store.
* synchronizes on a completely spilled store.
*/
class OOMRetryState {
private var synchronizeAttempts = 0
Expand Down Expand Up @@ -82,19 +83,19 @@ class DeviceMemoryEventHandler(
}

/**
* We reset our counters if `storeSize` is non-zero (as we only track when the store is
* depleted), or `retryCount` is less than or equal to what we had previously
* recorded in `shouldTrySynchronizing`. We do this to detect that the new failures
* are for a separate allocation (we need to give this new attempt a new set of
* retries.)
* We reset our counters if `storeSpillableSize` is non-zero (as we only track when all
* spillable buffers are removed from the store), or `retryCount` is less than or equal
* to what we had previously recorded in `shouldTrySynchronizing`.
* We do this to detect that the new failures are for a separate allocation (we need to
* give this new attempt a new set of retries.)
*
* For example, if an allocation fails and we deplete the store, `retryCountLastSynced`
* For example, if an allocation fails and we spill all eligible buffers, `retryCountLastSynced`
* is set to the last `retryCount` sent to us by cuDF as we keep allowing retries
* from cuDF. If we succeed, cuDF resets `retryCount`, and so the new count sent to us
* must be <= than what we saw last, so we can reset our tracking.
*/
def resetIfNeeded(retryCount: Int, storeSize: Long): Unit = {
if (storeSize != 0 || retryCount <= retryCountLastSynced) {
def resetIfNeeded(retryCount: Int, storeSpillableSize: Long): Unit = {
if (storeSpillableSize != 0 || retryCount <= retryCountLastSynced) {
reset()
}
}
Expand All @@ -108,28 +109,30 @@ class DeviceMemoryEventHandler(
*/
override def onAllocFailure(allocSize: Long, retryCount: Int): Boolean = {
// check arguments for good measure
require(allocSize >= 0,
require(allocSize >= 0,
s"onAllocFailure invoked with invalid allocSize $allocSize")

require(retryCount >= 0,
require(retryCount >= 0,
s"onAllocFailure invoked with invalid retryCount $retryCount")

try {
withResource(new NvtxRange("onAllocFailure", NvtxColor.RED)) { _ =>
val storeSize = store.currentSize
val storeSpillableSize = store.currentSpillableSize

val attemptMsg = if (retryCount > 0) {
s"Attempt ${retryCount}. "
} else {
"First attempt. "
}

val retryState = oomRetryState.get()
retryState.resetIfNeeded(retryCount, storeSize)
retryState.resetIfNeeded(retryCount, storeSpillableSize)

logInfo(s"Device allocation of $allocSize bytes failed, device store has " +
s"$storeSize bytes. $attemptMsg" +
s"$storeSize total and $storeSpillableSize spillable bytes. $attemptMsg" +
s"Total RMM allocated is ${Rmm.getTotalBytesAllocated} bytes. ")
if (storeSize == 0) {
if (storeSpillableSize == 0) {
if (retryState.shouldTrySynchronizing(retryCount)) {
Cuda.deviceSynchronize()
logWarning(s"[RETRY ${retryState.getRetriesSoFar}] " +
Expand All @@ -149,14 +152,17 @@ class DeviceMemoryEventHandler(
false
}
} else {
val targetSize = Math.max(storeSize - allocSize, 0)
val targetSize = Math.max(storeSpillableSize - allocSize, 0)
logDebug(s"Targeting device store size of $targetSize bytes")
val amountSpilled = store.synchronousSpill(targetSize)
logInfo(s"Spilled $amountSpilled bytes from the device store")
if (isGdsSpillEnabled) {
TrampolineUtil.incTaskMetricsDiskBytesSpilled(amountSpilled)
} else {
TrampolineUtil.incTaskMetricsMemoryBytesSpilled(amountSpilled)
val maybeAmountSpilled =
catalog.synchronousSpill(store, targetSize, Cuda.DEFAULT_STREAM)
maybeAmountSpilled.foreach { amountSpilled =>
logInfo(s"Spilled $amountSpilled bytes from the device store")
if (isGdsSpillEnabled) {
TrampolineUtil.incTaskMetricsDiskBytesSpilled(amountSpilled)
} else {
TrampolineUtil.incTaskMetricsMemoryBytesSpilled(amountSpilled)
}
}
true
}
Expand Down
Loading