diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala index b7539cf06461..44300b6100d8 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -101,8 +101,6 @@ trait RapidsBuffer extends AutoCloseable { /** The storage tier for this buffer */ val storageTier: StorageTier - val spillCallback: SpillCallback - /** * Get the columnar batch within this buffer. The caller must have * successfully acquired the buffer beforehand. @@ -171,13 +169,27 @@ trait RapidsBuffer extends AutoCloseable { */ def getSpillPriority: Long + /** + * Gets the spill metrics callback currently associated with this buffer. + * @return the current callback + */ + def getSpillCallback: SpillCallback + /** * Set the spill priority for this buffer. Lower values are higher priority * for spilling, meaning buffers with lower values will be preferred for * spilling over buffers with a higher value. + * @note should only be called from the buffer catalog * @param priority new priority value for this buffer */ def setSpillPriority(priority: Long): Unit + + /** + * Update the metrics callback that will be invoked next time a spill occurs. + * @note should only be called from the buffer catalog + * @param spillCallback the new callback + */ + def setSpillCallback(spillCallback: SpillCallback): Unit } /** @@ -226,9 +238,11 @@ sealed class DegenerateRapidsBuffer( override def getSpillPriority: Long = Long.MaxValue + override val getSpillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback + override def setSpillPriority(priority: Long): Unit = {} - override def close(): Unit = {} + override def setSpillCallback(callback: SpillCallback): Unit = {} - override val spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback + override def close(): Unit = {} } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala index ffe5178f866a..d94b9951d253 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,22 +34,184 @@ import org.apache.spark.sql.rapids.RapidsDiskBlockManager */ class DuplicateBufferException(s: String) extends RuntimeException(s) {} +/** + * An object that client code uses to interact with an underlying RapidsBufferId. + * + * A handle is obtained when a buffer, batch, or table is added to the spill framework + * via the `RapidsBufferCatalog` api. + */ +trait RapidsBufferHandle { + val id: RapidsBufferId + + /** + * Sets the spill priority for this handle and updates the maximum priority + * for the underlying `RapidsBuffer` if this new priority is the maximum. + * @param newPriority new priority for this handle + */ + def setSpillPriority(newPriority: Long): Unit +} + /** * Catalog for lookup of buffers by ID. The constructor is only visible for testing, generally * `RapidsBufferCatalog.singleton` should be used instead. */ -class RapidsBufferCatalog extends Logging { +class RapidsBufferCatalog extends Arm { + /** Map of buffer IDs to buffers sorted by storage tier */ private[this] val bufferMap = new ConcurrentHashMap[RapidsBufferId, Seq[RapidsBuffer]] + /** Map of buffer IDs to buffer handles in insertion order */ + private[this] val bufferIdToHandles = + new ConcurrentHashMap[RapidsBufferId, Seq[RapidsBufferHandleImpl]]() + + class RapidsBufferHandleImpl( + override val id: RapidsBufferId, + var priority: Long, + spillCallback: SpillCallback) + extends RapidsBufferHandle with Arm { + + override def setSpillPriority(newPriority: Long): Unit = { + priority = newPriority + updateUnderlyingRapidsBuffer(this) + } + + /** + * Get the spill priority that was associated with this handle. Since there can + * be multiple handles associated with one `RapidsBuffer`, the priority returned + * here is only useful for code in the catalog that updates the maximum priority + * for the underlying `RapidsBuffer` as handles are added and removed. + * + * @return this handle's spill priority + */ + def getSpillPriority: Long = priority + + /** + * Each handle was created in a different part of the code and as such could have + * different spill metrics callbacks. This function is used by the catalog to find + * out what the last spill callback added. This last callback gets reports of + * spill bytes if a spill were to occur to the `RapidsBuffer` this handle points to. + * + * @return the spill callback associated with this handle + */ + def getSpillCallback: SpillCallback = spillCallback + } + /** - * Lookup the buffer that corresponds to the specified buffer ID at the highest storage tier, + * Makes a new `RapidsBufferHandle` associated with `id`, keeping track + * of the spill priority and callback within this handle. + * + * This function also adds the handle for internal tracking in the catalog. + * + * @param id the `RapidsBufferId` that this handle refers to + * @param spillPriority the spill priority specified on creation of the handle + * @param spillCallback this handle's spill callback + * @note public for testing + * @return a new instance of `RapidsBufferHandle` + */ + def makeNewHandle( + id: RapidsBufferId, + spillPriority: Long, + spillCallback: SpillCallback): RapidsBufferHandle = { + val handle = new RapidsBufferHandleImpl(id, spillPriority, spillCallback) + trackNewHandle(handle) + handle + } + + /** + * Adds a handle to the internal `bufferIdToHandles` map. + * + * The priority and callback of the `RapidsBuffer` will also be updated. + * + * @param handle handle to start tracking + */ + private def trackNewHandle(handle: RapidsBufferHandleImpl): Unit = { + bufferIdToHandles.compute(handle.id, (_, h) => { + var handles = h + if (handles == null) { + handles = Seq.empty[RapidsBufferHandleImpl] + } + handles :+ handle + }) + updateUnderlyingRapidsBuffer(handle) + } + + /** + * Called when the `RapidsBufferHandle` is no longer needed by calling code + * + * If this is the last handle associated with a `RapidsBuffer`, `stopTrackingHandle` + * returns true, otherwise it returns false. + * + * @param handle handle to stop tracking + * @return + */ + private def stopTrackingHandle(handle: RapidsBufferHandle): Boolean = { + withResource(acquireBuffer(handle)) { buffer => + val id = handle.id + var maxPriority = Long.MinValue + val newHandles = bufferIdToHandles.compute(id, (_, handles) => { + if (handles == null) { + throw new IllegalStateException( + s"$id not found and we attempted to remove handles!") + } + if (handles.size == 1) { + require(handles.head == handle, + "Tried to remove a single handle, and we couldn't match on it") + null + } else { + val newHandles = handles.filter(h => h != handle).map { h => + if (h.getSpillPriority > maxPriority) { + maxPriority = h.getSpillPriority + } + h + } + if (newHandles.isEmpty) { + null // remove since no more handles exist, should not happen + } else { + // we pick the last spillCallback inserted as the winner every time + // this callback is going to get the metrics associated with this buffer's + // spill + newHandles + } + } + }) + + if (newHandles == null) { + // tell calling code that no more handles exist, + // for this RapidsBuffer + true + } else { + // more handles remain, our priority changed so we need to update things + buffer.setSpillPriority(maxPriority) + buffer.setSpillCallback(newHandles.last.getSpillCallback) + false // we have handles left + } + } + } + + /** + * Called by the catalog when a handle is first added to the catalog, or to refresh + * the priority of the underlying buffer if a handle's priority changed. + */ + private def updateUnderlyingRapidsBuffer(handle: RapidsBufferHandle): Unit = { + withResource(acquireBuffer(handle)) { buffer => + val handles = bufferIdToHandles.get(buffer.id) + val maxPriority = handles.map(_.getSpillPriority).max + // update the priority of the underlying RapidsBuffer to be the + // maximum priority for all handles associated with it + buffer.setSpillPriority(maxPriority) + buffer.setSpillCallback(handles.last.getSpillCallback) + } + } + + /** + * Lookup the buffer that corresponds to the specified handle at the highest storage tier, * and acquire it. * NOTE: It is the responsibility of the caller to close the buffer. - * @param id buffer identifier + * @param handle handle associated with this `RapidsBuffer` * @return buffer that has been acquired */ - def acquireBuffer(id: RapidsBufferId): RapidsBuffer = { + def acquireBuffer(handle: RapidsBufferHandle): RapidsBuffer = { + val id = handle.id (0 until RapidsBufferCatalog.MAX_BUFFER_LOOKUP_ATTEMPTS).foreach { _ => val buffers = bufferMap.get(id) if (buffers == null || buffers.isEmpty) { @@ -124,6 +286,7 @@ class RapidsBufferCatalog extends Logging { } } } + bufferMap.compute(buffer.id, updater) } @@ -142,10 +305,19 @@ class RapidsBufferCatalog extends Logging { bufferMap.computeIfPresent(id, updater) } - /** Remove a buffer ID from the catalog and release the resources of the registered buffers. */ - def removeBuffer(id: RapidsBufferId): Unit = { - val buffers = bufferMap.remove(id) - buffers.safeFree() + /** + * Remove a buffer handle from the catalog and, if it this was the final handle, + * release the resources of the registered buffers. + */ + def removeBuffer(handle: RapidsBufferHandle): Boolean = { + // if this is the last handle, remove the buffer + if (stopTrackingHandle(handle)) { + val buffers = bufferMap.remove(handle.id) + buffers.safeFree() + true + } else { + false + } } /** Return the number of buffers currently in the catalog. */ @@ -248,65 +420,76 @@ object RapidsBufferCatalog extends Logging with Arm { /** * Adds a contiguous table to the device storage, taking ownership of the table. - * @param id buffer ID to associate with this buffer * @param table cudf table based from the contiguous buffer * @param contigBuffer device memory buffer backing the table * @param tableMeta metadata describing the buffer layout * @param initialSpillPriority starting spill priority value for the buffer * @param spillCallback a callback when the buffer is spilled. This should be very light weight. * It should never allocate GPU memory and really just be used for metrics. + * @return RapidsBufferHandle associated with this buffer */ def addTable( - id: RapidsBufferId, table: Table, contigBuffer: DeviceMemoryBuffer, tableMeta: TableMeta, initialSpillPriority: Long, - spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback): Unit = - deviceStorage.addTable(id, table, contigBuffer, tableMeta, initialSpillPriority, spillCallback) + spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback): RapidsBufferHandle = { + val id = + deviceStorage.addTable(table, contigBuffer, tableMeta, initialSpillPriority) + singleton.makeNewHandle(id, initialSpillPriority, spillCallback) + } /** * Adds a contiguous table to the device storage, taking ownership of the table. - * @param id buffer ID to associate with this buffer - * @param contigTable contiguous table to track in device storage + * @param contigTable contiguous table to trackNewHandle in device storage * @param initialSpillPriority starting spill priority value for the buffer * @param spillCallback a callback when the buffer is spilled. This should be very light weight. * It should never allocate GPU memory and really just be used for metrics. + * @return RapidsBufferHandle associated with this buffer */ def addContiguousTable( - id: RapidsBufferId, contigTable: ContiguousTable, initialSpillPriority: Long, - spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback): Unit = - deviceStorage.addContiguousTable(id, contigTable, initialSpillPriority, spillCallback) + spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback): RapidsBufferHandle = { + val id = deviceStorage.addContiguousTable( + contigTable, initialSpillPriority, spillCallback) + singleton.makeNewHandle(id, initialSpillPriority, spillCallback) + } /** * Adds a buffer to the device storage, taking ownership of the buffer. - * @param id buffer ID to associate with this buffer * @param buffer buffer that will be owned by the store * @param tableMeta metadata describing the buffer layout * @param initialSpillPriority starting spill priority value for the buffer * @param spillCallback a callback when the buffer is spilled. This should be very light weight. * It should never allocate GPU memory and really just be used for metrics. + * @return RapidsBufferHandle associated with this buffer */ def addBuffer( - id: RapidsBufferId, buffer: DeviceMemoryBuffer, tableMeta: TableMeta, initialSpillPriority: Long, - spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback): Unit = - deviceStorage.addBuffer(id, buffer, tableMeta, initialSpillPriority, spillCallback) + spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback): RapidsBufferHandle = { + val id = deviceStorage.addBuffer( + buffer, tableMeta, initialSpillPriority, spillCallback) + singleton.makeNewHandle(id, initialSpillPriority, spillCallback) + } /** - * Lookup the buffer that corresponds to the specified buffer ID and acquire it. + * Lookup the buffer that corresponds to the specified buffer handle and acquire it. * NOTE: It is the responsibility of the caller to close the buffer. - * @param id buffer identifier + * @param handle buffer handle * @return buffer that has been acquired */ - def acquireBuffer(id: RapidsBufferId): RapidsBuffer = singleton.acquireBuffer(id) + def acquireBuffer(handle: RapidsBufferHandle): RapidsBuffer = + singleton.acquireBuffer(handle) - /** Remove a buffer ID from the catalog and release the resources of the registered buffer. */ - def removeBuffer(id: RapidsBufferId): Unit = singleton.removeBuffer(id) + /** + * Remove a buffer handle from the catalog and, if it this was the final handle, + * release the resources of the registered buffers. + */ + def removeBuffer(handle: RapidsBufferHandle): Unit = + singleton.removeBuffer(handle) def getDiskBlockManager(): RapidsDiskBlockManager = diskBlockManager } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferStore.scala index cf1db6e81ebd..ef97e99f0fdc 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferStore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferStore.scala @@ -254,7 +254,8 @@ abstract class RapidsBufferStore( } else { logDebug(s"Spilling $buffer ${buffer.id} to ${spillStore.name} " + s"total mem=${buffers.getTotalBytes}") - buffer.spillCallback(buffer.storageTier, spillStore.tier, buffer.size) + val spillCallback = buffer.getSpillCallback + spillCallback(buffer.storageTier, spillStore.tier, buffer.size) spillStore.copyBuffer(buffer, buffer.getMemoryBuffer, stream) } } finally { @@ -271,14 +272,16 @@ abstract class RapidsBufferStore( override val size: Long, override val meta: TableMeta, initialSpillPriority: Long, - override val spillCallback: SpillCallback, + initialSpillCallback: SpillCallback, catalog: RapidsBufferCatalog = RapidsBufferCatalog.singleton, deviceStorage: RapidsDeviceMemoryStore = RapidsBufferCatalog.getDeviceStorage) extends RapidsBuffer with Arm { private val MAX_UNSPILL_ATTEMPTS = 100 private[this] var isValid = true protected[this] var refcount = 0 + private[this] var spillPriority: Long = initialSpillPriority + private[this] var spillCallback: SpillCallback = initialSpillCallback /** Release the underlying resources for this buffer. */ protected def releaseResources(): Unit @@ -431,6 +434,12 @@ abstract class RapidsBufferStore( spillPriority = priority } + override def getSpillCallback: SpillCallback = spillCallback + + override def setSpillCallback(callback: SpillCallback): Unit = { + spillCallback = callback + } + /** Must be called with a lock on the buffer */ private def freeBuffer(): Unit = { releaseResources() diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala index c38e01e484e9..da4a6d52f3dd 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStore.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ import ai.rapids.cudf.{ContiguousTable, Cuda, DeviceMemoryBuffer, HostMemoryBuff import com.nvidia.spark.rapids.StorageTier.StorageTier import com.nvidia.spark.rapids.format.TableMeta +import org.apache.spark.sql.rapids.TempSpillBufferId import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.ColumnarBatch @@ -46,27 +47,40 @@ class RapidsDeviceMemoryStore(catalog: RapidsBufferCatalog = RapidsBufferCatalog case b => throw new IllegalStateException(s"Unrecognized buffer: $b") } } - new RapidsDeviceMemoryBuffer(other.id, other.size, other.meta, None, - deviceBuffer, other.getSpillPriority, other.spillCallback) + new RapidsDeviceMemoryBuffer( + other.id, + other.size, + other.meta, + None, + deviceBuffer, + other.getSpillPriority, + other.getSpillCallback) } /** * Adds a contiguous table to the device storage, taking ownership of the table. - * @param id buffer ID to associate with this buffer + * * @param table cudf table based from the contiguous buffer * @param contigBuffer device memory buffer backing the table * @param tableMeta metadata describing the buffer layout * @param initialSpillPriority starting spill priority value for the buffer * @param spillCallback a callback when the buffer is spilled. This should be very light weight. * It should never allocate GPU memory and really just be used for metrics. + * @return RapidsBufferId identifying this table */ def addTable( - id: RapidsBufferId, table: Table, contigBuffer: DeviceMemoryBuffer, tableMeta: TableMeta, initialSpillPriority: Long, - spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback): Unit = { + spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback): RapidsBufferId = { + // We increment this because this rapids device memory has two pointers to the buffer: + // the actual contig buffer, and the table. When this `RapidsBuffer` releases its resources, + // it will decrement the ref count for the contig buffer (negating this incRefCount), + // it will also close the table being passed here, which together brings the ref count + // to 0. + contigBuffer.incRefCount() + val id = TempSpillBufferId() freeOnExcept( new RapidsDeviceMemoryBuffer( id, @@ -77,8 +91,9 @@ class RapidsDeviceMemoryStore(catalog: RapidsBufferCatalog = RapidsBufferCatalog initialSpillPriority, spillCallback)) { buffer => logDebug(s"Adding table for: [id=$id, size=${buffer.size}, " + - s"meta_id=${buffer.meta.bufferMeta.id}, meta_size=${buffer.meta.bufferMeta.size}]") + s"meta_id=${buffer.meta.bufferMeta.id}, meta_size=${buffer.meta.bufferMeta.size}]") addDeviceBuffer(buffer, needsSync = true) + id } } @@ -87,20 +102,52 @@ class RapidsDeviceMemoryStore(catalog: RapidsBufferCatalog = RapidsBufferCatalog * contiguous table, so it is the responsibility of the caller to close it. The refcount of the * underlying device buffer will be incremented so the contiguous table can be closed before * this buffer is destroyed. - * @param id buffer ID to associate with this buffer + * + * This version of `addContiguousTable` creates a `TempSpillBufferId` to use + * to refer to this table. + * * @param contigTable contiguous table to track in storage * @param initialSpillPriority starting spill priority value for the buffer * @param spillCallback a callback when the buffer is spilled. This should be very light weight. * It should never allocate GPU memory and really just be used for metrics. * @param needsSync whether the spill framework should stream synchronize while adding * this device buffer (defaults to true) + * @return RapidsBufferId identifying this table */ def addContiguousTable( - id: RapidsBufferId, contigTable: ContiguousTable, initialSpillPriority: Long, spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback, - needsSync: Boolean = true): Unit = { + needsSync: Boolean = true): RapidsBufferId = { + addContiguousTable( + TempSpillBufferId(), + contigTable, + initialSpillPriority, + spillCallback, + needsSync) + } + + /** + * Adds a contiguous table to the device storage. This does NOT take ownership of the + * contiguous table, so it is the responsibility of the caller to close it. The refcount of the + * underlying device buffer will be incremented so the contiguous table can be closed before + * this buffer is destroyed. + * + * @param id the RapidsBufferId to use for this buffer + * @param contigTable contiguous table to track in storage + * @param initialSpillPriority starting spill priority value for the buffer + * @param spillCallback a callback when the buffer is spilled. This should be very light weight. + * It should never allocate GPU memory and really just be used for metrics. + * @param needsSync whether the spill framework should stream synchronize while adding + * this device buffer (defaults to true) + * @return RapidsBufferId identifying this table + */ + def addContiguousTable( + id: RapidsBufferId, + contigTable: ContiguousTable, + initialSpillPriority: Long, + spillCallback: SpillCallback, + needsSync: Boolean): RapidsBufferId = { val contigBuffer = contigTable.getBuffer val size = contigBuffer.getLength val meta = MetaUtils.buildTableMeta(id.tableId, contigTable) @@ -115,15 +162,20 @@ class RapidsDeviceMemoryStore(catalog: RapidsBufferCatalog = RapidsBufferCatalog initialSpillPriority, spillCallback)) { buffer => logDebug(s"Adding table for: [id=$id, size=${buffer.size}, " + - s"uncompressed=${buffer.meta.bufferMeta.uncompressedSize}, " + - s"meta_id=${buffer.meta.bufferMeta.id}, meta_size=${buffer.meta.bufferMeta.size}]") + s"uncompressed=${buffer.meta.bufferMeta.uncompressedSize}, " + + s"meta_id=${buffer.meta.bufferMeta.id}, meta_size=${buffer.meta.bufferMeta.size}]") addDeviceBuffer(buffer, needsSync) + id } } /** - * Adds a buffer to the device storage, taking ownership of the buffer. - * @param id buffer ID to associate with this buffer + * Adds a buffer to the device storage. This does NOT take ownership of the + * buffer, so it is the responsibility of the caller to close it. + * + * This version of `addBuffer` creates a `TempSpillBufferId` to use to refer to + * this buffer. + * * @param buffer buffer that will be owned by the store * @param tableMeta metadata describing the buffer layout * @param initialSpillPriority starting spill priority value for the buffer @@ -131,14 +183,45 @@ class RapidsDeviceMemoryStore(catalog: RapidsBufferCatalog = RapidsBufferCatalog * It should never allocate GPU memory and really just be used for metrics. * @param needsSync whether the spill framework should stream synchronize while adding * this device buffer (defaults to true) + * @return RapidsBufferId identifying this buffer */ def addBuffer( - id: RapidsBufferId, buffer: DeviceMemoryBuffer, tableMeta: TableMeta, initialSpillPriority: Long, spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback, - needsSync: Boolean = true): Unit = { + needsSync: Boolean = true): RapidsBufferId = { + addBuffer( + TempSpillBufferId(), + buffer, + tableMeta, + initialSpillPriority, + spillCallback, + needsSync) + } + + /** + * Adds a buffer to the device storage. This does NOT take ownership of the + * buffer, so it is the responsibility of the caller to close it. + * + * @param id the RapidsBufferId to use for this buffer + * @param buffer buffer that will be owned by the store + * @param tableMeta metadata describing the buffer layout + * @param initialSpillPriority starting spill priority value for the buffer + * @param spillCallback a callback when the buffer is spilled. This should be very light weight. + * It should never allocate GPU memory and really just be used for metrics. + * @param needsSync whether the spill framework should stream synchronize while adding + * this device buffer (defaults to true) + * @return RapidsBufferId identifying this buffer + */ + def addBuffer( + id: RapidsBufferId, + buffer: DeviceMemoryBuffer, + tableMeta: TableMeta, + initialSpillPriority: Long, + spillCallback: SpillCallback, + needsSync: Boolean): RapidsBufferId = { + buffer.incRefCount() freeOnExcept( new RapidsDeviceMemoryBuffer( id, @@ -149,10 +232,11 @@ class RapidsDeviceMemoryStore(catalog: RapidsBufferCatalog = RapidsBufferCatalog initialSpillPriority, spillCallback)) { buff => logDebug(s"Adding receive side table for: [id=$id, size=${buffer.getLength}, " + - s"uncompressed=${buff.meta.bufferMeta.uncompressedSize}, " + - s"meta_id=${tableMeta.bufferMeta.id}, " + - s"meta_size=${tableMeta.bufferMeta.size}]") + s"uncompressed=${buff.meta.bufferMeta.uncompressedSize}, " + + s"meta_id=${tableMeta.bufferMeta.id}, " + + s"meta_size=${tableMeta.bufferMeta.size}]") addDeviceBuffer(buff, needsSync) + id } } @@ -160,13 +244,16 @@ class RapidsDeviceMemoryStore(catalog: RapidsBufferCatalog = RapidsBufferCatalog * Adds a device buffer to the spill framework, stream synchronizing with the producer * stream to ensure that the buffer is fully materialized, and can be safely copied * as part of the spill. + * * @param needsSync true if we should stream synchronize before adding the buffer */ - private def addDeviceBuffer(buffer: RapidsDeviceMemoryBuffer, needsSync: Boolean): Unit = { + private def addDeviceBuffer( + buffer: RapidsDeviceMemoryBuffer, + needsSync: Boolean): Unit = { if (needsSync) { Cuda.DEFAULT_STREAM.sync() } - addBuffer(buffer); + addBuffer(buffer) } class RapidsDeviceMemoryBuffer( @@ -176,7 +263,7 @@ class RapidsDeviceMemoryStore(catalog: RapidsBufferCatalog = RapidsBufferCatalog table: Option[Table], contigBuffer: DeviceMemoryBuffer, spillPriority: Long, - override val spillCallback: SpillCallback) + spillCallback: SpillCallback) extends RapidsBufferBase(id, size, meta, spillPriority, spillCallback) { override val storageTier: StorageTier = StorageTier.DEVICE diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala index 6abc7bd5cc32..05e183a417cf 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala @@ -56,8 +56,14 @@ class RapidsDiskStore( copyBufferToPath(hostBuffer, path, append = false) } logDebug(s"Spilled to $path $fileOffset:${incoming.size}") - new this.RapidsDiskBuffer(id, fileOffset, incoming.size, incoming.meta, - incoming.getSpillPriority, incoming.spillCallback, deviceStorage) + new RapidsDiskBuffer( + id, + fileOffset, + incoming.size, + incoming.meta, + incoming.getSpillPriority, + incoming.getSpillCallback, + deviceStorage) } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsGdsStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsGdsStore.scala index 7936d6fdcbbe..b9c754c14aae 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsGdsStore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsGdsStore.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -62,7 +62,7 @@ class RapidsGdsStore( override val size: Long, override val meta: TableMeta, spillPriority: Long, - override val spillCallback: SpillCallback) + spillCallback: SpillCallback) extends RapidsBufferBase(id, size, meta, spillPriority, spillCallback) { override val storageTier: StorageTier = StorageTier.GDS @@ -124,8 +124,14 @@ class RapidsGdsStore( 0 } logDebug(s"Spilled to $path $fileOffset:${other.size} via GDS") - new RapidsGdsSingleShotBuffer(id, path, fileOffset, other.size, other.meta, - other.getSpillPriority, other.spillCallback) + new RapidsGdsSingleShotBuffer( + id, + path, + fileOffset, + other.size, + other.meta, + other.getSpillPriority, + other.getSpillCallback) } class BatchSpiller() extends AutoCloseable { @@ -161,8 +167,14 @@ class RapidsGdsStore( val id = other.id addBuffer(currentFile, id) - val gdsBuffer = new RapidsGdsBatchedBuffer(id, currentFile, currentOffset, - other.size, other.meta, other.getSpillPriority, other.spillCallback) + val gdsBuffer = new RapidsGdsBatchedBuffer( + id, + currentFile, + currentOffset, + other.size, + other.meta, + other.getSpillPriority, + other.getSpillCallback) currentOffset += alignUp(deviceBuffer.getLength) pendingBuffers += gdsBuffer gdsBuffer diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala index 15483fed2493..ef3dd77ff88e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -101,7 +101,7 @@ class RapidsHostMemoryStore( applyPriorityOffset(other.getSpillPriority, allocationMode.spillPriorityOffset), hostBuffer, allocationMode, - other.spillCallback, + other.getSpillCallback, deviceStorage) } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala index 0f228a0735f5..fd1d1c55be82 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleBufferCatalog.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ import java.util.function.{Consumer, IntUnaryOperator} import scala.collection.mutable.ArrayBuffer -import ai.rapids.cudf.Cuda +import ai.rapids.cudf.{ContiguousTable, Cuda, DeviceMemoryBuffer} import com.nvidia.spark.rapids.format.TableMeta import org.apache.spark.SparkEnv @@ -50,9 +50,117 @@ case class ShuffleBufferId( class ShuffleBufferCatalog( catalog: RapidsBufferCatalog, diskBlockManager: RapidsDiskBlockManager) extends Arm with Logging { + + private val deviceStore = RapidsBufferCatalog.getDeviceStorage + + private val bufferIdToHandle = new ConcurrentHashMap[RapidsBufferId, RapidsBufferHandle]() + + private def trackCachedHandle( + bufferId: ShuffleBufferId, + bufferHandle: RapidsBufferHandle): Unit = { + bufferIdToHandle.put(bufferId, bufferHandle) + } + + def removeCachedHandles(): Unit = { + bufferIdToHandle.forEach { case (_, handle) => + removeBuffer(handle) + } + } + + /** + * Adds a contiguous table shuffle table to the device storage. This does NOT take ownership of + * the contiguous table, so it is the responsibility of the caller to close it. + * The refcount of the underlying device buffer will be incremented so the contiguous table + * can be closed before this buffer is destroyed. + * + * @param blockId Spark's `ShuffleBlockId` that identifies this buffer + * @param contigTable contiguous table to track in storage + * @param initialSpillPriority starting spill priority value for the buffer + * @param spillCallback a callback when the buffer is spilled. This should be very light weight. + * It should never allocate GPU memory and really just be used for metrics. + * @param needsSync whether the spill framework should stream synchronize while adding + * this device buffer (defaults to true) + * @return RapidsBufferId identifying this table + */ + def addContiguousTable( + blockId: ShuffleBlockId, + contigTable: ContiguousTable, + initialSpillPriority: Long, + defaultSpillCallback: SpillCallback, + needsSync: Boolean): RapidsBufferHandle = { + val bufferId = nextShuffleBufferId(blockId) + withResource(contigTable) { _ => + deviceStore.addContiguousTable( + bufferId, + contigTable, + initialSpillPriority, + defaultSpillCallback, + needsSync) + val handle = catalog.makeNewHandle( + bufferId, initialSpillPriority, defaultSpillCallback) + trackCachedHandle(bufferId, handle) + handle + } + } + + /** + * Adds a buffer to the device storage, taking ownership of the buffer. + * + * @param blockId Spark's `ShuffleBlockId` that identifies this buffer + * @param buffer buffer that will be owned by the store + * @param tableMeta metadata describing the buffer layout + * @param initialSpillPriority starting spill priority value for the buffer + * @param spillCallback a callback when the buffer is spilled. This should be very light weight. + * It should never allocate GPU memory and really just be used for metrics. + * @return RapidsBufferHandle associated with this buffer + */ + def addBuffer( + blockId: ShuffleBlockId, + buffer: DeviceMemoryBuffer, + tableMeta: TableMeta, + initialSpillPriority: Long, + defaultSpillCallback: SpillCallback, + needsSync: Boolean): RapidsBufferHandle = { + val bufferId = nextShuffleBufferId(blockId) + // update the table metadata for the buffer ID generated above + tableMeta.bufferMeta.mutateId(bufferId.tableId) + // when we call `addBuffer` the store will incRefCount + withResource(buffer) { _ => + deviceStore.addBuffer( + bufferId, + buffer, + tableMeta, + initialSpillPriority, + defaultSpillCallback, + needsSync) + val handle = + catalog.makeNewHandle(bufferId, initialSpillPriority, defaultSpillCallback) + trackCachedHandle(bufferId, handle) + handle + } + } + + /** + * Register a new buffer with the catalog. An exception will be thrown if an + * existing buffer was registered with the same block ID (extremely unlikely) + */ + def addDegenerateRapidsBuffer( + blockId: ShuffleBlockId, + meta: TableMeta, + spillCallback: SpillCallback): RapidsBufferHandle = { + val bufferId = nextShuffleBufferId(blockId) + val buffer = new DegenerateRapidsBuffer(bufferId, meta) + catalog.registerNewBuffer(buffer) + val handle = + catalog.makeNewHandle(buffer.id, buffer.getSpillPriority, spillCallback) + trackCachedHandle(bufferId, handle) + handle + } + /** * Information stored for each active shuffle. * NOTE: ArrayBuffer in blockMap must be explicitly locked when using it! + * * @param blockMap mapping of block ID to array of buffers for the block */ private case class ShuffleInfo( @@ -88,7 +196,13 @@ class ShuffleBufferCatalog( // NOTE: Not synchronizing array buffer because this shuffle should be inactive. bufferIds.foreach { id => tableMap.remove(id.tableId) - catalog.removeBuffer(id) + val didRemove = catalog.removeBuffer(bufferIdToHandle.get(id)) + if (!didRemove) { + logWarning(s"Unable to remove from underlying storage ${id} when cleaning " + + s"shuffle blocks.") + } else { + logWarning(s"Did remove ${id}") + } } } info.blockMap.forEachValue(Long.MaxValue, bufferRemover) @@ -126,6 +240,20 @@ class ShuffleBufferCatalog( } } + def blockIdToBufferHandles(blockId: ShuffleBlockId): Array[RapidsBufferHandle] = { + val info = activeShuffles.get(blockId.shuffleId) + if (info == null) { + throw new NoSuchElementException(s"unknown shuffle $blockId.shuffleId") + } + val entries = info.blockMap.get(blockId) + if (entries == null) { + throw new NoSuchElementException(s"unknown shuffle block $blockId") + } + entries.synchronized { + entries.map(bufferIdToHandle.get).toArray + } + } + /** Get all the buffer metadata that correspond to a shuffle block identifier. */ def blockIdToMetas(blockId: ShuffleBlockId): Seq[TableMeta] = { blockIdToBuffersIds(blockId).map(catalog.getBufferMeta) @@ -151,76 +279,50 @@ class ShuffleBufferCatalog( blockBufferIds.synchronized { blockBufferIds.append(id) } - id } - /** Allocate a new table identifier for a shuffle block and update the shuffle block mapping. */ - def nextTableId(blockId: ShuffleBlockId): Int = { - val shuffleBufferId = nextShuffleBufferId(blockId) - shuffleBufferId.tableId - } - - /** Lookup the shuffle buffer identifier that corresponds to the specified table identifier. */ - def getShuffleBufferId(tableId: Int): ShuffleBufferId = { + /** Lookup the shuffle buffer handle that corresponds to the specified table identifier. */ + def getShuffleBufferHandle(tableId: Int): RapidsBufferHandle = { val shuffleBufferId = tableMap.get(tableId) if (shuffleBufferId == null) { throw new NoSuchElementException(s"unknown table ID $tableId") } - shuffleBufferId + bufferIdToHandle.get(shuffleBufferId) } - /** - * Register a new buffer with the catalog. An exception will be thrown if an - * existing buffer was registered with the same buffer ID. - */ - def registerNewBuffer(buffer: RapidsBuffer): Unit = catalog.registerNewBuffer(buffer) - /** * Update the spill priority of a shuffle buffer that soon will be read locally. - * @param id shuffle buffer identifier of buffer to update + * @param handle shuffle buffer handle of buffer to update */ - def updateSpillPriorityForLocalRead(id: ShuffleBufferId): Unit = { - withResource(catalog.acquireBuffer(id)) { buffer => - buffer.setSpillPriority(SpillPriorities.INPUT_FROM_SHUFFLE_PRIORITY) - } + def updateSpillPriorityForLocalRead(handle: RapidsBufferHandle): Unit = { + handle.setSpillPriority(SpillPriorities.INPUT_FROM_SHUFFLE_PRIORITY) } /** - * Lookup the shuffle buffer that corresponds to the specified shuffle buffer ID and acquire it. + * Lookup the shuffle buffer that corresponds to the specified buffer handle and acquire it. * NOTE: It is the responsibility of the caller to close the buffer. - * @param id shuffle buffer identifier + * @param handle shuffle buffer handle * @return shuffle buffer that has been acquired */ - def acquireBuffer(id: ShuffleBufferId): RapidsBuffer = { - val buffer = catalog.acquireBuffer(id) + def acquireBuffer(handle: RapidsBufferHandle): RapidsBuffer = { + val buffer = catalog.acquireBuffer(handle) // Shuffle buffers that have been read are less likely to be read again, // so update the spill priority based on this access - val spillPriority = SpillPriorities.getShuffleOutputBufferReadPriority - buffer.setSpillPriority(spillPriority) + handle.setSpillPriority(SpillPriorities.getShuffleOutputBufferReadPriority) buffer } /** - * Lookup the shuffle buffer that corresponds to the specified table ID and acquire it. - * NOTE: It is the responsibility of the caller to close the buffer. - * @param tableId table identifier - * @return shuffle buffer that has been acquired - */ - def acquireBuffer(tableId: Int): RapidsBuffer = { - val shuffleBufferId = getShuffleBufferId(tableId) - acquireBuffer(shuffleBufferId) - } - - /** - * Remove a buffer and table given a buffer ID + * Remove a buffer and table given a buffer handle * NOTE: This function is not thread safe! The caller should only invoke if - * the [[ShuffleBufferId]] being removed is not being utilized by another thread. - * @param id buffer identifier + * the handle being removed is not being utilized by another thread. + * @param handle buffer handle */ - def removeBuffer(id: ShuffleBufferId): Unit = { + def removeBuffer(handle: RapidsBufferHandle): Unit = { + val id = handle.id tableMap.remove(id.tableId) - catalog.removeBuffer(id) + catalog.removeBuffer(handle) } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala index 6ebf375d232d..2fecf1bc7060 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShuffleReceivedBufferCatalog.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,9 @@ import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicInteger import java.util.function.IntUnaryOperator +import ai.rapids.cudf.DeviceMemoryBuffer +import com.nvidia.spark.rapids.format.TableMeta + import org.apache.spark.internal.Logging import org.apache.spark.sql.rapids.RapidsDiskBlockManager @@ -37,7 +40,10 @@ case class ShuffleReceivedBufferId( /** Catalog for lookup of shuffle buffers by block ID */ class ShuffleReceivedBufferCatalog( - catalog: RapidsBufferCatalog) extends Logging { + catalog: RapidsBufferCatalog) extends Arm with Logging { + + private val deviceStore = RapidsBufferCatalog.getDeviceStorage + /** Mapping of table ID to shuffle buffer ID */ private[this] val tableMap = new ConcurrentHashMap[Int, ShuffleReceivedBufferId] @@ -45,7 +51,7 @@ class ShuffleReceivedBufferCatalog( private[this] val tableIdCounter = new AtomicInteger(0) /** Allocate a new shuffle buffer identifier and update the shuffle block mapping. */ - def nextShuffleReceivedBufferId(): ShuffleReceivedBufferId = { + private def nextShuffleReceivedBufferId(): ShuffleReceivedBufferId = { val tableId = tableIdCounter.getAndUpdate(ShuffleReceivedBufferCatalog.TABLE_ID_UPDATER) val id = ShuffleReceivedBufferId(tableId) val prev = tableMap.put(tableId, id) @@ -55,49 +61,78 @@ class ShuffleReceivedBufferCatalog( id } - /** Lookup the shuffle buffer identifier that corresponds to the specified table identifier. */ - def getShuffleBufferId(tableId: Int): ShuffleReceivedBufferId = { - val shuffleBufferId = tableMap.get(tableId) - if (shuffleBufferId == null) { - throw new NoSuchElementException(s"unknown table ID $tableId") - } - shuffleBufferId - } - /** - * Register a new buffer with the catalog. An exception will be thrown if an - * existing buffer was registered with the same buffer ID. + * Adds a buffer to the device storage, taking ownership of the buffer. + * + * This method associates a new `bufferId` which is tracked internally in this catalog. + * + * @param buffer buffer that will be owned by the store + * @param tableMeta metadata describing the buffer layout + * @param initialSpillPriority starting spill priority value for the buffer + * @param spillCallback a callback when the buffer is spilled. This should be very light weight. + * It should never allocate GPU memory and really just be used for metrics. + * @param needsSync tells the store a synchronize in the current stream is required + * before storing this buffer + * @return RapidsBufferHandle associated with this buffer */ - def registerNewBuffer(buffer: RapidsBuffer): Unit = catalog.registerNewBuffer(buffer) + def addBuffer( + buffer: DeviceMemoryBuffer, + tableMeta: TableMeta, + initialSpillPriority: Long, + defaultSpillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback, + needsSync: Boolean): RapidsBufferHandle = { + val bufferId = nextShuffleReceivedBufferId() + tableMeta.bufferMeta.mutateId(bufferId.tableId) + // when we call `addBuffer` the store will incRefCount + withResource(buffer) { _ => + deviceStore.addBuffer( + bufferId, + buffer, + tableMeta, + initialSpillPriority, + defaultSpillCallback, + needsSync) + catalog.makeNewHandle(bufferId, initialSpillPriority, defaultSpillCallback) + } + } /** - * Lookup the shuffle buffer that corresponds to the specified shuffle buffer ID and acquire it. - * NOTE: It is the responsibility of the caller to close the buffer. - * @param id shuffle buffer identifier - * @return shuffle buffer that has been acquired + * Adds a degenerate buffer (zero rows or columns) + * + * @param meta metadata describing the buffer layout + * @param spillCallback a callback when the buffer is spilled. This should be very light weight. + * It should never allocate GPU memory and really just be used for metrics. + * @return RapidsBufferHandle associated with this buffer */ - def acquireBuffer(id: ShuffleReceivedBufferId): RapidsBuffer = catalog.acquireBuffer(id) + def addDegenerateRapidsBuffer( + meta: TableMeta, + spillCallback: SpillCallback): RapidsBufferHandle = { + val bufferId = nextShuffleReceivedBufferId() + val buffer = new DegenerateRapidsBuffer(bufferId, meta) + catalog.registerNewBuffer(buffer) + catalog.makeNewHandle(bufferId, -1, spillCallback) + } /** - * Lookup the shuffle buffer that corresponds to the specified table ID and acquire it. + * Lookup the shuffle buffer that corresponds to the specified shuffle buffer + * handle and acquire it. * NOTE: It is the responsibility of the caller to close the buffer. - * @param tableId table identifier + * + * @param handle shuffle buffer handle * @return shuffle buffer that has been acquired */ - def acquireBuffer(tableId: Int): RapidsBuffer = { - val shuffleBufferId = getShuffleBufferId(tableId) - acquireBuffer(shuffleBufferId) - } + def acquireBuffer(handle: RapidsBufferHandle): RapidsBuffer = catalog.acquireBuffer(handle) /** - * Remove a buffer and table given a buffer ID + * Remove a buffer and table given a buffer handle * NOTE: This function is not thread safe! The caller should only invoke if - * the [[ShuffleReceivedBufferId]] being removed is not being utilized by another thread. - * @param id buffer identifier + * the handle being removed is not being utilized by another thread. + * @param handle buffer handle */ - def removeBuffer(id: ShuffleReceivedBufferId): Unit = { + def removeBuffer(handle: RapidsBufferHandle): Unit = { + val id = handle.id tableMap.remove(id.tableId) - catalog.removeBuffer(id) + catalog.removeBuffer(handle) } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala index e42ba52ec911..725272f642af 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SpillableColumnarBatch.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,6 @@ package com.nvidia.spark.rapids import ai.rapids.cudf.{ContiguousTable, DeviceMemoryBuffer} import org.apache.spark.TaskContext -import org.apache.spark.sql.rapids.TempSpillBufferId import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.ColumnarBatch @@ -54,15 +53,21 @@ trait SpillableColumnarBatch extends AutoCloseable { * spillable, even though in reality there is no backing buffer. It does this by just keeping the * row count in memory, and not dealing with the catalog at all. */ -class JustRowsColumnarBatch(numRows: Int, semWait: GpuMetric) extends SpillableColumnarBatch { +class JustRowsColumnarBatch(numRows: Int, semWait: GpuMetric) + extends SpillableColumnarBatch with Arm { override def numRows(): Int = numRows override def setSpillPriority(priority: Long): Unit = () // NOOP nothing to spill - override def getColumnarBatch(): ColumnarBatch = { + + private def makeJustRowsBatch(): ColumnarBatch = { GpuSemaphore.acquireIfNecessary(TaskContext.get(), semWait) new ColumnarBatch(Array.empty, numRows) } override def close(): Unit = () // NOOP nothing to close override val sizeInBytes: Long = 0L + + def getColumnarBatch(): ColumnarBatch = { + makeJustRowsBatch() + } } /** @@ -71,11 +76,12 @@ class JustRowsColumnarBatch(numRows: Int, semWait: GpuMetric) extends SpillableC * ownership of the life cycle of the batch. So don't call this constructor directly please * use `SpillableColumnarBatch.apply` instead. */ -class SpillableColumnarBatchImpl (id: TempSpillBufferId, +class SpillableColumnarBatchImpl ( + handle: RapidsBufferHandle, rowCount: Int, sparkTypes: Array[DataType], semWait: GpuMetric) - extends SpillableColumnarBatch with Arm { + extends SpillableColumnarBatch with Arm { private var closed = false /** @@ -83,35 +89,24 @@ class SpillableColumnarBatchImpl (id: TempSpillBufferId, */ override def numRows(): Int = rowCount - /** - * The ID that this is stored under. - * @note Use with caution because if this has been closed the id is no longer valid. - */ - def spillId: TempSpillBufferId = id + private def withRapidsBuffer[T](fn: RapidsBuffer => T): T = { + withResource(RapidsBufferCatalog.acquireBuffer(handle)) { rapidsBuffer => + fn(rapidsBuffer) + } + } override lazy val sizeInBytes: Long = - withResource(RapidsBufferCatalog.acquireBuffer(id)) { buff => - buff.size - } + withRapidsBuffer(_.size) /** * Set a new spill priority. */ override def setSpillPriority(priority: Long): Unit = { - withResource(RapidsBufferCatalog.acquireBuffer(id)) { rapidsBuffer => - rapidsBuffer.setSpillPriority(priority) - } + handle.setSpillPriority(priority) } - /** - * Get the columnar batch. - * @note It is the responsibility of the caller to close the batch. - * @note If the buffer is compressed data then the resulting batch will be built using - * `GpuCompressedColumnVector`, and it is the responsibility of the caller to deal - * with decompressing the data if necessary. - */ override def getColumnarBatch(): ColumnarBatch = { - withResource(RapidsBufferCatalog.acquireBuffer(id)) { rapidsBuffer => + withRapidsBuffer { rapidsBuffer => GpuSemaphore.acquireIfNecessary(TaskContext.get(), semWait) rapidsBuffer.getColumnarBatch(sparkTypes) } @@ -122,7 +117,8 @@ class SpillableColumnarBatchImpl (id: TempSpillBufferId, */ override def close(): Unit = { if (!closed) { - RapidsBufferCatalog.removeBuffer(id) + // closing my reference + RapidsBufferCatalog.removeBuffer(handle) closed = true } } @@ -131,9 +127,10 @@ class SpillableColumnarBatchImpl (id: TempSpillBufferId, object SpillableColumnarBatch extends Arm { /** * Create a new SpillableColumnarBatch. + * * @note This takes over ownership of batch, and batch should not be used after this. - * @param batch the batch to make spillable - * @param priority the initial spill priority of this batch + * @param batch the batch to make spillable + * @param priority the initial spill priority of this batch * @param spillCallback a callback when the buffer is spilled. This should be very light weight. * It should never allocate GPU memory and really just be used for metrics. */ @@ -146,10 +143,13 @@ object SpillableColumnarBatch extends Arm { batch.close() new JustRowsColumnarBatch(numRows, spillCallback.semaphoreWaitTime) } else { - val types = GpuColumnVector.extractTypes(batch) - val id = TempSpillBufferId() - addBatch(id, batch, priority, spillCallback) - new SpillableColumnarBatchImpl(id, numRows, types, spillCallback.semaphoreWaitTime) + val types = GpuColumnVector.extractTypes(batch) + val handle = addBatch(batch, priority, spillCallback) + new SpillableColumnarBatchImpl( + handle, + numRows, + types, + spillCallback.semaphoreWaitTime) } } @@ -167,85 +167,80 @@ object SpillableColumnarBatch extends Arm { sparkTypes: Array[DataType], priority: Long, spillCallback: SpillCallback): SpillableColumnarBatch = { - val id = TempSpillBufferId() - RapidsBufferCatalog.addContiguousTable(id, ct, priority, spillCallback) - new SpillableColumnarBatchImpl(id, ct.getRowCount.toInt, sparkTypes, - spillCallback.semaphoreWaitTime) + val handle = RapidsBufferCatalog.addContiguousTable(ct, priority, spillCallback) + withResource(RapidsBufferCatalog.acquireBuffer(handle)) { _ => + new SpillableColumnarBatchImpl( + handle, + ct.getRowCount.toInt, + sparkTypes, + spillCallback.semaphoreWaitTime) + } } private[this] def addBatch( - id: RapidsBufferId, batch: ColumnarBatch, initialSpillPriority: Long, - spillCallback: SpillCallback): Unit = { + spillCallback: SpillCallback): RapidsBufferHandle = { withResource(batch) { batch => val numColumns = batch.numCols() if (GpuCompressedColumnVector.isBatchCompressed(batch)) { val cv = batch.column(0).asInstanceOf[GpuCompressedColumnVector] val buff = cv.getTableBuffer - buff.incRefCount() - RapidsBufferCatalog.addBuffer(id, buff, cv.getTableMeta, initialSpillPriority, + RapidsBufferCatalog.addBuffer(buff, cv.getTableMeta, initialSpillPriority, spillCallback) } else if (GpuPackedTableColumn.isBatchPacked(batch)) { val cv = batch.column(0).asInstanceOf[GpuPackedTableColumn] - RapidsBufferCatalog.addContiguousTable(id, cv.getContiguousTable, initialSpillPriority, + RapidsBufferCatalog.addContiguousTable( + cv.getContiguousTable, + initialSpillPriority, spillCallback) } else if (numColumns > 0 && (0 until numColumns) .forall(i => batch.column(i).isInstanceOf[GpuColumnVectorFromBuffer])) { val cv = batch.column(0).asInstanceOf[GpuColumnVectorFromBuffer] - val table = GpuColumnVector.from(batch) val buff = cv.getBuffer - buff.incRefCount() - RapidsBufferCatalog.addTable(id, table, buff, cv.getTableMeta, initialSpillPriority, + // note the table here is handed over to the catalog + val table = GpuColumnVector.from(batch) + RapidsBufferCatalog.addTable(table, buff, cv.getTableMeta, initialSpillPriority, spillCallback) } else { withResource(GpuColumnVector.from(batch)) { tmpTable => withResource(tmpTable.contiguousSplit()) { contigTables => require(contigTables.length == 1, "Unexpected number of contiguous spit tables") - RapidsBufferCatalog.addContiguousTable(id, contigTables.head, initialSpillPriority, + RapidsBufferCatalog.addContiguousTable( + contigTables.head, + initialSpillPriority, spillCallback) } } } } } + } /** * Just like a SpillableColumnarBatch but for buffers. */ -class SpillableBuffer (id: TempSpillBufferId, semWait: GpuMetric) extends AutoCloseable with Arm { - private var closed = false +class SpillableBuffer( + handle: RapidsBufferHandle, + semWait: GpuMetric) extends AutoCloseable with Arm { - /** - * The ID that this is stored under. - * @note Use with caution because if this has been closed the id is no longer valid. - */ - def spillId: TempSpillBufferId = id - - lazy val sizeInBytes: Long = - withResource(RapidsBufferCatalog.acquireBuffer(id)) { buff => - buff.size - } + private var closed = false /** * Set a new spill priority. */ def setSpillPriority(priority: Long): Unit = { - withResource(RapidsBufferCatalog.acquireBuffer(id)) { rapidsBuffer => - rapidsBuffer.setSpillPriority(priority) - } + handle.setSpillPriority(priority) } /** - * Get the device buffer. - * @note It is the responsibility of the caller to close the buffer. + * Use the device buffer. */ def getDeviceBuffer(): DeviceMemoryBuffer = { - withResource(RapidsBufferCatalog.acquireBuffer(id)) { rapidsBuffer => - GpuSemaphore.acquireIfNecessary(TaskContext.get(), semWait) + withResource(RapidsBufferCatalog.acquireBuffer(handle)) { rapidsBuffer => rapidsBuffer.getDeviceMemoryBuffer } } @@ -255,7 +250,7 @@ class SpillableBuffer (id: TempSpillBufferId, semWait: GpuMetric) extends AutoCl */ override def close(): Unit = { if (!closed) { - RapidsBufferCatalog.removeBuffer(id) + RapidsBufferCatalog.removeBuffer(handle) closed = true } } @@ -274,9 +269,8 @@ object SpillableBuffer extends Arm { def apply(buffer: DeviceMemoryBuffer, priority: Long, spillCallback: SpillCallback): SpillableBuffer = { - val id = TempSpillBufferId() val meta = MetaUtils.getTableMetaNoTable(buffer) - RapidsBufferCatalog.addBuffer(id, buffer, meta, priority, spillCallback) - new SpillableBuffer(id, spillCallback.semaphoreWaitTime) + val handle = RapidsBufferCatalog.addBuffer(buffer, meta, priority, spillCallback) + new SpillableBuffer(handle, spillCallback.semaphoreWaitTime) } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala index 6384789147fe..2584f0d0b00f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleClient.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,7 +46,7 @@ trait RapidsShuffleFetchHandler { * @return a boolean that lets the caller know the batch was accepted (true), or * rejected (false), in which case the caller should dispose of the batch. */ - def batchReceived(bufferId: ShuffleReceivedBufferId): Boolean + def batchReceived(handle: RapidsBufferHandle): Boolean /** * Called when the transport layer is not able to handle a fetch error for metadata @@ -336,7 +336,7 @@ class RapidsShuffleClient( } else { // Degenerate buffer (no device data) so no more data to request. // We need to trigger call in iterator, otherwise this batch is never handled. - handler.batchReceived(track(null, tableMeta).asInstanceOf[ShuffleReceivedBufferId]) + handler.batchReceived(track(null, tableMeta)) } } @@ -379,9 +379,9 @@ class RapidsShuffleClient( // hand buffer off to the catalog buffMetas.foreach { consumed: ConsumedBatchFromBounceBuffer => - val bId = track(consumed.contigBuffer, consumed.meta) - if (!consumed.handler.batchReceived(bId)) { - catalog.removeBuffer(bId) + val handle = track(consumed.contigBuffer, consumed.meta) + if (!consumed.handler.batchReceived(handle)) { + catalog.removeBuffer(handle) numBatchesRejected += 1 } transport.doneBytesInFlight(consumed.contigBuffer.getLength) @@ -425,22 +425,26 @@ class RapidsShuffleClient( * @return the [[RapidsBufferId]] to be used to look up the buffer from catalog */ private[shuffle] def track( - buffer: DeviceMemoryBuffer, meta: TableMeta): ShuffleReceivedBufferId = { - val id: ShuffleReceivedBufferId = catalog.nextShuffleReceivedBufferId() - logDebug(s"Adding buffer id ${id} to catalog") + buffer: DeviceMemoryBuffer, meta: TableMeta): RapidsBufferHandle = { if (buffer != null) { // add the buffer to the catalog so it is available for spill - devStorage.addBuffer(id, buffer, meta, - SpillPriorities.INPUT_FROM_SHUFFLE_PRIORITY, - // set needsSync to false because we already have stream synchronized after - // consuming the bounce buffer, so we know these buffers are synchronized - // w.r.t. the CPU - needsSync = false) + withResource(buffer) { _ => + catalog.addBuffer( + buffer, + meta, + SpillPriorities.INPUT_FROM_SHUFFLE_PRIORITY, + // set needsSync to false because we already have stream synchronized after + // consuming the bounce buffer, so we know these buffers are synchronized + // w.r.t. the CPU + needsSync = false) + } } else { // no device data, just tracking metadata - catalog.registerNewBuffer(new DegenerateRapidsBuffer(id, meta)) + catalog.addDegenerateRapidsBuffer( + meta, + RapidsBuffer.defaultSpillCallback) + } - id } override def close(): Unit = { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala index 0092621e015c..1f44495759d6 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIterator.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ import java.util.concurrent.{LinkedBlockingQueue, TimeUnit} import scala.collection.mutable import ai.rapids.cudf.{NvtxColor, NvtxRange} -import com.nvidia.spark.rapids.{Arm, GpuSemaphore, NoopMetric, RapidsBuffer, RapidsConf, ShuffleReceivedBufferCatalog, ShuffleReceivedBufferId} +import com.nvidia.spark.rapids.{Arm, GpuSemaphore, NoopMetric, RapidsBuffer, RapidsBufferHandle, RapidsConf, ShuffleReceivedBufferCatalog} import org.apache.spark.TaskContext import org.apache.spark.internal.Logging @@ -66,10 +66,9 @@ class RapidsShuffleIterator( /** * A result for a successful buffer received - * @param bufferId - the shuffle received buffer id as tracked in the catalog + * @param handle - the shuffle received buffer handle as tracked in the catalog */ - case class BufferReceived( - bufferId: ShuffleReceivedBufferId) extends ShuffleClientResult + case class BufferReceived(handle: RapidsBufferHandle) extends ShuffleClientResult /** * A result for a failed attempt at receiving block metadata, or corresponding batches. @@ -222,7 +221,7 @@ class RapidsShuffleIterator( def clientDone: Boolean = clientExpectedBatches > 0 && clientExpectedBatches == clientResolvedBatches - def batchReceived(bufferId: ShuffleReceivedBufferId): Boolean = + def batchReceived(handle: RapidsBufferHandle): Boolean = { resolvedBatches.synchronized { if (taskComplete) { false @@ -235,7 +234,7 @@ class RapidsShuffleIterator( } totalBatchesResolved = totalBatchesResolved + 1 clientResolvedBatches = clientResolvedBatches + 1 - resolvedBatches.offer(BufferReceived(bufferId)) + resolvedBatches.offer(BufferReceived(handle)) if (clientDone) { logDebug(s"Task: $taskAttemptId Client $blockManagerId is " + @@ -250,6 +249,7 @@ class RapidsShuffleIterator( true } } + } override def transferError(errorMessage: String, throwable: Throwable): Unit = { resolvedBatches.synchronized { @@ -286,8 +286,8 @@ class RapidsShuffleIterator( logWarning(s"Iterator for task ${taskAttemptId} closing, " + s"but it is not done. Closing ${resolvedBatches.size()} resolved batches!!") resolvedBatches.forEach { - case BufferReceived(bufferId) => - GpuShuffleEnv.getReceivedCatalog.removeBuffer(bufferId) + case BufferReceived(handle) => + GpuShuffleEnv.getReceivedCatalog.removeBuffer(handle) case _ => } // tell the client to cancel pending requests @@ -349,11 +349,11 @@ class RapidsShuffleIterator( result = pollForResult(timeoutSeconds) val blockedTime = System.currentTimeMillis() - blockedStart result match { - case Some(BufferReceived(bufferId)) => + case Some(BufferReceived(handle)) => val nvtxRangeAfterGettingBatch = new NvtxRange("RapidsShuffleIterator.gotBatch", NvtxColor.PURPLE) try { - sb = catalog.acquireBuffer(bufferId) + sb = catalog.acquireBuffer(handle) cb = sb.getColumnarBatch(sparkTypes) metricsUpdater.update(blockedTime, 1, sb.size, cb.numRows()) } finally { @@ -362,7 +362,7 @@ class RapidsShuffleIterator( if (sb != null) { sb.close() } - catalog.removeBuffer(bufferId) + catalog.removeBuffer(handle) } case Some( TransferError(blockManagerId, shuffleBlockBatchId, mapIndex, errorMessage, throwable)) => diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsCachingReader.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsCachingReader.scala index f7083095e0db..f78bf3dc38c6 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsCachingReader.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsCachingReader.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -62,7 +62,7 @@ class RapidsCachingReader[K, C]( try { val blocksForRapidsTransport = new ArrayBuffer[(BlockManagerId, Seq[(BlockId, Long, Int)])]() val cachedBlocks = new ArrayBuffer[BlockId]() - val cachedBufferIds = new ArrayBuffer[ShuffleBufferId]() + val cachedBufferHandles = new ArrayBuffer[RapidsBufferHandle]() val blocksByAddressMap: Map[BlockManagerId, Seq[(BlockId, Long, Int)]] = blocksByAddress.toMap blocksByAddressMap.keys.foreach(blockManagerId => { @@ -75,27 +75,27 @@ class RapidsCachingReader[K, C]( blockInfos.foreach( blockInfo => { val blockId = blockInfo._1 - val shuffleBufferIds: IndexedSeq[ShuffleBufferId] = blockId match { + val shuffleBufferHandles: IndexedSeq[RapidsBufferHandle] = blockId match { case sbbid: ShuffleBlockBatchId => (sbbid.startReduceId to sbbid.endReduceId).flatMap { reduceId => cachedBlocks.append(blockId) val sBlockId = ShuffleBlockId(sbbid.shuffleId, sbbid.mapId, reduceId) - catalog.blockIdToBuffersIds(sBlockId) + catalog.blockIdToBufferHandles(sBlockId) } case sbid: ShuffleBlockId => cachedBlocks.append(blockId) - catalog.blockIdToBuffersIds(sbid) + catalog.blockIdToBufferHandles(sbid) case _ => throw new IllegalArgumentException( s"${blockId.getClass} $blockId is not currently supported") } - cachedBufferIds ++= shuffleBufferIds + cachedBufferHandles ++= shuffleBufferHandles // Update the spill priorities of these buffers to indicate they are about // to be read and therefore should not be spilled if possible. - shuffleBufferIds.foreach(catalog.updateSpillPriorityForLocalRead) + shuffleBufferHandles.foreach(catalog.updateSpillPriorityForLocalRead) - if (shuffleBufferIds.nonEmpty) { + if (shuffleBufferHandles.nonEmpty) { metrics.incLocalBlocksFetched(1) } }) @@ -140,10 +140,10 @@ class RapidsCachingReader[K, C]( val itRange = new NvtxRange("Shuffle Iterator prep", NvtxColor.BLUE) try { - val cachedIt = cachedBufferIds.iterator.map(bufferId => { + val cachedIt = cachedBufferHandles.iterator.map(bufferHandle => { // No good way to get a metric in here for semaphore wait time GpuSemaphore.acquireIfNecessary(context, NoopMetric) - val cb = withResource(catalog.acquireBuffer(bufferId)) { buffer => + val cb = withResource(catalog.acquireBuffer(bufferHandle)) { buffer => buffer.getColumnarBatch(sparkTypes) } val cachedBytesRead = GpuColumnVector.getTotalDeviceMemoryUsed(cb) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala index b1fe5237ad63..ac13a63dd368 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ import java.util.concurrent.{Callable, ConcurrentHashMap, ExecutionException, Ex import java.util.concurrent.atomic.{AtomicInteger, AtomicLong} import scala.collection.mutable -import scala.collection.mutable.{ArrayBuffer, ListBuffer} +import scala.collection.mutable.ListBuffer import ai.rapids.cudf.{NvtxColor, NvtxRange} import com.nvidia.spark.rapids._ @@ -892,10 +892,13 @@ class RapidsCachingWriter[K, V]( catalog: ShuffleBufferCatalog, shuffleStorage: RapidsDeviceMemoryStore, rapidsShuffleServer: Option[RapidsShuffleServer], - metrics: Map[String, SQLMetric]) extends ShuffleWriter[K, V] with Logging { + metrics: Map[String, SQLMetric]) + extends ShuffleWriter[K, V] + with Logging + with Arm { private val numParts = handle.dependency.partitioner.numPartitions private val sizes = new Array[Long](numParts) - private val writtenBufferIds = new ArrayBuffer[ShuffleBufferId](numParts) + private val uncompressedMetric: SQLMetric = metrics("dataSize") override def write(records: Iterator[Product2[K, V]]): Unit = { @@ -913,45 +916,49 @@ class RapidsCachingWriter[K, V]( recordsWritten = recordsWritten + batch.numRows() var partSize: Long = 0 val blockId = ShuffleBlockId(handle.shuffleId, mapId, partId) - val bufferId = catalog.nextShuffleBufferId(blockId) if (batch.numRows > 0 && batch.numCols > 0) { // Add the table to the shuffle store - batch.column(0) match { + val handle = batch.column(0) match { case c: GpuPackedTableColumn => val contigTable = c.getContiguousTable partSize = c.getTableBuffer.getLength uncompressedMetric += partSize - shuffleStorage.addContiguousTable( - bufferId, + catalog.addContiguousTable( + blockId, contigTable, SpillPriorities.OUTPUT_FOR_SHUFFLE_INITIAL_PRIORITY, + RapidsBuffer.defaultSpillCallback, // we don't need to sync here, because we sync on the cuda // stream after sliceInternalOnGpu (contiguous_split) needsSync = false) case c: GpuCompressedColumnVector => val buffer = c.getTableBuffer - buffer.incRefCount() partSize = buffer.getLength val tableMeta = c.getTableMeta - // update the table metadata for the buffer ID generated above - tableMeta.bufferMeta.mutateId(bufferId.tableId) uncompressedMetric += tableMeta.bufferMeta().uncompressedSize() - shuffleStorage.addBuffer( - bufferId, + catalog.addBuffer( + blockId, buffer, tableMeta, SpillPriorities.OUTPUT_FOR_SHUFFLE_INITIAL_PRIORITY, + RapidsBuffer.defaultSpillCallback, // we don't need to sync here, because we sync on the cuda // stream after compression. needsSync = false) - case c => throw new IllegalStateException(s"Unexpected column type: ${c.getClass}") + case c => + throw new IllegalStateException(s"Unexpected column type: ${c.getClass}") } bytesWritten += partSize sizes(partId) += partSize + handle } else { // no device data, tracking only metadata val tableMeta = MetaUtils.buildDegenerateTableMeta(batch) - catalog.registerNewBuffer(new DegenerateRapidsBuffer(bufferId, tableMeta)) + val handle = + catalog.addDegenerateRapidsBuffer( + blockId, + tableMeta, + RapidsBuffer.defaultSpillCallback) // The size of the data is really only used to tell if the data should be shuffled or not // a 0 indicates that we should not shuffle anything. This is here for the special case @@ -961,8 +968,8 @@ class RapidsCachingWriter[K, V]( if (batch.numRows > 0) { sizes(partId) += 100 } + handle } - writtenBufferIds.append(bufferId) } metricsReporter.incBytesWritten(bytesWritten) metricsReporter.incRecordsWritten(recordsWritten) @@ -975,7 +982,7 @@ class RapidsCachingWriter[K, V]( * Used to remove shuffle buffers when the writing task detects an error, calling `stop(false)` */ private def cleanStorage(): Unit = { - writtenBufferIds.foreach(catalog.removeBuffer) + catalog.removeCachedHandles() } override def stop(success: Boolean): Option[MapStatus] = { @@ -1129,8 +1136,8 @@ abstract class RapidsShuffleInternalManagerBase(conf: SparkConf, val isDriver: B val catalog = getCatalogOrThrow val requestHandler = new RapidsShuffleRequestHandler() { override def acquireShuffleBuffer(tableId: Int): RapidsBuffer = { - val shuffleBufferId = catalog.getShuffleBufferId(tableId) - catalog.acquireBuffer(shuffleBufferId) + val handle = catalog.getShuffleBufferHandle(tableId) + catalog.acquireBuffer(handle) } override def getShuffleBufferMetas(sbbId: ShuffleBlockBatchId): Seq[TableMeta] = { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/GpuPartitioningSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/GpuPartitioningSuite.scala index fc5ccdd9775d..94e0c9f762e7 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/GpuPartitioningSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/GpuPartitioningSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -194,13 +194,13 @@ class GpuPartitioningSuite extends FunSuite with Arm { } if (GpuCompressedColumnVector.isBatchCompressed(partBatch)) { val gccv = columns.head.asInstanceOf[GpuCompressedColumnVector] - val bufferId = MockRapidsBufferId(partIndex) val devBuffer = gccv.getTableBuffer // device store takes ownership of the buffer devBuffer.incRefCount() - deviceStore.addBuffer(bufferId, devBuffer, gccv.getTableMeta, spillPriority) + val handle = + RapidsBufferCatalog.addBuffer(devBuffer, gccv.getTableMeta, spillPriority) withResource(buildSubBatch(batch, startRow, endRow)) { expectedBatch => - withResource(catalog.acquireBuffer(bufferId)) { buffer => + withResource(catalog.acquireBuffer(handle)) { buffer => withResource(buffer.getColumnarBatch(sparkTypes)) { batch => compareBatches(expectedBatch, batch) } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala index 54534951d93e..b721959be84c 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,8 +17,8 @@ package com.nvidia.spark.rapids import java.io.File -import java.util.NoSuchElementException +import ai.rapids.cudf.{Cuda, DeviceMemoryBuffer, MemoryBuffer} import com.nvidia.spark.rapids.StorageTier.{DEVICE, DISK, HOST, StorageTier} import com.nvidia.spark.rapids.format.TableMeta import org.mockito.Mockito._ @@ -26,15 +26,22 @@ import org.scalatest.FunSuite import org.scalatest.mockito.MockitoSugar import org.apache.spark.sql.rapids.RapidsDiskBlockManager +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.vectorized.ColumnarBatch -class RapidsBufferCatalogSuite extends FunSuite with MockitoSugar { +class RapidsBufferCatalogSuite extends FunSuite with MockitoSugar with Arm { test("lookup unknown buffer") { val catalog = new RapidsBufferCatalog val bufferId = new RapidsBufferId { override val tableId: Int = 10 override def getDiskPath(m: RapidsDiskBlockManager): File = null } - assertThrows[NoSuchElementException](catalog.acquireBuffer(bufferId)) + val bufferHandle = new RapidsBufferHandle { + override val id: RapidsBufferId = bufferId + override def setSpillPriority(newPriority: Long): Unit = {} + } + + assertThrows[NoSuchElementException](catalog.acquireBuffer(bufferHandle)) assertThrows[NoSuchElementException](catalog.getBufferMeta(bufferId)) } @@ -47,19 +54,127 @@ class RapidsBufferCatalogSuite extends FunSuite with MockitoSugar { assertThrows[DuplicateBufferException](catalog.registerNewBuffer(buffer2)) } + test("a second handle prevents buffer to be removed") { + val catalog = new RapidsBufferCatalog + val bufferId = MockBufferId(5) + val buffer = mockBuffer(bufferId) + catalog.registerNewBuffer(buffer) + val handle1 = + catalog.makeNewHandle(bufferId, -1, RapidsBuffer.defaultSpillCallback) + val handle2 = + catalog.makeNewHandle(bufferId, -1, RapidsBuffer.defaultSpillCallback) + + catalog.removeBuffer(handle1) + + // this does not throw + catalog.acquireBuffer(handle2).close() + // actually this doesn't throw either + catalog.acquireBuffer(handle1).close() + + catalog.removeBuffer(handle2) + + assertThrows[NoSuchElementException](catalog.acquireBuffer(handle1)) + assertThrows[NoSuchElementException](catalog.acquireBuffer(handle2)) + } + + test("spill priorities are updated as handles are registered and unregistered") { + val catalog = new RapidsBufferCatalog + val bufferId = MockBufferId(5) + val buffer = mockBuffer(bufferId, initialPriority = -1) + catalog.registerNewBuffer(buffer) + val handle1 = + catalog.makeNewHandle(bufferId, -1, RapidsBuffer.defaultSpillCallback) + withResource(catalog.acquireBuffer(handle1)) { buff => + assertResult(-1)(buff.getSpillPriority) + } + val handle2 = + catalog.makeNewHandle(bufferId, 0, RapidsBuffer.defaultSpillCallback) + withResource(catalog.acquireBuffer(handle2)) { buff => + assertResult(0)(buff.getSpillPriority) + } + + // removing the lower priority handle, keeps the high priority spill + catalog.removeBuffer(handle1) + withResource(catalog.acquireBuffer(handle2)) { buff => + assertResult(0)(buff.getSpillPriority) + } + + // adding a lower priority -1000 handle keeps the high priority (0) spill + val handle3 = + catalog.makeNewHandle(bufferId, -1000, RapidsBuffer.defaultSpillCallback) + withResource(catalog.acquireBuffer(handle3)) { buff => + assertResult(0)(buff.getSpillPriority) + } + + // removing the high priority spill (0) brings us down to the + // low priority that is remaining + catalog.removeBuffer(handle2) + withResource(catalog.acquireBuffer(handle2)) { buff => + assertResult(-1000)(buff.getSpillPriority) + } + + catalog.removeBuffer(handle3) + } + + test("spill callbacks are updated as handles are registered and unregistered") { + val catalog = new RapidsBufferCatalog + val bufferId = MockBufferId(5) + val buffer = mockBuffer(bufferId, initialPriority = -1) + catalog.registerNewBuffer(buffer) + val handle1 = + catalog.makeNewHandle(bufferId, -1, null) + withResource(catalog.acquireBuffer(handle1)) { buff => + assertResult(null)(buff.getSpillCallback) + } + val handle2 = + catalog.makeNewHandle(bufferId, 0, RapidsBuffer.defaultSpillCallback) + withResource(catalog.acquireBuffer(handle2)) { buff => + assertResult(RapidsBuffer.defaultSpillCallback)(buff.getSpillCallback) + } + + // adding a new handle puts a new callback in front, that's the new callback + val mySpillCallback = new SpillCallback { + override def apply(from: StorageTier, to: StorageTier, amount: Long): Unit = {} + override def semaphoreWaitTime: GpuMetric = null + } + + val handle3 = + catalog.makeNewHandle(bufferId, -1000, mySpillCallback) + withResource(catalog.acquireBuffer(handle3)) { buff => + assertResult(mySpillCallback)(buff.getSpillCallback) + } + + // removing handles brings back the prior inserted callback + // low priority that is remaining + catalog.removeBuffer(handle3) + withResource(catalog.acquireBuffer(handle2)) { buff => + assertResult(RapidsBuffer.defaultSpillCallback)(buff.getSpillCallback) + } + + catalog.removeBuffer(handle2) + withResource(catalog.acquireBuffer(handle1)) { buff => + assertResult(null)(buff.getSpillCallback) + } + + catalog.removeBuffer(handle1) + } + test("buffer registering slower tier does not hide faster tier") { val catalog = new RapidsBufferCatalog val bufferId = MockBufferId(5) val buffer = mockBuffer(bufferId, tier = DEVICE) catalog.registerNewBuffer(buffer) + val handle = catalog.makeNewHandle(bufferId, 0, RapidsBuffer.defaultSpillCallback) val buffer2 = mockBuffer(bufferId, tier = HOST) catalog.registerNewBuffer(buffer2) val buffer3 = mockBuffer(bufferId, tier = DISK) catalog.registerNewBuffer(buffer3) - val acquired = catalog.acquireBuffer(MockBufferId(5)) + val acquired = catalog.acquireBuffer(handle) assertResult(5)(acquired.id.tableId) assertResult(buffer)(acquired) - verify(buffer).addReference() + + // registering the handle acquires the buffer + verify(buffer, times(2)).addReference() } test("acquire buffer") { @@ -67,10 +182,13 @@ class RapidsBufferCatalogSuite extends FunSuite with MockitoSugar { val bufferId = MockBufferId(5) val buffer = mockBuffer(bufferId) catalog.registerNewBuffer(buffer) - val acquired = catalog.acquireBuffer(MockBufferId(5)) + val handle = catalog.makeNewHandle(bufferId, 0, RapidsBuffer.defaultSpillCallback) + val acquired = catalog.acquireBuffer(handle) assertResult(5)(acquired.id.tableId) assertResult(buffer)(acquired) - verify(buffer).addReference() + + // registering the handle acquires the buffer + verify(buffer, times(2)).addReference() } test("acquire buffer retries automatically") { @@ -78,10 +196,13 @@ class RapidsBufferCatalogSuite extends FunSuite with MockitoSugar { val bufferId = MockBufferId(5) val buffer = mockBuffer(bufferId, acquireAttempts = 9) catalog.registerNewBuffer(buffer) - val acquired = catalog.acquireBuffer(MockBufferId(5)) + val handle = catalog.makeNewHandle(bufferId, 0, RapidsBuffer.defaultSpillCallback) + val acquired = catalog.acquireBuffer(handle) assertResult(5)(acquired.id.tableId) assertResult(buffer)(acquired) - verify(buffer, times(9)).addReference() + + // registering the handle acquires the buffer + verify(buffer, times(10)).addReference() } test("acquire buffer at specific tier") { @@ -110,7 +231,7 @@ class RapidsBufferCatalogSuite extends FunSuite with MockitoSugar { val catalog = new RapidsBufferCatalog val bufferId = MockBufferId(5) val expectedMeta = new TableMeta - val buffer = mockBuffer(bufferId, meta = expectedMeta) + val buffer = mockBuffer(bufferId, tableMeta = expectedMeta) catalog.registerNewBuffer(buffer) val meta = catalog.getBufferMeta(bufferId) assertResult(expectedMeta)(meta) @@ -163,7 +284,9 @@ class RapidsBufferCatalogSuite extends FunSuite with MockitoSugar { val bufferId = MockBufferId(5) val buffer = mockBuffer(bufferId) catalog.registerNewBuffer(buffer) - catalog.removeBuffer(bufferId) + val handle = catalog.makeNewHandle( + bufferId, -1, RapidsBuffer.defaultSpillCallback) + catalog.removeBuffer(handle) verify(buffer).free() } @@ -172,11 +295,18 @@ class RapidsBufferCatalogSuite extends FunSuite with MockitoSugar { val bufferId = MockBufferId(5) val buffer = mockBuffer(bufferId, tier = DEVICE) catalog.registerNewBuffer(buffer) + val handle = catalog.makeNewHandle( + bufferId, -1, RapidsBuffer.defaultSpillCallback) + + // these next registrations don't get their own handle. This is an internal + // operation from the store where it has spilled to host and disk the RapidsBuffer val buffer2 = mockBuffer(bufferId, tier = HOST) catalog.registerNewBuffer(buffer2) val buffer3 = mockBuffer(bufferId, tier = DISK) catalog.registerNewBuffer(buffer3) - catalog.removeBuffer(bufferId) + + // removing the original handle removes all buffers from all tiers. + catalog.removeBuffer(handle) verify(buffer).free() verify(buffer2).free() verify(buffer3).free() @@ -184,17 +314,44 @@ class RapidsBufferCatalogSuite extends FunSuite with MockitoSugar { private def mockBuffer( bufferId: RapidsBufferId, - meta: TableMeta = null, + tableMeta: TableMeta = null, tier: StorageTier = StorageTier.DEVICE, - acquireAttempts: Int = 1): RapidsBuffer = { - val buffer = mock[RapidsBuffer] - when(buffer.id).thenReturn(bufferId) - when(buffer.storageTier).thenReturn(tier) - when(buffer.meta).thenReturn(meta) - var stub = when(buffer.addReference()) - (0 until acquireAttempts - 1).foreach(_ => stub = stub.thenReturn(false)) - stub.thenReturn(true) - buffer + acquireAttempts: Int = 1, + initialPriority: Long = -1): RapidsBuffer = { + spy(new RapidsBuffer { + var _acquireAttempts: Int = acquireAttempts + var currentPriority: Long = initialPriority + var currentCallback: SpillCallback = null + override val id: RapidsBufferId = bufferId + override val size: Long = 0 + override val meta: TableMeta = tableMeta + override val storageTier: StorageTier = tier + override def getColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = null + override def getMemoryBuffer: MemoryBuffer = null + override def copyToMemoryBuffer( + srcOffset: Long, + dst: MemoryBuffer, + dstOffset: Long, + length: Long, + stream: Cuda.Stream): Unit = {} + override def getDeviceMemoryBuffer: DeviceMemoryBuffer = null + override def addReference(): Boolean = { + if (_acquireAttempts > 0) { + _acquireAttempts -= 1 + } + _acquireAttempts == 0 + } + override def free(): Unit = {} + override def getSpillPriority: Long = currentPriority + override def getSpillCallback: SpillCallback = currentCallback + override def setSpillPriority(priority: Long): Unit = { + currentPriority = priority + } + override def setSpillCallback(spillCallback: SpillCallback): Unit = { + currentCallback = spillCallback + } + override def close(): Unit = {} + }) } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStoreSuite.scala index cb43a1198a70..df84da9efd06 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStoreSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStoreSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -50,7 +50,8 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { val spillPriority = 3 val bufferId = MockRapidsBufferId(7) withResource(buildContiguousTable()) { ct => - store.addContiguousTable(bufferId, ct, spillPriority) + store.addContiguousTable( + bufferId, ct, spillPriority, RapidsBuffer.defaultSpillCallback, false) } val captor: ArgumentCaptor[RapidsBuffer] = ArgumentCaptor.forClass(classOf[RapidsBuffer]) verify(catalog).registerNewBuffer(captor.capture()) @@ -69,7 +70,8 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { val meta = MetaUtils.buildTableMeta(bufferId.tableId, ct) // store takes ownership of the buffer ct.getBuffer.incRefCount() - store.addBuffer(bufferId, ct.getBuffer, meta, spillPriority) + store.addBuffer( + bufferId, ct.getBuffer, meta, spillPriority, RapidsBuffer.defaultSpillCallback, false) meta } val captor: ArgumentCaptor[RapidsBuffer] = ArgumentCaptor.forClass(classOf[RapidsBuffer]) @@ -91,8 +93,16 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { val meta = MetaUtils.buildTableMeta(bufferId.tableId, ct) // store takes ownership of the buffer ct.getBuffer.incRefCount() - store.addBuffer(bufferId, ct.getBuffer, meta, initialSpillPriority = 3) - withResource(catalog.acquireBuffer(bufferId)) { buffer => + store.addBuffer( + bufferId, + ct.getBuffer, + meta, + initialSpillPriority = 3, + RapidsBuffer.defaultSpillCallback, + needsSync = false) + val handle = + catalog.makeNewHandle(bufferId, -1, RapidsBuffer.defaultSpillCallback) + withResource(catalog.acquireBuffer(handle)) { buffer => withResource(buffer.getMemoryBuffer.asInstanceOf[DeviceMemoryBuffer]) { devbuf => withResource(HostMemoryBuffer.allocate(devbuf.getLength)) { actualHostBuffer => actualHostBuffer.copyFromDeviceBuffer(devbuf) @@ -106,7 +116,7 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { } test("get column batch") { - val catalog = new RapidsBufferCatalog + val catalog = RapidsBufferCatalog.singleton val sparkTypes = Array[DataType](IntegerType, StringType, DoubleType, DecimalType(ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION, 5)) withResource(new RapidsDeviceMemoryStore(catalog)) { store => @@ -117,8 +127,11 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { val meta = MetaUtils.buildTableMeta(bufferId.tableId, ct) // store takes ownership of the buffer ct.getBuffer.incRefCount() - store.addBuffer(bufferId, ct.getBuffer, meta, initialSpillPriority = 3) - withResource(catalog.acquireBuffer(bufferId)) { buffer => + store.addBuffer(bufferId, ct.getBuffer, meta, initialSpillPriority = 3, + RapidsBuffer.defaultSpillCallback, false) + val handle = + catalog.makeNewHandle(bufferId, -1, RapidsBuffer.defaultSpillCallback) + withResource(catalog.acquireBuffer(handle)) { buffer => withResource(buffer.getColumnarBatch(sparkTypes)) { actualBatch => TestUtils.compareBatches(expectedBatch, actualBatch) } @@ -129,7 +142,7 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { } test("cannot receive spilled buffers") { - val catalog = new RapidsBufferCatalog + val catalog = RapidsBufferCatalog.singleton withResource(new RapidsDeviceMemoryStore(catalog)) { store => assertThrows[IllegalStateException](store.copyBuffer( mock[RapidsBuffer], mock[MemoryBuffer], Cuda.DEFAULT_STREAM)) @@ -137,21 +150,25 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { } test("size statistics") { - val catalog = new RapidsBufferCatalog + val catalog = RapidsBufferCatalog.singleton withResource(new RapidsDeviceMemoryStore(catalog)) { store => assertResult(0)(store.currentSize) val bufferSizes = new Array[Long](2) + val bufferHandles = new Array[RapidsBufferHandle](2) bufferSizes.indices.foreach { i => withResource(buildContiguousTable()) { ct => bufferSizes(i) = ct.getBuffer.getLength // store takes ownership of the table - store.addContiguousTable(MockRapidsBufferId(i), ct, initialSpillPriority = 0) + store.addContiguousTable(MockRapidsBufferId(i), ct, initialSpillPriority = 0, + RapidsBuffer.defaultSpillCallback, false) + bufferHandles(i) = + catalog.makeNewHandle(MockRapidsBufferId(i), 0, RapidsBuffer.defaultSpillCallback) } assertResult(bufferSizes.take(i+1).sum)(store.currentSize) } - catalog.removeBuffer(MockRapidsBufferId(0)) + catalog.removeBuffer(bufferHandles(0)) assertResult(bufferSizes(1))(store.currentSize) - catalog.removeBuffer(MockRapidsBufferId(1)) + catalog.removeBuffer(bufferHandles(1)) assertResult(0)(store.currentSize) } } @@ -167,7 +184,9 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { withResource(buildContiguousTable()) { ct => bufferSizes(i) = ct.getBuffer.getLength // store takes ownership of the table - store.addContiguousTable(MockRapidsBufferId(i), ct, spillPriorities(i)) + store.addContiguousTable( + MockRapidsBufferId(i), ct, spillPriorities(i), + RapidsBuffer.defaultSpillCallback, false) } } assert(spillStore.spilledBuffers.isEmpty) @@ -223,6 +242,10 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { override def getMemoryBuffer: MemoryBuffer = throw new UnsupportedOperationException + + override def getSpillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback + + override def setSpillCallback(spillCallback: SpillCallback): Unit = {} } } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDiskStoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDiskStoreSuite.scala index 9a593155101b..fd42f254b3c4 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDiskStoreSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDiskStoreSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -41,7 +41,6 @@ class RapidsDiskStoreSuite extends FunSuiteWithTempDir with Arm with MockitoSuga } test("spill updates catalog") { - val bufferId = MockRapidsBufferId(7, canShareDiskPaths = false) val spillPriority = -7 val hostStoreMaxSize = 1L * 1024 * 1024 val catalog = spy(new RapidsBufferCatalog) @@ -52,21 +51,23 @@ class RapidsDiskStoreSuite extends FunSuiteWithTempDir with Arm with MockitoSuga withResource(new RapidsDiskStore(mock[RapidsDiskBlockManager], catalog)) { diskStore => assertResult(0)(diskStore.currentSize) hostStore.setSpillStore(diskStore) - val bufferSize = addTableToStore(devStore, bufferId, spillPriority) + val (bufferSize, handle) = + addTableToStore(spillPriority) + val path = handle.id.getDiskPath(null) + assert(!path.exists()) devStore.synchronousSpill(0) hostStore.synchronousSpill(0) assertResult(0)(hostStore.currentSize) assertResult(bufferSize)(diskStore.currentSize) - val path = bufferId.getDiskPath(null) assert(path.exists) assertResult(bufferSize)(path.length) verify(catalog, times(3)).registerNewBuffer(ArgumentMatchers.any[RapidsBuffer]) verify(catalog).removeBufferTier( - ArgumentMatchers.eq(bufferId), ArgumentMatchers.eq(StorageTier.DEVICE)) - withResource(catalog.acquireBuffer(bufferId)) { buffer => + ArgumentMatchers.eq(handle.id), ArgumentMatchers.eq(StorageTier.DEVICE)) + withResource(catalog.acquireBuffer(handle)) { buffer => assertResult(StorageTier.DISK)(buffer.storageTier) assertResult(bufferSize)(buffer.size) - assertResult(bufferId)(buffer.id) + assertResult(handle.id)(buffer.id) assertResult(spillPriority)(buffer.getSpillPriority) } } @@ -77,9 +78,6 @@ class RapidsDiskStoreSuite extends FunSuiteWithTempDir with Arm with MockitoSuga test("get columnar batch") { val sparkTypes = Array[DataType](IntegerType, StringType, DoubleType, DecimalType(ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION, 5)) - val bufferId = MockRapidsBufferId(1, canShareDiskPaths = false) - val bufferPath = bufferId.getDiskPath(null) - assert(!bufferPath.exists) val spillPriority = -7 val hostStoreMaxSize = 1L * 1024 * 1024 val catalog = new RapidsBufferCatalog @@ -90,15 +88,16 @@ class RapidsDiskStoreSuite extends FunSuiteWithTempDir with Arm with MockitoSuga withResource(new RapidsDiskStore(mock[RapidsDiskBlockManager], catalog, devStore)) { diskStore => hostStore.setSpillStore(diskStore) - addTableToStore(devStore, bufferId, spillPriority) - val expectedBatch = withResource(catalog.acquireBuffer(bufferId)) { buffer => + val (_, handle) = addTableToStore(spillPriority) + assert(!getDiskPath(handle.id).exists()) + val expectedBatch = withResource(catalog.acquireBuffer(handle)) { buffer => assertResult(StorageTier.DEVICE)(buffer.storageTier) buffer.getColumnarBatch(sparkTypes) } withResource(expectedBatch) { expectedBatch => devStore.synchronousSpill(0) hostStore.synchronousSpill(0) - withResource(catalog.acquireBuffer(bufferId)) { buffer => + withResource(catalog.acquireBuffer(handle)) { buffer => assertResult(StorageTier.DISK)(buffer.storageTier) withResource(buffer.getColumnarBatch(sparkTypes)) { actualBatch => TestUtils.compareBatches(expectedBatch, actualBatch) @@ -111,9 +110,6 @@ class RapidsDiskStoreSuite extends FunSuiteWithTempDir with Arm with MockitoSuga } test("get memory buffer") { - val bufferId = MockRapidsBufferId(1, canShareDiskPaths = false) - val bufferPath = bufferId.getDiskPath(null) - assert(!bufferPath.exists) val spillPriority = -7 val hostStoreMaxSize = 1L * 1024 * 1024 val catalog = new RapidsBufferCatalog @@ -123,8 +119,9 @@ class RapidsDiskStoreSuite extends FunSuiteWithTempDir with Arm with MockitoSuga devStore.setSpillStore(hostStore) withResource(new RapidsDiskStore(mock[RapidsDiskBlockManager], catalog)) { diskStore => hostStore.setSpillStore(diskStore) - addTableToStore(devStore, bufferId, spillPriority) - val expectedBuffer = withResource(catalog.acquireBuffer(bufferId)) { buffer => + val (_, handle) = addTableToStore(spillPriority) + assert(!getDiskPath(handle.id).exists()) + val expectedBuffer = withResource(catalog.acquireBuffer(handle)) { buffer => assertResult(StorageTier.DEVICE)(buffer.storageTier) withResource(buffer.getMemoryBuffer) { devbuf => closeOnExcept(HostMemoryBuffer.allocate(devbuf.getLength)) { hostbuf => @@ -136,7 +133,7 @@ class RapidsDiskStoreSuite extends FunSuiteWithTempDir with Arm with MockitoSuga withResource(expectedBuffer) { expectedBuffer => devStore.synchronousSpill(0) hostStore.synchronousSpill(0) - withResource(catalog.acquireBuffer(bufferId)) { buffer => + withResource(catalog.acquireBuffer(handle)) { buffer => assertResult(StorageTier.DISK)(buffer.storageTier) withResource(buffer.getMemoryBuffer) { actualBuffer => assert(actualBuffer.isInstanceOf[HostMemoryBuffer]) @@ -159,9 +156,6 @@ class RapidsDiskStoreSuite extends FunSuiteWithTempDir with Arm with MockitoSuga } private def testBufferFileDeletion(canShareDiskPaths: Boolean): Unit = { - val bufferId = MockRapidsBufferId(1, canShareDiskPaths) - val bufferPath = bufferId.getDiskPath(null) - assert(!bufferPath.exists) val spillPriority = -7 val hostStoreMaxSize = 1L * 1024 * 1024 val catalog = new RapidsBufferCatalog @@ -171,11 +165,13 @@ class RapidsDiskStoreSuite extends FunSuiteWithTempDir with Arm with MockitoSuga devStore.setSpillStore(hostStore) withResource(new RapidsDiskStore(mock[RapidsDiskBlockManager], catalog)) { diskStore => hostStore.setSpillStore(diskStore) - addTableToStore(devStore, bufferId, spillPriority) + val (_, handle) = addTableToStore(spillPriority) + val bufferPath = handle.id.getDiskPath(null) + assert(!bufferPath.exists()) devStore.synchronousSpill(0) hostStore.synchronousSpill(0) assert(bufferPath.exists) - catalog.removeBuffer(bufferId) + catalog.removeBuffer(handle) if (canShareDiskPaths) { assert(bufferPath.exists()) } else { @@ -186,22 +182,19 @@ class RapidsDiskStoreSuite extends FunSuiteWithTempDir with Arm with MockitoSuga } } - private def addTableToStore( - devStore: RapidsDeviceMemoryStore, - bufferId: RapidsBufferId, - spillPriority: Long): Long = { + private def addTableToStore(spillPriority: Long): (Long, RapidsBufferHandle) = { withResource(buildContiguousTable()) { ct => val bufferSize = ct.getBuffer.getLength // store takes ownership of the table - devStore.addContiguousTable(bufferId, ct, spillPriority) - bufferSize + val handle = RapidsBufferCatalog.addContiguousTable( + ct, + spillPriority, + RapidsBuffer.defaultSpillCallback) + (bufferSize, handle) } } - case class MockRapidsBufferId( - tableId: Int, - override val canShareDiskPaths: Boolean) extends RapidsBufferId { - override def getDiskPath(diskBlockManager: RapidsDiskBlockManager): File = - new File(TEST_FILES_ROOT, s"diskbuffer-$tableId") + def getDiskPath(bufferId: RapidsBufferId): File = { + new File(TEST_FILES_ROOT, s"diskbuffer-${bufferId.tableId}") } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsGdsStoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsGdsStoreSuite.scala index 10c4c12959b3..54d1d4c963fd 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsGdsStoreSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsGdsStoreSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,69 +33,77 @@ object GdsTest extends Tag("GdsTest") class RapidsGdsStoreSuite extends FunSuiteWithTempDir with Arm with MockitoSugar { - test("single shot spill with shared path", GdsTest) { - assume(CuFile.libraryLoaded()) - verifySingleShotSpill(canShareDiskPaths = true) - } - - test("single shot spill with exclusive path", GdsTest) { - assume(CuFile.libraryLoaded()) - verifySingleShotSpill(canShareDiskPaths = false) - } - - test("batch spill", GdsTest) { - assume(CuFile.libraryLoaded()) - - val bufferIds = Array(MockRapidsBufferId(7), MockRapidsBufferId(8), MockRapidsBufferId(9)) - val diskBlockManager = mock[RapidsDiskBlockManager] - val paths = Array( - new File(TEST_FILES_ROOT, s"gdsbuffer-0"), new File(TEST_FILES_ROOT, s"gdsbuffer-1")) - when(diskBlockManager.getFile(any[BlockId]())) - .thenReturn(paths(0)) - .thenReturn(paths(1)) - paths.foreach(f => assert(!f.exists)) - val spillPriority = -7 - val catalog = spy(new RapidsBufferCatalog) - val batchWriteBufferSize = 16384 // Holds 2 buffers. - withResource(new RapidsDeviceMemoryStore(catalog)) { devStore => - withResource(new RapidsGdsStore( - diskBlockManager, batchWriteBufferSize, catalog)) { gdsStore => - - devStore.setSpillStore(gdsStore) - assertResult(0)(gdsStore.currentSize) - - val bufferSizes = bufferIds.map(id => { - val size = addTableToStore(devStore, id, spillPriority) - devStore.synchronousSpill(0) - size - }) - val totalSize = bufferSizes.sum - assertResult(totalSize)(gdsStore.currentSize) - - assert(paths(0).exists) - assert(!paths(1).exists) - val alignedSize = Math.ceil((bufferSizes(0) + bufferSizes(1)) / 4096d).toLong * 4096 - assertResult(alignedSize)(paths(0).length) - - verify(catalog, times(6)).registerNewBuffer(ArgumentMatchers.any[RapidsBuffer]) - (bufferIds, bufferSizes).zipped.foreach { (id, size) => - verify(catalog).removeBufferTier( - ArgumentMatchers.eq(id), ArgumentMatchers.eq(StorageTier.DEVICE)) - withResource(catalog.acquireBuffer(id)) { buffer => - assertResult(StorageTier.GDS)(buffer.storageTier) - assertResult(id)(buffer.id) - assertResult(size)(buffer.size) - assertResult(spillPriority)(buffer.getSpillPriority) - } - } - - catalog.removeBuffer(bufferIds(0)) - assert(paths(0).exists) - catalog.removeBuffer(bufferIds(1)) - assert(!paths(0).exists) - } - } - } + test("single shot spill with shared path", GdsTest) { + println("Trying to load CuFile") + assume(CuFile.libraryLoaded()) + println("DID LOAD") + verifySingleShotSpill(canShareDiskPaths = true) + } + + test("single shot spill with exclusive path", GdsTest) { + assume(CuFile.libraryLoaded()) + verifySingleShotSpill(canShareDiskPaths = false) + } + + test("batch spill", GdsTest) { + assume(CuFile.libraryLoaded()) + + val bufferIds = Array(MockRapidsBufferId(7), MockRapidsBufferId(8), MockRapidsBufferId(9)) + val diskBlockManager = mock[RapidsDiskBlockManager] + val paths = Array( + new File(TEST_FILES_ROOT, s"gdsbuffer-0"), new File(TEST_FILES_ROOT, s"gdsbuffer-1")) + when(diskBlockManager.getFile(any[BlockId]())) + .thenReturn(paths(0)) + .thenReturn(paths(1)) + paths.foreach(f => assert(!f.exists)) + val spillPriority = -7 + val catalog = spy(new RapidsBufferCatalog) + val batchWriteBufferSize = 16384 // Holds 2 buffers. + withResource(new RapidsDeviceMemoryStore(catalog)) { devStore => + withResource(new RapidsGdsStore( + diskBlockManager, batchWriteBufferSize, catalog)) { gdsStore => + + devStore.setSpillStore(gdsStore) + assertResult(0)(gdsStore.currentSize) + + val bufferSizes = new Array[Long](bufferIds.length) + val bufferHandles = new Array[RapidsBufferHandle](bufferIds.length) + + bufferIds.zipWithIndex.foreach { case(id, ix) => + val size = addTableToStore(devStore, id, spillPriority) + devStore.synchronousSpill(0) + bufferSizes(ix) = size + bufferHandles(ix) = + catalog.makeNewHandle(id, spillPriority, RapidsBuffer.defaultSpillCallback) + } + + val totalSize = bufferSizes.sum + assertResult(totalSize)(gdsStore.currentSize) + + assert(paths(0).exists) + assert(!paths(1).exists) + val alignedSize = Math.ceil((bufferSizes(0) + bufferSizes(1)) / 4096d).toLong * 4096 + assertResult(alignedSize)(paths(0).length) + + verify(catalog, times(6)).registerNewBuffer(ArgumentMatchers.any[RapidsBuffer]) + (bufferIds, bufferSizes, bufferHandles).zipped.foreach { (id, size, handle) => + verify(catalog).removeBufferTier( + ArgumentMatchers.eq(id), ArgumentMatchers.eq(StorageTier.DEVICE)) + withResource(catalog.acquireBuffer(handle)) { buffer => + assertResult(StorageTier.GDS)(buffer.storageTier) + assertResult(id)(buffer.id) + assertResult(size)(buffer.size) + assertResult(spillPriority)(buffer.getSpillPriority) + } + } + + catalog.removeBuffer(bufferHandles(0)) + assert(paths(0).exists) + catalog.removeBuffer(bufferHandles(1)) + assert(!paths(0).exists) + } + } + } private def verifySingleShotSpill(canShareDiskPaths: Boolean): Assertion = { val bufferId = MockRapidsBufferId(7, canShareDiskPaths) @@ -109,6 +117,8 @@ class RapidsGdsStoreSuite extends FunSuiteWithTempDir with Arm with MockitoSugar devStore.setSpillStore(gdsStore) assertResult(0)(gdsStore.currentSize) val bufferSize = addTableToStore(devStore, bufferId, spillPriority) + val handle = + catalog.makeNewHandle(bufferId, spillPriority, RapidsBuffer.defaultSpillCallback) devStore.synchronousSpill(0) assertResult(bufferSize)(gdsStore.currentSize) assert(path.exists) @@ -116,14 +126,14 @@ class RapidsGdsStoreSuite extends FunSuiteWithTempDir with Arm with MockitoSugar verify(catalog, times(2)).registerNewBuffer(ArgumentMatchers.any[RapidsBuffer]) verify(catalog).removeBufferTier( ArgumentMatchers.eq(bufferId), ArgumentMatchers.eq(StorageTier.DEVICE)) - withResource(catalog.acquireBuffer(bufferId)) { buffer => + withResource(catalog.acquireBuffer(handle)) { buffer => assertResult(StorageTier.GDS)(buffer.storageTier) assertResult(bufferSize)(buffer.size) assertResult(bufferId)(buffer.id) assertResult(spillPriority)(buffer.getSpillPriority) } - catalog.removeBuffer(bufferId) + catalog.removeBuffer(handle) if (canShareDiskPaths) { assert(path.exists()) } else { @@ -140,7 +150,8 @@ class RapidsGdsStoreSuite extends FunSuiteWithTempDir with Arm with MockitoSugar withResource(buildContiguousTable()) { ct => val bufferSize = ct.getBuffer.getLength // store takes ownership of the table - devStore.addContiguousTable(bufferId, ct, spillPriority) + devStore.addContiguousTable(bufferId, ct, spillPriority, + RapidsBuffer.defaultSpillCallback, false) bufferSize } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala index f5455d627336..3d363d146b28 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -52,7 +52,6 @@ class RapidsHostMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { } test("spill updates catalog") { - val bufferId = MockRapidsBufferId(7) val spillPriority = -7 val hostStoreMaxSize = 1L * 1024 * 1024 val catalog = spy(new RapidsBufferCatalog) @@ -63,11 +62,14 @@ class RapidsHostMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { assertResult(hostStoreMaxSize)(hostStore.numBytesFree) devStore.setSpillStore(hostStore) - val bufferSize = withResource(buildContiguousTable()) { ct => + val (bufferSize, handle) = withResource(buildContiguousTable()) { ct => val len = ct.getBuffer.getLength // store takes ownership of the table - devStore.addContiguousTable(bufferId, ct, spillPriority) - len + val handle = RapidsBufferCatalog.addContiguousTable( + ct, + spillPriority, + RapidsBuffer.defaultSpillCallback) + (len, handle) } devStore.synchronousSpill(0) @@ -75,11 +77,11 @@ class RapidsHostMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { assertResult(hostStoreMaxSize - bufferSize)(hostStore.numBytesFree) verify(catalog, times(2)).registerNewBuffer(ArgumentMatchers.any[RapidsBuffer]) verify(catalog).removeBufferTier( - ArgumentMatchers.eq(bufferId), ArgumentMatchers.eq(StorageTier.DEVICE)) - withResource(catalog.acquireBuffer(bufferId)) { buffer => + ArgumentMatchers.eq(handle.id), ArgumentMatchers.eq(StorageTier.DEVICE)) + withResource(catalog.acquireBuffer(handle)) { buffer => assertResult(StorageTier.HOST)(buffer.storageTier) assertResult(bufferSize)(buffer.size) - assertResult(bufferId)(buffer.id) + assertResult(handle.id)(buffer.id) assertResult(spillPriority)(buffer.getSpillPriority) } } @@ -87,7 +89,6 @@ class RapidsHostMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { } test("get columnar batch") { - val bufferId = MockRapidsBufferId(7) val spillPriority = -10 val hostStoreMaxSize = 1L * 1024 * 1024 val catalog = new RapidsBufferCatalog @@ -98,9 +99,12 @@ class RapidsHostMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { withResource(buildContiguousTable()) { ct => withResource(HostMemoryBuffer.allocate(ct.getBuffer.getLength)) { expectedBuffer => expectedBuffer.copyFromDeviceBuffer(ct.getBuffer) - devStore.addContiguousTable(bufferId, ct, spillPriority) + val handle = RapidsBufferCatalog.addContiguousTable( + ct, + spillPriority, + RapidsBuffer.defaultSpillCallback) devStore.synchronousSpill(0) - withResource(catalog.acquireBuffer(bufferId)) { buffer => + withResource(catalog.acquireBuffer(handle)) { buffer => withResource(buffer.getMemoryBuffer) { actualBuffer => assert(actualBuffer.isInstanceOf[HostMemoryBuffer]) assertResult(expectedBuffer.asByteBuffer) { @@ -117,7 +121,6 @@ class RapidsHostMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { test("get memory buffer") { val sparkTypes = Array[DataType](IntegerType, StringType, DoubleType, DecimalType(ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION, 5)) - val bufferId = MockRapidsBufferId(7) val spillPriority = -10 val hostStoreMaxSize = 1L * 1024 * 1024 val catalog = new RapidsBufferCatalog @@ -129,9 +132,12 @@ class RapidsHostMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { withResource(buildContiguousTable()) { ct => withResource(GpuColumnVector.from(ct.getTable, sparkTypes)) { expectedBatch => - devStore.addContiguousTable(bufferId, ct, spillPriority) + val handle = RapidsBufferCatalog.addContiguousTable( + ct, + spillPriority, + RapidsBuffer.defaultSpillCallback) devStore.synchronousSpill(0) - withResource(catalog.acquireBuffer(bufferId)) { buffer => + withResource(catalog.acquireBuffer(handle)) { buffer => assertResult(StorageTier.HOST)(buffer.storageTier) withResource(buffer.getColumnarBatch(sparkTypes)) { actualBatch => TestUtils.compareBatches(expectedBatch, actualBatch) @@ -145,8 +151,6 @@ class RapidsHostMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { test("buffer exceeds maximum size") { val sparkTypes = Array[DataType](LongType) - val bigBufferId = MockRapidsBufferId(7) - val smallBufferId = MockRapidsBufferId(8) val spillPriority = -10 val hostStoreMaxSize = 256 val catalog = new RapidsBufferCatalog @@ -162,19 +166,24 @@ class RapidsHostMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { withResource(buildContiguousTable(1)) { smallTable => withResource(GpuColumnVector.from(bigTable.getTable, sparkTypes)) { expectedBatch => // store takes ownership of the table - devStore.addContiguousTable(bigBufferId, bigTable, spillPriority) + val bigHandle = RapidsBufferCatalog.addContiguousTable( + bigTable, + spillPriority, + RapidsBuffer.defaultSpillCallback) devStore.synchronousSpill(0) verify(mockStore, never()).copyBuffer(ArgumentMatchers.any[RapidsBuffer], ArgumentMatchers.any[MemoryBuffer], ArgumentMatchers.any[Cuda.Stream]) - withResource(catalog.acquireBuffer(bigBufferId)) { buffer => + withResource(catalog.acquireBuffer(bigHandle)) { buffer => assertResult(StorageTier.HOST)(buffer.storageTier) withResource(buffer.getColumnarBatch(sparkTypes)) { actualBatch => TestUtils.compareBatches(expectedBatch, actualBatch) } } - devStore.addContiguousTable(smallBufferId, smallTable, spillPriority) + devStore.addContiguousTable( + smallTable, spillPriority, + RapidsBuffer.defaultSpillCallback, false) devStore.synchronousSpill(0) val rapidsBufferCaptor: ArgumentCaptor[RapidsBuffer] = ArgumentCaptor.forClass(classOf[RapidsBuffer]) @@ -183,7 +192,7 @@ class RapidsHostMemoryStoreSuite extends FunSuite with Arm with MockitoSugar { verify(mockStore).copyBuffer(rapidsBufferCaptor.capture(), memoryBufferCaptor.capture(), ArgumentMatchers.any[Cuda.Stream]) withResource(memoryBufferCaptor.getValue) { _ => - assertResult(bigBufferId)(rapidsBufferCaptor.getValue.id) + assertResult(bigHandle.id)(rapidsBufferCaptor.getValue.id) } } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala index a06f51cf4da2..4a348a2ad1a8 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleIteratorSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids.shuffle -import com.nvidia.spark.rapids.{RapidsBuffer, ShuffleReceivedBufferId} +import com.nvidia.spark.rapids.{RapidsBuffer, RapidsBufferHandle} import org.mockito.ArgumentCaptor import org.mockito.ArgumentMatchers._ import org.mockito.Mockito._ @@ -181,19 +181,18 @@ class RapidsShuffleIteratorSuite extends RapidsShuffleTestHelper { val ac = ArgumentCaptor.forClass(classOf[RapidsShuffleFetchHandler]) when(mockTransport.makeClient(any())).thenReturn(client) doNothing().when(client).doFetch(any(), ac.capture()) - val bufferId = ShuffleReceivedBufferId(1) val mockBuffer = mock[RapidsBuffer] val cb = new ColumnarBatch(Array.empty, 10) - + val handle = mock[RapidsBufferHandle] when(mockBuffer.getColumnarBatch(Array.empty)).thenReturn(cb) - when(mockCatalog.acquireBuffer(any[ShuffleReceivedBufferId]())).thenReturn(mockBuffer) + when(mockCatalog.acquireBuffer(any[RapidsBufferHandle]())).thenReturn(mockBuffer) doNothing().when(mockCatalog).removeBuffer(any()) cl.start() val handler = ac.getValue.asInstanceOf[RapidsShuffleFetchHandler] handler.start(1) - handler.batchReceived(bufferId) + handler.batchReceived(handle) verify(mockTransport, times(0)).cancelPending(handler) diff --git a/tests/src/test/scala/org/apache/spark/sql/rapids/SpillableColumnarBatchSuite.scala b/tests/src/test/scala/org/apache/spark/sql/rapids/SpillableColumnarBatchSuite.scala index 7be39d01b9e6..4433b78933c2 100644 --- a/tests/src/test/scala/org/apache/spark/sql/rapids/SpillableColumnarBatchSuite.scala +++ b/tests/src/test/scala/org/apache/spark/sql/rapids/SpillableColumnarBatchSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,14 +29,19 @@ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.storage.TempLocalBlockId class SpillableColumnarBatchSuite extends FunSuite with Arm { + test("close updates catalog") { val id = TempSpillBufferId(0, TempLocalBlockId(new UUID(1, 2))) val mockBuffer = new MockBuffer(id) val catalog = RapidsBufferCatalog.singleton val oldBufferCount = catalog.numBuffers catalog.registerNewBuffer(mockBuffer) + val handle = catalog.makeNewHandle(id, -1, RapidsBuffer.defaultSpillCallback) assertResult(oldBufferCount + 1)(catalog.numBuffers) - val spillableBatch = new SpillableColumnarBatchImpl(id, 5, Array[DataType](IntegerType), + val spillableBatch = new SpillableColumnarBatchImpl( + handle, + 5, + Array[DataType](IntegerType), NoopMetric) spillableBatch.close() assertResult(oldBufferCount)(catalog.numBuffers) @@ -55,7 +60,10 @@ class SpillableColumnarBatchSuite extends FunSuite with Arm { override def getSpillPriority: Long = 0 override def setSpillPriority(priority: Long): Unit = {} override def close(): Unit = {} - override def getColumnarBatch(sparkTypes: Array[DataType]): ColumnarBatch = null - override val spillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback + override def getColumnarBatch( + sparkTypes: Array[DataType]): ColumnarBatch = null + + override val getSpillCallback: SpillCallback = RapidsBuffer.defaultSpillCallback + override def setSpillCallback(spillCallback: SpillCallback): Unit = {} } }