From f723dfcaf07b9438f92389f8dc98f648610316b8 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Mon, 21 Aug 2023 09:30:52 -0500 Subject: [PATCH] Spillable host buffer (#9070) * Spillable host buffer --------- Signed-off-by: Alessandro Bellina --- .../com/nvidia/spark/rapids/MetaUtils.scala | 5 +- .../nvidia/spark/rapids/RapidsBuffer.scala | 34 ++++++ .../spark/rapids/RapidsBufferCatalog.scala | 56 ++++++---- .../spark/rapids/RapidsBufferStore.scala | 43 ++++++++ .../rapids/RapidsDeviceMemoryStore.scala | 19 +--- .../spark/rapids/RapidsHostMemoryStore.scala | 94 +++++++++++++++- .../spark/rapids/SpillableColumnarBatch.scala | 104 +++++++++++++++++- .../rapids/RapidsBufferCatalogSuite.scala | 5 +- .../rapids/RapidsHostMemoryStoreSuite.scala | 91 ++++++++++++++- .../rapids/SpillableColumnarBatchSuite.scala | 2 + 10 files changed, 399 insertions(+), 54 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala index 434f0aadbed..80acddcb257 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala @@ -155,12 +155,11 @@ object MetaUtils { } /** - * Constructs a table metadata buffer from a device buffer without describing any schema + * Constructs a table metadata buffer from a buffer length without describing any schema * for the buffer. */ - def getTableMetaNoTable(buffer: DeviceMemoryBuffer): TableMeta = { + def getTableMetaNoTable(bufferSize: Long): TableMeta = { val fbb = new FlatBufferBuilder(1024) - val bufferSize = buffer.getLength BufferMeta.startBufferMeta(fbb) BufferMeta.addId(fbb, 0) BufferMeta.addSize(fbb, bufferSize) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala index b53a6e3a6fb..2c6aa11df44 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala @@ -333,6 +333,32 @@ trait RapidsBuffer extends AutoCloseable { * @param priority new priority value for this buffer */ def setSpillPriority(priority: Long): Unit + + /** + * Function invoked by the `RapidsBufferStore.addBuffer` method that prompts + * the specific `RapidsBuffer` to check its reference counting to make itself + * spillable or not. Only `RapidsTable` and `RapidsHostMemoryBuffer` implement + * this method. + */ + def updateSpillability(): Unit = {} + + /** + * Obtains a read lock on this instance of `RapidsBuffer` and calls the function + * in `body` while holding the lock. + * @param body function that takes a `MemoryBuffer` and produces `K` + * @tparam K any return type specified by `body` + * @return the result of body(memoryBuffer) + */ + def withMemoryBufferReadLock[K](body: MemoryBuffer => K): K + + /** + * Obtains a write lock on this instance of `RapidsBuffer` and calls the function + * in `body` while holding the lock. + * @param body function that takes a `MemoryBuffer` and produces `K` + * @tparam K any return type specified by `body` + * @return the result of body(memoryBuffer) + */ + def withMemoryBufferWriteLock[K](body: MemoryBuffer => K): K } /** @@ -385,5 +411,13 @@ sealed class DegenerateRapidsBuffer( override def setSpillPriority(priority: Long): Unit = {} + override def withMemoryBufferReadLock[K](body: MemoryBuffer => K): K = { + throw new UnsupportedOperationException("degenerate buffer has no memory buffer") + } + + override def withMemoryBufferWriteLock[K](body: MemoryBuffer => K): K = { + throw new UnsupportedOperationException("degenerate buffer has no memory buffer") + } + override def close(): Unit = {} } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala index c7dce220cd2..9ee65503a58 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala @@ -21,7 +21,7 @@ import java.util.function.BiFunction import scala.collection.mutable.ArrayBuffer -import ai.rapids.cudf.{ContiguousTable, Cuda, DeviceMemoryBuffer, NvtxColor, NvtxRange, Rmm, Table} +import ai.rapids.cudf.{ContiguousTable, Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer, NvtxColor, NvtxRange, Rmm, Table} import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.RapidsBufferCatalog.getExistingRapidsBufferAndAcquire import com.nvidia.spark.rapids.RapidsPluginImplicits._ @@ -63,7 +63,8 @@ trait RapidsBufferHandle extends AutoCloseable { * `RapidsBufferCatalog.singleton` should be used instead. */ class RapidsBufferCatalog( - deviceStorage: RapidsDeviceMemoryStore = RapidsBufferCatalog.deviceStorage) + deviceStorage: RapidsDeviceMemoryStore = RapidsBufferCatalog.deviceStorage, + hostStorage: RapidsHostMemoryStore = RapidsBufferCatalog.hostStorage) extends AutoCloseable with Logging { /** Map of buffer IDs to buffers sorted by storage tier */ @@ -198,7 +199,7 @@ class RapidsBufferCatalog( } /** - * Adds a buffer to the device storage. This does NOT take ownership of the + * Adds a buffer to the catalog and store. This does NOT take ownership of the * buffer, so it is the responsibility of the caller to close it. * * This version of `addBuffer` should not be called from the shuffle catalogs @@ -212,7 +213,7 @@ class RapidsBufferCatalog( * @return RapidsBufferHandle handle for this buffer */ def addBuffer( - buffer: DeviceMemoryBuffer, + buffer: MemoryBuffer, tableMeta: TableMeta, initialSpillPriority: Long, needsSync: Boolean = true): RapidsBufferHandle = synchronized { @@ -294,29 +295,42 @@ class RapidsBufferCatalog( } /** - * Adds a buffer to the device storage. This does NOT take ownership of the - * buffer, so it is the responsibility of the caller to close it. + * Adds a buffer to either the device or host storage. This does NOT take + * ownership of the buffer, so it is the responsibility of the caller to close it. * * @param id the RapidsBufferId to use for this buffer - * @param buffer buffer that will be owned by the store + * @param buffer buffer that will be owned by the target store * @param tableMeta metadata describing the buffer layout * @param initialSpillPriority starting spill priority value for the buffer * @param needsSync whether the spill framework should stream synchronize while adding - * this device buffer (defaults to true) + * this buffer (defaults to true) * @return RapidsBufferHandle handle for this RapidsBuffer */ def addBuffer( id: RapidsBufferId, - buffer: DeviceMemoryBuffer, + buffer: MemoryBuffer, tableMeta: TableMeta, initialSpillPriority: Long, needsSync: Boolean): RapidsBufferHandle = synchronized { - val rapidsBuffer = deviceStorage.addBuffer( - id, - buffer, - tableMeta, - initialSpillPriority, - needsSync) + val rapidsBuffer = buffer match { + case gpuBuffer: DeviceMemoryBuffer => + deviceStorage.addBuffer( + id, + gpuBuffer, + tableMeta, + initialSpillPriority, + needsSync) + case hostBuffer: HostMemoryBuffer => + hostStorage.addBuffer( + id, + hostBuffer, + tableMeta, + initialSpillPriority, + needsSync) + case _ => + throw new IllegalArgumentException( + s"Cannot call addBuffer with buffer $buffer") + } registerNewBuffer(rapidsBuffer) makeNewHandle(id, initialSpillPriority) } @@ -591,6 +605,8 @@ class RapidsBufferCatalog( if (!bufferHasSpilled) { // if the spillStore specifies a maximum size spill taking this ceiling // into account before trying to create a buffer there + // TODO: we may need to handle what happens if we can't spill anymore + // because all host buffers are being referenced. trySpillToMaximumSize(buffer, spillStore, stream) // copy the buffer to spillStore @@ -869,7 +885,7 @@ object RapidsBufferCatalog extends Logging { } /** - * Adds a buffer to the device storage. This does NOT take ownership of the + * Adds a buffer to the catalog and store. This does NOT take ownership of the * buffer, so it is the responsibility of the caller to close it. * @param buffer buffer that will be owned by the store * @param tableMeta metadata describing the buffer layout @@ -877,7 +893,7 @@ object RapidsBufferCatalog extends Logging { * @return RapidsBufferHandle associated with this buffer */ def addBuffer( - buffer: DeviceMemoryBuffer, + buffer: MemoryBuffer, tableMeta: TableMeta, initialSpillPriority: Long): RapidsBufferHandle = { singleton.addBuffer(buffer, tableMeta, initialSpillPriority) @@ -901,7 +917,7 @@ object RapidsBufferCatalog extends Logging { def getDiskBlockManager(): RapidsDiskBlockManager = diskBlockManager /** - * Given a `DeviceMemoryBuffer` find out if a `MemoryBuffer.EventHandler` is associated + * Given a `MemoryBuffer` find out if a `MemoryBuffer.EventHandler` is associated * with it. * * After getting the `RapidsBuffer` try to acquire it via `addReference`. @@ -910,7 +926,7 @@ object RapidsBufferCatalog extends Logging { * are adding it again). * * @note public for testing - * @param buffer - the `DeviceMemoryBuffer` to inspect + * @param buffer - the `MemoryBuffer` to inspect * @return - Some(RapidsBuffer): the handler is associated with a rapids buffer * and the rapids buffer is currently valid, or * @@ -919,7 +935,7 @@ object RapidsBufferCatalog extends Logging { * about to be removed). */ private def getExistingRapidsBufferAndAcquire( - buffer: DeviceMemoryBuffer): Option[RapidsBuffer] = { + buffer: MemoryBuffer): Option[RapidsBuffer] = { val eh = buffer.getEventHandler eh match { case null => diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferStore.scala index 64ae09ae28a..cd89c3a5602 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferStore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferStore.scala @@ -17,6 +17,7 @@ package com.nvidia.spark.rapids import java.util.Comparator +import java.util.concurrent.locks.ReentrantReadWriteLock import scala.collection.mutable @@ -233,6 +234,21 @@ abstract class RapidsBufferStore(val tier: StorageTier) /** Update bookkeeping for a new buffer */ protected def addBuffer(buffer: RapidsBufferBase): Unit = { buffers.add(buffer) + buffer.updateSpillability() + } + + /** + * Adds a buffer to the spill framework, stream synchronizing with the producer + * stream to ensure that the buffer is fully materialized, and can be safely copied + * as part of the spill. + * + * @param needsSync true if we should stream synchronize before adding the buffer + */ + protected def addBuffer(buffer: RapidsBufferBase, needsSync: Boolean): Unit = { + if (needsSync) { + Cuda.DEFAULT_STREAM.sync() + } + addBuffer(buffer) } override def close(): Unit = { @@ -258,6 +274,9 @@ abstract class RapidsBufferStore(val tier: StorageTier) private[this] var spillPriority: Long = initialSpillPriority + private[this] val rwl: ReentrantReadWriteLock = new ReentrantReadWriteLock() + + def meta: TableMeta = _meta /** Release the underlying resources for this buffer. */ @@ -409,6 +428,30 @@ abstract class RapidsBufferStore(val tier: StorageTier) spillPriority = priority } + override def withMemoryBufferReadLock[K](body: MemoryBuffer => K): K = { + withResource(getMemoryBuffer) { buff => + val lock = rwl.readLock() + try { + lock.lock() + body(buff) + } finally { + lock.unlock() + } + } + } + + override def withMemoryBufferWriteLock[K](body: MemoryBuffer => K): K = { + withResource(getMemoryBuffer) { buff => + val lock = rwl.writeLock() + try { + lock.lock() + body(buff) + } finally { + lock.unlock() + } + } + } + /** Must be called with a lock on the buffer */ private def freeBuffer(): Unit = { releaseResources() diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala index 30a2c472101..336301784b3 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala @@ -137,27 +137,10 @@ class RapidsDeviceMemoryStore(chunkedPackBounceBufferSize: Long = 128L*1024*1024 initialSpillPriority) freeOnExcept(rapidsTable) { _ => addBuffer(rapidsTable, needsSync) - rapidsTable.updateSpillability() rapidsTable } } - /** - * Adds a device buffer to the spill framework, stream synchronizing with the producer - * stream to ensure that the buffer is fully materialized, and can be safely copied - * as part of the spill. - * - * @param needsSync true if we should stream synchronize before adding the buffer - */ - private def addBuffer( - buffer: RapidsBufferBase, - needsSync: Boolean): Unit = { - if (needsSync) { - Cuda.DEFAULT_STREAM.sync() - } - addBuffer(buffer) - } - /** * The RapidsDeviceMemoryStore is the only store that supports setting a buffer spillable * or not. @@ -309,7 +292,7 @@ class RapidsDeviceMemoryStore(chunkedPackBounceBufferSize: Long = 128L*1024*1024 * - after adding a table to the store to mark the table as spillable if * all columns are spillable. */ - def updateSpillability(): Unit = { + override def updateSpillability(): Unit = { doSetSpillable(this, columnSpillability.size == numDistinctColumns) } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala index 21604fd6b8b..986f7fa73e8 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala @@ -17,7 +17,7 @@ package com.nvidia.spark.rapids import ai.rapids.cudf.{Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer, NvtxColor, NvtxRange, PinnedMemoryPool} -import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} +import com.nvidia.spark.rapids.Arm.{closeOnExcept, freeOnExcept, withResource} import com.nvidia.spark.rapids.SpillPriorities.{applyPriorityOffset, HOST_MEMORY_BUFFER_SPILL_OFFSET} import com.nvidia.spark.rapids.StorageTier.StorageTier import com.nvidia.spark.rapids.format.TableMeta @@ -31,6 +31,12 @@ class RapidsHostMemoryStore( maxSize: Long) extends RapidsBufferStore(StorageTier.HOST) { + override def spillableOnAdd: Boolean = false + + override protected def setSpillable(buffer: RapidsBufferBase, spillable: Boolean): Unit = { + doSetSpillable(buffer, spillable) + } + override def getMaxSize: Option[Long] = Some(maxSize) private def allocateHostBuffer( @@ -47,6 +53,29 @@ class RapidsHostMemoryStore( HostMemoryBuffer.allocate(size, false) } + def addBuffer( + id: RapidsBufferId, + buffer: HostMemoryBuffer, + tableMeta: TableMeta, + initialSpillPriority: Long, + needsSync: Boolean): RapidsBuffer = { + buffer.incRefCount() + val rapidsBuffer = new RapidsHostMemoryBuffer( + id, + buffer.getLength, + tableMeta, + initialSpillPriority, + buffer) + freeOnExcept(rapidsBuffer) { _ => + logDebug(s"Adding host buffer for: [id=$id, size=${buffer.getLength}, " + + s"uncompressed=${rapidsBuffer.meta.bufferMeta.uncompressedSize}, " + + s"meta_id=${tableMeta.bufferMeta.id}, " + + s"meta_size=${tableMeta.bufferMeta.size}]") + addBuffer(rapidsBuffer, needsSync) + rapidsBuffer + } + } + override protected def createBuffer( other: RapidsBuffer, stream: Cuda.Stream): RapidsBufferBase = { @@ -95,13 +124,22 @@ class RapidsHostMemoryStore( meta: TableMeta, spillPriority: Long, buffer: HostMemoryBuffer) - extends RapidsBufferBase( - id, meta, spillPriority) { + extends RapidsBufferBase(id, meta, spillPriority) + with MemoryBuffer.EventHandler { override val storageTier: StorageTier = StorageTier.HOST - override def getMemoryBuffer: MemoryBuffer = { - buffer.incRefCount() - buffer + override def getMemoryBuffer: MemoryBuffer = synchronized { + buffer.synchronized { + setSpillable(this, false) + buffer.incRefCount() + buffer + } + } + + override def updateSpillability(): Unit = { + if (buffer.getRefCount == 1) { + setSpillable(this, true) + } } override protected def releaseResources(): Unit = { @@ -110,5 +148,49 @@ class RapidsHostMemoryStore( /** The size of this buffer in bytes. */ override def getMemoryUsedBytes: Long = size + + // If this require triggers, we are re-adding a `HostMemoryBuffer` outside of + // the catalog lock, which should not possible. The event handler is set to null + // when we free the `RapidsHostMemoryBuffer` and if the buffer is not free, we + // take out another handle (in the catalog). + // TODO: This is not robust (to rely on outside locking and addReference/free) + // and should be revisited. + require(buffer.setEventHandler(this) == null, + "HostMemoryBuffer with non-null event handler failed to add!!") + + /** + * Override from the MemoryBuffer.EventHandler interface. + * + * If we are being invoked we have the `buffer` lock, as this callback + * is being invoked from `MemoryBuffer.close` + * + * @param refCount - buffer's current refCount + */ + override def onClosed(refCount: Int): Unit = { + // refCount == 1 means only 1 reference exists to `buffer` in the + // RapidsHostMemoryBuffer (we own it) + if (refCount == 1) { + // setSpillable is being called here as an extension of `MemoryBuffer.close()` + // we hold the MemoryBuffer lock and we could be called from a Spark task thread + // Since we hold the MemoryBuffer lock, `incRefCount` waits for us. The only other + // call to `setSpillable` is also under this same MemoryBuffer lock (see: + // `getMemoryBuffer`) + setSpillable(this, true) + } + } + + /** + * We overwrite free to make sure we don't have a handler for the underlying + * buffer, since this `RapidsBuffer` is no longer tracked. + */ + override def free(): Unit = synchronized { + if (isValid) { + // it is going to be invalid when calling super.free() + buffer.setEventHandler(null) + } + super.free() + } } } + + diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala index 08997ea1612..3b76ad5fa83 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala @@ -16,8 +16,8 @@ package com.nvidia.spark.rapids -import ai.rapids.cudf.{ContiguousTable, DeviceMemoryBuffer} -import com.nvidia.spark.rapids.Arm.withResource +import ai.rapids.cudf.{ContiguousTable, DeviceMemoryBuffer, HostMemoryBuffer} +import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import org.apache.spark.TaskContext import org.apache.spark.sql.types.DataType @@ -210,7 +210,6 @@ object SpillableColumnarBatch { } - /** * Just like a SpillableColumnarBatch but for buffers. */ @@ -241,6 +240,72 @@ class SpillableBuffer( } } +/** + * This represents a spillable `HostMemoryBuffer` and adds an interface to access + * this host buffer at the host layer, unlike `SpillableBuffer` (device) + * @param handle an object used to refer to this buffer in the spill framework + * @param length a metadata-only length that is kept in the `SpillableHostBuffer` + * instance. Used in cases where the backing host buffer is larger + * than the number of usable bytes. + * @param catalog this was added for tests, it defaults to + * `RapidsBufferCatalog.singleton` in the companion object. + */ +class SpillableHostBuffer(handle: RapidsBufferHandle, + val length: Long, + catalog: RapidsBufferCatalog) extends AutoCloseable { + /** + * Set a new spill priority. + */ + def setSpillPriority(priority: Long): Unit = { + handle.setSpillPriority(priority) + } + + /** + * Remove the buffer from the cache. + */ + override def close(): Unit = { + handle.close() + } + + /** + * Acquires the underlying `RapidsBuffer` and uses + * `RapidsBuffer.withMemoryBufferReadLock` to obtain a read lock + * that will held while invoking `body` with a `HostMemoryBuffer`. + * @param body function that takes a `HostMemoryBuffer` and produces `K` + * @tparam K any return type specified by `body` + * @return the result of body(hostMemoryBuffer) + */ + def withHostBufferReadOnly[K](body: HostMemoryBuffer => K): K = { + withResource(catalog.acquireBuffer(handle)) { rapidsBuffer => + rapidsBuffer.withMemoryBufferReadLock { + case hmb: HostMemoryBuffer => body(hmb) + case memoryBuffer => + throw new IllegalStateException( + s"Expected a HostMemoryBuffer but instead got ${memoryBuffer}") + } + } + } + + /** + * Acquires the underlying `RapidsBuffer` and uses + * `RapidsBuffer.withMemoryBufferWriteLock` to obtain a write lock + * that will held while invoking `body` with a `HostMemoryBuffer`. + * @param body function that takes a `HostMemoryBuffer` and produces `K` + * @tparam K any return type specified by `body` + * @return the result of body(hostMemoryBuffer) + */ + def withHostBufferWriteLock[K](body: HostMemoryBuffer => K): K = { + withResource(catalog.acquireBuffer(handle)) { rapidsBuffer => + rapidsBuffer.withMemoryBufferWriteLock { + case hmb: HostMemoryBuffer => body(hmb) + case memoryBuffer => + throw new IllegalStateException( + s"Expected a HostMemoryBuffer but instead got ${memoryBuffer}") + } + } + } +} + object SpillableBuffer { /** @@ -249,12 +314,41 @@ object SpillableBuffer { * @param buffer the buffer to make spillable * @param priority the initial spill priority of this buffer */ - def apply(buffer: DeviceMemoryBuffer, + def apply( + buffer: DeviceMemoryBuffer, priority: Long): SpillableBuffer = { - val meta = MetaUtils.getTableMetaNoTable(buffer) + val meta = MetaUtils.getTableMetaNoTable(buffer.getLength) val handle = withResource(buffer) { _ => RapidsBufferCatalog.addBuffer(buffer, meta, priority) } new SpillableBuffer(handle) } } + +object SpillableHostBuffer { + + /** + * Create a new SpillableBuffer. + * @note This takes over ownership of buffer, and buffer should not be used after this. + * @param length the actual length of the data within the host buffer, which + * must be <= than buffer.getLength, otherwise this function throws + * and closes `buffer` + * @param buffer the buffer to make spillable + * @param priority the initial spill priority of this buffer + */ + def apply(buffer: HostMemoryBuffer, + length: Long, + priority: Long, + catalog: RapidsBufferCatalog = RapidsBufferCatalog.singleton): SpillableHostBuffer = { + closeOnExcept(buffer) { _ => + require(length <= buffer.getLength, + s"Attempted to add a host spillable with a length ${length} B which is " + + s"greater than the backing host buffer length ${buffer.getLength} B") + } + val meta = MetaUtils.getTableMetaNoTable(buffer.getLength) + val handle = withResource(buffer) { _ => + catalog.addBuffer(buffer, meta, priority) + } + new SpillableHostBuffer(handle, length, catalog) + } +} \ No newline at end of file diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala index 25fd915f44d..54827e12878 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala @@ -220,7 +220,7 @@ class RapidsBufferCatalogSuite extends AnyFunSuite with MockitoSugar { hostStore.setSpillStore(mockStore) val catalog = new RapidsBufferCatalog(deviceStore) val handle = withResource(DeviceMemoryBuffer.allocate(1024)) { buff => - val meta = MetaUtils.getTableMetaNoTable(buff) + val meta = MetaUtils.getTableMetaNoTable(buff.getLength) catalog.addBuffer( buff, meta, -1) } @@ -353,6 +353,9 @@ class RapidsBufferCatalogSuite extends AnyFunSuite with MockitoSugar { override def setSpillPriority(priority: Long): Unit = { currentPriority = priority } + + override def withMemoryBufferReadLock[K](body: MemoryBuffer => K): K = { body(null) } + override def withMemoryBufferWriteLock[K](body: MemoryBuffer => K): K = { body(null) } override def close(): Unit = {} }) } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala index a9a09ee0bbc..7a451d29bab 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala @@ -27,11 +27,11 @@ import org.mockito.Mockito.{never, spy, times, verify, when} import org.scalatest.funsuite.AnyFunSuite import org.scalatestplus.mockito.MockitoSugar +import org.apache.spark.SparkConf import org.apache.spark.sql.rapids.RapidsDiskBlockManager import org.apache.spark.sql.types.{DataType, DecimalType, DoubleType, IntegerType, LongType, StringType} import org.apache.spark.sql.vectorized.ColumnarBatch - class RapidsHostMemoryStoreSuite extends AnyFunSuite with MockitoSugar { private def buildContiguousTable(): ContiguousTable = { withResource(new Table.TestBuilder() @@ -162,6 +162,95 @@ class RapidsHostMemoryStoreSuite extends AnyFunSuite with MockitoSugar { } } + test("host buffer originated: get host memory buffer") { + val spillPriority = -10 + val hostStoreMaxSize = 1L * 1024 * 1024 + val mockStore = mock[RapidsDiskStore] + withResource(new RapidsHostMemoryStore(hostStoreMaxSize)) { hostStore => + withResource(new RapidsDeviceMemoryStore) { devStore => + val catalog = new RapidsBufferCatalog(devStore, hostStore) + devStore.setSpillStore(hostStore) + hostStore.setSpillStore(mockStore) + val hmb = HostMemoryBuffer.allocate(1L * 1024) + val spillableBuffer = + SpillableHostBuffer(hmb, hmb.getLength, spillPriority, catalog) + withResource(spillableBuffer) { _ => + // the refcount of 1 is the store + assertResult(1)(hmb.getRefCount) + spillableBuffer.withHostBufferReadOnly { memoryBuffer => + assertResult(hmb)(memoryBuffer) + assertResult(2)(memoryBuffer.getRefCount) + } + } + assertResult(0)(hmb.getRefCount) + } + } + } + + test("host buffer originated: get host memory buffer after spill") { + val spillPriority = -10 + val hostStoreMaxSize = 1L * 1024 * 1024 + val bm = new RapidsDiskBlockManager(new SparkConf()) + withResource(new RapidsDiskStore(bm)) { diskStore => + withResource(new RapidsHostMemoryStore(hostStoreMaxSize)) { hostStore => + withResource(new RapidsDeviceMemoryStore) { devStore => + val catalog = new RapidsBufferCatalog(devStore, hostStore) + devStore.setSpillStore(hostStore) + hostStore.setSpillStore(diskStore) + val hmb = HostMemoryBuffer.allocate(1L * 1024) + val spillableBuffer = SpillableHostBuffer( + hmb, + hmb.getLength, + spillPriority, + catalog) + assertResult(1)(hmb.getRefCount) + // we spill it + catalog.synchronousSpill(hostStore, 0) + withResource(spillableBuffer) { _ => + // the refcount of the original buffer is 0 because it spilled + assertResult(0)(hmb.getRefCount) + spillableBuffer.withHostBufferReadOnly { memoryBuffer => + assertResult(memoryBuffer.getLength)(hmb.getLength) + } + } + } + } + } + } + + test("host buffer originated: get host memory buffer OOM when unable to spill") { + val spillPriority = -10 + val hostStoreMaxSize = 1L * 1024 * 1024 + val bm = new RapidsDiskBlockManager(new SparkConf()) + withResource(new RapidsDiskStore(bm)) { diskStore => + withResource(new RapidsHostMemoryStore(hostStoreMaxSize)) { hostStore => + withResource(new RapidsDeviceMemoryStore) { devStore => + val catalog = new RapidsBufferCatalog(devStore, hostStore) + devStore.setSpillStore(hostStore) + hostStore.setSpillStore(diskStore) + val hmb = HostMemoryBuffer.allocate(1L * 1024) + val spillableBuffer = SpillableHostBuffer( + hmb, + hmb.getLength, + spillPriority, + catalog) + // spillable is 1K + assertResult(hmb.getLength)(hostStore.currentSpillableSize) + spillableBuffer.withHostBufferReadOnly { memoryBuffer => + // 0 because we have a reference to the memoryBuffer + assertResult(0)(hostStore.currentSpillableSize) + val spilled = catalog.synchronousSpill(hostStore, 0) + assertResult(0)(spilled.get) + } + assertResult(hmb.getLength)(hostStore.currentSpillableSize) + val spilled = catalog.synchronousSpill(hostStore, 0) + assertResult(1L * 1024)(spilled.get) + spillableBuffer.close() + } + } + } + } + test("buffer exceeds maximum size") { val sparkTypes = Array[DataType](LongType) val spillPriority = -10 diff --git a/tests/src/test/scala/org/apache/spark/sql/rapids/SpillableColumnarBatchSuite.scala b/tests/src/test/scala/org/apache/spark/sql/rapids/SpillableColumnarBatchSuite.scala index 952ef09f9aa..a5209e9bd0e 100644 --- a/tests/src/test/scala/org/apache/spark/sql/rapids/SpillableColumnarBatchSuite.scala +++ b/tests/src/test/scala/org/apache/spark/sql/rapids/SpillableColumnarBatchSuite.scala @@ -61,5 +61,7 @@ class SpillableColumnarBatchSuite extends AnyFunSuite { override def close(): Unit = {} override def getColumnarBatch( sparkTypes: Array[DataType]): ColumnarBatch = null + override def withMemoryBufferReadLock[K](body: MemoryBuffer => K): K = { body(null) } + override def withMemoryBufferWriteLock[K](body: MemoryBuffer => K): K = { body(null) } } }