Skip to content

Commit

Permalink
Add in a new API for high priority allocation and add in spilling (#9225
Browse files Browse the repository at this point in the history
)

Signed-off-by: Robert (Bobby) Evans <[email protected]>
  • Loading branch information
revans2 authored Sep 12, 2023
1 parent e8ba8ef commit 8399836
Show file tree
Hide file tree
Showing 3 changed files with 273 additions and 28 deletions.
87 changes: 77 additions & 10 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostAlloc.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.TaskContext
private class HostAlloc(nonPinnedLimit: Long) {
private var currentNonPinnedAllocated: Long = 0L
private var currentNonPinnedReserved: Long = 0L
private val pinnedLimit: Long = PinnedMemoryPool.getTotalPoolSizeBytes()
private val pinnedLimit: Long = PinnedMemoryPool.getTotalPoolSizeBytes
// For now we are going to assume that we are the only ones calling into the pinned pool
// That is not really true, but should be okay.
private var currentPinnedAllocated: Long = 0L
Expand All @@ -48,15 +48,16 @@ private class HostAlloc(nonPinnedLimit: Long) {
* An allocation that has not been completed yet. It is blocked waiting for more resources.
*/
private class BlockedAllocation(val amount: Long, val taskId: Long) {
var shouldWake = false
private var shouldWake = false

def isReady: Boolean = shouldWake

/**
* Wait until we should retry the allocation because it might succeed. It is not
* guaranteed though.
* It is required that the parent lock is held before this is called.
*/
def waitUntilPossiblyReady(): Unit = {
shouldWake = false
while (!shouldWake) {
HostAlloc.this.wait(1000)
}
Expand Down Expand Up @@ -280,12 +281,16 @@ private class HostAlloc(nonPinnedLimit: Long) {
ret
}

private def checkSize(amount: Long, tryPinned: Boolean): Unit = {
val pinnedFailed = (isPinnedOnly || tryPinned) && (amount > pinnedLimit)
private def canNeverSucceed(amount: Long, preferPinned: Boolean): Boolean = {
val pinnedFailed = (isPinnedOnly || preferPinned) && (amount > pinnedLimit)
val nonPinnedFailed = isPinnedOnly || (amount > nonPinnedLimit)
if (pinnedFailed && nonPinnedFailed) {
pinnedFailed && nonPinnedFailed
}

private def checkSize(amount: Long, preferPinned: Boolean): Unit = {
if (canNeverSucceed(amount, preferPinned)) {
throw new IllegalArgumentException(s"The amount requested $amount is larger than the " +
s"maximum pool size ${math.max(pinnedLimit, nonPinnedLimit)}")
s"maximum pool size ${math.max(pinnedLimit, nonPinnedLimit)}")
}
}

Expand All @@ -311,16 +316,52 @@ private class HostAlloc(nonPinnedLimit: Long) {
do {
ret = tryAlloc(amount, preferPinned)
if (ret.isEmpty) {
if (blocked == null) {
blocked = new BlockedAllocation(amount, TaskContext.get().taskAttemptId())
}
blocked = new BlockedAllocation(amount, TaskContext.get().taskAttemptId())
pendingAllowedQueue.offer(blocked)
var amountSpilled: Option[Long] = None
// None for amountSpilled means we need to retry because of a race.
// forall returns true for None in this case.
while(!blocked.isReady && amountSpilled.forall(_ > 0)) {
amountSpilled = RapidsBufferCatalog.synchronousSpill(
RapidsBufferCatalog.getHostStorage, amount)
}
// Wait until we think we are ready to allocate something
blocked.waitUntilPossiblyReady()
}
} while(ret.isEmpty)
ret.get
}

/**
* Allocate a buffer at the highest priority possible. If the allocation cannot happen
* for whatever reason a None is returned instead of blocking
*/
def allocHighPriority(amount: Long,
preferPinned: Boolean = true): Option[HostMemoryBuffer] = synchronized {
var ret: Option[HostMemoryBuffer] = None
if (!canNeverSucceed(amount, preferPinned)) {
ret = tryAlloc(amount, preferPinned)
if (ret.isEmpty) {
val blocked = new BlockedAllocation(amount, Long.MinValue)
pendingAllowedQueue.offer(blocked)
var amountSpilled: Option[Long] = None
// None for amountSpilled means we need to retry because of a race.
// forall returns true for None in this case.
while (!blocked.isReady && amountSpilled.forall(_ > 0)) {
amountSpilled = RapidsBufferCatalog.synchronousSpill(
RapidsBufferCatalog.getHostStorage, amount)
}

if (blocked.isReady) {
ret = tryAlloc(amount, preferPinned)
} else {
pendingAllowedQueue.remove(blocked)
}
}
}
ret
}

def reserve(amount: Long, preferPinned: Boolean): HostMemoryReservation = synchronized {
var ret: Option[HostMemoryReservation] = None
var blocked: BlockedAllocation = null
Expand Down Expand Up @@ -377,6 +418,16 @@ object HostAlloc {
getSingleton.alloc(amount, preferPinned)
}

/**
* Allocate a HostMemoryBuffer, but at the highest priority. This will not block for a free. It
* may spill data to make room for the allocation, but it will do it at the highest priority.
* If we cannot make it work, then a None will be returned an whoever tries to use this needs
* a backup plan.
*/
def allocHighPriority(amount: Long, preferPinned: Boolean = true): Option[HostMemoryBuffer] = {
getSingleton.allocHighPriority(amount, preferPinned)
}

def reserve(amount: Long, preferPinned: Boolean = true): HostMemoryReservation = {
getSingleton.reserve(amount, preferPinned)
}
Expand Down Expand Up @@ -426,6 +477,22 @@ object HostAlloc {
}
}

private def findEventHandlerInternal[K](handler: MemoryBuffer.EventHandler,
eh: PartialFunction[MemoryBuffer.EventHandler, K]): Option[K] = handler match {
case multi: MultiEventHandler =>
findEventHandlerInternal(multi.a, eh)
.orElse(findEventHandlerInternal(multi.b, eh))
case other =>
eh.lift(other)
}

def findEventHandler[K](buff: HostMemoryBuffer)(
eh: PartialFunction[MemoryBuffer.EventHandler, K]): Option[K] = {
buff.synchronized {
findEventHandlerInternal(buff.getEventHandler, eh)
}
}

private case class MultiEventHandler(a: MemoryBuffer.EventHandler,
b: MemoryBuffer.EventHandler)
extends MemoryBuffer.EventHandler {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -625,8 +625,8 @@ class RapidsBufferCatalog(
}
}
}
Some(totalSpilled)
}
Some(totalSpilled)
}
}

Expand Down Expand Up @@ -1023,20 +1023,31 @@ object RapidsBufferCatalog extends Logging {
* brand new to the store, or the `RapidsBuffer` is invalid and
* about to be removed).
*/
private def getExistingRapidsBufferAndAcquire(
buffer: MemoryBuffer): Option[RapidsBuffer] = {
val eh = buffer.getEventHandler
eh match {
case null =>
None
case rapidsBuffer: RapidsBuffer =>
if (rapidsBuffer.addReference()) {
Some(rapidsBuffer)
} else {
None
}
private def getExistingRapidsBufferAndAcquire(buffer: MemoryBuffer): Option[RapidsBuffer] = {
buffer match {
case hb: HostMemoryBuffer =>
HostAlloc.findEventHandler(hb) {
case rapidsBuffer: RapidsBuffer =>
if (rapidsBuffer.addReference()) {
Some(rapidsBuffer)
} else {
None
}
}.flatten
case _ =>
throw new IllegalStateException("Unknown event handler")
val eh = buffer.getEventHandler
eh match {
case null =>
None
case rapidsBuffer: RapidsBuffer =>
if (rapidsBuffer.addReference()) {
Some(rapidsBuffer)
} else {
None
}
case _ =>
throw new IllegalStateException("Unknown event handler")
}
}
}
}
Expand Down
Loading

0 comments on commit 8399836

Please sign in to comment.