Skip to content

Commit

Permalink
Spillable host buffer (#9070)
Browse files Browse the repository at this point in the history
* Spillable host buffer

---------

Signed-off-by: Alessandro Bellina <[email protected]>
  • Loading branch information
abellina authored Aug 21, 2023
1 parent 5cafe66 commit f723dfc
Show file tree
Hide file tree
Showing 10 changed files with 399 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,11 @@ object MetaUtils {
}

/**
* Constructs a table metadata buffer from a device buffer without describing any schema
* Constructs a table metadata buffer from a buffer length without describing any schema
* for the buffer.
*/
def getTableMetaNoTable(buffer: DeviceMemoryBuffer): TableMeta = {
def getTableMetaNoTable(bufferSize: Long): TableMeta = {
val fbb = new FlatBufferBuilder(1024)
val bufferSize = buffer.getLength
BufferMeta.startBufferMeta(fbb)
BufferMeta.addId(fbb, 0)
BufferMeta.addSize(fbb, bufferSize)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,32 @@ trait RapidsBuffer extends AutoCloseable {
* @param priority new priority value for this buffer
*/
def setSpillPriority(priority: Long): Unit

/**
* Function invoked by the `RapidsBufferStore.addBuffer` method that prompts
* the specific `RapidsBuffer` to check its reference counting to make itself
* spillable or not. Only `RapidsTable` and `RapidsHostMemoryBuffer` implement
* this method.
*/
def updateSpillability(): Unit = {}

/**
* Obtains a read lock on this instance of `RapidsBuffer` and calls the function
* in `body` while holding the lock.
* @param body function that takes a `MemoryBuffer` and produces `K`
* @tparam K any return type specified by `body`
* @return the result of body(memoryBuffer)
*/
def withMemoryBufferReadLock[K](body: MemoryBuffer => K): K

/**
* Obtains a write lock on this instance of `RapidsBuffer` and calls the function
* in `body` while holding the lock.
* @param body function that takes a `MemoryBuffer` and produces `K`
* @tparam K any return type specified by `body`
* @return the result of body(memoryBuffer)
*/
def withMemoryBufferWriteLock[K](body: MemoryBuffer => K): K
}

/**
Expand Down Expand Up @@ -385,5 +411,13 @@ sealed class DegenerateRapidsBuffer(

override def setSpillPriority(priority: Long): Unit = {}

override def withMemoryBufferReadLock[K](body: MemoryBuffer => K): K = {
throw new UnsupportedOperationException("degenerate buffer has no memory buffer")
}

override def withMemoryBufferWriteLock[K](body: MemoryBuffer => K): K = {
throw new UnsupportedOperationException("degenerate buffer has no memory buffer")
}

override def close(): Unit = {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.util.function.BiFunction

import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf.{ContiguousTable, Cuda, DeviceMemoryBuffer, NvtxColor, NvtxRange, Rmm, Table}
import ai.rapids.cudf.{ContiguousTable, Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer, NvtxColor, NvtxRange, Rmm, Table}
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.RapidsBufferCatalog.getExistingRapidsBufferAndAcquire
import com.nvidia.spark.rapids.RapidsPluginImplicits._
Expand Down Expand Up @@ -63,7 +63,8 @@ trait RapidsBufferHandle extends AutoCloseable {
* `RapidsBufferCatalog.singleton` should be used instead.
*/
class RapidsBufferCatalog(
deviceStorage: RapidsDeviceMemoryStore = RapidsBufferCatalog.deviceStorage)
deviceStorage: RapidsDeviceMemoryStore = RapidsBufferCatalog.deviceStorage,
hostStorage: RapidsHostMemoryStore = RapidsBufferCatalog.hostStorage)
extends AutoCloseable with Logging {

/** Map of buffer IDs to buffers sorted by storage tier */
Expand Down Expand Up @@ -198,7 +199,7 @@ class RapidsBufferCatalog(
}

/**
* Adds a buffer to the device storage. This does NOT take ownership of the
* Adds a buffer to the catalog and store. This does NOT take ownership of the
* buffer, so it is the responsibility of the caller to close it.
*
* This version of `addBuffer` should not be called from the shuffle catalogs
Expand All @@ -212,7 +213,7 @@ class RapidsBufferCatalog(
* @return RapidsBufferHandle handle for this buffer
*/
def addBuffer(
buffer: DeviceMemoryBuffer,
buffer: MemoryBuffer,
tableMeta: TableMeta,
initialSpillPriority: Long,
needsSync: Boolean = true): RapidsBufferHandle = synchronized {
Expand Down Expand Up @@ -294,29 +295,42 @@ class RapidsBufferCatalog(
}

/**
* Adds a buffer to the device storage. This does NOT take ownership of the
* buffer, so it is the responsibility of the caller to close it.
* Adds a buffer to either the device or host storage. This does NOT take
* ownership of the buffer, so it is the responsibility of the caller to close it.
*
* @param id the RapidsBufferId to use for this buffer
* @param buffer buffer that will be owned by the store
* @param buffer buffer that will be owned by the target store
* @param tableMeta metadata describing the buffer layout
* @param initialSpillPriority starting spill priority value for the buffer
* @param needsSync whether the spill framework should stream synchronize while adding
* this device buffer (defaults to true)
* this buffer (defaults to true)
* @return RapidsBufferHandle handle for this RapidsBuffer
*/
def addBuffer(
id: RapidsBufferId,
buffer: DeviceMemoryBuffer,
buffer: MemoryBuffer,
tableMeta: TableMeta,
initialSpillPriority: Long,
needsSync: Boolean): RapidsBufferHandle = synchronized {
val rapidsBuffer = deviceStorage.addBuffer(
id,
buffer,
tableMeta,
initialSpillPriority,
needsSync)
val rapidsBuffer = buffer match {
case gpuBuffer: DeviceMemoryBuffer =>
deviceStorage.addBuffer(
id,
gpuBuffer,
tableMeta,
initialSpillPriority,
needsSync)
case hostBuffer: HostMemoryBuffer =>
hostStorage.addBuffer(
id,
hostBuffer,
tableMeta,
initialSpillPriority,
needsSync)
case _ =>
throw new IllegalArgumentException(
s"Cannot call addBuffer with buffer $buffer")
}
registerNewBuffer(rapidsBuffer)
makeNewHandle(id, initialSpillPriority)
}
Expand Down Expand Up @@ -591,6 +605,8 @@ class RapidsBufferCatalog(
if (!bufferHasSpilled) {
// if the spillStore specifies a maximum size spill taking this ceiling
// into account before trying to create a buffer there
// TODO: we may need to handle what happens if we can't spill anymore
// because all host buffers are being referenced.
trySpillToMaximumSize(buffer, spillStore, stream)

// copy the buffer to spillStore
Expand Down Expand Up @@ -869,15 +885,15 @@ object RapidsBufferCatalog extends Logging {
}

/**
* Adds a buffer to the device storage. This does NOT take ownership of the
* Adds a buffer to the catalog and store. This does NOT take ownership of the
* buffer, so it is the responsibility of the caller to close it.
* @param buffer buffer that will be owned by the store
* @param tableMeta metadata describing the buffer layout
* @param initialSpillPriority starting spill priority value for the buffer
* @return RapidsBufferHandle associated with this buffer
*/
def addBuffer(
buffer: DeviceMemoryBuffer,
buffer: MemoryBuffer,
tableMeta: TableMeta,
initialSpillPriority: Long): RapidsBufferHandle = {
singleton.addBuffer(buffer, tableMeta, initialSpillPriority)
Expand All @@ -901,7 +917,7 @@ object RapidsBufferCatalog extends Logging {
def getDiskBlockManager(): RapidsDiskBlockManager = diskBlockManager

/**
* Given a `DeviceMemoryBuffer` find out if a `MemoryBuffer.EventHandler` is associated
* Given a `MemoryBuffer` find out if a `MemoryBuffer.EventHandler` is associated
* with it.
*
* After getting the `RapidsBuffer` try to acquire it via `addReference`.
Expand All @@ -910,7 +926,7 @@ object RapidsBufferCatalog extends Logging {
* are adding it again).
*
* @note public for testing
* @param buffer - the `DeviceMemoryBuffer` to inspect
* @param buffer - the `MemoryBuffer` to inspect
* @return - Some(RapidsBuffer): the handler is associated with a rapids buffer
* and the rapids buffer is currently valid, or
*
Expand All @@ -919,7 +935,7 @@ object RapidsBufferCatalog extends Logging {
* about to be removed).
*/
private def getExistingRapidsBufferAndAcquire(
buffer: DeviceMemoryBuffer): Option[RapidsBuffer] = {
buffer: MemoryBuffer): Option[RapidsBuffer] = {
val eh = buffer.getEventHandler
eh match {
case null =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.nvidia.spark.rapids

import java.util.Comparator
import java.util.concurrent.locks.ReentrantReadWriteLock

import scala.collection.mutable

Expand Down Expand Up @@ -233,6 +234,21 @@ abstract class RapidsBufferStore(val tier: StorageTier)
/** Update bookkeeping for a new buffer */
protected def addBuffer(buffer: RapidsBufferBase): Unit = {
buffers.add(buffer)
buffer.updateSpillability()
}

/**
* Adds a buffer to the spill framework, stream synchronizing with the producer
* stream to ensure that the buffer is fully materialized, and can be safely copied
* as part of the spill.
*
* @param needsSync true if we should stream synchronize before adding the buffer
*/
protected def addBuffer(buffer: RapidsBufferBase, needsSync: Boolean): Unit = {
if (needsSync) {
Cuda.DEFAULT_STREAM.sync()
}
addBuffer(buffer)
}

override def close(): Unit = {
Expand All @@ -258,6 +274,9 @@ abstract class RapidsBufferStore(val tier: StorageTier)

private[this] var spillPriority: Long = initialSpillPriority

private[this] val rwl: ReentrantReadWriteLock = new ReentrantReadWriteLock()


def meta: TableMeta = _meta

/** Release the underlying resources for this buffer. */
Expand Down Expand Up @@ -409,6 +428,30 @@ abstract class RapidsBufferStore(val tier: StorageTier)
spillPriority = priority
}

override def withMemoryBufferReadLock[K](body: MemoryBuffer => K): K = {
withResource(getMemoryBuffer) { buff =>
val lock = rwl.readLock()
try {
lock.lock()
body(buff)
} finally {
lock.unlock()
}
}
}

override def withMemoryBufferWriteLock[K](body: MemoryBuffer => K): K = {
withResource(getMemoryBuffer) { buff =>
val lock = rwl.writeLock()
try {
lock.lock()
body(buff)
} finally {
lock.unlock()
}
}
}

/** Must be called with a lock on the buffer */
private def freeBuffer(): Unit = {
releaseResources()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,27 +137,10 @@ class RapidsDeviceMemoryStore(chunkedPackBounceBufferSize: Long = 128L*1024*1024
initialSpillPriority)
freeOnExcept(rapidsTable) { _ =>
addBuffer(rapidsTable, needsSync)
rapidsTable.updateSpillability()
rapidsTable
}
}

/**
* Adds a device buffer to the spill framework, stream synchronizing with the producer
* stream to ensure that the buffer is fully materialized, and can be safely copied
* as part of the spill.
*
* @param needsSync true if we should stream synchronize before adding the buffer
*/
private def addBuffer(
buffer: RapidsBufferBase,
needsSync: Boolean): Unit = {
if (needsSync) {
Cuda.DEFAULT_STREAM.sync()
}
addBuffer(buffer)
}

/**
* The RapidsDeviceMemoryStore is the only store that supports setting a buffer spillable
* or not.
Expand Down Expand Up @@ -309,7 +292,7 @@ class RapidsDeviceMemoryStore(chunkedPackBounceBufferSize: Long = 128L*1024*1024
* - after adding a table to the store to mark the table as spillable if
* all columns are spillable.
*/
def updateSpillability(): Unit = {
override def updateSpillability(): Unit = {
doSetSpillable(this, columnSpillability.size == numDistinctColumns)
}

Expand Down
Loading

0 comments on commit f723dfc

Please sign in to comment.