Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds RapidsBufferHandle as an indirection layer to RapidsBufferId #7512

Merged
merged 15 commits into from
Jan 19, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -101,8 +101,6 @@ trait RapidsBuffer extends AutoCloseable {
/** The storage tier for this buffer */
val storageTier: StorageTier

val spillCallback: SpillCallback

/**
* Get the columnar batch within this buffer. The caller must have
* successfully acquired the buffer beforehand.
Expand Down Expand Up @@ -171,13 +169,27 @@ trait RapidsBuffer extends AutoCloseable {
*/
def getSpillPriority: Long

/**
* Gets the spill metrics callback currently associated with this buffer.
* @return the current callback
*/
def getSpillCallback: SpillCallback

/**
* Set the spill priority for this buffer. Lower values are higher priority
* for spilling, meaning buffers with lower values will be preferred for
* spilling over buffers with a higher value.
* @note should only be called from the buffer catalog
* @param priority new priority value for this buffer
*/
def setSpillPriority(priority: Long): Unit

/**
* Update the metrics callback that will be invoked next time a spill occurs.
* @note should only be called from the buffer catalog
* @param spillCallback the new callback
*/
def setSpillCallback(spillCallback: SpillCallback): Unit
}

/**
Expand Down Expand Up @@ -226,9 +238,11 @@ sealed class DegenerateRapidsBuffer(

override def getSpillPriority: Long = Long.MaxValue

override val getSpillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback

override def setSpillPriority(priority: Long): Unit = {}

override def close(): Unit = {}
override def setSpillCallback(callback: SpillCallback): Unit = {}

override val spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback
override def close(): Unit = {}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -34,22 +34,184 @@ import org.apache.spark.sql.rapids.RapidsDiskBlockManager
*/
class DuplicateBufferException(s: String) extends RuntimeException(s) {}

/**
* An object that client code uses to interact with an underlying RapidsBufferId.
*
* A handle is obtained when a buffer, batch, or table is added to the spill framework
* via the `RapidsBufferCatalog` api.
*/
trait RapidsBufferHandle {
val id: RapidsBufferId
revans2 marked this conversation as resolved.
Show resolved Hide resolved

/**
* Sets the spill priority for this handle and updates the maximum priority
* for the underlying `RapidsBuffer` if this new priority is the maximum.
* @param newPriority new priority for this handle
*/
def setSpillPriority(newPriority: Long): Unit
}

/**
* Catalog for lookup of buffers by ID. The constructor is only visible for testing, generally
* `RapidsBufferCatalog.singleton` should be used instead.
*/
class RapidsBufferCatalog extends Logging {
class RapidsBufferCatalog extends Arm {

/** Map of buffer IDs to buffers sorted by storage tier */
private[this] val bufferMap = new ConcurrentHashMap[RapidsBufferId, Seq[RapidsBuffer]]

/** Map of buffer IDs to buffer handles in insertion order */
private[this] val bufferIdToHandles =
new ConcurrentHashMap[RapidsBufferId, Seq[RapidsBufferHandleImpl]]()

class RapidsBufferHandleImpl(
override val id: RapidsBufferId,
var priority: Long,
spillCallback: SpillCallback)
extends RapidsBufferHandle with Arm {
jlowe marked this conversation as resolved.
Show resolved Hide resolved

override def setSpillPriority(newPriority: Long): Unit = {
priority = newPriority
updateUnderlyingRapidsBuffer(this)
}

/**
* Get the spill priority that was associated with this handle. Since there can
* be multiple handles associated with one `RapidsBuffer`, the priority returned
* here is only useful for code in the catalog that updates the maximum priority
* for the underlying `RapidsBuffer` as handles are added and removed.
*
* @return this handle's spill priority
*/
def getSpillPriority: Long = priority

/**
* Each handle was created in a different part of the code and as such could have
* different spill metrics callbacks. This function is used by the catalog to find
* out what the last spill callback added. This last callback gets reports of
* spill bytes if a spill were to occur to the `RapidsBuffer` this handle points to.
*
* @return the spill callback associated with this handle
*/
def getSpillCallback: SpillCallback = spillCallback
}

/**
* Lookup the buffer that corresponds to the specified buffer ID at the highest storage tier,
* Makes a new `RapidsBufferHandle` associated with `id`, keeping track
* of the spill priority and callback within this handle.
*
* This function also adds the handle for internal tracking in the catalog.
*
* @param id the `RapidsBufferId` that this handle refers to
* @param spillPriority the spill priority specified on creation of the handle
* @param spillCallback this handle's spill callback
* @note public for testing
* @return a new instance of `RapidsBufferHandle`
*/
def makeNewHandle(
id: RapidsBufferId,
spillPriority: Long,
spillCallback: SpillCallback): RapidsBufferHandle = {
val handle = new RapidsBufferHandleImpl(id, spillPriority, spillCallback)
trackNewHandle(handle)
handle
}

/**
* Adds a handle to the internal `bufferIdToHandles` map.
*
* The priority and callback of the `RapidsBuffer` will also be updated.
*
* @param handle handle to start tracking
*/
private def trackNewHandle(handle: RapidsBufferHandleImpl): Unit = {
bufferIdToHandles.compute(handle.id, (_, h) => {
var handles = h
if (handles == null) {
handles = Seq.empty[RapidsBufferHandleImpl]
}
handles :+ handle
})
updateUnderlyingRapidsBuffer(handle)
}

/**
* Called when the `RapidsBufferHandle` is no longer needed by calling code
*
* If this is the last handle associated with a `RapidsBuffer`, `stopTrackingHandle`
* returns true, otherwise it returns false.
*
* @param handle handle to stop tracking
* @return
jlowe marked this conversation as resolved.
Show resolved Hide resolved
*/
private def stopTrackingHandle(handle: RapidsBufferHandle): Boolean = {
withResource(acquireBuffer(handle)) { buffer =>
val id = handle.id
var maxPriority = Long.MinValue
val newHandles = bufferIdToHandles.compute(id, (_, handles) => {
if (handles == null) {
throw new IllegalStateException(
s"$id not found and we attempted to remove handles!")
}
if (handles.size == 1) {
require(handles.head == handle,
"Tried to remove a single handle, and we couldn't match on it")
null
abellina marked this conversation as resolved.
Show resolved Hide resolved
} else {
val newHandles = handles.filter(h => h != handle).map { h =>
if (h.getSpillPriority > maxPriority) {
maxPriority = h.getSpillPriority
}
jlowe marked this conversation as resolved.
Show resolved Hide resolved
h
}
if (newHandles.isEmpty) {
null // remove since no more handles exist, should not happen
} else {
// we pick the last spillCallback inserted as the winner every time
// this callback is going to get the metrics associated with this buffer's
// spill
newHandles
}
}
})

if (newHandles == null) {
// tell calling code that no more handles exist,
// for this RapidsBuffer
true
} else {
// more handles remain, our priority changed so we need to update things
buffer.setSpillPriority(maxPriority)
buffer.setSpillCallback(newHandles.last.getSpillCallback)
false // we have handles left
}
}
}

/**
* Called by the catalog when a handle is first added to the catalog, or to refresh
* the priority of the underlying buffer if a handle's priority changed.
*/
private def updateUnderlyingRapidsBuffer(handle: RapidsBufferHandle): Unit = {
withResource(acquireBuffer(handle)) { buffer =>
val handles = bufferIdToHandles.get(buffer.id)
val maxPriority = handles.map(_.getSpillPriority).max
// update the priority of the underlying RapidsBuffer to be the
// maximum priority for all handles associated with it
buffer.setSpillPriority(maxPriority)
buffer.setSpillCallback(handles.last.getSpillCallback)
}
}

/**
* Lookup the buffer that corresponds to the specified handle at the highest storage tier,
* and acquire it.
* NOTE: It is the responsibility of the caller to close the buffer.
* @param id buffer identifier
* @param handle handle associated with this `RapidsBuffer`
* @return buffer that has been acquired
*/
def acquireBuffer(id: RapidsBufferId): RapidsBuffer = {
def acquireBuffer(handle: RapidsBufferHandle): RapidsBuffer = {
val id = handle.id
(0 until RapidsBufferCatalog.MAX_BUFFER_LOOKUP_ATTEMPTS).foreach { _ =>
val buffers = bufferMap.get(id)
if (buffers == null || buffers.isEmpty) {
Expand Down Expand Up @@ -124,6 +286,7 @@ class RapidsBufferCatalog extends Logging {
}
}
}

bufferMap.compute(buffer.id, updater)
}

Expand All @@ -142,10 +305,19 @@ class RapidsBufferCatalog extends Logging {
bufferMap.computeIfPresent(id, updater)
}

/** Remove a buffer ID from the catalog and release the resources of the registered buffers. */
def removeBuffer(id: RapidsBufferId): Unit = {
val buffers = bufferMap.remove(id)
buffers.safeFree()
/**
* Remove a buffer handle from the catalog and, if it this was the final handle,
* release the resources of the registered buffers.
jlowe marked this conversation as resolved.
Show resolved Hide resolved
*/
def removeBuffer(handle: RapidsBufferHandle): Boolean = {
// if this is the last handle, remove the buffer
if (stopTrackingHandle(handle)) {
val buffers = bufferMap.remove(handle.id)
buffers.safeFree()
true
} else {
false
}
}

/** Return the number of buffers currently in the catalog. */
Expand Down Expand Up @@ -248,65 +420,76 @@ object RapidsBufferCatalog extends Logging with Arm {

/**
* Adds a contiguous table to the device storage, taking ownership of the table.
* @param id buffer ID to associate with this buffer
* @param table cudf table based from the contiguous buffer
* @param contigBuffer device memory buffer backing the table
* @param tableMeta metadata describing the buffer layout
* @param initialSpillPriority starting spill priority value for the buffer
* @param spillCallback a callback when the buffer is spilled. This should be very light weight.
* It should never allocate GPU memory and really just be used for metrics.
* @return RapidsBufferHandle associated with this buffer
*/
def addTable(
id: RapidsBufferId,
table: Table,
contigBuffer: DeviceMemoryBuffer,
tableMeta: TableMeta,
initialSpillPriority: Long,
spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback): Unit =
deviceStorage.addTable(id, table, contigBuffer, tableMeta, initialSpillPriority, spillCallback)
spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback): RapidsBufferHandle = {
val id =
deviceStorage.addTable(table, contigBuffer, tableMeta, initialSpillPriority)
singleton.makeNewHandle(id, initialSpillPriority, spillCallback)
}

/**
* Adds a contiguous table to the device storage, taking ownership of the table.
* @param id buffer ID to associate with this buffer
* @param contigTable contiguous table to track in device storage
* @param contigTable contiguous table to trackNewHandle in device storage
* @param initialSpillPriority starting spill priority value for the buffer
* @param spillCallback a callback when the buffer is spilled. This should be very light weight.
* It should never allocate GPU memory and really just be used for metrics.
* @return RapidsBufferHandle associated with this buffer
*/
def addContiguousTable(
id: RapidsBufferId,
contigTable: ContiguousTable,
initialSpillPriority: Long,
spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback): Unit =
deviceStorage.addContiguousTable(id, contigTable, initialSpillPriority, spillCallback)
spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback): RapidsBufferHandle = {
val id = deviceStorage.addContiguousTable(
contigTable, initialSpillPriority, spillCallback)
singleton.makeNewHandle(id, initialSpillPriority, spillCallback)
}

/**
* Adds a buffer to the device storage, taking ownership of the buffer.
* @param id buffer ID to associate with this buffer
* @param buffer buffer that will be owned by the store
* @param tableMeta metadata describing the buffer layout
* @param initialSpillPriority starting spill priority value for the buffer
* @param spillCallback a callback when the buffer is spilled. This should be very light weight.
* It should never allocate GPU memory and really just be used for metrics.
* @return RapidsBufferHandle associated with this buffer
*/
def addBuffer(
id: RapidsBufferId,
buffer: DeviceMemoryBuffer,
tableMeta: TableMeta,
initialSpillPriority: Long,
spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback): Unit =
deviceStorage.addBuffer(id, buffer, tableMeta, initialSpillPriority, spillCallback)
spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback): RapidsBufferHandle = {
val id = deviceStorage.addBuffer(
buffer, tableMeta, initialSpillPriority, spillCallback)
singleton.makeNewHandle(id, initialSpillPriority, spillCallback)
}

/**
* Lookup the buffer that corresponds to the specified buffer ID and acquire it.
* Lookup the buffer that corresponds to the specified buffer handle and acquire it.
* NOTE: It is the responsibility of the caller to close the buffer.
* @param id buffer identifier
* @param handle buffer handle
* @return buffer that has been acquired
*/
def acquireBuffer(id: RapidsBufferId): RapidsBuffer = singleton.acquireBuffer(id)
def acquireBuffer(handle: RapidsBufferHandle): RapidsBuffer =
singleton.acquireBuffer(handle)

/** Remove a buffer ID from the catalog and release the resources of the registered buffer. */
def removeBuffer(id: RapidsBufferId): Unit = singleton.removeBuffer(id)
/**
* Remove a buffer handle from the catalog and, if it this was the final handle,
* release the resources of the registered buffers.
*/
def removeBuffer(handle: RapidsBufferHandle): Unit =
singleton.removeBuffer(handle)

def getDiskBlockManager(): RapidsDiskBlockManager = diskBlockManager
}
Loading