diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala index d4b0e31a01f..39f98e9685f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala @@ -174,6 +174,11 @@ object RapidsBufferCatalog extends Logging with Arm { } } + // For testing + def setDeviceStorage(rdms: RapidsDeviceMemoryStore): Unit = { + deviceStorage = rdms + } + def init(rapidsConf: RapidsConf): Unit = { // We are going to re-initialize so make sure all of the old things were closed... closeImpl() diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala index 5d3959bd0e6..b9fe46efe79 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala @@ -33,6 +33,7 @@ import com.nvidia.spark.rapids.shims.{ShimBroadcastExchangeLike, ShimUnaryExecNo import org.apache.spark.SparkException import org.apache.spark.broadcast.Broadcast +import org.apache.spark.internal.Logging import org.apache.spark.launcher.SparkLauncher import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -87,18 +88,21 @@ object SerializedHostTableUtils extends Arm { } } +// scalastyle:off no.finalize @SerialVersionUID(100L) class SerializeConcatHostBuffersDeserializeBatch( data: Array[SerializeBatchDeserializeHostBuffer], output: Seq[Attribute]) - extends Serializable with Arm with AutoCloseable { + extends Serializable with Arm with AutoCloseable with Logging { @transient private var dataTypes = output.map(_.dataType).toArray @transient private var headers = data.map(_.header) @transient private var buffers = data.map(_.buffer) - @transient private var batchInternal: ColumnarBatch = null - def batch: ColumnarBatch = this.synchronized { - if (batchInternal == null) { + // used for memoization of deserialization to GPU on Executor + @transient private var batchInternal: SpillableColumnarBatch = null + + def batch: SpillableColumnarBatch = this.synchronized { + Option(batchInternal).getOrElse { if (headers.length > 1) { // This should only happen if the driver is trying to access the batch. That should not be // a common occurrence, so for simplicity just round-trip this through the serialization. @@ -111,30 +115,34 @@ class SerializeConcatHostBuffersDeserializeBatch( } assert(headers.length <= 1 && buffers.length <= 1) withResource(new NvtxRange("broadcast manifest batch", NvtxColor.PURPLE)) { _ => - if (headers.isEmpty) { - batchInternal = GpuColumnVector.emptyBatchFromTypes(dataTypes) - GpuColumnVector.extractBases(batchInternal).foreach(_.noWarnLeakExpected()) - } else { - withResource(JCudfSerialization.readTableFrom(headers.head, buffers.head)) { tableInfo => - val table = tableInfo.getContiguousTable - if (table == null) { - val numRows = tableInfo.getNumRows - batchInternal = new ColumnarBatch(new Array[ColumnVector](0), numRows) - } else { - batchInternal = GpuColumnVectorFromBuffer.from(table, dataTypes) - GpuColumnVector.extractBases(batchInternal).foreach(_.noWarnLeakExpected()) - table.getBuffer.noWarnLeakExpected() + try { + val res = if (headers.isEmpty) { + SpillableColumnarBatch(GpuColumnVector.emptyBatchFromTypes(dataTypes), + SpillPriorities.ACTIVE_BATCHING_PRIORITY, RapidsBuffer.defaultSpillCallback) + } else { + withResource(JCudfSerialization.readTableFrom(headers.head, buffers.head)) { + tableInfo => + val table = tableInfo.getContiguousTable + if (table == null) { + val numRows = tableInfo.getNumRows + SpillableColumnarBatch(new ColumnarBatch(Array.empty[ColumnVector], numRows), + SpillPriorities.ACTIVE_BATCHING_PRIORITY, RapidsBuffer.defaultSpillCallback) + } else { + SpillableColumnarBatch(table, dataTypes, + SpillPriorities.ACTIVE_BATCHING_PRIORITY, RapidsBuffer.defaultSpillCallback) + } } } + batchInternal = res + res + } finally { + // At this point we no longer need the host data and should not need to touch it again. + buffers.safeClose() + headers = null + buffers = null } - - // At this point we no longer need the host data and should not need to touch it again. - buffers.safeClose() - headers = null - buffers = null } } - batchInternal } /** @@ -145,32 +153,35 @@ class SerializeConcatHostBuffersDeserializeBatch( * NOTE: The caller is responsible to release these host columnar batches. */ def hostBatches: Array[ColumnarBatch] = this.synchronized { - batchInternal match { - case batch if batch == null => - withResource(new NvtxRange("broadcast manifest batch", NvtxColor.PURPLE)) { _ => - val columnBatches = new mutable.ArrayBuffer[ColumnarBatch]() - closeOnExcept(columnBatches) { cBatches => - headers.zip(buffers).foreach { case (header, buffer) => - val hostColumns = SerializedHostTableUtils.buildHostColumns( - header, buffer, dataTypes) - val rowCount = header.getNumRows - cBatches += new ColumnarBatch(hostColumns.toArray, rowCount) - } + Option(batchInternal).map { spillable => + withResource(spillable.getColumnarBatch()) { batch => + val hostColumns: Array[ColumnVector] = GpuColumnVector + .extractColumns(batch) + .safeMap(_.copyToHost()) + Array(new ColumnarBatch(hostColumns, numRows)) + } + }.getOrElse { + withResource(new NvtxRange("broadcast manifest batch", NvtxColor.PURPLE)) { _ => + val columnBatches = new mutable.ArrayBuffer[ColumnarBatch]() + closeOnExcept(columnBatches) { cBatches => + headers.zip(buffers).foreach { case (header, buffer) => + val hostColumns = SerializedHostTableUtils.buildHostColumns( + header, buffer, dataTypes) + val rowCount = header.getNumRows + cBatches += new ColumnarBatch(hostColumns.toArray, rowCount) } - columnBatches.toArray } - case batch => - val hostColumns = GpuColumnVector.extractColumns(batch).map(_.copyToHost()) - Array(new ColumnarBatch(hostColumns.toArray, batch.numRows())) + columnBatches.toArray + } } } private def writeObject(out: ObjectOutputStream): Unit = { - if (batchInternal != null) { - val table = GpuColumnVector.from(batchInternal) + Option(batchInternal).map { spillable => + val table = withResource(spillable.getColumnarBatch())(GpuColumnVector.from) JCudfSerialization.writeToStream(table, out, 0, table.getRowCount) out.writeObject(dataTypes) - } else { + }.getOrElse { if (headers.length == 0) { // We didn't get any data back, but we need to write out an empty table that matches withResource(GpuColumnVector.emptyHostColumns(dataTypes)) { hostVectors => @@ -201,35 +212,25 @@ class SerializeConcatHostBuffersDeserializeBatch( } } - def numRows: Int = { - if (batchInternal != null) { - batchInternal.numRows() - } else { - headers.map(_.getNumRows).sum - } - } + def numRows: Int = Option(batchInternal) + .map(_.numRows()) + .getOrElse(headers.map(_.getNumRows).sum) - def dataSize: Long = { - if (batchInternal != null) { - val bases = GpuColumnVector.extractBases(batchInternal).map(_.copyToHost()) - try { - JCudfSerialization.getSerializedSizeInBytes(bases, 0, batchInternal.numRows()) - } finally { - bases.safeClose() - } - } else { - buffers.map(_.getLength).sum - } - } + def dataSize: Long = Option(batchInternal) + .map(_.sizeInBytes) + .getOrElse(buffers.map(_.getLength).sum) override def close(): Unit = this.synchronized { buffers.safeClose() - if (batchInternal != null) { - batchInternal.close() - batchInternal = null - } + Option(batchInternal).foreach(_.close()) + } + + override def finalize(): Unit = { + super.finalize() + close() } } +// scalastyle:on no.finalize @SerialVersionUID(100L) class SerializeBatchDeserializeHostBuffer(batch: ColumnarBatch) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHelper.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHelper.scala index 988d855dbf7..092341fbbbf 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHelper.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHelper.scala @@ -41,9 +41,7 @@ object GpuBroadcastHelper { broadcastSchema: StructType): ColumnarBatch = { broadcastRelation.value match { case broadcastBatch: SerializeConcatHostBuffersDeserializeBatch => - val builtBatch = broadcastBatch.batch - GpuColumnVector.incRefCounts(builtBatch) - builtBatch + broadcastBatch.batch.getColumnarBatch() case v if SparkShimImpl.isEmptyRelation(v) => GpuColumnVector.emptyBatch(broadcastSchema) case t => @@ -67,7 +65,7 @@ object GpuBroadcastHelper { def getBroadcastBatchNumRows(broadcastRelation: Broadcast[Any]): Int = { broadcastRelation.value match { case broadcastBatch: SerializeConcatHostBuffersDeserializeBatch => - broadcastBatch.batch.numRows() + broadcastBatch.numRows case v if SparkShimImpl.isEmptyRelation(v) => 0 case t => throw new IllegalStateException(s"Invalid broadcast batch received $t") diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/SerializationSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/SerializationSuite.scala index 55906f78732..817e69e43fb 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/SerializationSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/SerializationSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,14 +18,20 @@ package com.nvidia.spark.rapids import ai.rapids.cudf.Table import org.apache.commons.lang3.SerializationUtils -import org.scalatest.FunSuite +import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.rapids.execution.{SerializeBatchDeserializeHostBuffer, SerializeConcatHostBuffersDeserializeBatch} import org.apache.spark.sql.types.{DoubleType, IntegerType, StringType} import org.apache.spark.sql.vectorized.ColumnarBatch -class SerializationSuite extends FunSuite with Arm { +class SerializationSuite extends FunSuite + with BeforeAndAfterAll with Arm { + + override def beforeAll(): Unit = { + RapidsBufferCatalog.setDeviceStorage(new RapidsDeviceMemoryStore()) + } + private def buildBatch(): ColumnarBatch = { withResource(new Table.TestBuilder() .column(5, null.asInstanceOf[java.lang.Integer], 3, 1, 1, 1, 1, 1, 1, 1) @@ -68,12 +74,14 @@ class SerializationSuite extends FunSuite with Arm { val buffer = createDeserializedHostBuffer(gpuExpected) val hostBatch = new SerializeConcatHostBuffersDeserializeBatch(Array(buffer), attrs) withResource(hostBatch) { _ => - val gpuBatch = hostBatch.batch - TestUtils.compareBatches(gpuExpected, gpuBatch) + withResource(hostBatch.batch.getColumnarBatch()) { gpuBatch => + TestUtils.compareBatches(gpuExpected, gpuBatch) + } // clone via serialization after manifesting the GPU batch withResource(SerializationUtils.clone(hostBatch)) { clonedObj => - val gpuClonedBatch = clonedObj.batch - TestUtils.compareBatches(gpuExpected, gpuClonedBatch) + withResource(clonedObj.batch.getColumnarBatch()) { gpuClonedBatch => + TestUtils.compareBatches(gpuExpected, gpuClonedBatch) + } // try to clone it again from the cloned object SerializationUtils.clone(clonedObj).close() }