From 4852bd7a51fbb9472e85583b5999290b50d73eff Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Tue, 16 Jun 2020 12:20:17 -0500 Subject: [PATCH 1/2] Add RapidsBufferCatalog tests --- .../ai/rapids/spark/RapidsBufferCatalog.scala | 13 -- .../ai/rapids/spark/RapidsBufferStore.scala | 2 - .../spark/RapidsBufferCatalogSuite.scala | 115 ++++++++++++++++++ 3 files changed, 115 insertions(+), 15 deletions(-) create mode 100644 tests/src/test/scala/ai/rapids/spark/RapidsBufferCatalogSuite.scala diff --git a/sql-plugin/src/main/scala/ai/rapids/spark/RapidsBufferCatalog.scala b/sql-plugin/src/main/scala/ai/rapids/spark/RapidsBufferCatalog.scala index 3a0c74d1843..45431fc4c87 100644 --- a/sql-plugin/src/main/scala/ai/rapids/spark/RapidsBufferCatalog.scala +++ b/sql-plugin/src/main/scala/ai/rapids/spark/RapidsBufferCatalog.scala @@ -28,22 +28,9 @@ import org.apache.spark.internal.Logging /** Catalog for lookup of buffers by ID */ class RapidsBufferCatalog extends Logging { - /** Tracks all buffer stores using this catalog */ - private[this] val stores = new ArrayBuffer[RapidsBufferStore] - /** Map of buffer IDs to buffers */ private[this] val bufferMap = new ConcurrentHashMap[RapidsBufferId, RapidsBuffer] - /** - * Register a buffer store that is using this catalog. - * NOTE: It is assumed all stores are registered before any buffers are added to the catalog. - * @param store buffer store - */ - def registerStore(store: RapidsBufferStore): Unit = { - require(store.currentSize == 0, "Store must not have any buffers when registered") - stores.append(store) - } - /** * Lookup the buffer that corresponds to the specified buffer ID and acquire it. * NOTE: It is the responsibility of the caller to close the buffer. diff --git a/sql-plugin/src/main/scala/ai/rapids/spark/RapidsBufferStore.scala b/sql-plugin/src/main/scala/ai/rapids/spark/RapidsBufferStore.scala index 6e7db0202c9..9c4cd8b9ea8 100644 --- a/sql-plugin/src/main/scala/ai/rapids/spark/RapidsBufferStore.scala +++ b/sql-plugin/src/main/scala/ai/rapids/spark/RapidsBufferStore.scala @@ -106,8 +106,6 @@ abstract class RapidsBufferStore( private[this] val nvtxSyncSpillName: String = name + " sync spill" - catalog.registerStore(this) - /** Return the current byte total of buffers in this store. */ def currentSize: Long = buffers.getTotalBytes diff --git a/tests/src/test/scala/ai/rapids/spark/RapidsBufferCatalogSuite.scala b/tests/src/test/scala/ai/rapids/spark/RapidsBufferCatalogSuite.scala new file mode 100644 index 00000000000..cf83932cd2c --- /dev/null +++ b/tests/src/test/scala/ai/rapids/spark/RapidsBufferCatalogSuite.scala @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.rapids.spark + +import java.io.File +import java.util.NoSuchElementException + +import ai.rapids.spark.StorageTier.StorageTier +import ai.rapids.spark.format.TableMeta +import org.mockito.Mockito._ +import org.scalatest.FunSuite +import org.scalatest.mockito.MockitoSugar + +import org.apache.spark.sql.rapids.RapidsDiskBlockManager + +class RapidsBufferCatalogSuite extends FunSuite with MockitoSugar { + test("lookup unknown buffer") { + val catalog = new RapidsBufferCatalog + val bufferId = new RapidsBufferId { + override val tableId: Int = 10 + override def getDiskPath(m: RapidsDiskBlockManager): File = null + } + assertThrows[NoSuchElementException](catalog.acquireBuffer(bufferId)) + assertThrows[NoSuchElementException](catalog.getBufferMeta(bufferId)) + } + + test("acquire buffer") { + val catalog = new RapidsBufferCatalog + val bufferId = MockBufferId(5) + val buffer = mockBuffer(bufferId) + catalog.registerNewBuffer(buffer) + val acquired = catalog.acquireBuffer(MockBufferId(5)) + assertResult(5)(acquired.id.tableId) + assertResult(buffer)(acquired) + verify(buffer).addReference() + } + + test("acquire buffer retries automatically") { + val catalog = new RapidsBufferCatalog + val bufferId = MockBufferId(5) + val buffer = mockBuffer(bufferId, acquireAttempts = 9) + catalog.registerNewBuffer(buffer) + val acquired = catalog.acquireBuffer(MockBufferId(5)) + assertResult(5)(acquired.id.tableId) + assertResult(buffer)(acquired) + verify(buffer, times(9)).addReference() + } + + test("get buffer meta") { + val catalog = new RapidsBufferCatalog + val bufferId = MockBufferId(5) + val expectedMeta = new TableMeta + val buffer = mockBuffer(bufferId, meta = expectedMeta) + catalog.registerNewBuffer(buffer) + val meta = catalog.getBufferMeta(bufferId) + assertResult(expectedMeta)(meta) + } + + test("update buffer map only updates for faster tier") { + val catalog = new RapidsBufferCatalog + val bufferId = MockBufferId(5) + val buffer1 = mockBuffer(bufferId, tier = StorageTier.HOST) + catalog.registerNewBuffer(buffer1) + val buffer2 = mockBuffer(bufferId, tier = StorageTier.DEVICE) + catalog.updateBufferMap(StorageTier.HOST, buffer2) + var resultBuffer = catalog.acquireBuffer(bufferId) + assertResult(buffer2)(resultBuffer) + catalog.updateBufferMap(StorageTier.HOST, buffer1) + resultBuffer = catalog.acquireBuffer(bufferId) + assertResult(buffer2)(resultBuffer) + } + + test("remove buffer releases buffer resources") { + val catalog = new RapidsBufferCatalog + val bufferId = MockBufferId(5) + val buffer = mockBuffer(bufferId) + catalog.registerNewBuffer(buffer) + catalog.removeBuffer(bufferId) + verify(buffer).free() + } + + private def mockBuffer( + bufferId: RapidsBufferId, + meta: TableMeta = null, + tier: StorageTier = StorageTier.DEVICE, + acquireAttempts: Int = 1): RapidsBuffer = { + val buffer = mock[RapidsBuffer] + when(buffer.id).thenReturn(bufferId) + when(buffer.storageTier).thenReturn(tier) + when(buffer.meta).thenReturn(meta) + var stub = when(buffer.addReference()) + (0 until acquireAttempts - 1).foreach(_ => stub = stub.thenReturn(false)) + stub.thenReturn(true) + buffer + } +} + +case class MockBufferId(override val tableId: Int) extends RapidsBufferId { + override def getDiskPath(dbm: RapidsDiskBlockManager): File = + throw new UnsupportedOperationException +} From 66437582533d2a97fbecaf354e7f14b73d2c0ad9 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Wed, 17 Jun 2020 09:58:39 -0500 Subject: [PATCH 2/2] Update to new packaging names --- .../com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala index cf83932cd2c..11092b1d68d 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala @@ -14,13 +14,13 @@ * limitations under the License. */ -package ai.rapids.spark +package com.nvidia.spark.rapids import java.io.File import java.util.NoSuchElementException -import ai.rapids.spark.StorageTier.StorageTier -import ai.rapids.spark.format.TableMeta +import com.nvidia.spark.rapids.StorageTier.StorageTier +import com.nvidia.spark.rapids.format.TableMeta import org.mockito.Mockito._ import org.scalatest.FunSuite import org.scalatest.mockito.MockitoSugar