diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala index 2b5e592c144..5950617ea48 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala @@ -212,29 +212,39 @@ class RapidsBufferCatalogSuite extends FunSuite with MockitoSugar { } test("multiple calls to unspill return existing DEVICE buffer") { - val deviceStore = spy(new RapidsDeviceMemoryStore) - val mockStore = mock[RapidsBufferStore] - withResource( - new RapidsHostMemoryStore(10000, 1000)) { hostStore => - deviceStore.setSpillStore(hostStore) - hostStore.setSpillStore(mockStore) - val catalog = new RapidsBufferCatalog(deviceStore) - val handle = withResource(DeviceMemoryBuffer.allocate(1024)) { buff => - val meta = MetaUtils.getTableMetaNoTable(buff) - catalog.addBuffer( - buff, meta, -1) - } - withResource(handle) { _ => - catalog.synchronousSpill(deviceStore, 0) - val acquiredHostBuffer = catalog.acquireBuffer(handle) - val unspilled = withResource(acquiredHostBuffer) { _ => - assertResult(HOST)(acquiredHostBuffer.storageTier) - val unspilled = - catalog.unspillBufferToDeviceStore( + withResource(spy(new RapidsDeviceMemoryStore)) { deviceStore => + val mockStore = mock[RapidsBufferStore] + withResource( + new RapidsHostMemoryStore(10000, 1000)) { hostStore => + deviceStore.setSpillStore(hostStore) + hostStore.setSpillStore(mockStore) + val catalog = new RapidsBufferCatalog(deviceStore) + val handle = withResource(DeviceMemoryBuffer.allocate(1024)) { buff => + val meta = MetaUtils.getTableMetaNoTable(buff) + catalog.addBuffer( + buff, meta, -1) + } + withResource(handle) { _ => + catalog.synchronousSpill(deviceStore, 0) + val acquiredHostBuffer = catalog.acquireBuffer(handle) + val unspilled = withResource(acquiredHostBuffer) { _ => + assertResult(HOST)(acquiredHostBuffer.storageTier) + val unspilled = + catalog.unspillBufferToDeviceStore( + acquiredHostBuffer, + Cuda.DEFAULT_STREAM) + withResource(unspilled) { _ => + assertResult(DEVICE)(unspilled.storageTier) + } + val unspilledSame = catalog.unspillBufferToDeviceStore( acquiredHostBuffer, Cuda.DEFAULT_STREAM) - withResource(unspilled) { _ => - assertResult(DEVICE)(unspilled.storageTier) + withResource(unspilledSame) { _ => + assertResult(unspilled)(unspilledSame) + } + // verify that we invoked the copy function exactly once + verify(deviceStore, times(1)).copyBuffer(any(), any()) + unspilled } val unspilledSame = catalog.unspillBufferToDeviceStore( acquiredHostBuffer, @@ -244,16 +254,7 @@ class RapidsBufferCatalogSuite extends FunSuite with MockitoSugar { } // verify that we invoked the copy function exactly once verify(deviceStore, times(1)).copyBuffer(any(), any()) - unspilled - } - val unspilledSame = catalog.unspillBufferToDeviceStore( - acquiredHostBuffer, - Cuda.DEFAULT_STREAM) - withResource(unspilledSame) { _ => - assertResult(unspilled)(unspilledSame) } - // verify that we invoked the copy function exactly once - verify(deviceStore, times(1)).copyBuffer(any(), any()) } } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStoreSuite.scala index 32eef46bb0b..3b3b3820b07 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStoreSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDeviceMemoryStoreSuite.scala @@ -31,14 +31,14 @@ import org.scalatest.FunSuite import org.scalatest.mockito.MockitoSugar import org.apache.spark.sql.rapids.RapidsDiskBlockManager -import org.apache.spark.sql.types.{DataType, DecimalType, DoubleType, FloatType, IntegerType, StringType} +import org.apache.spark.sql.types.{DataType, DecimalType, DoubleType, IntegerType, StringType} class RapidsDeviceMemoryStoreSuite extends FunSuite with MockitoSugar { private def buildTable(): Table = { new Table.TestBuilder() .column(5, null.asInstanceOf[java.lang.Integer], 3, 1) .column("five", "two", null, null) - .column(5.0, 2.0, 3.0, 1.0) + .column(5.0D, 2.0D, 3.0D, 1.0D) .decimal64Column(-5, RoundingMode.UNNECESSARY, 0, null, -1.4, 10.123) .build() } @@ -96,7 +96,7 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with MockitoSugar { val table = buildTable() val handle = catalog.addTable(table, spillPriority) val types: Array[DataType] = - Seq(IntegerType, StringType, FloatType, DecimalType(10, 5)).toArray + Seq(IntegerType, StringType, DoubleType, DecimalType(10, 5)).toArray val buffSize = GpuColumnVector.getTotalDeviceMemoryUsed(table) assertResult(buffSize)(store.currentSize) assertResult(buffSize)(store.currentSpillableSize) @@ -118,7 +118,7 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with MockitoSugar { val table = buildTable() val handle = catalog.addTable(table, spillPriority) val types: Array[DataType] = - Seq(IntegerType, StringType, FloatType, DecimalType(10, 5)).toArray + Seq(IntegerType, StringType, DoubleType, DecimalType(10, 5)).toArray val buffSize = GpuColumnVector.getTotalDeviceMemoryUsed(table) assertResult(buffSize)(store.currentSize) assertResult(buffSize)(store.currentSpillableSize) @@ -149,7 +149,7 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with MockitoSugar { val table = buildTable() val handle = catalog.addTable(table, spillPriority) val types: Array[DataType] = - Seq(IntegerType, StringType, FloatType, DecimalType(10, 5)).toArray + Seq(IntegerType, StringType, DoubleType, DecimalType(10, 5)).toArray val buffSize = GpuColumnVector.getTotalDeviceMemoryUsed(table) assertResult(buffSize)(store.currentSize) assertResult(buffSize)(store.currentSpillableSize) @@ -176,7 +176,7 @@ class RapidsDeviceMemoryStoreSuite extends FunSuite with MockitoSugar { val table = buildTable() val handle = catalog.addTable(table, spillPriority) val types: Array[DataType] = - Seq(IntegerType, StringType, FloatType, DecimalType(10, 5)).toArray + Seq(IntegerType, StringType, DoubleType, DecimalType(10, 5)).toArray val buffSize = GpuColumnVector.getTotalDeviceMemoryUsed(table) assertResult(buffSize)(store.currentSize) assertResult(buffSize)(store.currentSpillableSize)