Skip to content

Commit

Permalink
Fix test issues
Browse files Browse the repository at this point in the history
  • Loading branch information
abellina committed May 22, 2023
1 parent 3fd1e28 commit a5b008f
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -212,29 +212,39 @@ class RapidsBufferCatalogSuite extends FunSuite with MockitoSugar {
}

test("multiple calls to unspill return existing DEVICE buffer") {
val deviceStore = spy(new RapidsDeviceMemoryStore)
val mockStore = mock[RapidsBufferStore]
withResource(
new RapidsHostMemoryStore(10000, 1000)) { hostStore =>
deviceStore.setSpillStore(hostStore)
hostStore.setSpillStore(mockStore)
val catalog = new RapidsBufferCatalog(deviceStore)
val handle = withResource(DeviceMemoryBuffer.allocate(1024)) { buff =>
val meta = MetaUtils.getTableMetaNoTable(buff)
catalog.addBuffer(
buff, meta, -1)
}
withResource(handle) { _ =>
catalog.synchronousSpill(deviceStore, 0)
val acquiredHostBuffer = catalog.acquireBuffer(handle)
val unspilled = withResource(acquiredHostBuffer) { _ =>
assertResult(HOST)(acquiredHostBuffer.storageTier)
val unspilled =
catalog.unspillBufferToDeviceStore(
withResource(spy(new RapidsDeviceMemoryStore)) { deviceStore =>
val mockStore = mock[RapidsBufferStore]
withResource(
new RapidsHostMemoryStore(10000, 1000)) { hostStore =>
deviceStore.setSpillStore(hostStore)
hostStore.setSpillStore(mockStore)
val catalog = new RapidsBufferCatalog(deviceStore)
val handle = withResource(DeviceMemoryBuffer.allocate(1024)) { buff =>
val meta = MetaUtils.getTableMetaNoTable(buff)
catalog.addBuffer(
buff, meta, -1)
}
withResource(handle) { _ =>
catalog.synchronousSpill(deviceStore, 0)
val acquiredHostBuffer = catalog.acquireBuffer(handle)
val unspilled = withResource(acquiredHostBuffer) { _ =>
assertResult(HOST)(acquiredHostBuffer.storageTier)
val unspilled =
catalog.unspillBufferToDeviceStore(
acquiredHostBuffer,
Cuda.DEFAULT_STREAM)
withResource(unspilled) { _ =>
assertResult(DEVICE)(unspilled.storageTier)
}
val unspilledSame = catalog.unspillBufferToDeviceStore(
acquiredHostBuffer,
Cuda.DEFAULT_STREAM)
withResource(unspilled) { _ =>
assertResult(DEVICE)(unspilled.storageTier)
withResource(unspilledSame) { _ =>
assertResult(unspilled)(unspilledSame)
}
// verify that we invoked the copy function exactly once
verify(deviceStore, times(1)).copyBuffer(any(), any())
unspilled
}
val unspilledSame = catalog.unspillBufferToDeviceStore(
acquiredHostBuffer,
Expand All @@ -244,16 +254,7 @@ class RapidsBufferCatalogSuite extends FunSuite with MockitoSugar {
}
// verify that we invoked the copy function exactly once
verify(deviceStore, times(1)).copyBuffer(any(), any())
unspilled
}
val unspilledSame = catalog.unspillBufferToDeviceStore(
acquiredHostBuffer,
Cuda.DEFAULT_STREAM)
withResource(unspilledSame) { _ =>
assertResult(unspilled)(unspilledSame)
}
// verify that we invoked the copy function exactly once
verify(deviceStore, times(1)).copyBuffer(any(), any())
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ import org.scalatest.FunSuite
import org.scalatest.mockito.MockitoSugar

import org.apache.spark.sql.rapids.RapidsDiskBlockManager
import org.apache.spark.sql.types.{DataType, DecimalType, DoubleType, FloatType, IntegerType, StringType}
import org.apache.spark.sql.types.{DataType, DecimalType, DoubleType, IntegerType, StringType}

class RapidsDeviceMemoryStoreSuite extends FunSuite with MockitoSugar {
private def buildTable(): Table = {
new Table.TestBuilder()
.column(5, null.asInstanceOf[java.lang.Integer], 3, 1)
.column("five", "two", null, null)
.column(5.0, 2.0, 3.0, 1.0)
.column(5.0D, 2.0D, 3.0D, 1.0D)
.decimal64Column(-5, RoundingMode.UNNECESSARY, 0, null, -1.4, 10.123)
.build()
}
Expand Down Expand Up @@ -96,7 +96,7 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with MockitoSugar {
val table = buildTable()
val handle = catalog.addTable(table, spillPriority)
val types: Array[DataType] =
Seq(IntegerType, StringType, FloatType, DecimalType(10, 5)).toArray
Seq(IntegerType, StringType, DoubleType, DecimalType(10, 5)).toArray
val buffSize = GpuColumnVector.getTotalDeviceMemoryUsed(table)
assertResult(buffSize)(store.currentSize)
assertResult(buffSize)(store.currentSpillableSize)
Expand All @@ -118,7 +118,7 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with MockitoSugar {
val table = buildTable()
val handle = catalog.addTable(table, spillPriority)
val types: Array[DataType] =
Seq(IntegerType, StringType, FloatType, DecimalType(10, 5)).toArray
Seq(IntegerType, StringType, DoubleType, DecimalType(10, 5)).toArray
val buffSize = GpuColumnVector.getTotalDeviceMemoryUsed(table)
assertResult(buffSize)(store.currentSize)
assertResult(buffSize)(store.currentSpillableSize)
Expand Down Expand Up @@ -149,7 +149,7 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with MockitoSugar {
val table = buildTable()
val handle = catalog.addTable(table, spillPriority)
val types: Array[DataType] =
Seq(IntegerType, StringType, FloatType, DecimalType(10, 5)).toArray
Seq(IntegerType, StringType, DoubleType, DecimalType(10, 5)).toArray
val buffSize = GpuColumnVector.getTotalDeviceMemoryUsed(table)
assertResult(buffSize)(store.currentSize)
assertResult(buffSize)(store.currentSpillableSize)
Expand All @@ -176,7 +176,7 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with MockitoSugar {
val table = buildTable()
val handle = catalog.addTable(table, spillPriority)
val types: Array[DataType] =
Seq(IntegerType, StringType, FloatType, DecimalType(10, 5)).toArray
Seq(IntegerType, StringType, DoubleType, DecimalType(10, 5)).toArray
val buffSize = GpuColumnVector.getTotalDeviceMemoryUsed(table)
assertResult(buffSize)(store.currentSize)
assertResult(buffSize)(store.currentSpillableSize)
Expand Down

0 comments on commit a5b008f

Please sign in to comment.