Skip to content

Commit

Permalink
Add closeOnExcept to clean up code that must close resources only on …
Browse files Browse the repository at this point in the history
…exceptions (#456)

Signed-off-by: Jason Lowe <[email protected]>
  • Loading branch information
jlowe authored Jul 29, 2020
1 parent 76a1d4d commit 42d0564
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 74 deletions.
35 changes: 35 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/Arm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
package com.nvidia.spark.rapids

import scala.collection.mutable.ArrayBuffer

import com.nvidia.spark.rapids.RapidsPluginImplicits._

/** Implementation of the automatic-resource-management pattern */
Expand All @@ -37,4 +39,37 @@ trait Arm {
r.safeClose()
}
}

/** Executes the provided code block, closing the resource only if an exception occurs */
def closeOnExcept[T <: AutoCloseable, V](r: T)(block: T => V): V = {
try {
block(r)
} catch {
case t: Throwable =>
r.safeClose()
throw t
}
}

/** Executes the provided code block, closing the resources only if an exception occurs */
def closeOnExcept[T <: AutoCloseable, V](r: Seq[T])(block: Seq[T] => V): V = {
try {
block(r)
} catch {
case t: Throwable =>
r.safeClose()
throw t
}
}

/** Executes the provided code block, closing the resources only if an exception occurs */
def closeOnExcept[T <: AutoCloseable, V](r: ArrayBuffer[T])(block: ArrayBuffer[T] => V): V = {
try {
block(r)
} catch {
case t: Throwable =>
r.safeClose()
throw t
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ class CSVPartitionReader(
maxRowsPerChunk: Integer,
maxBytesPerChunk: Long,
execMetrics: Map[String, SQLMetric])
extends PartitionReader[ColumnarBatch] with ScanWithMetrics {
extends PartitionReader[ColumnarBatch] with ScanWithMetrics with Arm {

private var batch: Option[ColumnarBatch] = None
private val lineReader = new HadoopFileLinesReader(partFile, parsedOptions.lineSeparatorInRead,
Expand Down Expand Up @@ -380,16 +380,11 @@ class CSVPartitionReader(
*/
private def growHostBuffer(original: HostMemoryBuffer, needed: Long): HostMemoryBuffer = {
val newSize = Math.max(original.getLength * 2, needed)
val result = HostMemoryBuffer.allocate(newSize)
try {
closeOnExcept(HostMemoryBuffer.allocate(newSize)) { result =>
result.copyFromHostBuffer(0, original, 0, original.getLength)
original.close()
} catch {
case e: Throwable =>
result.close()
throw e
result
}
result
}

private def readPartFile(): (HostMemoryBuffer, Long, Integer) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,17 @@ import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.unsafe.types.UTF8String

object GpuExpressionsUtils {
object GpuExpressionsUtils extends Arm {
def evaluateBoundExpressions[A <: GpuExpression](cb: ColumnarBatch,
boundExprs: Seq[A]): Seq[GpuColumnVector] = {
val numCols = boundExprs.length
val resultCvs = new ArrayBuffer[GpuColumnVector](numCols)
try {
closeOnExcept(new ArrayBuffer[GpuColumnVector](numCols)) { resultCvs =>
for (i <- 0 until numCols) {
val ref = boundExprs(i)
resultCvs += ref.columnarEval(cb).asInstanceOf[GpuColumnVector]
}
} catch {
case t: Throwable =>
resultCvs.safeClose()
throw t
resultCvs
}
resultCvs
}

def getTrimString(trimStr: Option[Expression]): String = trimStr match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.storage.ShuffleBlockBatchId

object MetaUtils {
object MetaUtils extends Arm {
/**
* Build a TableMeta message from a Table in contiguous memory
*
Expand Down Expand Up @@ -202,17 +202,12 @@ object MetaUtils {
* @return columnar batch that must be closed by the caller
*/
def getBatchFromMeta(deviceBuffer: DeviceMemoryBuffer, meta: TableMeta): ColumnarBatch = {
val columns = new ArrayBuffer[GpuColumnVector](meta.columnMetasLength())
try {
closeOnExcept(new ArrayBuffer[GpuColumnVector](meta.columnMetasLength())) { columns =>
val columnMeta = new ColumnMeta
(0 until meta.columnMetasLength).foreach { i =>
columns.append(makeColumn(deviceBuffer, meta.columnMetas(columnMeta, i)))
}
new ColumnarBatch(columns.toArray, meta.rowCount.toInt)
} catch {
case e: Exception =>
columns.foreach(_.close())
throw e
}
}

Expand Down
80 changes: 80 additions & 0 deletions tests/src/test/scala/com/nvidia/spark/rapids/ArmSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids

import scala.collection.mutable.ArrayBuffer

import org.scalatest.FunSuite

class ArmSuite extends FunSuite with Arm {
class TestResource extends AutoCloseable {
var closed = false

override def close(): Unit = {
closed = true
}
}

class TestException extends RuntimeException

test("closeOnExcept single instance") {
val resource = new TestResource
closeOnExcept(resource) { r => assertResult(resource)(r) }
assertResult(false)(resource.closed)
try {
closeOnExcept(resource) { _ => throw new TestException }
} catch {
case _: TestException =>
}
assert(resource.closed)
}

test("closeOnExcept sequence") {
val resources = new Array[TestResource](3)
resources(0) = new TestResource
resources(2) = new TestResource
closeOnExcept(resources) { r => assertResult(resources)(r) }
assert(resources.forall(r => Option(r).forall(!_.closed)))
try {
closeOnExcept(resources) { _ => throw new TestException }
} catch {
case _: TestException =>
}
assert(resources.forall(r => Option(r).forall(_.closed)))
}

test("closeOnExcept arraybuffer") {
val resources = new ArrayBuffer[TestResource]
closeOnExcept(resources) { r =>
r += new TestResource
r += null
r += new TestResource
}
assertResult(3)(resources.length)
assert(resources.forall(r => Option(r).forall(!_.closed)))
try {
closeOnExcept(resources) { r =>
r += new TestResource
throw new TestException
}
} catch {
case _: TestException =>
}
assertResult(4)(resources.length)
assert(resources.forall(r => Option(r).forall(_.closed)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,15 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar {
withResource(new RapidsDeviceMemoryStore(catalog)) { store =>
val spillPriority = 3
val bufferId = MockRapidsBufferId(7)
var ct: ContiguousTable = buildContiguousTable()
try {
closeOnExcept(buildContiguousTable()) { ct =>
// store takes ownership of the table
store.addTable(bufferId, ct.getTable, ct.getBuffer, spillPriority)
ct = null
val captor: ArgumentCaptor[RapidsBuffer] = ArgumentCaptor.forClass(classOf[RapidsBuffer])
verify(catalog).registerNewBuffer(captor.capture())
val resultBuffer = captor.getValue
assertResult(bufferId)(resultBuffer.id)
assertResult(spillPriority)(resultBuffer.getSpillPriority)
} finally {
if (ct != null) {
ct.close()
}
}
val captor: ArgumentCaptor[RapidsBuffer] = ArgumentCaptor.forClass(classOf[RapidsBuffer])
verify(catalog).registerNewBuffer(captor.capture())
val resultBuffer = captor.getValue
assertResult(bufferId)(resultBuffer.id)
assertResult(spillPriority)(resultBuffer.getSpillPriority)
}
}

Expand All @@ -69,16 +63,11 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar {
withResource(new RapidsDeviceMemoryStore(catalog)) { store =>
val spillPriority = 3
val bufferId = MockRapidsBufferId(7)
val ct = buildContiguousTable()
val meta = try {
val meta = closeOnExcept(buildContiguousTable()) { ct =>
val meta = MetaUtils.buildTableMeta(bufferId.tableId, ct.getTable, ct.getBuffer)
// store takes ownership of the buffer
store.addBuffer(bufferId, ct.getBuffer, meta, spillPriority)
meta
} catch {
case t: Throwable =>
ct.close()
throw t
}
val captor: ArgumentCaptor[RapidsBuffer] = ArgumentCaptor.forClass(classOf[RapidsBuffer])
verify(catalog).registerNewBuffer(captor.capture())
Expand All @@ -93,14 +82,12 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar {
val catalog = new RapidsBufferCatalog
withResource(new RapidsDeviceMemoryStore(catalog)) { store =>
val bufferId = MockRapidsBufferId(7)
var ct = buildContiguousTable()
try {
closeOnExcept(buildContiguousTable()) { ct =>
withResource(HostMemoryBuffer.allocate(ct.getBuffer.getLength)) { expectedHostBuffer =>
expectedHostBuffer.copyFromDeviceBuffer(ct.getBuffer)
val meta = MetaUtils.buildTableMeta(bufferId.tableId, ct.getTable, ct.getBuffer)
// store takes ownership of the buffer
store.addBuffer(bufferId, ct.getBuffer, meta, initialSpillPriority = 3)
ct = null
withResource(catalog.acquireBuffer(bufferId)) { buffer =>
withResource(buffer.getMemoryBuffer.asInstanceOf[DeviceMemoryBuffer]) { devbuf =>
withResource(HostMemoryBuffer.allocate(devbuf.getLength)) { actualHostBuffer =>
Expand All @@ -110,10 +97,6 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar {
}
}
}
} finally {
if (ct != null) {
ct.close()
}
}
}
}
Expand All @@ -122,23 +105,17 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar {
val catalog = new RapidsBufferCatalog
withResource(new RapidsDeviceMemoryStore(catalog)) { store =>
val bufferId = MockRapidsBufferId(7)
var ct = buildContiguousTable()
try {
closeOnExcept(buildContiguousTable()) { ct =>
withResource(GpuColumnVector.from(ct.getTable)) { expectedBatch =>
val meta = MetaUtils.buildTableMeta(bufferId.tableId, ct.getTable, ct.getBuffer)
// store takes ownership of the buffer
store.addBuffer(bufferId, ct.getBuffer, meta, initialSpillPriority = 3)
ct = null
withResource(catalog.acquireBuffer(bufferId)) { buffer =>
withResource(buffer.getColumnarBatch) { actualBatch =>
TestUtils.compareBatches(expectedBatch, actualBatch)
}
}
}
} finally {
if (ct != null) {
ct.close()
}
}
}
}
Expand All @@ -156,15 +133,10 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar {
assertResult(0)(store.currentSize)
val bufferSizes = new Array[Long](2)
bufferSizes.indices.foreach { i =>
val ct = buildContiguousTable()
try {
closeOnExcept(buildContiguousTable()) { ct =>
bufferSizes(i) = ct.getBuffer.getLength
// store takes ownership of the table
store.addTable(MockRapidsBufferId(i), ct.getTable, ct.getBuffer, initialSpillPriority = 0)
} catch {
case t: Throwable =>
ct.close()
throw t
}
assertResult(bufferSizes.take(i+1).sum)(store.currentSize)
}
Expand All @@ -183,15 +155,10 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with Arm with MockitoSugar {
withResource(new RapidsDeviceMemoryStore(catalog)) { store =>
store.setSpillStore(spillStore)
spillPriorities.indices.foreach { i =>
val ct = buildContiguousTable()
try {
closeOnExcept(buildContiguousTable()) { ct =>
bufferSizes(i) = ct.getBuffer.getLength
// store takes ownership of the table
store.addTable(MockRapidsBufferId(i), ct.getTable, ct.getBuffer, spillPriorities(i))
} catch {
case t: Throwable =>
ct.close()
throw t
}
}
assert(spillStore.spilledBuffers.isEmpty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,11 @@ class RapidsHostMemoryStoreSuite extends FunSuite with Arm with MockitoSugar {
assertResult(hostStoreMaxSize)(hostStore.numBytesFree)
devStore.setSpillStore(hostStore)

val ct = buildContiguousTable()
val bufferSize = ct.getBuffer.getLength
try {
val bufferSize = closeOnExcept(buildContiguousTable()) { ct =>
val len = ct.getBuffer.getLength
// store takes ownership of the table
devStore.addTable(bufferId, ct.getTable, ct.getBuffer, spillPriority)
} catch {
case t: Throwable =>
ct.close()
throw t
len
}

devStore.synchronousSpill(0)
Expand Down

0 comments on commit 42d0564

Please sign in to comment.