Skip to content

Commit

Permalink
Adds RapidsBufferHandle as an indirection layer to RapidsBufferId
Browse files Browse the repository at this point in the history
Signed-off-by: Alessandro Bellina <[email protected]>
  • Loading branch information
abellina committed Jan 13, 2023
1 parent 2737942 commit e190406
Show file tree
Hide file tree
Showing 22 changed files with 1,095 additions and 442 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -101,8 +101,6 @@ trait RapidsBuffer extends AutoCloseable {
/** The storage tier for this buffer */
val storageTier: StorageTier

val spillCallback: SpillCallback

/**
* Get the columnar batch within this buffer. The caller must have
* successfully acquired the buffer beforehand.
Expand Down Expand Up @@ -171,13 +169,27 @@ trait RapidsBuffer extends AutoCloseable {
*/
def getSpillPriority: Long

/**
* Gets the spill metrics callback currently associated with this buffer.
* @return the current callback
*/
def getSpillCallback: SpillCallback

/**
* Set the spill priority for this buffer. Lower values are higher priority
* for spilling, meaning buffers with lower values will be preferred for
* spilling over buffers with a higher value.
* @note should only be called from the buffer catalog
* @param priority new priority value for this buffer
*/
def setSpillPriority(priority: Long): Unit

/**
* Update the metrics callback that will be invoked next time a spill occurs.
* @note should only be called from the buffer catalog
* @param spillCallback the new callback
*/
def setSpillCallback(spillCallback: SpillCallback): Unit
}

/**
Expand Down Expand Up @@ -226,9 +238,11 @@ sealed class DegenerateRapidsBuffer(

override def getSpillPriority: Long = Long.MaxValue

override val getSpillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback

override def setSpillPriority(priority: Long): Unit = {}

override def close(): Unit = {}
override def setSpillCallback(callback: SpillCallback): Unit = {}

override val spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback
override def close(): Unit = {}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -34,22 +34,184 @@ import org.apache.spark.sql.rapids.RapidsDiskBlockManager
*/
class DuplicateBufferException(s: String) extends RuntimeException(s) {}

/**
* An object that client code uses to interact with an underlying RapidsBufferId.
*
* A handle is obtained when a buffer, batch, or table is added to the spill framework
* via the `RapidsBufferCatalog` api.
*/
trait RapidsBufferHandle {
val id: RapidsBufferId

/**
* Sets the spill priority for this handle and updates the maximum priority
* for the underlying `RapidsBuffer` if this new priority is the maximum.
* @param newPriority new priority for this handle
*/
def setSpillPriority(newPriority: Long): Unit
}

/**
* Catalog for lookup of buffers by ID. The constructor is only visible for testing, generally
* `RapidsBufferCatalog.singleton` should be used instead.
*/
class RapidsBufferCatalog extends Logging {
class RapidsBufferCatalog extends Arm {

/** Map of buffer IDs to buffers sorted by storage tier */
private[this] val bufferMap = new ConcurrentHashMap[RapidsBufferId, Seq[RapidsBuffer]]

/** Map of buffer IDs to buffer handles in insertion order */
private[this] val bufferIdToHandles =
new ConcurrentHashMap[RapidsBufferId, Seq[RapidsBufferHandleImpl]]()

class RapidsBufferHandleImpl(
override val id: RapidsBufferId,
var priority: Long,
spillCallback: SpillCallback)
extends RapidsBufferHandle with Arm {

override def setSpillPriority(newPriority: Long): Unit = {
priority = newPriority
updateUnderlyingRapidsBuffer(this)
}

/**
* Get the spill priority that was associated with this handle. Since there can
* be multiple handles associated with one `RapidsBuffer`, the priority returned
* here is only useful for code in the catalog that updates the maximum priority
* for the underlying `RapidsBuffer` as handles are added and removed.
*
* @return this handle's spill priority
*/
def getSpillPriority: Long = priority

/**
* Each handle was created in a different part of the code and as such could have
* different spill metrics callbacks. This function is used by the catalog to find
* out what the last spill callback added. This last callback gets reports of
* spill bytes if a spill were to occur to the `RapidsBuffer` this handle points to.
*
* @return the spill callback associated with this handle
*/
def getSpillCallback: SpillCallback = spillCallback
}

/**
* Lookup the buffer that corresponds to the specified buffer ID at the highest storage tier,
* Makes a new `RapidsBufferHandle` associated with `id`, keeping track
* of the spill priority and callback within this handle.
*
* This function also adds the handle for internal tracking in the catalog.
*
* @param id the `RapidsBufferId` that this handle refers to
* @param spillPriority the spill priority specified on creation of the handle
* @param spillCallback this handle's spill callback
* @note public for testing
* @return a new instance of `RapidsBufferHandle`
*/
def makeNewHandle(
id: RapidsBufferId,
spillPriority: Long,
spillCallback: SpillCallback): RapidsBufferHandle = {
val handle = new RapidsBufferHandleImpl(id, spillPriority, spillCallback)
trackNewHandle(handle)
handle
}

/**
* Adds a handle to the internal `bufferIdToHandles` map.
*
* The priority and callback of the `RapidsBuffer` will also be updated.
*
* @param handle handle to start tracking
*/
private def trackNewHandle(handle: RapidsBufferHandleImpl): Unit = {
bufferIdToHandles.compute(handle.id, (_, h) => {
var handles = h
if (handles == null) {
handles = Seq.empty[RapidsBufferHandleImpl]
}
handles :+ handle
})
updateUnderlyingRapidsBuffer(handle)
}

/**
* Called when the `RapidsBufferHandle` is no longer needed by calling code
*
* If this is the last handle associated with a `RapidsBuffer`, `stopTrackingHandle`
* returns true, otherwise it returns false.
*
* @param handle handle to stop tracking
* @return
*/
private def stopTrackingHandle(handle: RapidsBufferHandle): Boolean = {
withResource(acquireBuffer(handle)) { buffer =>
val id = handle.id
var maxPriority = Long.MinValue
val newHandles = bufferIdToHandles.compute(id, (_, handles) => {
if (handles == null) {
throw new IllegalStateException(
s"$id not found and we attempted to remove handles!")
}
if (handles.size == 1) {
require(handles.head == handle,
"Tried to remove a single handle, and we couldn't match on it")
null
} else {
val newHandles = handles.filter(h => h != handle).map { h =>
if (h.getSpillPriority > maxPriority) {
maxPriority = h.getSpillPriority
}
h
}
if (newHandles.isEmpty) {
null // remove since no more handles exist, should not happen
} else {
// we pick the last spillCallback inserted as the winner every time
// this callback is going to get the metrics associated with this buffer's
// spill
newHandles
}
}
})

if (newHandles == null) {
// tell calling code that no more handles exist,
// for this RapidsBuffer
true
} else {
// more handles remain, our priority changed so we need to update things
buffer.setSpillPriority(maxPriority)
buffer.setSpillCallback(newHandles.last.getSpillCallback)
false // we have handles left
}
}
}

/**
* Called by the catalog when a handle is first added to the catalog, or to refresh
* the priority of the underlying buffer if a handle's priority changed.
*/
private def updateUnderlyingRapidsBuffer(handle: RapidsBufferHandle): Unit = {
withResource(acquireBuffer(handle)) { buffer =>
val handles = bufferIdToHandles.get(buffer.id)
val maxPriority = handles.map(_.getSpillPriority).max
// update the priority of the underlying RapidsBuffer to be the
// maximum priority for all handles associated with it
buffer.setSpillPriority(maxPriority)
buffer.setSpillCallback(handles.last.getSpillCallback)
}
}

/**
* Lookup the buffer that corresponds to the specified handle at the highest storage tier,
* and acquire it.
* NOTE: It is the responsibility of the caller to close the buffer.
* @param id buffer identifier
* @param handle handle associated with this `RapidsBuffer`
* @return buffer that has been acquired
*/
def acquireBuffer(id: RapidsBufferId): RapidsBuffer = {
def acquireBuffer(handle: RapidsBufferHandle): RapidsBuffer = {
val id = handle.id
(0 until RapidsBufferCatalog.MAX_BUFFER_LOOKUP_ATTEMPTS).foreach { _ =>
val buffers = bufferMap.get(id)
if (buffers == null || buffers.isEmpty) {
Expand Down Expand Up @@ -124,6 +286,7 @@ class RapidsBufferCatalog extends Logging {
}
}
}

bufferMap.compute(buffer.id, updater)
}

Expand All @@ -142,10 +305,19 @@ class RapidsBufferCatalog extends Logging {
bufferMap.computeIfPresent(id, updater)
}

/** Remove a buffer ID from the catalog and release the resources of the registered buffers. */
def removeBuffer(id: RapidsBufferId): Unit = {
val buffers = bufferMap.remove(id)
buffers.safeFree()
/**
* Remove a buffer handle from the catalog and, if it this was the final handle,
* release the resources of the registered buffers.
*/
def removeBuffer(handle: RapidsBufferHandle): Boolean = {
// if this is the last handle, remove the buffer
if (stopTrackingHandle(handle)) {
val buffers = bufferMap.remove(handle.id)
buffers.safeFree()
true
} else {
false
}
}

/** Return the number of buffers currently in the catalog. */
Expand Down Expand Up @@ -248,65 +420,76 @@ object RapidsBufferCatalog extends Logging with Arm {

/**
* Adds a contiguous table to the device storage, taking ownership of the table.
* @param id buffer ID to associate with this buffer
* @param table cudf table based from the contiguous buffer
* @param contigBuffer device memory buffer backing the table
* @param tableMeta metadata describing the buffer layout
* @param initialSpillPriority starting spill priority value for the buffer
* @param spillCallback a callback when the buffer is spilled. This should be very light weight.
* It should never allocate GPU memory and really just be used for metrics.
* @return RapidsBufferHandle associated with this buffer
*/
def addTable(
id: RapidsBufferId,
table: Table,
contigBuffer: DeviceMemoryBuffer,
tableMeta: TableMeta,
initialSpillPriority: Long,
spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback): Unit =
deviceStorage.addTable(id, table, contigBuffer, tableMeta, initialSpillPriority, spillCallback)
spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback): RapidsBufferHandle = {
val id =
deviceStorage.addTable(table, contigBuffer, tableMeta, initialSpillPriority)
singleton.makeNewHandle(id, initialSpillPriority, spillCallback)
}

/**
* Adds a contiguous table to the device storage, taking ownership of the table.
* @param id buffer ID to associate with this buffer
* @param contigTable contiguous table to track in device storage
* @param contigTable contiguous table to trackNewHandle in device storage
* @param initialSpillPriority starting spill priority value for the buffer
* @param spillCallback a callback when the buffer is spilled. This should be very light weight.
* It should never allocate GPU memory and really just be used for metrics.
* @return RapidsBufferHandle associated with this buffer
*/
def addContiguousTable(
id: RapidsBufferId,
contigTable: ContiguousTable,
initialSpillPriority: Long,
spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback): Unit =
deviceStorage.addContiguousTable(id, contigTable, initialSpillPriority, spillCallback)
spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback): RapidsBufferHandle = {
val id = deviceStorage.addContiguousTable(
contigTable, initialSpillPriority, spillCallback)
singleton.makeNewHandle(id, initialSpillPriority, spillCallback)
}

/**
* Adds a buffer to the device storage, taking ownership of the buffer.
* @param id buffer ID to associate with this buffer
* @param buffer buffer that will be owned by the store
* @param tableMeta metadata describing the buffer layout
* @param initialSpillPriority starting spill priority value for the buffer
* @param spillCallback a callback when the buffer is spilled. This should be very light weight.
* It should never allocate GPU memory and really just be used for metrics.
* @return RapidsBufferHandle associated with this buffer
*/
def addBuffer(
id: RapidsBufferId,
buffer: DeviceMemoryBuffer,
tableMeta: TableMeta,
initialSpillPriority: Long,
spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback): Unit =
deviceStorage.addBuffer(id, buffer, tableMeta, initialSpillPriority, spillCallback)
spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback): RapidsBufferHandle = {
val id = deviceStorage.addBuffer(
buffer, tableMeta, initialSpillPriority, spillCallback)
singleton.makeNewHandle(id, initialSpillPriority, spillCallback)
}

/**
* Lookup the buffer that corresponds to the specified buffer ID and acquire it.
* Lookup the buffer that corresponds to the specified buffer handle and acquire it.
* NOTE: It is the responsibility of the caller to close the buffer.
* @param id buffer identifier
* @param handle buffer handle
* @return buffer that has been acquired
*/
def acquireBuffer(id: RapidsBufferId): RapidsBuffer = singleton.acquireBuffer(id)
def acquireBuffer(handle: RapidsBufferHandle): RapidsBuffer =
singleton.acquireBuffer(handle)

/** Remove a buffer ID from the catalog and release the resources of the registered buffer. */
def removeBuffer(id: RapidsBufferId): Unit = singleton.removeBuffer(id)
/**
* Remove a buffer handle from the catalog and, if it this was the final handle,
* release the resources of the registered buffers.
*/
def removeBuffer(handle: RapidsBufferHandle): Unit =
singleton.removeBuffer(handle)

def getDiskBlockManager(): RapidsDiskBlockManager = diskBlockManager
}
Loading

0 comments on commit e190406

Please sign in to comment.