diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandler.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandler.scala index a97ff235527..1e019770f08 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandler.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandler.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,6 +35,7 @@ import org.apache.spark.sql.rapids.execution.TrampolineUtil * depleting the device store */ class DeviceMemoryEventHandler( + catalog: RapidsBufferCatalog, store: RapidsDeviceMemoryStore, oomDumpDir: Option[String], isGdsSpillEnabled: Boolean, @@ -52,7 +53,7 @@ class DeviceMemoryEventHandler( /** * A small helper class that helps keep track of retry counts as we trigger - * synchronizes on a depleted store. + * synchronizes on a completely spilled store. */ class OOMRetryState { private var synchronizeAttempts = 0 @@ -82,19 +83,19 @@ class DeviceMemoryEventHandler( } /** - * We reset our counters if `storeSize` is non-zero (as we only track when the store is - * depleted), or `retryCount` is less than or equal to what we had previously - * recorded in `shouldTrySynchronizing`. We do this to detect that the new failures - * are for a separate allocation (we need to give this new attempt a new set of - * retries.) + * We reset our counters if `storeSpillableSize` is non-zero (as we only track when all + * spillable buffers are removed from the store), or `retryCount` is less than or equal + * to what we had previously recorded in `shouldTrySynchronizing`. + * We do this to detect that the new failures are for a separate allocation (we need to + * give this new attempt a new set of retries.) * - * For example, if an allocation fails and we deplete the store, `retryCountLastSynced` + * For example, if an allocation fails and we spill all eligible buffers, `retryCountLastSynced` * is set to the last `retryCount` sent to us by cuDF as we keep allowing retries * from cuDF. If we succeed, cuDF resets `retryCount`, and so the new count sent to us * must be <= than what we saw last, so we can reset our tracking. */ - def resetIfNeeded(retryCount: Int, storeSize: Long): Unit = { - if (storeSize != 0 || retryCount <= retryCountLastSynced) { + def resetIfNeeded(retryCount: Int, storeSpillableSize: Long): Unit = { + if (storeSpillableSize != 0 || retryCount <= retryCountLastSynced) { reset() } } @@ -108,15 +109,17 @@ class DeviceMemoryEventHandler( */ override def onAllocFailure(allocSize: Long, retryCount: Int): Boolean = { // check arguments for good measure - require(allocSize >= 0, + require(allocSize >= 0, s"onAllocFailure invoked with invalid allocSize $allocSize") - require(retryCount >= 0, + require(retryCount >= 0, s"onAllocFailure invoked with invalid retryCount $retryCount") try { withResource(new NvtxRange("onAllocFailure", NvtxColor.RED)) { _ => val storeSize = store.currentSize + val storeSpillableSize = store.currentSpillableSize + val attemptMsg = if (retryCount > 0) { s"Attempt ${retryCount}. " } else { @@ -124,12 +127,12 @@ class DeviceMemoryEventHandler( } val retryState = oomRetryState.get() - retryState.resetIfNeeded(retryCount, storeSize) + retryState.resetIfNeeded(retryCount, storeSpillableSize) logInfo(s"Device allocation of $allocSize bytes failed, device store has " + - s"$storeSize bytes. $attemptMsg" + + s"$storeSize total and $storeSpillableSize spillable bytes. $attemptMsg" + s"Total RMM allocated is ${Rmm.getTotalBytesAllocated} bytes. ") - if (storeSize == 0) { + if (storeSpillableSize == 0) { if (retryState.shouldTrySynchronizing(retryCount)) { Cuda.deviceSynchronize() logWarning(s"[RETRY ${retryState.getRetriesSoFar}] " + @@ -149,14 +152,17 @@ class DeviceMemoryEventHandler( false } } else { - val targetSize = Math.max(storeSize - allocSize, 0) + val targetSize = Math.max(storeSpillableSize - allocSize, 0) logDebug(s"Targeting device store size of $targetSize bytes") - val amountSpilled = store.synchronousSpill(targetSize) - logInfo(s"Spilled $amountSpilled bytes from the device store") - if (isGdsSpillEnabled) { - TrampolineUtil.incTaskMetricsDiskBytesSpilled(amountSpilled) - } else { - TrampolineUtil.incTaskMetricsMemoryBytesSpilled(amountSpilled) + val maybeAmountSpilled = + catalog.synchronousSpill(store, targetSize, Cuda.DEFAULT_STREAM) + maybeAmountSpilled.foreach { amountSpilled => + logInfo(s"Spilled $amountSpilled bytes from the device store") + if (isGdsSpillEnabled) { + TrampolineUtil.incTaskMetricsDiskBytesSpilled(amountSpilled) + } else { + TrampolineUtil.incTaskMetricsMemoryBytesSpilled(amountSpilled) + } } true } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala index 4dded36f265..0b86cea5272 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala @@ -19,14 +19,16 @@ package com.nvidia.spark.rapids import java.util.concurrent.ConcurrentHashMap import java.util.function.BiFunction -import ai.rapids.cudf.{ContiguousTable, DeviceMemoryBuffer, Rmm, Table} +import ai.rapids.cudf.{ContiguousTable, Cuda, DeviceMemoryBuffer, MemoryBuffer, NvtxColor, NvtxRange, Rmm} +import com.nvidia.spark.rapids.RapidsBufferCatalog.getExistingRapidsBufferAndAcquire import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.StorageTier.StorageTier import com.nvidia.spark.rapids.format.TableMeta import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.internal.Logging -import org.apache.spark.sql.rapids.RapidsDiskBlockManager +import org.apache.spark.sql.rapids.{RapidsDiskBlockManager, TempSpillBufferId} +import org.apache.spark.sql.rapids.execution.TrampolineUtil /** * Exception thrown when inserting a buffer into the catalog with a duplicate buffer ID @@ -55,7 +57,9 @@ trait RapidsBufferHandle extends AutoCloseable { * Catalog for lookup of buffers by ID. The constructor is only visible for testing, generally * `RapidsBufferCatalog.singleton` should be used instead. */ -class RapidsBufferCatalog extends AutoCloseable with Arm { +class RapidsBufferCatalog( + deviceStorage: RapidsDeviceMemoryStore = RapidsBufferCatalog.deviceStorage) + extends AutoCloseable with Arm with Logging { /** Map of buffer IDs to buffers sorted by storage tier */ private[this] val bufferMap = new ConcurrentHashMap[RapidsBufferId, Seq[RapidsBuffer]] @@ -64,6 +68,9 @@ class RapidsBufferCatalog extends AutoCloseable with Arm { private[this] val bufferIdToHandles = new ConcurrentHashMap[RapidsBufferId, Seq[RapidsBufferHandleImpl]]() + /** A counter used to skip a spill attempt if we detect a different thread has spilled */ + @volatile private[this] var spillCount: Integer = 0 + class RapidsBufferHandleImpl( override val id: RapidsBufferId, var priority: Long, @@ -202,6 +209,160 @@ class RapidsBufferCatalog extends AutoCloseable with Arm { } } + /** + * Adds a buffer to the device storage. This does NOT take ownership of the + * buffer, so it is the responsibility of the caller to close it. + * + * This version of `addBuffer` should not be called from the shuffle catalogs + * since they provide their own ids. + * + * @param buffer buffer that will be owned by the store + * @param tableMeta metadata describing the buffer layout + * @param initialSpillPriority starting spill priority value for the buffer + * @param spillCallback a callback when the buffer is spilled. This should be very light weight. + * It should never allocate GPU memory and really just be used for metrics. + * @param needsSync whether the spill framework should stream synchronize while adding + * this device buffer (defaults to true) + * @return RapidsBufferHandle handle for this buffer + */ + def addBuffer( + buffer: DeviceMemoryBuffer, + tableMeta: TableMeta, + initialSpillPriority: Long, + spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback, + needsSync: Boolean = true): RapidsBufferHandle = synchronized { + // first time we see `buffer` + val existing = getExistingRapidsBufferAndAcquire(buffer) + existing match { + case None => + addBuffer( + TempSpillBufferId(), + buffer, + tableMeta, + initialSpillPriority, + spillCallback, + needsSync) + case Some(rapidsBuffer) => + withResource(rapidsBuffer) { _ => + makeNewHandle(rapidsBuffer.id, initialSpillPriority, spillCallback) + } + } + } + + /** + * Adds a contiguous table to the device storage. This does NOT take ownership of the + * contiguous table, so it is the responsibility of the caller to close it. The refcount of the + * underlying device buffer will be incremented so the contiguous table can be closed before + * this buffer is destroyed. + * + * This version of `addContiguousTable` should not be called from the shuffle catalogs + * since they provide their own ids. + * + * @param contigTable contiguous table to track in storage + * @param initialSpillPriority starting spill priority value for the buffer + * @param spillCallback a callback when the buffer is spilled. This should be very light weight. + * It should never allocate GPU memory and really just be used for metrics. + * @param needsSync whether the spill framework should stream synchronize while adding + * this device buffer (defaults to true) + * @return RapidsBufferHandle handle for this table + */ + def addContiguousTable( + contigTable: ContiguousTable, + initialSpillPriority: Long, + spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback, + needsSync: Boolean = true): RapidsBufferHandle = synchronized { + val existing = getExistingRapidsBufferAndAcquire(contigTable.getBuffer) + existing match { + case None => + addContiguousTable( + TempSpillBufferId(), + contigTable, + initialSpillPriority, + spillCallback, + needsSync) + case Some(rapidsBuffer) => + withResource(rapidsBuffer) { _ => + makeNewHandle(rapidsBuffer.id, initialSpillPriority, spillCallback) + } + } + } + + /** + * Adds a contiguous table to the device storage. This does NOT take ownership of the + * contiguous table, so it is the responsibility of the caller to close it. The refcount of the + * underlying device buffer will be incremented so the contiguous table can be closed before + * this buffer is destroyed. + * + * @param id the RapidsBufferId to use for this buffer + * @param contigTable contiguous table to track in storage + * @param initialSpillPriority starting spill priority value for the buffer + * @param spillCallback a callback when the buffer is spilled. This should be very light weight. + * It should never allocate GPU memory and really just be used for metrics. + * @param needsSync whether the spill framework should stream synchronize while adding + * this device buffer (defaults to true) + * @return RapidsBufferHandle handle for this table + */ + def addContiguousTable( + id: RapidsBufferId, + contigTable: ContiguousTable, + initialSpillPriority: Long, + spillCallback: SpillCallback, + needsSync: Boolean): RapidsBufferHandle = synchronized { + addBuffer( + id, + contigTable.getBuffer, + MetaUtils.buildTableMeta(id.tableId, contigTable), + initialSpillPriority, + spillCallback, + needsSync) + } + + /** + * Adds a buffer to the device storage. This does NOT take ownership of the + * buffer, so it is the responsibility of the caller to close it. + * + * @param id the RapidsBufferId to use for this buffer + * @param buffer buffer that will be owned by the store + * @param tableMeta metadata describing the buffer layout + * @param initialSpillPriority starting spill priority value for the buffer + * @param spillCallback a callback when the buffer is spilled. This should be very light weight. + * It should never allocate GPU memory and really just be used for metrics. + * @param needsSync whether the spill framework should stream synchronize while adding + * this device buffer (defaults to true) + * @return RapidsBufferHandle handle for this RapidsBuffer + */ + def addBuffer( + id: RapidsBufferId, + buffer: DeviceMemoryBuffer, + tableMeta: TableMeta, + initialSpillPriority: Long, + spillCallback: SpillCallback, + needsSync: Boolean): RapidsBufferHandle = synchronized { + logDebug(s"Adding buffer ${id} to ${deviceStorage}") + val rapidsBuffer = deviceStorage.addBuffer( + id, + buffer, + tableMeta, + initialSpillPriority, + spillCallback, + needsSync) + registerNewBuffer(rapidsBuffer) + makeNewHandle(id, initialSpillPriority, spillCallback) + } + + /** + * Register a degenerate RapidsBufferId given a TableMeta + * @note this is called from the shuffle catalogs only + */ + def registerDegenerateBuffer( + bufferId: RapidsBufferId, + meta: TableMeta, + spillCallback: SpillCallback): RapidsBufferHandle = synchronized { + val buffer = new DegenerateRapidsBuffer(bufferId, meta) + registerNewBuffer(buffer) + makeNewHandle(buffer.id, buffer.getSpillPriority, spillCallback) + } + /** * Called by the catalog when a handle is first added to the catalog, or to refresh * the priority of the underlying buffer if a handle's priority changed. @@ -229,7 +390,8 @@ class RapidsBufferCatalog extends AutoCloseable with Arm { (0 until RapidsBufferCatalog.MAX_BUFFER_LOOKUP_ATTEMPTS).foreach { _ => val buffers = bufferMap.get(id) if (buffers == null || buffers.isEmpty) { - throw new NoSuchElementException(s"Cannot locate buffers associated with ID: $id") + throw new NoSuchElementException( + s"Cannot locate buffers associated with ID: $id") } val buffer = buffers.head if (buffer.addReference()) { @@ -264,6 +426,7 @@ class RapidsBufferCatalog extends AutoCloseable with Arm { * * @param id buffer identifier * @param tier storage tier to check + * @note public for testing * @return true if the buffer is stored in multiple tiers */ def isBufferSpilled(id: RapidsBufferId, tier: StorageTier): Boolean = { @@ -283,6 +446,7 @@ class RapidsBufferCatalog extends AutoCloseable with Arm { /** * Register a new buffer with the catalog. An exception will be thrown if an * existing buffer was registered with the same buffer ID and storage tier. + * @note public for testing */ def registerNewBuffer(buffer: RapidsBuffer): Unit = { val updater = new BiFunction[RapidsBufferId, Seq[RapidsBuffer], Seq[RapidsBuffer]] { @@ -304,8 +468,173 @@ class RapidsBufferCatalog extends AutoCloseable with Arm { bufferMap.compute(buffer.id, updater) } - /** Remove a buffer ID from the catalog at the specified storage tier. */ - def removeBufferTier(id: RapidsBufferId, tier: StorageTier): Unit = { + /** + * Free memory in `store` by spilling buffers to the spill store synchronously. + * @param store store to spill from + * @param targetTotalSize maximum total size of this store after spilling completes + * @param stream CUDA stream to use or null for default stream + * @return optionally number of bytes that were spilled, or None if this called + * made no attempt to spill due to a detected spill race + */ + def synchronousSpill( + store: RapidsBufferStore, + targetTotalSize: Long, + stream: Cuda.Stream = Cuda.DEFAULT_STREAM): Option[Long] = { + val spillStore = store.spillStore + if (spillStore == null) { + throw new OutOfMemoryError("Requested to spill without a spill store") + } + require(targetTotalSize >= 0, s"Negative spill target size: $targetTotalSize") + logWarning(s"Targeting a ${store.name} size of $targetTotalSize. " + + s"Current total ${store.currentSize}. " + + s"Current spillable ${store.currentSpillableSize}") + + // we try to spill in this thread. If another thread is also spilling, we let that + // thread win and we return letting RMM retry the alloc + var rmmShouldRetryAlloc = false + + // total amount spilled in this invocation + var totalSpilled: Long = 0 + + if (store.currentSpillableSize > targetTotalSize) { + withResource(new NvtxRange(s"${store.name} sync spill", NvtxColor.ORANGE)) { _ => + logWarning(s"${store.name} store spilling to reduce usage from " + + s"${store.currentSize} total (${store.currentSpillableSize} spillable) " + + s"to $targetTotalSize bytes") + + // If the store has 0 spillable bytes left, it has exhausted. + var exhausted = false + + while (!exhausted && !rmmShouldRetryAlloc && + store.currentSpillableSize > targetTotalSize) { + val mySpillCount = spillCount + synchronized { + if (spillCount == mySpillCount) { + spillCount += 1 + val nextSpillable = store.nextSpillable() + if (nextSpillable != null) { + // we have a buffer (nextSpillable) to spill + spillAndFreeBuffer(nextSpillable, spillStore, stream) + totalSpilled += nextSpillable.size + } + } else { + rmmShouldRetryAlloc = true + } + } + if (!rmmShouldRetryAlloc && totalSpilled <= 0) { + // we didn't spill in this iteration, exit loop + exhausted = true + logWarning("Unable to spill enough to meet request. " + + s"Total=${store.currentSize} " + + s"Spillable=${store.currentSpillableSize} " + + s"Target=$targetTotalSize") + } + } + } + } + + if (rmmShouldRetryAlloc) { + // if we are going to retry, and didn't spill, returning None prevents extra + // logs where we say we spilled 0 bytes from X store + None + } else { + Some(totalSpilled) + } + } + + /** + * Given a specific `RapidsBuffer` spill it to `spillStore` + * @note called with catalog lock held + */ + private def spillAndFreeBuffer( + buffer: RapidsBuffer, + spillStore: RapidsBufferStore, + stream: Cuda.Stream): Unit = { + if (buffer.addReference()) { + withResource(buffer) { _ => + logDebug(s"Spilling $buffer ${buffer.id} to ${spillStore.name}") + val bufferHasSpilled = isBufferSpilled(buffer.id, buffer.storageTier) + if (!bufferHasSpilled) { + val spillCallback = buffer.getSpillCallback + spillCallback(buffer.storageTier, spillStore.tier, buffer.size) + + // if the spillStore specifies a maximum size spill taking this ceiling + // into account before trying to create a buffer there + trySpillToMaximumSize(buffer, spillStore, stream) + + // copy the buffer to spillStore + val newBuffer = spillStore.copyBuffer(buffer, buffer.getMemoryBuffer, stream) + + // once spilled, we get back a new RapidsBuffer instance in this new tier + registerNewBuffer(newBuffer) + } else { + logDebug(s"Skipping spilling $buffer ${buffer.id} to ${spillStore.name} as it is " + + s"already stored in multiple tiers") + } + } + // we can now remove the old tier linkage + removeBufferTier(buffer.id, buffer.storageTier) + // and free + buffer.safeFree() + } + } + + /** + * If `spillStore` defines a maximum size, spill to make room for `buffer`. + */ + private def trySpillToMaximumSize( + buffer: RapidsBuffer, + spillStore: RapidsBufferStore, + stream: Cuda.Stream): Unit = { + val spillStoreMaxSize = spillStore.getMaxSize + if (spillStoreMaxSize.isDefined) { + // this spillStore has a maximum size requirement (host only). We need to spill from it + // in order to make room for `buffer`. + val targetTotalSize = + math.max(spillStoreMaxSize.get - buffer.size, 0) + val maybeAmountSpilled = synchronousSpill(spillStore, targetTotalSize, stream) + maybeAmountSpilled.foreach { amountSpilled => + if (amountSpilled != 0) { + logInfo(s"Spilled $amountSpilled bytes from the ${spillStore.name} store") + TrampolineUtil.incTaskMetricsDiskBytesSpilled(amountSpilled) + } + } + } + } + + /** + * Copies `buffer` to the `deviceStorage` store, registering a new `RapidsBuffer` in + * the process + * @param buffer - buffer to copy + * @param memoryBuffer - cuDF MemoryBuffer to copy from + * @param stream - Cuda.Stream to synchronize on + * @return - The `RapidsBuffer` instance that was added to the device store. + */ + def unspillBufferToDeviceStore( + buffer: RapidsBuffer, + memoryBuffer: MemoryBuffer, + stream: Cuda.Stream): RapidsBuffer = synchronized { + // try to acquire the buffer, if it's already in the store + // do not create a new one, else add a reference + acquireBuffer(buffer.id, StorageTier.DEVICE) match { + case None => + val newBuffer = deviceStorage.copyBuffer( + buffer, + memoryBuffer, + stream) + newBuffer.addReference() // add a reference since we are about to use it + registerNewBuffer(newBuffer) + newBuffer + case Some(existingBuffer) => + existingBuffer + } + } + + /** + * Remove a buffer ID from the catalog at the specified storage tier. + * @note public for testing + */ + def removeBufferTier(id: RapidsBufferId, tier: StorageTier): Unit = synchronized { val updater = new BiFunction[RapidsBufferId, Seq[RapidsBuffer], Seq[RapidsBuffer]] { override def apply(key: RapidsBufferId, value: Seq[RapidsBuffer]): Seq[RapidsBuffer] = { val updated = value.filter(_.storageTier != tier) @@ -327,11 +656,11 @@ class RapidsBufferCatalog extends AutoCloseable with Arm { * (`handle` was the last handle) * false: if buffer was not removed due to other live handles. */ - private def removeBuffer(handle: RapidsBufferHandle): Boolean = { + private def removeBuffer(handle: RapidsBufferHandle): Boolean = synchronized { // if this is the last handle, remove the buffer if (stopTrackingHandle(handle)) { - val buffers = bufferMap.remove(handle.id) - buffers.safeFree() + logDebug(s"Removing buffer ${handle.id}") + bufferMap.remove(handle.id).safeFree() true } else { false @@ -350,9 +679,9 @@ class RapidsBufferCatalog extends AutoCloseable with Arm { } object RapidsBufferCatalog extends Logging with Arm { + private val MAX_BUFFER_LOOKUP_ATTEMPTS = 100 - val singleton = new RapidsBufferCatalog private var deviceStorage: RapidsDeviceMemoryStore = _ private var hostStorage: RapidsHostMemoryStore = _ private var diskBlockManager: RapidsDiskBlockManager = _ @@ -360,6 +689,16 @@ object RapidsBufferCatalog extends Logging with Arm { private var gdsStorage: RapidsGdsStore = _ private var memoryEventHandler: DeviceMemoryEventHandler = _ private var _shouldUnspill: Boolean = _ + private var _singleton: RapidsBufferCatalog = null + + def singleton: RapidsBufferCatalog = { + if (_singleton == null) { + synchronized { + _singleton = new RapidsBufferCatalog(deviceStorage) + } + } + _singleton + } private lazy val conf: SparkConf = { val env = SparkEnv.get @@ -399,6 +738,7 @@ object RapidsBufferCatalog extends Logging with Arm { logInfo("Installing GPU memory handler for spill") memoryEventHandler = new DeviceMemoryEventHandler( + singleton, deviceStorage, rapidsConf.gpuOomDumpDir, rapidsConf.isGdsSpillEnabled, @@ -413,8 +753,11 @@ object RapidsBufferCatalog extends Logging with Arm { closeImpl() } - private def closeImpl(): Unit = { - singleton.close() + private def closeImpl(): Unit = synchronized { + if (_singleton != null) { + _singleton.close() + _singleton = null + } if (memoryEventHandler != null) { // Workaround for shutdown ordering problems where device buffers allocated with this handler @@ -445,25 +788,6 @@ object RapidsBufferCatalog extends Logging with Arm { def shouldUnspill: Boolean = _shouldUnspill - /** - * Adds a contiguous table to the device storage, taking ownership of the table. - * @param table cudf table based from the contiguous buffer - * @param contigBuffer device memory buffer backing the table - * @param tableMeta metadata describing the buffer layout - * @param initialSpillPriority starting spill priority value for the buffer - * @param spillCallback a callback when the buffer is spilled. This should be very light weight. - * It should never allocate GPU memory and really just be used for metrics. - * @return RapidsBufferHandle associated with this buffer - */ - def addTable( - table: Table, - contigBuffer: DeviceMemoryBuffer, - tableMeta: TableMeta, - initialSpillPriority: Long, - spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback): RapidsBufferHandle = { - deviceStorage.addTable(table, contigBuffer, tableMeta, initialSpillPriority) - } - /** * Adds a contiguous table to the device storage. This does NOT take ownership of the * contiguous table, so it is the responsibility of the caller to close it. The refcount of the @@ -479,7 +803,7 @@ object RapidsBufferCatalog extends Logging with Arm { contigTable: ContiguousTable, initialSpillPriority: Long, spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback): RapidsBufferHandle = { - deviceStorage.addContiguousTable(contigTable, initialSpillPriority, spillCallback) + singleton.addContiguousTable(contigTable, initialSpillPriority, spillCallback) } /** @@ -497,7 +821,7 @@ object RapidsBufferCatalog extends Logging with Arm { tableMeta: TableMeta, initialSpillPriority: Long, spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback): RapidsBufferHandle = { - deviceStorage.addBuffer(buffer, tableMeta, initialSpillPriority, spillCallback) + singleton.addBuffer(buffer, tableMeta, initialSpillPriority, spillCallback) } /** @@ -510,4 +834,40 @@ object RapidsBufferCatalog extends Logging with Arm { singleton.acquireBuffer(handle) def getDiskBlockManager(): RapidsDiskBlockManager = diskBlockManager + + /** + * Given a `DeviceMemoryBuffer` find out if a `MemoryBuffer.EventHandler` is associated + * with it. + * + * After getting the `RapidsBuffer` try to acquire it via `addReference`. + * If successful, we can point to this buffer with a new handle, otherwise the buffer is + * about to be removed/freed (unlikely, because we are holding onto the reference as we + * are adding it again). + * + * @note public for testing + * @param buffer - the `DeviceMemoryBuffer` to inspect + * @return - Some(RapidsBuffer): the handler is associated with a rapids buffer + * and the rapids buffer is currently valid, or + * + * - None: if no `RapidsBuffer` is associated with this buffer (it is + * brand new to the store, or the `RapidsBuffer` is invalid and + * about to be removed). + */ + private def getExistingRapidsBufferAndAcquire( + buffer: DeviceMemoryBuffer): Option[RapidsBuffer] = { + val eh = buffer.getEventHandler + eh match { + case null => + None + case rapidsBuffer: RapidsBuffer => + if (rapidsBuffer.addReference()) { + Some(rapidsBuffer) + } else { + None + } + case _ => + throw new IllegalStateException("Unknown event handler") + } + } } + diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferStore.scala index 91e58983d5a..9bd86632d56 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferStore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferStore.scala @@ -17,10 +17,10 @@ package com.nvidia.spark.rapids import java.util.Comparator -import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.atomic.AtomicLong -import ai.rapids.cudf.{BaseDeviceMemoryBuffer, Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer, NvtxColor, NvtxRange} +import scala.collection.mutable + +import ai.rapids.cudf.{BaseDeviceMemoryBuffer, Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer} import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.StorageTier.{DEVICE, StorageTier} import com.nvidia.spark.rapids.format.TableMeta @@ -29,19 +29,13 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.ColumnarBatch -object RapidsBufferStore { - private val FREE_WAIT_TIMEOUT = 10 * 1000 -} - /** * Base class for all buffer store types. * * @param tier storage tier of this store * @param catalog catalog to register this store */ -abstract class RapidsBufferStore( - val tier: StorageTier, - catalog: RapidsBufferCatalog = RapidsBufferCatalog.singleton) +abstract class RapidsBufferStore(val tier: StorageTier) extends AutoCloseable with Logging with Arm { val name: String = tier.toString @@ -50,24 +44,47 @@ abstract class RapidsBufferStore( private[this] val comparator: Comparator[RapidsBufferBase] = (o1: RapidsBufferBase, o2: RapidsBufferBase) => java.lang.Long.compare(o1.getSpillPriority, o2.getSpillPriority) + // buffers: contains all buffers in this store, whether spillable or not private[this] val buffers = new java.util.HashMap[RapidsBufferId, RapidsBufferBase] + // spillable: contains only those buffers that are currently spillable private[this] val spillable = new HashedPriorityQueue[RapidsBufferBase](comparator) + // spilling: contains only those buffers that are currently being spilled, but + // have not been removed from the store + private[this] val spilling = new mutable.HashSet[RapidsBufferId]() + // total bytes stored, regardless of spillable status private[this] var totalBytesStored: Long = 0L + // total bytes that are currently eligible to be spilled + private[this] var totalBytesSpillable: Long = 0L def add(buffer: RapidsBufferBase): Unit = synchronized { val old = buffers.put(buffer.id, buffer) + // it is unlikely that the buffer was in this collection, but removing + // anyway. We assume the buffer is safe in this tier, and is not spilling + spilling.remove(buffer.id) if (old != null) { throw new DuplicateBufferException(s"duplicate buffer registered: ${buffer.id}") } - spillable.offer(buffer) totalBytesStored += buffer.size + + // device buffers "spillability" is handled via DeviceMemoryBuffer ref counting + // so spillableOnAdd should be false, all other buffer tiers are spillable at + // all times. + if (spillableOnAdd) { + if (spillable.offer(buffer)) { + totalBytesSpillable += buffer.size + } + } } def remove(id: RapidsBufferId): Unit = synchronized { + // when removing a buffer we no longer need to know if it was spilling + spilling.remove(id) val obj = buffers.remove(id) if (obj != null) { - spillable.remove(obj) totalBytesStored -= obj.size + if (spillable.remove(obj)) { + totalBytesSpillable -= obj.size + } } } @@ -76,6 +93,7 @@ abstract class RapidsBufferStore( val buffs = buffers.values().toArray(new Array[RapidsBufferBase](0)) buffers.clear() spillable.clear() + spilling.clear() buffs } // We need to release the `RapidsBufferStore` lock to prevent a lock order inversion @@ -84,8 +102,45 @@ abstract class RapidsBufferStore( values.safeFree() } + /** + * Sets a buffers state to spillable or non-spillable. + * + * If the buffer is currently being spilled or it is no longer in the `buffers` collection + * (e.g. it is not in this store), the action is skipped. + * + * @param buffer the buffer to mark as spillable or not + * @param isSpillable whether the buffer should now be spillable + */ + def setSpillable(buffer: RapidsBufferBase, isSpillable: Boolean): Unit = synchronized { + if (isSpillable) { + // if this buffer is in the store and isn't currently spilling + if (!spilling.contains(buffer.id) && buffers.containsKey(buffer.id)) { + // try to add it to the spillable collection + if (spillable.offer(buffer)) { + totalBytesSpillable += buffer.size + logDebug(s"Buffer ${buffer.id} is spillable. " + + s"total=${totalBytesStored} spillable=${totalBytesSpillable}") + } // else it was already there (unlikely) + } + } else { + if (spillable.remove(buffer)) { + totalBytesSpillable -= buffer.size + logDebug(s"Buffer ${buffer.id} is not spillable. " + + s"total=${totalBytesStored}, spillable=${totalBytesSpillable}") + } // else it was already removed + } + } + def nextSpillableBuffer(): RapidsBufferBase = synchronized { - spillable.poll() + val buffer = spillable.poll() + if (buffer != null) { + // mark the id as "spilling" (this buffer is in the middle of a spill operation) + spilling.add(buffer.id) + totalBytesSpillable -= buffer.size + logDebug(s"Spilling buffer ${buffer.id}. size=${buffer.size} " + + s"total=${totalBytesStored}, new spillable=${totalBytesSpillable}") + } + buffer } def updateSpillPriority(buffer: RapidsBufferBase, priority:Long): Unit = synchronized { @@ -94,26 +149,34 @@ abstract class RapidsBufferStore( } def getTotalBytes: Long = synchronized { totalBytesStored } + + def getTotalSpillableBytes: Long = synchronized { totalBytesSpillable } } - private[this] val pendingFreeBytes = new AtomicLong(0L) + /** + * Stores that need to stay within a specific byte limit of buffers stored override + * this function. Only the `HostMemoryBufferStore` requires such a limit. + * @return maximum amount of bytes that can be stored in the store, None for no + * limit + */ + def getMaxSize: Option[Long] = None private[this] val buffers = new BufferTracker - /** Tracks buffers that are waiting on outstanding references to be freed. */ - private[this] val pendingFreeBuffers = new ConcurrentHashMap[RapidsBufferId, RapidsBufferBase] - - /** A monitor that can be used to wait for memory to be freed from this store. */ - protected[this] val memoryFreedMonitor = new Object - /** A store that can be used for spilling. */ - private[this] var spillStore: RapidsBufferStore = _ - - private[this] val nvtxSyncSpillName: String = name + " sync spill" + var spillStore: RapidsBufferStore = _ /** Return the current byte total of buffers in this store. */ def currentSize: Long = buffers.getTotalBytes + def currentSpillableSize: Long = buffers.getTotalSpillableBytes + + /** + * A store that manages spillability of buffers should override this method + * to false, otherwise `BufferTracker` treats buffers as always spillable. + */ + protected def spillableOnAdd: Boolean = true + /** * Specify another store that can be used when this store needs to spill. * @note Only one spill store can be registered. This will throw if a @@ -134,79 +197,31 @@ abstract class RapidsBufferStore( * for `memoryBuffer` is transferred to this store. The store may close * `memoryBuffer` if necessary. * @param stream CUDA stream to use for copy or null - * @return new buffer that was created + * @return the new buffer that was created */ - def copyBuffer(buffer: RapidsBuffer, memoryBuffer: MemoryBuffer, stream: Cuda.Stream) - : RapidsBufferBase = { + def copyBuffer( + buffer: RapidsBuffer, + memoryBuffer: MemoryBuffer, + stream: Cuda.Stream): RapidsBufferBase = { freeOnExcept(createBuffer(buffer, memoryBuffer, stream)) { newBuffer => addBuffer(newBuffer) newBuffer } } - /** - * Free memory in this store by spilling buffers to the spill store synchronously. - * @param targetTotalSize maximum total size of this store after spilling completes - * @return number of bytes that were spilled - */ - def synchronousSpill(targetTotalSize: Long): Long = - synchronousSpill(targetTotalSize, Cuda.DEFAULT_STREAM) - - /** - * Free memory in this store by spilling buffers to the spill store synchronously. - * @param targetTotalSize maximum total size of this store after spilling completes - * @param stream CUDA stream to use or null for default stream - * @return number of bytes that were spilled - */ - def synchronousSpill(targetTotalSize: Long, stream: Cuda.Stream): Long = { - require(targetTotalSize >= 0, s"Negative spill target size: $targetTotalSize") - - var totalSpilled: Long = 0 - if (buffers.getTotalBytes > targetTotalSize) { - val nvtx = new NvtxRange(nvtxSyncSpillName, NvtxColor.ORANGE) - try { - logDebug(s"$name store spilling to reduce usage from " + - s"${buffers.getTotalBytes} to $targetTotalSize bytes") - var waited = false - var exhausted = false - while (!exhausted && buffers.getTotalBytes > targetTotalSize) { - val amountSpilled = trySpillAndFreeBuffer(stream) - if (amountSpilled != 0) { - totalSpilled += amountSpilled - waited = false - } else { - if (!waited && pendingFreeBytes.get > 0) { - waited = true - logWarning(s"Cannot spill further, waiting for ${pendingFreeBytes.get} " + - " bytes of pending buffers to be released") - memoryFreedMonitor.synchronized { - val memNeeded = buffers.getTotalBytes - targetTotalSize - if (memNeeded > 0 && memNeeded <= pendingFreeBytes.get) { - // This could be a futile wait if the thread(s) holding the pending buffers open - // are here waiting for more memory. - memoryFreedMonitor.wait(RapidsBufferStore.FREE_WAIT_TIMEOUT) - } - } - } else { - logWarning("Unable to spill enough to meet request. " + - s"Total=${buffers.getTotalBytes} Target=$targetTotalSize") - exhausted = true - } - } - } - logDebug(s"$this spill complete") - } finally { - nvtx.close() - } - } + protected def doSetSpillable(buffer: RapidsBufferBase, isSpillable: Boolean): Unit = { + buffers.setSpillable(buffer, isSpillable) + } - totalSpilled + protected def setSpillable(buffer: RapidsBufferBase, isSpillable: Boolean): Unit = { + throw new NotImplementedError(s"This store ${this} does not implement setSpillable") } /** * Create a new buffer from an existing buffer in another store. * If the data transfer will be performed asynchronously, this method is responsible for * adding a reference to the existing buffer and later closing it when the transfer completes. + * * @note DO NOT close the buffer unless adding a reference! * @note `createBuffer` impls should synchronize against `stream` before returning, if needed. * @param buffer data from another store @@ -214,56 +229,24 @@ abstract class RapidsBufferStore( * for `memoryBuffer` is transferred to this store. The store may close * `memoryBuffer` if necessary. * @param stream CUDA stream to use or null - * @return new buffer tracking the data in this store + * @return the new buffer that was created. */ - protected def createBuffer(buffer: RapidsBuffer, memoryBuffer: MemoryBuffer, stream: Cuda.Stream) - : RapidsBufferBase + protected def createBuffer( + buffer: RapidsBuffer, + memoryBuffer: MemoryBuffer, + stream: Cuda.Stream): RapidsBufferBase /** Update bookkeeping for a new buffer */ - protected def addBuffer(buffer: RapidsBufferBase): Unit = synchronized { + protected def addBuffer(buffer: RapidsBufferBase): Unit = { buffers.add(buffer) - catalog.registerNewBuffer(buffer) } override def close(): Unit = { buffers.freeAll() } - private def trySpillAndFreeBuffer(stream: Cuda.Stream): Long = synchronized { - val bufferToSpill = buffers.nextSpillableBuffer() - if (bufferToSpill != null) { - spillAndFreeBuffer(bufferToSpill, stream) - bufferToSpill.size - } else { - 0 - } - } - - private def spillAndFreeBuffer(buffer: RapidsBufferBase, stream: Cuda.Stream): Unit = { - if (spillStore == null) { - throw new OutOfMemoryError("Requested to spill without a spill store") - } - // If we fail to get a reference then this buffer has since been freed and probably best - // to return back to the outer loop to see if enough has been freed. - if (buffer.addReference()) { - try { - if (catalog.isBufferSpilled(buffer.id, buffer.storageTier)) { - logDebug(s"Skipping spilling $buffer ${buffer.id} to ${spillStore.name} as it is " + - s"already stored in multiple tiers total mem=${buffers.getTotalBytes}") - catalog.removeBufferTier(buffer.id, buffer.storageTier) - } else { - logDebug(s"Spilling $buffer ${buffer.id} to ${spillStore.name} " + - s"total mem=${buffers.getTotalBytes}") - val spillCallback = buffer.getSpillCallback - spillCallback(buffer.storageTier, spillStore.tier, buffer.size) - spillStore.copyBuffer(buffer, buffer.getMemoryBuffer, stream) - } - } finally { - buffer.close() - } - catalog.removeBufferTier(buffer.id, buffer.storageTier) - buffer.free() - } + def nextSpillable(): RapidsBuffer = { + buffers.nextSpillableBuffer() } /** Base class for all buffers in this store. */ @@ -273,11 +256,12 @@ abstract class RapidsBufferStore( override val meta: TableMeta, initialSpillPriority: Long, initialSpillCallback: SpillCallback, - catalog: RapidsBufferCatalog = RapidsBufferCatalog.singleton, - deviceStorage: RapidsDeviceMemoryStore = RapidsBufferCatalog.getDeviceStorage) + catalog: RapidsBufferCatalog = RapidsBufferCatalog.singleton) extends RapidsBuffer with Arm { private val MAX_UNSPILL_ATTEMPTS = 100 - private[this] var isValid = true + + // isValid and refcount must be used with the `RapidsBufferBase` lock held + protected[this] var isValid = true protected[this] var refcount = 0 private[this] var spillPriority: Long = initialSpillPriority @@ -299,17 +283,6 @@ abstract class RapidsBufferStore( */ protected def materializeMemoryBuffer: MemoryBuffer = getMemoryBuffer - /** - * Determine if a buffer is currently acquired. - * @note Unless this is called by the thread that currently "owns" an - * acquired buffer, the acquisition state could be changing - * asynchronously, and therefore the result cannot always be used as a - * proxy for the result obtained from the addReference method. - */ - def isAcquired: Boolean = synchronized { - refcount > 0 - } - override def addReference(): Boolean = synchronized { if (isValid) { refcount += 1 @@ -352,6 +325,11 @@ abstract class RapidsBufferStore( } } + /** + * TODO: we want to remove this method from the buffer, instead we want the catalog + * to be responsible for producing the DeviceMemoryBuffer by asking the buffer. This + * hides the RapidsBuffer from clients and simplifies locking. + */ override def getDeviceMemoryBuffer: DeviceMemoryBuffer = { if (RapidsBufferCatalog.shouldUnspill) { (0 until MAX_UNSPILL_ATTEMPTS).foreach { _ => @@ -363,12 +341,12 @@ abstract class RapidsBufferStore( case _ => try { logDebug(s"Unspilling $this $id to $DEVICE") - val newBuffer = deviceStorage.copyBuffer( - this, materializeMemoryBuffer, Cuda.DEFAULT_STREAM) - if (newBuffer.addReference()) { - withResource(newBuffer) { _ => - return newBuffer.getDeviceMemoryBuffer - } + val newBuffer = catalog.unspillBufferToDeviceStore( + this, + materializeMemoryBuffer, + Cuda.DEFAULT_STREAM) + withResource(newBuffer) { _ => + return newBuffer.getDeviceMemoryBuffer } } catch { case _: DuplicateBufferException => @@ -393,20 +371,25 @@ abstract class RapidsBufferStore( } } + /** + * close() is called by client code to decrease the ref count of this RapidsBufferBase. + * In the off chance that by the time close is invoked, the buffer was freed (not valid) + * then this close call winds up freeing the resources of the rapids buffer. + */ override def close(): Unit = synchronized { if (refcount == 0) { throw new IllegalStateException("Buffer already closed") } refcount -= 1 if (refcount == 0 && !isValid) { - pendingFreeBuffers.remove(id) - pendingFreeBytes.addAndGet(-size) freeBuffer() } } /** - * Mark the buffer as freed and no longer valid. + * Mark the buffer as freed and no longer valid. This is called by the store when removing a + * buffer (it is no longer tracked). + * * @note The resources may not be immediately released if the buffer has outstanding references. * In that case the resources will be released when the reference count reaches zero. */ @@ -416,9 +399,6 @@ abstract class RapidsBufferStore( buffers.remove(id) if (refcount == 0) { freeBuffer() - } else { - pendingFreeBuffers.put(id, this) - pendingFreeBytes.addAndGet(size) } } else { logWarning(s"Trying to free an invalid buffer => $id, size = $size, $this") @@ -443,9 +423,6 @@ abstract class RapidsBufferStore( /** Must be called with a lock on the buffer */ private def freeBuffer(): Unit = { releaseResources() - memoryFreedMonitor.synchronized { - memoryFreedMonitor.notifyAll() - } } override def toString: String = s"$name buffer size=$size" diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala index b6833074e19..469fa5fa3d6 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala @@ -16,11 +16,10 @@ package com.nvidia.spark.rapids -import ai.rapids.cudf.{ContiguousTable, Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer, Table} +import ai.rapids.cudf.{Cuda, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer} import com.nvidia.spark.rapids.StorageTier.StorageTier import com.nvidia.spark.rapids.format.TableMeta -import org.apache.spark.sql.rapids.TempSpillBufferId import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.ColumnarBatch @@ -28,10 +27,15 @@ import org.apache.spark.sql.vectorized.ColumnarBatch * Buffer storage using device memory. * @param catalog catalog to register this store */ -class RapidsDeviceMemoryStore(catalog: RapidsBufferCatalog = RapidsBufferCatalog.singleton) - extends RapidsBufferStore(StorageTier.DEVICE, catalog) with Arm { +class RapidsDeviceMemoryStore + extends RapidsBufferStore(StorageTier.DEVICE) with Arm { - override protected def createBuffer(other: RapidsBuffer, memoryBuffer: MemoryBuffer, + // The RapidsDeviceMemoryStore handles spillability via ref counting + override protected def spillableOnAdd: Boolean = false + + override protected def createBuffer( + other: RapidsBuffer, + memoryBuffer: MemoryBuffer, stream: Cuda.Stream): RapidsBufferBase = { val deviceBuffer = { memoryBuffer match { @@ -51,158 +55,17 @@ class RapidsDeviceMemoryStore(catalog: RapidsBufferCatalog = RapidsBufferCatalog other.id, other.size, other.meta, - None, deviceBuffer, other.getSpillPriority, other.getSpillCallback) } - /** - * Adds a contiguous table to the device storage, taking ownership of the table. - * - * @param table cudf table based from the contiguous buffer - * @param contigBuffer device memory buffer backing the table - * @param tableMeta metadata describing the buffer layout - * @param initialSpillPriority starting spill priority value for the buffer - * @param spillCallback a callback when the buffer is spilled. This should be very light weight. - * It should never allocate GPU memory and really just be used for metrics. - * @return RapidsBufferHandle handle for this table - */ - def addTable( - table: Table, - contigBuffer: DeviceMemoryBuffer, - tableMeta: TableMeta, - initialSpillPriority: Long, - spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback): RapidsBufferHandle = { - // We increment this because this rapids device memory has two pointers to the buffer: - // the actual contig buffer, and the table. When this `RapidsBuffer` releases its resources, - // it will decrement the ref count for the contig buffer (negating this incRefCount), - // it will also close the table being passed here, which together brings the ref count - // to 0. - contigBuffer.incRefCount() - val id = TempSpillBufferId() - freeOnExcept( - new RapidsDeviceMemoryBuffer( - id, - contigBuffer.getLength, - tableMeta, - Some(table), - contigBuffer, - initialSpillPriority, - spillCallback)) { buffer => - logDebug(s"Adding table for: [id=$id, size=${buffer.size}, " + - s"meta_id=${buffer.meta.bufferMeta.id}, meta_size=${buffer.meta.bufferMeta.size}]") - addDeviceBuffer(buffer, needsSync = true) - catalog.makeNewHandle(id, initialSpillPriority, spillCallback) - } - } - - /** - * Adds a contiguous table to the device storage. This does NOT take ownership of the - * contiguous table, so it is the responsibility of the caller to close it. The refcount of the - * underlying device buffer will be incremented so the contiguous table can be closed before - * this buffer is destroyed. - * - * This version of `addContiguousTable` creates a `TempSpillBufferId` to use - * to refer to this table. - * - * @param contigTable contiguous table to track in storage - * @param initialSpillPriority starting spill priority value for the buffer - * @param spillCallback a callback when the buffer is spilled. This should be very light weight. - * It should never allocate GPU memory and really just be used for metrics. - * @param needsSync whether the spill framework should stream synchronize while adding - * this device buffer (defaults to true) - * @return RapidsBufferHandle handle for this table - */ - def addContiguousTable( - contigTable: ContiguousTable, - initialSpillPriority: Long, - spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback, - needsSync: Boolean = true): RapidsBufferHandle = { - addContiguousTable( - TempSpillBufferId(), - contigTable, - initialSpillPriority, - spillCallback, - needsSync) - } - - /** - * Adds a contiguous table to the device storage. This does NOT take ownership of the - * contiguous table, so it is the responsibility of the caller to close it. The refcount of the - * underlying device buffer will be incremented so the contiguous table can be closed before - * this buffer is destroyed. - * - * @param id the RapidsBufferId to use for this buffer - * @param contigTable contiguous table to track in storage - * @param initialSpillPriority starting spill priority value for the buffer - * @param spillCallback a callback when the buffer is spilled. This should be very light weight. - * It should never allocate GPU memory and really just be used for metrics. - * @param needsSync whether the spill framework should stream synchronize while adding - * this device buffer (defaults to true) - * @return RapidsBufferHandle handle for this table - */ - def addContiguousTable( - id: RapidsBufferId, - contigTable: ContiguousTable, - initialSpillPriority: Long, - spillCallback: SpillCallback, - needsSync: Boolean): RapidsBufferHandle = { - val contigBuffer = contigTable.getBuffer - val size = contigBuffer.getLength - val meta = MetaUtils.buildTableMeta(id.tableId, contigTable) - contigBuffer.incRefCount() - freeOnExcept( - new RapidsDeviceMemoryBuffer( - id, - size, - meta, - None, - contigBuffer, - initialSpillPriority, - spillCallback)) { buffer => - logDebug(s"Adding table for: [id=$id, size=${buffer.size}, " + - s"uncompressed=${buffer.meta.bufferMeta.uncompressedSize}, " + - s"meta_id=${buffer.meta.bufferMeta.id}, meta_size=${buffer.meta.bufferMeta.size}]") - addDeviceBuffer(buffer, needsSync) - catalog.makeNewHandle(id, initialSpillPriority, spillCallback) - } - } - /** * Adds a buffer to the device storage. This does NOT take ownership of the * buffer, so it is the responsibility of the caller to close it. * - * This version of `addBuffer` creates a `TempSpillBufferId` to use to refer to - * this buffer. - * - * @param buffer buffer that will be owned by the store - * @param tableMeta metadata describing the buffer layout - * @param initialSpillPriority starting spill priority value for the buffer - * @param spillCallback a callback when the buffer is spilled. This should be very light weight. - * It should never allocate GPU memory and really just be used for metrics. - * @param needsSync whether the spill framework should stream synchronize while adding - * this device buffer (defaults to true) - * @return RapidsBufferHandle handle for this buffer - */ - def addBuffer( - buffer: DeviceMemoryBuffer, - tableMeta: TableMeta, - initialSpillPriority: Long, - spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback, - needsSync: Boolean = true): RapidsBufferHandle = { - addBuffer( - TempSpillBufferId(), - buffer, - tableMeta, - initialSpillPriority, - spillCallback, - needsSync) - } - - /** - * Adds a buffer to the device storage. This does NOT take ownership of the - * buffer, so it is the responsibility of the caller to close it. + * This function is called only from the RapidsBufferCatalog, under the + * catalog lock. * * @param id the RapidsBufferId to use for this buffer * @param buffer buffer that will be owned by the store @@ -212,7 +75,7 @@ class RapidsDeviceMemoryStore(catalog: RapidsBufferCatalog = RapidsBufferCatalog * It should never allocate GPU memory and really just be used for metrics. * @param needsSync whether the spill framework should stream synchronize while adding * this device buffer (defaults to true) - * @return RapidsBufferHandle handle for this RapidsBuffer + * @return the RapidsBuffer instance that was added. */ def addBuffer( id: RapidsBufferId, @@ -220,23 +83,22 @@ class RapidsDeviceMemoryStore(catalog: RapidsBufferCatalog = RapidsBufferCatalog tableMeta: TableMeta, initialSpillPriority: Long, spillCallback: SpillCallback, - needsSync: Boolean): RapidsBufferHandle = { + needsSync: Boolean): RapidsBuffer = { buffer.incRefCount() - freeOnExcept( - new RapidsDeviceMemoryBuffer( - id, - buffer.getLength, - tableMeta, - None, - buffer, - initialSpillPriority, - spillCallback)) { buff => + val rapidsBuffer = new RapidsDeviceMemoryBuffer( + id, + buffer.getLength, + tableMeta, + buffer, + initialSpillPriority, + spillCallback) + freeOnExcept(rapidsBuffer) { _ => logDebug(s"Adding receive side table for: [id=$id, size=${buffer.getLength}, " + - s"uncompressed=${buff.meta.bufferMeta.uncompressedSize}, " + + s"uncompressed=${rapidsBuffer.meta.bufferMeta.uncompressedSize}, " + s"meta_id=${tableMeta.bufferMeta.id}, " + s"meta_size=${tableMeta.bufferMeta.size}]") - addDeviceBuffer(buff, needsSync) - catalog.makeNewHandle(id, initialSpillPriority, spillCallback) + addDeviceBuffer(rapidsBuffer, needsSync) + rapidsBuffer } } @@ -256,36 +118,95 @@ class RapidsDeviceMemoryStore(catalog: RapidsBufferCatalog = RapidsBufferCatalog addBuffer(buffer) } + /** + * The RapidsDeviceMemoryStore is the only store that supports setting a buffer spillable + * or not. + */ + override protected def setSpillable(buffer: RapidsBufferBase, spillable: Boolean): Unit = { + doSetSpillable(buffer, spillable) + } + class RapidsDeviceMemoryBuffer( id: RapidsBufferId, size: Long, meta: TableMeta, - table: Option[Table], contigBuffer: DeviceMemoryBuffer, spillPriority: Long, spillCallback: SpillCallback) - extends RapidsBufferBase(id, size, meta, spillPriority, spillCallback) { + extends RapidsBufferBase(id, size, meta, spillPriority, spillCallback) + with MemoryBuffer.EventHandler { + override val storageTier: StorageTier = StorageTier.DEVICE - override protected def releaseResources(): Unit = { + // If this require triggers, we are re-adding a `DeviceMemoryBuffer` outside of + // the catalog lock, which should not possible. The event handler is set to null + // when we free the `RapidsDeviceMemoryBuffer` and if the buffer is not free, we + // take out another handle (in the catalog). + // TODO: This is not robust (to rely on outside locking and addReference/free) + // and should be revisited. + require(contigBuffer.setEventHandler(this) == null, + "DeviceMemoryBuffer with non-null event handler failed to add!!") + + /** + * Override from the MemoryBuffer.EventHandler interface. + * + * If we are being invoked we have the `contigBuffer` lock, as this callback + * is being invoked from `MemoryBuffer.close` + * + * @param refCount - contigBuffer's current refCount + */ + override def onClosed(refCount: Int): Unit = { + // refCount == 1 means only 1 reference exists to `contigBuffer` in the + // RapidsDeviceMemoryBuffer (we own it) + if (refCount == 1) { + // setSpillable is being called here as an extension of `MemoryBuffer.close()` + // we hold the MemoryBuffer lock and we could be called from a Spark task thread + // Since we hold the MemoryBuffer lock, `incRefCount` waits for us. The only other + // call to `setSpillable` is also under this same MemoryBuffer lock (see: + // `getDeviceMemoryBuffer`) + setSpillable(this, true) + } + } + + override protected def releaseResources(): Unit = synchronized { + // we need to disassociate this RapidsBuffer from the underlying buffer contigBuffer.close() - table.foreach(_.close()) } - override def getDeviceMemoryBuffer: DeviceMemoryBuffer = { - contigBuffer.incRefCount() - contigBuffer + /** + * Get and increase the reference count of the device memory buffer + * in this RapidsBuffer, while making the RapidsBuffer non-spillable. + * + * @note It is the responsibility of the caller to close the DeviceMemoryBuffer + */ + override def getDeviceMemoryBuffer: DeviceMemoryBuffer = synchronized { + contigBuffer.synchronized { + setSpillable(this, false) + contigBuffer.incRefCount() + contigBuffer + } } override def getMemoryBuffer: MemoryBuffer = getDeviceMemoryBuffer override def getColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = { - if (table.isDefined) { - //REFCOUNT ++ of all columns - GpuColumnVectorFromBuffer.from(table.get, contigBuffer, meta, sparkTypes) - } else { - columnarBatchFromDeviceBuffer(contigBuffer, sparkTypes) + // calling `getDeviceMemoryBuffer` guarantees that we have marked this RapidsBuffer + // as not spillable and increased its refCount atomically + withResource(getDeviceMemoryBuffer) { buff => + columnarBatchFromDeviceBuffer(buff, sparkTypes) + } + } + + /** + * We overwrite free to make sure we don't have a handler for the underlying + * contigBuffer, since this `RapidsBuffer` is no longer tracked. + */ + override def free(): Unit = synchronized { + if (isValid) { + // it is going to be invalid when calling super.free() + contigBuffer.setEventHandler(null) } + super.free() } } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala index 12c312c52f4..18b3b9839a7 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala @@ -27,14 +27,13 @@ import com.nvidia.spark.rapids.format.TableMeta import org.apache.spark.sql.rapids.RapidsDiskBlockManager /** A buffer store using files on the local disks. */ -class RapidsDiskStore( - diskBlockManager: RapidsDiskBlockManager, - catalog: RapidsBufferCatalog = RapidsBufferCatalog.singleton, - deviceStorage: RapidsDeviceMemoryStore = RapidsBufferCatalog.getDeviceStorage) - extends RapidsBufferStore(StorageTier.DISK, catalog) { +class RapidsDiskStore(diskBlockManager: RapidsDiskBlockManager) + extends RapidsBufferStore(StorageTier.DISK) { private[this] val sharedBufferFiles = new ConcurrentHashMap[RapidsBufferId, File] - override protected def createBuffer(incoming: RapidsBuffer, incomingBuffer: MemoryBuffer, + override protected def createBuffer( + incoming: RapidsBuffer, + incomingBuffer: MemoryBuffer, stream: Cuda.Stream): RapidsBufferBase = { withResource(incomingBuffer) { _ => val hostBuffer = incomingBuffer match { @@ -62,8 +61,7 @@ class RapidsDiskStore( incoming.size, incoming.meta, incoming.getSpillPriority, - incoming.getSpillCallback, - deviceStorage) + incoming.getSpillCallback) } } @@ -95,10 +93,9 @@ class RapidsDiskStore( size: Long, meta: TableMeta, spillPriority: Long, - spillCallback: SpillCallback, - deviceStorage: RapidsDeviceMemoryStore) + spillCallback: SpillCallback) extends RapidsBufferBase( - id, size, meta, spillPriority, spillCallback, deviceStorage = deviceStorage) { + id, size, meta, spillPriority, spillCallback) { private[this] var hostBuffer: Option[HostMemoryBuffer] = None override val storageTier: StorageTier = StorageTier.DISK diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsGdsStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsGdsStore.scala index b9c754c14aa..a613004b072 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsGdsStore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsGdsStore.scala @@ -32,9 +32,8 @@ import org.apache.spark.sql.rapids.{RapidsDiskBlockManager, TempSpillBufferId} /** A buffer store using GPUDirect Storage (GDS). */ class RapidsGdsStore( diskBlockManager: RapidsDiskBlockManager, - batchWriteBufferSize: Long, - catalog: RapidsBufferCatalog = RapidsBufferCatalog.singleton) - extends RapidsBufferStore(StorageTier.GDS, catalog) with Arm { + batchWriteBufferSize: Long) + extends RapidsBufferStore(StorageTier.GDS) with Arm { private[this] val batchSpiller = new BatchSpiller() override protected def createBuffer(other: RapidsBuffer, otherBuffer: MemoryBuffer, diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala index ef3dd77ff88..9317df134b9 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala @@ -21,20 +21,15 @@ import com.nvidia.spark.rapids.SpillPriorities.{applyPriorityOffset, HOST_MEMORY import com.nvidia.spark.rapids.StorageTier.StorageTier import com.nvidia.spark.rapids.format.TableMeta -import org.apache.spark.sql.rapids.execution.TrampolineUtil - /** * A buffer store using host memory. * @param maxSize maximum size in bytes for all buffers in this store * @param pageableMemoryPoolSize maximum size in bytes for the internal pageable memory pool - * @param catalog buffer catalog to use with this store */ class RapidsHostMemoryStore( maxSize: Long, - pageableMemoryPoolSize: Long, - catalog: RapidsBufferCatalog = RapidsBufferCatalog.singleton, - deviceStorage: RapidsDeviceMemoryStore = RapidsBufferCatalog.getDeviceStorage) - extends RapidsBufferStore(StorageTier.HOST, catalog) { + pageableMemoryPoolSize: Long) + extends RapidsBufferStore(StorageTier.HOST) { private[this] val pool = HostMemoryBuffer.allocate(pageableMemoryPoolSize, false) private[this] val addressAllocator = new AddressSpaceAllocator(pageableMemoryPoolSize) private[this] var haveLoggedMaxExceeded = false @@ -44,40 +39,26 @@ class RapidsHostMemoryStore( private case object Pooled extends AllocationMode(HOST_MEMORY_BUFFER_PAGEABLE_OFFSET) private case object Direct extends AllocationMode(HOST_MEMORY_BUFFER_DIRECT_OFFSET) - // Returns an allocated host buffer and its allocation mode + override def getMaxSize: Option[Long] = Some(maxSize) + private def allocateHostBuffer(size: Long): (HostMemoryBuffer, AllocationMode) = { - // spill to keep within the targeted size - val amountSpilled = synchronousSpill(math.max(maxSize - size, 0)) - if (amountSpilled != 0) { - logInfo(s"Spilled $amountSpilled bytes from the host memory store") - TrampolineUtil.incTaskMetricsDiskBytesSpilled(amountSpilled) + var buffer: HostMemoryBuffer = PinnedMemoryPool.tryAllocate(size) + if (buffer != null) { + return (buffer, Pinned) } - var buffer: HostMemoryBuffer = null - while (buffer == null) { - buffer = PinnedMemoryPool.tryAllocate(size) - if (buffer != null) { - return (buffer, Pinned) - } - - if (size > pageableMemoryPoolSize) { - if (!haveLoggedMaxExceeded) { - logWarning(s"Exceeding host spill max of $pageableMemoryPoolSize bytes to accommodate " + - s"a buffer of $size bytes. Consider increasing pageable memory store size.") - haveLoggedMaxExceeded = true - } - return (HostMemoryBuffer.allocate(size, false), Direct) - } + val allocation = addressAllocator.allocate(size) + if (allocation.isDefined) { + buffer = pool.slice(allocation.get, size) + return (buffer, Pooled) + } - val allocation = addressAllocator.allocate(size) - if (allocation.isDefined) { - buffer = pool.slice(allocation.get, size) - } else { - val targetSize = math.max(currentSize - size, 0) - synchronousSpill(targetSize) - } + if (!haveLoggedMaxExceeded) { + logWarning(s"Exceeding host spill max of $pageableMemoryPoolSize bytes to accommodate " + + s"a buffer of $size bytes. Consider increasing pageable memory store size.") + haveLoggedMaxExceeded = true } - (buffer, Pooled) + (HostMemoryBuffer.allocate(size, false), Direct) } override protected def createBuffer(other: RapidsBuffer, otherBuffer: MemoryBuffer, @@ -86,8 +67,10 @@ class RapidsHostMemoryStore( val (hostBuffer, allocationMode) = allocateHostBuffer(other.size) try { otherBuffer match { - case devBuffer: DeviceMemoryBuffer => hostBuffer.copyFromDeviceBuffer(devBuffer, stream) - case _ => throw new IllegalStateException("copying from buffer without device memory") + case devBuffer: DeviceMemoryBuffer => + hostBuffer.copyFromDeviceBuffer(devBuffer, stream) + case _ => + throw new IllegalStateException("copying from buffer without device memory") } } catch { case e: Exception => @@ -101,8 +84,7 @@ class RapidsHostMemoryStore( applyPriorityOffset(other.getSpillPriority, allocationMode.spillPriorityOffset), hostBuffer, allocationMode, - other.getSpillCallback, - deviceStorage) + other.getSpillCallback) } } @@ -120,10 +102,9 @@ class RapidsHostMemoryStore( spillPriority: Long, buffer: HostMemoryBuffer, allocationMode: AllocationMode, - spillCallback: SpillCallback, - deviceStorage: RapidsDeviceMemoryStore) + spillCallback: SpillCallback) extends RapidsBufferBase( - id, size, meta, spillPriority, spillCallback, deviceStorage = deviceStorage) { + id, size, meta, spillPriority, spillCallback) { override val storageTier: StorageTier = StorageTier.HOST override def getMemoryBuffer: MemoryBuffer = { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala index 236e87fe4b0..0c79d1124c3 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala @@ -80,7 +80,7 @@ class ShuffleBufferCatalog( * It should never allocate GPU memory and really just be used for metrics. * @param needsSync whether the spill framework should stream synchronize while adding * this device buffer (defaults to true) - * @return RapidsBufferId identifying this table + * @return RapidsBufferHandle identifying this table */ def addContiguousTable( blockId: ShuffleBlockId, @@ -90,7 +90,7 @@ class ShuffleBufferCatalog( needsSync: Boolean): RapidsBufferHandle = { val bufferId = nextShuffleBufferId(blockId) withResource(contigTable) { _ => - val handle = deviceStore.addContiguousTable( + val handle = catalog.addContiguousTable( bufferId, contigTable, initialSpillPriority, @@ -124,7 +124,7 @@ class ShuffleBufferCatalog( tableMeta.bufferMeta.mutateId(bufferId.tableId) // when we call `addBuffer` the store will incRefCount withResource(buffer) { _ => - val handle = deviceStore.addBuffer( + val handle = catalog.addBuffer( bufferId, buffer, tableMeta, @@ -145,10 +145,7 @@ class ShuffleBufferCatalog( meta: TableMeta, spillCallback: SpillCallback): RapidsBufferHandle = { val bufferId = nextShuffleBufferId(blockId) - val buffer = new DegenerateRapidsBuffer(bufferId, meta) - catalog.registerNewBuffer(buffer) - val handle = - catalog.makeNewHandle(buffer.id, buffer.getSpillPriority, spillCallback) + val handle = catalog.registerDegenerateBuffer(bufferId, meta, spillCallback) trackCachedHandle(bufferId, handle) handle } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala index e6b58b90b6a..c868c7e48a6 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala @@ -85,7 +85,7 @@ class ShuffleReceivedBufferCatalog( tableMeta.bufferMeta.mutateId(bufferId.tableId) // when we call `addBuffer` the store will incRefCount withResource(buffer) { _ => - deviceStore.addBuffer( + catalog.addBuffer( bufferId, buffer, tableMeta, @@ -107,9 +107,7 @@ class ShuffleReceivedBufferCatalog( meta: TableMeta, spillCallback: SpillCallback): RapidsBufferHandle = { val bufferId = nextShuffleReceivedBufferId() - val buffer = new DegenerateRapidsBuffer(bufferId, meta) - catalog.registerNewBuffer(buffer) - catalog.makeNewHandle(bufferId, -1, spillCallback) + catalog.registerDegenerateBuffer(bufferId, meta, spillCallback) } /** diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala index 09d7134954b..44049e348fb 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala @@ -191,9 +191,7 @@ object SpillableColumnarBatch extends Arm { .forall(i => batch.column(i).isInstanceOf[GpuColumnVectorFromBuffer])) { val cv = batch.column(0).asInstanceOf[GpuColumnVectorFromBuffer] val buff = cv.getBuffer - // note the table here is handed over to the catalog - val table = GpuColumnVector.from(batch) - RapidsBufferCatalog.addTable(table, buff, cv.getTableMeta, initialSpillPriority, + RapidsBufferCatalog.addBuffer(buff, cv.getTableMeta, initialSpillPriority, spillCallback) } else { withResource(GpuColumnVector.from(batch)) { tmpTable => diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandlerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandlerSuite.scala index ec9971a2d92..be370209e32 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandlerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandlerSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,28 +24,46 @@ import org.scalatest.mockito.MockitoSugar class DeviceMemoryEventHandlerSuite extends FunSuite with MockitoSugar { test("a failed allocation should be retried if we spilled enough") { + val mockCatalog = mock[RapidsBufferCatalog] val mockStore = mock[RapidsDeviceMemoryStore] when(mockStore.currentSize).thenReturn(1024) - when(mockStore.synchronousSpill(any())).thenReturn(1024) - val handler = new DeviceMemoryEventHandler(mockStore, None, false, 2) + when(mockCatalog.synchronousSpill(any(), any(), any())).thenAnswer(_ => Some(1024)) + val handler = new DeviceMemoryEventHandler( + mockCatalog, + mockStore, + None, + false, + 2) assertResult(true)(handler.onAllocFailure(1024, 0)) } test("when we deplete the store, retry up to max failed OOM retries") { + val mockCatalog = mock[RapidsBufferCatalog] val mockStore = mock[RapidsDeviceMemoryStore] when(mockStore.currentSize).thenReturn(0) - when(mockStore.synchronousSpill(any())).thenReturn(0) - val handler = new DeviceMemoryEventHandler(mockStore, None, false, 2) + when(mockCatalog.synchronousSpill(any(), any(), any())).thenAnswer(_ => Some(0)) + val handler = new DeviceMemoryEventHandler( + mockCatalog, + mockStore, + None, + false, + 2) assertResult(true)(handler.onAllocFailure(1024, 0)) // sync assertResult(true)(handler.onAllocFailure(1024, 1)) // sync 2 assertResult(false)(handler.onAllocFailure(1024, 2)) // cuDF would OOM here } test("we reset our OOM state after a successful retry") { + val mockCatalog = mock[RapidsBufferCatalog] val mockStore = mock[RapidsDeviceMemoryStore] when(mockStore.currentSize).thenReturn(0) - when(mockStore.synchronousSpill(any())).thenReturn(0) - val handler = new DeviceMemoryEventHandler(mockStore, None, false, 2) + when(mockCatalog.synchronousSpill(any(), any(), any())).thenAnswer(_ => Some(0)) + val handler = new DeviceMemoryEventHandler( + mockCatalog, + mockStore, + None, + false, + 2) // with this call we sync, and we mark our attempts at 1, we store 0 as the last count assertResult(true)(handler.onAllocFailure(1024, 0)) // this retryCount is still 0, we should be back at 1 for attempts @@ -55,18 +73,30 @@ class DeviceMemoryEventHandlerSuite extends FunSuite with MockitoSugar { } test("a negative allocation cannot be retried and handler throws") { + val mockCatalog = mock[RapidsBufferCatalog] val mockStore = mock[RapidsDeviceMemoryStore] when(mockStore.currentSize).thenReturn(1024) - when(mockStore.synchronousSpill(any())).thenReturn(1024) - val handler = new DeviceMemoryEventHandler(mockStore, None, false, 2) + when(mockCatalog.synchronousSpill(any(), any(), any())).thenAnswer(_ => Some(1024)) + val handler = new DeviceMemoryEventHandler( + mockCatalog, + mockStore, + None, + false, + 2) assertThrows[IllegalArgumentException](handler.onAllocFailure(-1, 0)) } test("a negative retry count is invalid") { + val mockCatalog = mock[RapidsBufferCatalog] val mockStore = mock[RapidsDeviceMemoryStore] when(mockStore.currentSize).thenReturn(1024) - when(mockStore.synchronousSpill(any())).thenReturn(1024) - val handler = new DeviceMemoryEventHandler(mockStore, None, false, 2) + when(mockCatalog.synchronousSpill(any(), any(), any())).thenAnswer(_ => Some(1024)) + val handler = new DeviceMemoryEventHandler( + mockCatalog, + mockStore, + None, + false, + 2) assertThrows[IllegalArgumentException](handler.onAllocFailure(1024, -1)) } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/GpuPartitioningSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/GpuPartitioningSuite.scala index 1f87c6f7520..662f8a5d5e0 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/GpuPartitioningSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/GpuPartitioningSuite.scala @@ -150,8 +150,9 @@ class GpuPartitioningSuite extends FunSuite with Arm { TestUtils.withGpuSparkSession(conf) { _ => GpuShuffleEnv.init(new RapidsConf(conf), new RapidsDiskBlockManager(conf)) val spillPriority = 7L - val catalog = new RapidsBufferCatalog - withResource(new RapidsDeviceMemoryStore(catalog)) { deviceStore => + + withResource(new RapidsDeviceMemoryStore) { store => + val catalog = new RapidsBufferCatalog(store) val partitionIndices = Array(0, 2, 2) val gp = new GpuPartitioning { override val numPartitions: Int = partitionIndices.length @@ -195,7 +196,7 @@ class GpuPartitioningSuite extends FunSuite with Arm { if (GpuCompressedColumnVector.isBatchCompressed(partBatch)) { val gccv = columns.head.asInstanceOf[GpuCompressedColumnVector] val devBuffer = gccv.getTableBuffer - val handle = deviceStore.addBuffer(devBuffer, gccv.getTableMeta, spillPriority) + val handle = catalog.addBuffer(devBuffer, gccv.getTableMeta, spillPriority) withResource(buildSubBatch(batch, startRow, endRow)) { expectedBatch => withResource(catalog.acquireBuffer(handle)) { buffer => withResource(buffer.getColumnarBatch(sparkTypes)) { batch => diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala index 2a951eb2902..dc7958d1c31 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala @@ -21,6 +21,7 @@ import java.io.File import ai.rapids.cudf.{Cuda, DeviceMemoryBuffer, MemoryBuffer} import com.nvidia.spark.rapids.StorageTier.{DEVICE, DISK, HOST, StorageTier} import com.nvidia.spark.rapids.format.TableMeta +import org.mockito.ArgumentMatchers.any import org.mockito.Mockito._ import org.scalatest.FunSuite import org.scalatest.mockito.MockitoSugar @@ -252,6 +253,44 @@ class RapidsBufferCatalogSuite extends FunSuite with MockitoSugar with Arm { assert(!catalog.isBufferSpilled(bufferId, DISK)) } + test("multiple calls to unspill return existing DEVICE buffer") { + val deviceStore = spy(new RapidsDeviceMemoryStore) + val mockStore = mock[RapidsBufferStore] + val hostStore = new RapidsHostMemoryStore(10000, 1000) + deviceStore.setSpillStore(hostStore) + hostStore.setSpillStore(mockStore) + val catalog = new RapidsBufferCatalog(deviceStore) + val handle = withResource(DeviceMemoryBuffer.allocate(1024)) { buff => + val meta = MetaUtils.getTableMetaNoTable(buff) + catalog.addBuffer( + buff, meta, -1, RapidsBuffer.defaultSpillCallback) + } + withResource(handle) { _ => + catalog.synchronousSpill(deviceStore, 0) + val acquiredHostBuffer = catalog.acquireBuffer(handle) + withResource(acquiredHostBuffer) { _ => + assertResult(HOST)(acquiredHostBuffer.storageTier) + val unspilled = + catalog.unspillBufferToDeviceStore( + acquiredHostBuffer, + acquiredHostBuffer.getMemoryBuffer, + Cuda.DEFAULT_STREAM) + withResource(unspilled) { _ => + assertResult(DEVICE)(unspilled.storageTier) + } + val unspilledSame = catalog.unspillBufferToDeviceStore( + acquiredHostBuffer, + acquiredHostBuffer.getMemoryBuffer, + Cuda.DEFAULT_STREAM) + withResource(unspilledSame) { _ => + assertResult(unspilled)(unspilledSame) + } + // verify that we invoked the copy function exactly once + verify(deviceStore, times(1)).copyBuffer(any(), any(), any()) + } + } + } + test("remove buffer tier") { val catalog = new RapidsBufferCatalog val bufferId = MockBufferId(5) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStoreSuite.scala index a00ea1c031d..82322c3e863 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStoreSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStoreSuite.scala @@ -25,7 +25,7 @@ import ai.rapids.cudf.{ContiguousTable, Cuda, DeviceMemoryBuffer, HostMemoryBuff import com.nvidia.spark.rapids.StorageTier.StorageTier import com.nvidia.spark.rapids.format.TableMeta import org.mockito.ArgumentCaptor -import org.mockito.Mockito.verify +import org.mockito.Mockito.{spy, verify} import org.scalatest.FunSuite import org.scalatest.mockito.MockitoSugar @@ -45,12 +45,12 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { } test("add table registers with catalog") { - val catalog = mock[RapidsBufferCatalog] - withResource(new RapidsDeviceMemoryStore(catalog)) { store => + withResource(new RapidsDeviceMemoryStore) { store => + val catalog = spy(new RapidsBufferCatalog(store)) val spillPriority = 3 val bufferId = MockRapidsBufferId(7) withResource(buildContiguousTable()) { ct => - store.addContiguousTable( + catalog.addContiguousTable( bufferId, ct, spillPriority, RapidsBuffer.defaultSpillCallback, false) } val captor: ArgumentCaptor[RapidsBuffer] = ArgumentCaptor.forClass(classOf[RapidsBuffer]) @@ -61,16 +61,108 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { } } + test("a table is not spillable until the owner closes it") { + withResource(new RapidsDeviceMemoryStore) { store => + val catalog = spy(new RapidsBufferCatalog(store)) + val spillPriority = 3 + val bufferId = MockRapidsBufferId(7) + val ct = buildContiguousTable() + val buffSize = ct.getBuffer.getLength + withResource(ct) { _ => + catalog.addContiguousTable( + bufferId, + ct, + spillPriority, + RapidsBuffer.defaultSpillCallback, + false) + assertResult(buffSize)(store.currentSize) + assertResult(0)(store.currentSpillableSize) + } + // after closing the original table, the RapidsBuffer should be spillable + assertResult(buffSize)(store.currentSize) + assertResult(buffSize)(store.currentSpillableSize) + } + } + + test("a buffer is not spillable until the owner closes columns referencing it") { + withResource(new RapidsDeviceMemoryStore) { store => + val catalog = spy(new RapidsBufferCatalog(store)) + val spillPriority = 3 + val bufferId = MockRapidsBufferId(7) + val ct = buildContiguousTable() + val buffSize = ct.getBuffer.getLength + withResource(ct) { _ => + val meta = MetaUtils.buildTableMeta(bufferId.tableId, ct) + withResource(ct) { _ => + store.addBuffer( + bufferId, + ct.getBuffer, + meta, + spillPriority, + RapidsBuffer.defaultSpillCallback, + false) + assertResult(buffSize)(store.currentSize) + assertResult(0)(store.currentSpillableSize) + } + } + // after closing the original table, the RapidsBuffer should be spillable + assertResult(buffSize)(store.currentSize) + assertResult(buffSize)(store.currentSpillableSize) + } + } + + test("a buffer is not spillable when the underlying device buffer is obtained from it") { + withResource(new RapidsDeviceMemoryStore) { store => + val catalog = spy(new RapidsBufferCatalog(store)) + val spillPriority = 3 + val bufferId = MockRapidsBufferId(7) + val ct = buildContiguousTable() + val underlyingBuff = ct.getBuffer + val buffSize = ct.getBuffer.getLength + val buffer = withResource(ct) { _ => + val meta = MetaUtils.buildTableMeta(bufferId.tableId, ct) + val buffer = store.addBuffer( + bufferId, + ct.getBuffer, + meta, + spillPriority, + RapidsBuffer.defaultSpillCallback, + false) + assertResult(buffSize)(store.currentSize) + assertResult(0)(store.currentSpillableSize) + buffer + } + + // after closing the original table, the RapidsBuffer should be spillable + assertResult(buffSize)(store.currentSize) + assertResult(buffSize)(store.currentSpillableSize) + + // if a device memory buffer is obtained from the buffer, it is no longer spillable + withResource(buffer.getDeviceMemoryBuffer) { deviceBuffer => + assertResult(buffSize)(store.currentSize) + assertResult(0)(store.currentSpillableSize) + } + + // once the DeviceMemoryBuffer is closed, the RapidsBuffer should be spillable again + assertResult(buffSize)(store.currentSpillableSize) + } + } + test("add buffer registers with catalog") { - val catalog = mock[RapidsBufferCatalog] - withResource(new RapidsDeviceMemoryStore(catalog)) { store => + withResource(new RapidsDeviceMemoryStore) { store => + val catalog = spy(new RapidsBufferCatalog(store)) val spillPriority = 3 val bufferId = MockRapidsBufferId(7) val meta = withResource(buildContiguousTable()) { ct => val meta = MetaUtils.buildTableMeta(bufferId.tableId, ct) withResource(ct) { _ => - store.addBuffer( - bufferId, ct.getBuffer, meta, spillPriority, RapidsBuffer.defaultSpillCallback, false) + catalog.addBuffer( + bufferId, + ct.getBuffer, + meta, + spillPriority, + RapidsBuffer.defaultSpillCallback, + false) } meta } @@ -84,15 +176,15 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { } test("get memory buffer") { - val catalog = new RapidsBufferCatalog - withResource(new RapidsDeviceMemoryStore(catalog)) { store => + withResource(new RapidsDeviceMemoryStore) { store => + val catalog = spy(new RapidsBufferCatalog(store)) val bufferId = MockRapidsBufferId(7) withResource(buildContiguousTable()) { ct => withResource(HostMemoryBuffer.allocate(ct.getBuffer.getLength)) { expectedHostBuffer => expectedHostBuffer.copyFromDeviceBuffer(ct.getBuffer) val meta = MetaUtils.buildTableMeta(bufferId.tableId, ct) val handle = withResource(ct) { _ => - store.addBuffer( + catalog.addBuffer( bufferId, ct.getBuffer, meta, @@ -114,17 +206,21 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { } test("get column batch") { - val catalog = new RapidsBufferCatalog - val sparkTypes = Array[DataType](IntegerType, StringType, DoubleType, - DecimalType(ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION, 5)) - withResource(new RapidsDeviceMemoryStore(catalog)) { store => + withResource(new RapidsDeviceMemoryStore) { store => + val catalog = new RapidsBufferCatalog(store) + val sparkTypes = Array[DataType](IntegerType, StringType, DoubleType, + DecimalType(ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION, 5)) val bufferId = MockRapidsBufferId(7) withResource(buildContiguousTable()) { ct => withResource(GpuColumnVector.from(ct.getTable, sparkTypes)) { expectedBatch => val meta = MetaUtils.buildTableMeta(bufferId.tableId, ct) val handle = withResource(ct) { _ => - store.addBuffer(bufferId, ct.getBuffer, meta, initialSpillPriority = 3, + catalog.addBuffer( + bufferId, + ct.getBuffer, + meta, + initialSpillPriority = 3, RapidsBuffer.defaultSpillCallback, false) } withResource(catalog.acquireBuffer(handle)) { buffer => @@ -138,16 +234,16 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { } test("cannot receive spilled buffers") { - val catalog = new RapidsBufferCatalog - withResource(new RapidsDeviceMemoryStore(catalog)) { store => + withResource(new RapidsDeviceMemoryStore) { store => assertThrows[IllegalStateException](store.copyBuffer( mock[RapidsBuffer], mock[MemoryBuffer], Cuda.DEFAULT_STREAM)) } } test("size statistics") { - val catalog = new RapidsBufferCatalog - withResource(new RapidsDeviceMemoryStore(catalog)) { store => + + withResource(new RapidsDeviceMemoryStore) { store => + val catalog = new RapidsBufferCatalog(store) assertResult(0)(store.currentSize) val bufferSizes = new Array[Long](2) val bufferHandles = new Array[RapidsBufferHandle](2) @@ -156,7 +252,10 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { bufferSizes(i) = ct.getBuffer.getLength // store takes ownership of the table bufferHandles(i) = - store.addContiguousTable(MockRapidsBufferId(i), ct, initialSpillPriority = 0, + catalog.addContiguousTable( + MockRapidsBufferId(i), + ct, + initialSpillPriority = 0, RapidsBuffer.defaultSpillCallback, false) } assertResult(bufferSizes.take(i+1).sum)(store.currentSize) @@ -169,17 +268,17 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { } test("spill") { - val catalog = new RapidsBufferCatalog - val spillStore = new MockSpillStore(catalog) + val spillStore = new MockSpillStore val spillPriorities = Array(0, -1, 2) val bufferSizes = new Array[Long](spillPriorities.length) - withResource(new RapidsDeviceMemoryStore(catalog)) { store => + withResource(new RapidsDeviceMemoryStore) { store => + val catalog = new RapidsBufferCatalog(store) store.setSpillStore(spillStore) spillPriorities.indices.foreach { i => withResource(buildContiguousTable()) { ct => bufferSizes(i) = ct.getBuffer.getLength // store takes ownership of the table - store.addContiguousTable( + catalog.addContiguousTable( MockRapidsBufferId(i), ct, spillPriorities(i), RapidsBuffer.defaultSpillCallback, false) } @@ -188,21 +287,21 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { // asking to spill 0 bytes should not spill val sizeBeforeSpill = store.currentSize - store.synchronousSpill(sizeBeforeSpill) + catalog.synchronousSpill(store, sizeBeforeSpill) assert(spillStore.spilledBuffers.isEmpty) assertResult(sizeBeforeSpill)(store.currentSize) - store.synchronousSpill(sizeBeforeSpill + 1) + catalog.synchronousSpill(store, sizeBeforeSpill + 1) assert(spillStore.spilledBuffers.isEmpty) assertResult(sizeBeforeSpill)(store.currentSize) // spilling 1 byte should force one buffer to spill in priority order - store.synchronousSpill(sizeBeforeSpill - 1) + catalog.synchronousSpill(store, sizeBeforeSpill - 1) assertResult(1)(spillStore.spilledBuffers.length) assertResult(bufferSizes.drop(1).sum)(store.currentSize) assertResult(1)(spillStore.spilledBuffers(0).tableId) // spilling to zero should force all buffers to spill in priority order - store.synchronousSpill(0) + catalog.synchronousSpill(store, 0) assertResult(3)(spillStore.spilledBuffers.length) assertResult(0)(store.currentSize) assertResult(0)(spillStore.spilledBuffers(1).tableId) @@ -215,8 +314,7 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { throw new UnsupportedOperationException } - class MockSpillStore(catalog: RapidsBufferCatalog) - extends RapidsBufferStore(StorageTier.HOST, catalog) with Arm { + class MockSpillStore extends RapidsBufferStore(StorageTier.HOST) with Arm { val spilledBuffers = new ArrayBuffer[RapidsBufferId] override protected def createBuffer( diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDiskStoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDiskStoreSuite.scala index a487df0e6da..c885c2d7d17 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDiskStoreSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDiskStoreSuite.scala @@ -44,20 +44,20 @@ class RapidsDiskStoreSuite extends FunSuiteWithTempDir with Arm with MockitoSuga val bufferId = MockRapidsBufferId(7, canShareDiskPaths = false) val spillPriority = -7 val hostStoreMaxSize = 1L * 1024 * 1024 - val catalog = spy(new RapidsBufferCatalog) - withResource(new RapidsDeviceMemoryStore(catalog)) { devStore => - withResource(new RapidsHostMemoryStore(hostStoreMaxSize, hostStoreMaxSize, catalog)) { + withResource(new RapidsDeviceMemoryStore) { devStore => + val catalog = spy(new RapidsBufferCatalog(devStore)) + withResource(new RapidsHostMemoryStore(hostStoreMaxSize, hostStoreMaxSize)) { hostStore => devStore.setSpillStore(hostStore) - withResource(new RapidsDiskStore(mock[RapidsDiskBlockManager], catalog)) { diskStore => + withResource(new RapidsDiskStore(mock[RapidsDiskBlockManager])) { diskStore => assertResult(0)(diskStore.currentSize) hostStore.setSpillStore(diskStore) val (bufferSize, handle) = - addTableToStore(devStore, bufferId, spillPriority) + addTableToCatalog(catalog, bufferId, spillPriority) val path = handle.id.getDiskPath(null) assert(!path.exists()) - devStore.synchronousSpill(0) - hostStore.synchronousSpill(0) + catalog.synchronousSpill(devStore, 0) + catalog.synchronousSpill(hostStore, 0) assertResult(0)(hostStore.currentSize) assertResult(bufferSize)(diskStore.currentSize) assert(path.exists) @@ -84,27 +84,34 @@ class RapidsDiskStoreSuite extends FunSuiteWithTempDir with Arm with MockitoSuga DecimalType(ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION, 5)) val spillPriority = -7 val hostStoreMaxSize = 1L * 1024 * 1024 - val catalog = new RapidsBufferCatalog - withResource(new RapidsDeviceMemoryStore(catalog)) { devStore => - withResource(new RapidsHostMemoryStore(hostStoreMaxSize, hostStoreMaxSize, catalog)) { + withResource(new RapidsDeviceMemoryStore) { devStore => + val catalog = new RapidsBufferCatalog(devStore) + withResource(new RapidsHostMemoryStore(hostStoreMaxSize, hostStoreMaxSize)) { hostStore => devStore.setSpillStore(hostStore) - withResource(new RapidsDiskStore(mock[RapidsDiskBlockManager], catalog, devStore)) { + withResource(new RapidsDiskStore(mock[RapidsDiskBlockManager])) { diskStore => hostStore.setSpillStore(diskStore) - val (_, handle) = addTableToStore(devStore, bufferId, spillPriority) + val (_, handle) = addTableToCatalog(catalog, bufferId, spillPriority) assert(!handle.id.getDiskPath(null).exists()) - val expectedBatch = withResource(catalog.acquireBuffer(handle)) { buffer => + val expectedTable = withResource(catalog.acquireBuffer(handle)) { buffer => assertResult(StorageTier.DEVICE)(buffer.storageTier) - buffer.getColumnarBatch(sparkTypes) + withResource(buffer.getColumnarBatch(sparkTypes)) { beforeSpill => + withResource(GpuColumnVector.from(beforeSpill)) { table => + table.contiguousSplit()(0) + } + } // closing the batch from the store so that we can spill it } - withResource(expectedBatch) { expectedBatch => - devStore.synchronousSpill(0) - hostStore.synchronousSpill(0) - withResource(catalog.acquireBuffer(handle)) { buffer => - assertResult(StorageTier.DISK)(buffer.storageTier) - withResource(buffer.getColumnarBatch(sparkTypes)) { actualBatch => - TestUtils.compareBatches(expectedBatch, actualBatch) + withResource(expectedTable) { _ => + withResource( + GpuColumnVector.from(expectedTable.getTable, sparkTypes)) { expectedBatch => + catalog.synchronousSpill(devStore, 0) + catalog.synchronousSpill(hostStore, 0) + withResource(catalog.acquireBuffer(handle)) { buffer => + assertResult(StorageTier.DISK)(buffer.storageTier) + withResource(buffer.getColumnarBatch(sparkTypes)) { actualBatch => + TestUtils.compareBatches(expectedBatch, actualBatch) + } } } } @@ -119,14 +126,14 @@ class RapidsDiskStoreSuite extends FunSuiteWithTempDir with Arm with MockitoSuga assert(!bufferPath.exists) val spillPriority = -7 val hostStoreMaxSize = 1L * 1024 * 1024 - val catalog = new RapidsBufferCatalog - withResource(new RapidsDeviceMemoryStore(catalog)) { devStore => - withResource(new RapidsHostMemoryStore(hostStoreMaxSize, hostStoreMaxSize, catalog)) { + withResource(new RapidsDeviceMemoryStore) { devStore => + val catalog = new RapidsBufferCatalog(devStore) + withResource(new RapidsHostMemoryStore(hostStoreMaxSize, hostStoreMaxSize)) { hostStore => devStore.setSpillStore(hostStore) - withResource(new RapidsDiskStore(mock[RapidsDiskBlockManager], catalog)) { diskStore => + withResource(new RapidsDiskStore(mock[RapidsDiskBlockManager])) { diskStore => hostStore.setSpillStore(diskStore) - val (_, handle) = addTableToStore(devStore, bufferId, spillPriority) + val (_, handle) = addTableToCatalog(catalog, bufferId, spillPriority) assert(!handle.id.getDiskPath(null).exists()) val expectedBuffer = withResource(catalog.acquireBuffer(handle)) { buffer => assertResult(StorageTier.DEVICE)(buffer.storageTier) @@ -138,8 +145,8 @@ class RapidsDiskStoreSuite extends FunSuiteWithTempDir with Arm with MockitoSuga } } withResource(expectedBuffer) { expectedBuffer => - devStore.synchronousSpill(0) - hostStore.synchronousSpill(0) + catalog.synchronousSpill(devStore, 0) + catalog.synchronousSpill(hostStore, 0) withResource(catalog.acquireBuffer(handle)) { buffer => assertResult(StorageTier.DISK)(buffer.storageTier) withResource(buffer.getMemoryBuffer) { actualBuffer => @@ -168,18 +175,18 @@ class RapidsDiskStoreSuite extends FunSuiteWithTempDir with Arm with MockitoSuga assert(!bufferPath.exists) val spillPriority = -7 val hostStoreMaxSize = 1L * 1024 * 1024 - val catalog = new RapidsBufferCatalog - withResource(new RapidsDeviceMemoryStore(catalog)) { devStore => - withResource(new RapidsHostMemoryStore(hostStoreMaxSize, hostStoreMaxSize, catalog)) { + withResource(new RapidsDeviceMemoryStore) { devStore => + val catalog = new RapidsBufferCatalog(devStore) + withResource(new RapidsHostMemoryStore(hostStoreMaxSize, hostStoreMaxSize)) { hostStore => devStore.setSpillStore(hostStore) - withResource(new RapidsDiskStore(mock[RapidsDiskBlockManager], catalog)) { diskStore => + withResource(new RapidsDiskStore(mock[RapidsDiskBlockManager])) { diskStore => hostStore.setSpillStore(diskStore) - val (_, handle) = addTableToStore(devStore, bufferId, spillPriority) + val (_, handle) = addTableToCatalog(catalog, bufferId, spillPriority) val bufferPath = handle.id.getDiskPath(null) assert(!bufferPath.exists()) - devStore.synchronousSpill(0) - hostStore.synchronousSpill(0) + catalog.synchronousSpill(devStore, 0) + catalog.synchronousSpill(hostStore, 0) assert(bufferPath.exists) handle.close() if (canShareDiskPaths) { @@ -192,14 +199,14 @@ class RapidsDiskStoreSuite extends FunSuiteWithTempDir with Arm with MockitoSuga } } - private def addTableToStore( - deviceMemoryStore: RapidsDeviceMemoryStore, + private def addTableToCatalog( + catalog: RapidsBufferCatalog, bufferId: RapidsBufferId, spillPriority: Long): (Long, RapidsBufferHandle) = { withResource(buildContiguousTable()) { ct => val bufferSize = ct.getBuffer.getLength // store takes ownership of the table - val handle = deviceMemoryStore.addContiguousTable( + val handle = catalog.addContiguousTable( bufferId, ct, spillPriority, diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsGdsStoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsGdsStoreSuite.scala index fbb3b6b5750..6d99b632bea 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsGdsStoreSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsGdsStoreSuite.scala @@ -57,11 +57,11 @@ class RapidsGdsStoreSuite extends FunSuiteWithTempDir with Arm with MockitoSugar .thenReturn(paths(1)) paths.foreach(f => assert(!f.exists)) val spillPriority = -7 - val catalog = spy(new RapidsBufferCatalog) val batchWriteBufferSize = 16384 // Holds 2 buffers. - withResource(new RapidsDeviceMemoryStore(catalog)) { devStore => + withResource(new RapidsDeviceMemoryStore) { devStore => + val catalog = spy(new RapidsBufferCatalog(devStore)) withResource(new RapidsGdsStore( - diskBlockManager, batchWriteBufferSize, catalog)) { gdsStore => + diskBlockManager, batchWriteBufferSize)) { gdsStore => devStore.setSpillStore(gdsStore) assertResult(0)(gdsStore.currentSize) @@ -70,8 +70,8 @@ class RapidsGdsStoreSuite extends FunSuiteWithTempDir with Arm with MockitoSugar val bufferHandles = new Array[RapidsBufferHandle](bufferIds.length) bufferIds.zipWithIndex.foreach { case(id, ix) => - val (size, handle) = addTableToStore(devStore, id, spillPriority) - devStore.synchronousSpill(0) + val (size, handle) = addTableToCatalog(catalog, id, spillPriority) + catalog.synchronousSpill(devStore, 0) bufferSizes(ix) = size bufferHandles(ix) = handle } @@ -109,14 +109,14 @@ class RapidsGdsStoreSuite extends FunSuiteWithTempDir with Arm with MockitoSugar val path = bufferId.getDiskPath(null) assert(!path.exists) val spillPriority = -7 - val catalog = spy(new RapidsBufferCatalog) - withResource(new RapidsDeviceMemoryStore(catalog)) { devStore => - withResource(new RapidsGdsStore(mock[RapidsDiskBlockManager], 4096, catalog)) { + withResource(new RapidsDeviceMemoryStore) { devStore => + val catalog = spy(new RapidsBufferCatalog(devStore)) + withResource(new RapidsGdsStore(mock[RapidsDiskBlockManager], 4096)) { gdsStore => devStore.setSpillStore(gdsStore) assertResult(0)(gdsStore.currentSize) - val (bufferSize, handle) = addTableToStore(devStore, bufferId, spillPriority) - devStore.synchronousSpill(0) + val (bufferSize, handle) = addTableToCatalog(catalog, bufferId, spillPriority) + catalog.synchronousSpill(devStore, 0) assertResult(bufferSize)(gdsStore.currentSize) assert(path.exists) assertResult(bufferSize)(path.length) @@ -140,14 +140,14 @@ class RapidsGdsStoreSuite extends FunSuiteWithTempDir with Arm with MockitoSugar } } - private def addTableToStore( - devStore: RapidsDeviceMemoryStore, + private def addTableToCatalog( + catalog: RapidsBufferCatalog, bufferId: RapidsBufferId, spillPriority: Long): (Long, RapidsBufferHandle) = { withResource(buildContiguousTable()) { ct => val bufferSize = ct.getBuffer.getLength // store takes ownership of the table - val handle = devStore.addContiguousTable(bufferId, ct, spillPriority, + val handle = catalog.addContiguousTable(bufferId, ct, spillPriority, RapidsBuffer.defaultSpillCallback, false) (bufferSize, handle) } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala index bc6dd6fefce..16895182051 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala @@ -21,12 +21,15 @@ import java.math.RoundingMode import ai.rapids.cudf.{ContiguousTable, Cuda, HostColumnVector, HostMemoryBuffer, MemoryBuffer, Table} import org.mockito.{ArgumentCaptor, ArgumentMatchers} +import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.{never, spy, times, verify, when} import org.scalatest.FunSuite import org.scalatest.mockito.MockitoSugar import org.apache.spark.sql.rapids.RapidsDiskBlockManager import org.apache.spark.sql.types.{DataType, DecimalType, DoubleType, IntegerType, LongType, StringType} +import org.apache.spark.sql.vectorized.ColumnarBatch + class RapidsHostMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { private def buildContiguousTable(): ContiguousTable = { @@ -54,25 +57,27 @@ class RapidsHostMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { test("spill updates catalog") { val spillPriority = -7 val hostStoreMaxSize = 1L * 1024 * 1024 - val catalog = spy(new RapidsBufferCatalog) - withResource(new RapidsDeviceMemoryStore(catalog)) { devStore => - withResource(new RapidsHostMemoryStore(hostStoreMaxSize, hostStoreMaxSize, catalog)) { + val mockStore = mock[RapidsHostMemoryStore] + withResource(new RapidsDeviceMemoryStore) { devStore => + val catalog = spy(new RapidsBufferCatalog(devStore)) + withResource(new RapidsHostMemoryStore(hostStoreMaxSize, hostStoreMaxSize)) { hostStore => assertResult(0)(hostStore.currentSize) assertResult(hostStoreMaxSize)(hostStore.numBytesFree) devStore.setSpillStore(hostStore) + hostStore.setSpillStore(mockStore) val (bufferSize, handle) = withResource(buildContiguousTable()) { ct => val len = ct.getBuffer.getLength // store takes ownership of the table - val handle = devStore.addContiguousTable( + val handle = catalog.addContiguousTable( ct, spillPriority, RapidsBuffer.defaultSpillCallback) (len, handle) } - devStore.synchronousSpill(0) + catalog.synchronousSpill(devStore, 0) assertResult(bufferSize)(hostStore.currentSize) assertResult(hostStoreMaxSize - bufferSize)(hostStore.numBytesFree) verify(catalog, times(2)).registerNewBuffer(ArgumentMatchers.any[RapidsBuffer]) @@ -91,25 +96,29 @@ class RapidsHostMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { test("get columnar batch") { val spillPriority = -10 val hostStoreMaxSize = 1L * 1024 * 1024 - val catalog = new RapidsBufferCatalog - withResource(new RapidsDeviceMemoryStore(catalog)) { devStore => - withResource(new RapidsHostMemoryStore(hostStoreMaxSize, hostStoreMaxSize, catalog)) { + val mockStore = mock[RapidsHostMemoryStore] + withResource(new RapidsDeviceMemoryStore) { devStore => + val catalog = new RapidsBufferCatalog(devStore) + withResource(new RapidsHostMemoryStore(hostStoreMaxSize, hostStoreMaxSize)) { hostStore => devStore.setSpillStore(hostStore) - withResource(buildContiguousTable()) { ct => - withResource(HostMemoryBuffer.allocate(ct.getBuffer.getLength)) { expectedBuffer => - expectedBuffer.copyFromDeviceBuffer(ct.getBuffer) - val handle = devStore.addContiguousTable( - ct, - spillPriority, - RapidsBuffer.defaultSpillCallback) - devStore.synchronousSpill(0) - withResource(catalog.acquireBuffer(handle)) { buffer => - withResource(buffer.getMemoryBuffer) { actualBuffer => - assert(actualBuffer.isInstanceOf[HostMemoryBuffer]) - assertResult(expectedBuffer.asByteBuffer) { - actualBuffer.asInstanceOf[HostMemoryBuffer].asByteBuffer - } + hostStore.setSpillStore(mockStore) + var expectedBuffer: HostMemoryBuffer = null + val handle = withResource(buildContiguousTable()) { ct => + expectedBuffer = HostMemoryBuffer.allocate(ct.getBuffer.getLength) + expectedBuffer.copyFromDeviceBuffer(ct.getBuffer) + catalog.addContiguousTable( + ct, + spillPriority, + RapidsBuffer.defaultSpillCallback) + } + withResource(expectedBuffer) { _ => + catalog.synchronousSpill(devStore, 0) + withResource(catalog.acquireBuffer(handle)) { buffer => + withResource(buffer.getMemoryBuffer) { actualBuffer => + assert(actualBuffer.isInstanceOf[HostMemoryBuffer]) + assertResult(expectedBuffer.asByteBuffer) { + actualBuffer.asInstanceOf[HostMemoryBuffer].asByteBuffer } } } @@ -123,26 +132,32 @@ class RapidsHostMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { DecimalType(ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION, 5)) val spillPriority = -10 val hostStoreMaxSize = 1L * 1024 * 1024 - val catalog = new RapidsBufferCatalog - withResource(new RapidsDeviceMemoryStore(catalog)) { devStore => - withResource( - new RapidsHostMemoryStore(hostStoreMaxSize, hostStoreMaxSize, catalog, devStore)) { + val mockStore = mock[RapidsHostMemoryStore] + withResource(new RapidsDeviceMemoryStore) { devStore => + val catalog = new RapidsBufferCatalog(devStore) + withResource(new RapidsHostMemoryStore(hostStoreMaxSize, hostStoreMaxSize)) { hostStore => devStore.setSpillStore(hostStore) - withResource(buildContiguousTable()) { ct => - withResource(GpuColumnVector.from(ct.getTable, sparkTypes)) { - expectedBatch => - val handle = devStore.addContiguousTable( - ct, - spillPriority, - RapidsBuffer.defaultSpillCallback) - devStore.synchronousSpill(0) - withResource(catalog.acquireBuffer(handle)) { buffer => - assertResult(StorageTier.HOST)(buffer.storageTier) - withResource(buffer.getColumnarBatch(sparkTypes)) { actualBatch => - TestUtils.compareBatches(expectedBatch, actualBatch) - } - } + hostStore.setSpillStore(mockStore) + var expectedBatch: ColumnarBatch = null + val handle = withResource(buildContiguousTable()) { ct => + // make a copy of the table so we can compare it later to the + // one reconstituted after the spill + withResource(ct.getTable.contiguousSplit()) { copied => + expectedBatch = GpuColumnVector.from(copied(0).getTable, sparkTypes) + } + catalog.addContiguousTable( + ct, + spillPriority, + RapidsBuffer.defaultSpillCallback) + } + withResource(expectedBatch) { _ => + catalog.synchronousSpill(devStore, 0) + withResource(catalog.acquireBuffer(handle)) { buffer => + assertResult(StorageTier.HOST)(buffer.storageTier) + withResource(buffer.getColumnarBatch(sparkTypes)) { actualBatch => + TestUtils.compareBatches(expectedBatch, actualBatch) + } } } } @@ -153,50 +168,68 @@ class RapidsHostMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { val sparkTypes = Array[DataType](LongType) val spillPriority = -10 val hostStoreMaxSize = 256 - val catalog = new RapidsBufferCatalog - withResource(new RapidsDeviceMemoryStore(catalog)) { devStore => + withResource(new RapidsDeviceMemoryStore) { devStore => + val catalog = new RapidsBufferCatalog(devStore) val mockStore = mock[RapidsBufferStore] + val mockBuff = mock[mockStore.RapidsBufferBase] + when(mockBuff.id).thenReturn(new RapidsBufferId { + override val tableId: Int = 0 + override def getDiskPath(diskBlockManager: RapidsDiskBlockManager): File = null + }) + when(mockStore.getMaxSize).thenAnswer(_ => None) + when(mockStore.copyBuffer(any(), any(), any())).thenReturn(mockBuff) when(mockStore.tier) thenReturn (StorageTier.DISK) - withResource( - new RapidsHostMemoryStore(hostStoreMaxSize, hostStoreMaxSize, catalog, devStore)) { - hostStore => - devStore.setSpillStore(hostStore) - hostStore.setSpillStore(mockStore) - withResource(buildContiguousTable(1024 * 1024)) { bigTable => - withResource(buildContiguousTable(1)) { smallTable => - withResource(GpuColumnVector.from(bigTable.getTable, sparkTypes)) { expectedBatch => - // store takes ownership of the table - val bigHandle = devStore.addContiguousTable( + withResource(new RapidsHostMemoryStore(hostStoreMaxSize, hostStoreMaxSize)) { hostStore => + devStore.setSpillStore(hostStore) + hostStore.setSpillStore(mockStore) + var bigHandle: RapidsBufferHandle = null + var bigTable = buildContiguousTable(1024 * 1024) + var smallTable = buildContiguousTable(1) + closeOnExcept(bigTable) { _ => + closeOnExcept(smallTable) { _ => + // make a copy of the table so we can compare it later to the + // one reconstituted after the spill + val expectedBatch = + withResource(bigTable.getTable.contiguousSplit()) { expectedTable => + GpuColumnVector.from(expectedTable(0).getTable, sparkTypes) + } + withResource(expectedBatch) { _ => + bigHandle = withResource(bigTable) { _ => + catalog.addContiguousTable( bigTable, spillPriority, RapidsBuffer.defaultSpillCallback) - devStore.synchronousSpill(0) - verify(mockStore, never()).copyBuffer(ArgumentMatchers.any[RapidsBuffer], - ArgumentMatchers.any[MemoryBuffer], - ArgumentMatchers.any[Cuda.Stream]) - withResource(catalog.acquireBuffer(bigHandle)) { buffer => - assertResult(StorageTier.HOST)(buffer.storageTier) - withResource(buffer.getColumnarBatch(sparkTypes)) { actualBatch => - TestUtils.compareBatches(expectedBatch, actualBatch) - } - } - - devStore.addContiguousTable( - smallTable, spillPriority, - RapidsBuffer.defaultSpillCallback, false) - devStore.synchronousSpill(0) - val rapidsBufferCaptor: ArgumentCaptor[RapidsBuffer] = - ArgumentCaptor.forClass(classOf[RapidsBuffer]) - val memoryBufferCaptor: ArgumentCaptor[MemoryBuffer] = - ArgumentCaptor.forClass(classOf[MemoryBuffer]) - verify(mockStore).copyBuffer(rapidsBufferCaptor.capture(), - memoryBufferCaptor.capture(), ArgumentMatchers.any[Cuda.Stream]) - withResource(memoryBufferCaptor.getValue) { _ => - assertResult(bigHandle.id)(rapidsBufferCaptor.getValue.id) + } // close the bigTable so it can be spilled + bigTable = null + catalog.synchronousSpill(devStore, 0) + verify(mockStore, never()).copyBuffer(ArgumentMatchers.any[RapidsBuffer], + ArgumentMatchers.any[MemoryBuffer], + ArgumentMatchers.any[Cuda.Stream]) + withResource(catalog.acquireBuffer(bigHandle)) { buffer => + assertResult(StorageTier.HOST)(buffer.storageTier) + withResource(buffer.getColumnarBatch(sparkTypes)) { actualBatch => + TestUtils.compareBatches(expectedBatch, actualBatch) } } } + withResource(smallTable) { _ => + catalog.addContiguousTable( + smallTable, spillPriority, + RapidsBuffer.defaultSpillCallback, false) + } // close the smallTable so it can be spilled + smallTable = null + catalog.synchronousSpill(devStore, 0) + val rapidsBufferCaptor: ArgumentCaptor[RapidsBuffer] = + ArgumentCaptor.forClass(classOf[RapidsBuffer]) + val memoryBufferCaptor: ArgumentCaptor[MemoryBuffer] = + ArgumentCaptor.forClass(classOf[MemoryBuffer]) + verify(mockStore).copyBuffer(rapidsBufferCaptor.capture(), + memoryBufferCaptor.capture(), ArgumentMatchers.any[Cuda.Stream]) + withResource(memoryBufferCaptor.getValue) { _ => + assertResult(bigHandle.id)(rapidsBufferCaptor.getValue.id) + } } + } } } }