Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make RapidsBufferHandle AutoCloseable to prevent extra attempts to remove buffers #7548

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class DuplicateBufferException(s: String) extends RuntimeException(s) {}
* A handle is obtained when a buffer, batch, or table is added to the spill framework
* via the `RapidsBufferCatalog` api.
*/
trait RapidsBufferHandle {
trait RapidsBufferHandle extends AutoCloseable {
val id: RapidsBufferId

/**
Expand Down Expand Up @@ -70,6 +70,8 @@ class RapidsBufferCatalog extends AutoCloseable with Arm {
spillCallback: SpillCallback)
extends RapidsBufferHandle {

private var closed = false

override def setSpillPriority(newPriority: Long): Unit = {
priority = newPriority
updateUnderlyingRapidsBuffer(this)
Expand All @@ -94,6 +96,18 @@ class RapidsBufferCatalog extends AutoCloseable with Arm {
* @return the spill callback associated with this handle
*/
def getSpillCallback: SpillCallback = spillCallback

override def close(): Unit = synchronized {
// since the handle is stored in the catalog in addition to being
// handed out to potentially a `SpillableColumnarBatch` or `SpillableBuffer`
// there is a chance we may double close it. For example, a broadcast exec
// that is closing its spillable (and therefore the handle) + the handle being
// closed from the catalog's close method.
if (!closed) {
removeBuffer(this)
}
closed = true
}
}

/**
Expand Down Expand Up @@ -313,7 +327,7 @@ class RapidsBufferCatalog extends AutoCloseable with Arm {
* (`handle` was the last handle)
* false: if buffer was not removed due to other live handles.
*/
def removeBuffer(handle: RapidsBufferHandle): Boolean = {
private def removeBuffer(handle: RapidsBufferHandle): Boolean = {
// if this is the last handle, remove the buffer
if (stopTrackingHandle(handle)) {
val buffers = bufferMap.remove(handle.id)
Expand All @@ -329,7 +343,7 @@ class RapidsBufferCatalog extends AutoCloseable with Arm {

override def close(): Unit = {
bufferIdToHandles.values.forEach { handles =>
handles.foreach(removeBuffer)
handles.foreach(_.close())
}
bufferIdToHandles.clear()
}
Expand Down Expand Up @@ -495,12 +509,5 @@ object RapidsBufferCatalog extends Logging with Arm {
def acquireBuffer(handle: RapidsBufferHandle): RapidsBuffer =
singleton.acquireBuffer(handle)

/**
* Remove a buffer handle from the catalog and, if it this was the final handle,
* release the resources of the registered buffers.
*/
def removeBuffer(handle: RapidsBufferHandle): Unit =
singleton.removeBuffer(handle)

def getDiskBlockManager(): RapidsDiskBlockManager = diskBlockManager
}
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,7 @@ class ShuffleBufferCatalog(
// NOTE: Not synchronizing array buffer because this shuffle should be inactive.
bufferIds.foreach { id =>
tableMap.remove(id.tableId)
val didRemove = catalog.removeBuffer(bufferIdToHandle.get(id))
if (!didRemove) {
logWarning(s"Unable to remove $id from underlying storage when cleaning " +
s"shuffle blocks.")
}
bufferIdToHandle.get(id).close()
}
}
info.blockMap.forEachValue(Long.MaxValue, bufferRemover)
Expand Down Expand Up @@ -316,7 +312,7 @@ class ShuffleBufferCatalog(
def removeBuffer(handle: RapidsBufferHandle): Unit = {
val id = handle.id
tableMap.remove(id.tableId)
catalog.removeBuffer(handle)
handle.close()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class ShuffleReceivedBufferCatalog(
def removeBuffer(handle: RapidsBufferHandle): Unit = {
val id = handle.id
tableMap.remove(id.tableId)
catalog.removeBuffer(handle)
handle.close()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@ class SpillableColumnarBatchImpl (
sparkTypes: Array[DataType],
semWait: GpuMetric)
extends SpillableColumnarBatch with Arm {
private var closed = false

/**
* The number of rows stored in this batch.
*/
Expand Down Expand Up @@ -113,11 +111,8 @@ class SpillableColumnarBatchImpl (
* Remove the `ColumnarBatch` from the cache.
*/
override def close(): Unit = {
if (!closed) {
// closing my reference
RapidsBufferCatalog.removeBuffer(handle)
closed = true
}
// closing my reference
handle.close()
}
}

Expand Down Expand Up @@ -224,8 +219,6 @@ class SpillableBuffer(
handle: RapidsBufferHandle,
semWait: GpuMetric) extends AutoCloseable with Arm {

private var closed = false

/**
* Set a new spill priority.
*/
Expand All @@ -246,10 +239,7 @@ class SpillableBuffer(
* Remove the buffer from the cache.
*/
override def close(): Unit = {
if (!closed) {
RapidsBufferCatalog.removeBuffer(handle)
closed = true
}
handle.close()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class RapidsBufferCatalogSuite extends FunSuite with MockitoSugar with Arm {
val bufferHandle = new RapidsBufferHandle {
override val id: RapidsBufferId = bufferId
override def setSpillPriority(newPriority: Long): Unit = {}
override def close(): Unit = {}
}

assertThrows[NoSuchElementException](catalog.acquireBuffer(bufferHandle))
Expand All @@ -64,14 +65,14 @@ class RapidsBufferCatalogSuite extends FunSuite with MockitoSugar with Arm {
val handle2 =
catalog.makeNewHandle(bufferId, -1, RapidsBuffer.defaultSpillCallback)

catalog.removeBuffer(handle1)
handle1.close()

// this does not throw
catalog.acquireBuffer(handle2).close()
// actually this doesn't throw either
catalog.acquireBuffer(handle1).close()

catalog.removeBuffer(handle2)
handle2.close()

assertThrows[NoSuchElementException](catalog.acquireBuffer(handle1))
assertThrows[NoSuchElementException](catalog.acquireBuffer(handle2))
Expand All @@ -94,7 +95,7 @@ class RapidsBufferCatalogSuite extends FunSuite with MockitoSugar with Arm {
}

// removing the lower priority handle, keeps the high priority spill
catalog.removeBuffer(handle1)
handle1.close()
withResource(catalog.acquireBuffer(handle2)) { buff =>
assertResult(0)(buff.getSpillPriority)
}
Expand All @@ -108,12 +109,12 @@ class RapidsBufferCatalogSuite extends FunSuite with MockitoSugar with Arm {

// removing the high priority spill (0) brings us down to the
// low priority that is remaining
catalog.removeBuffer(handle2)
handle2.close()
withResource(catalog.acquireBuffer(handle2)) { buff =>
assertResult(-1000)(buff.getSpillPriority)
}

catalog.removeBuffer(handle3)
handle3.close()
}

test("spill callbacks are updated as handles are registered and unregistered") {
Expand Down Expand Up @@ -146,17 +147,17 @@ class RapidsBufferCatalogSuite extends FunSuite with MockitoSugar with Arm {

// removing handles brings back the prior inserted callback
// low priority that is remaining
catalog.removeBuffer(handle3)
handle3.close()
withResource(catalog.acquireBuffer(handle2)) { buff =>
assertResult(RapidsBuffer.defaultSpillCallback)(buff.getSpillCallback)
}

catalog.removeBuffer(handle2)
handle2.close()
withResource(catalog.acquireBuffer(handle1)) { buff =>
assertResult(null)(buff.getSpillCallback)
}

catalog.removeBuffer(handle1)
handle1.close()
}

test("buffer registering slower tier does not hide faster tier") {
Expand Down Expand Up @@ -286,7 +287,7 @@ class RapidsBufferCatalogSuite extends FunSuite with MockitoSugar with Arm {
catalog.registerNewBuffer(buffer)
val handle = catalog.makeNewHandle(
bufferId, -1, RapidsBuffer.defaultSpillCallback)
catalog.removeBuffer(handle)
handle.close()
verify(buffer).free()
}

Expand All @@ -306,7 +307,7 @@ class RapidsBufferCatalogSuite extends FunSuite with MockitoSugar with Arm {
catalog.registerNewBuffer(buffer3)

// removing the original handle removes all buffers from all tiers.
catalog.removeBuffer(handle)
handle.close()
verify(buffer).free()
verify(buffer2).free()
verify(buffer3).free()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,9 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar {
}
assertResult(bufferSizes.take(i+1).sum)(store.currentSize)
}
catalog.removeBuffer(bufferHandles(0))
bufferHandles(0).close()
assertResult(bufferSizes(1))(store.currentSize)
catalog.removeBuffer(bufferHandles(1))
bufferHandles(1).close()
assertResult(0)(store.currentSize)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ class RapidsDiskStoreSuite extends FunSuiteWithTempDir with Arm with MockitoSuga
devStore.synchronousSpill(0)
hostStore.synchronousSpill(0)
assert(bufferPath.exists)
catalog.removeBuffer(handle)
handle.close()
if (canShareDiskPaths) {
assert(bufferPath.exists())
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ class RapidsGdsStoreSuite extends FunSuiteWithTempDir with Arm with MockitoSugar
}
}

catalog.removeBuffer(bufferHandles(0))
bufferHandles(0).close()
assert(paths(0).exists)
catalog.removeBuffer(bufferHandles(1))
bufferHandles(1).close()
assert(!paths(0).exists)
}
}
Expand Down Expand Up @@ -130,7 +130,7 @@ class RapidsGdsStoreSuite extends FunSuiteWithTempDir with Arm with MockitoSugar
assertResult(spillPriority)(buffer.getSpillPriority)
}

catalog.removeBuffer(handle)
handle.close()
if (canShareDiskPaths) {
assert(path.exists())
} else {
Expand Down