diff --git a/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala b/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala index e74f2c6753b..dd4bd4d0a28 100644 --- a/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala +++ b/shims/spark300/src/main/scala/com/nvidia/spark/rapids/shims/spark300/Spark300Shims.scala @@ -23,6 +23,7 @@ import scala.collection.mutable.ListBuffer import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.spark300.RapidsShuffleManager +import org.apache.arrow.memory.ReferenceManager import org.apache.arrow.vector.ValueVector import org.apache.hadoop.fs.Path @@ -481,16 +482,19 @@ class Spark300Shims extends SparkShims { } // Arrow version changed between Spark versions - override def getArrowDataBuf(vec: ValueVector): ByteBuffer = { - vec.getDataBuffer().nioBuffer() + override def getArrowDataBuf(vec: ValueVector): (ByteBuffer, ReferenceManager) = { + val arrowBuf = vec.getDataBuffer() + (arrowBuf.nioBuffer(), arrowBuf.getReferenceManager) } - override def getArrowValidityBuf(vec: ValueVector): ByteBuffer = { - vec.getValidityBuffer().nioBuffer() + override def getArrowValidityBuf(vec: ValueVector): (ByteBuffer, ReferenceManager) = { + val arrowBuf = vec.getValidityBuffer() + (arrowBuf.nioBuffer(), arrowBuf.getReferenceManager) } - override def getArrowOffsetsBuf(vec: ValueVector): ByteBuffer = { - vec.getOffsetBuffer().nioBuffer() + override def getArrowOffsetsBuf(vec: ValueVector): (ByteBuffer, ReferenceManager) = { + val arrowBuf = vec.getOffsetBuffer() + (arrowBuf.nioBuffer(), arrowBuf.getReferenceManager) } override def replaceWithAlluxioPathIfNeeded( diff --git a/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala b/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala index 4d090113eca..557c3fb9fd2 100644 --- a/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala +++ b/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala @@ -21,6 +21,7 @@ import java.nio.ByteBuffer import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.shims.spark301.Spark301Shims import com.nvidia.spark.rapids.spark311.RapidsShuffleManager +import org.apache.arrow.memory.ReferenceManager import org.apache.arrow.vector.ValueVector import org.apache.spark.SparkEnv @@ -424,16 +425,19 @@ class Spark311Shims extends Spark301Shims { } // Arrow version changed between Spark versions - override def getArrowDataBuf(vec: ValueVector): ByteBuffer = { - vec.getDataBuffer.nioBuffer() + override def getArrowDataBuf(vec: ValueVector): (ByteBuffer, ReferenceManager) = { + val arrowBuf = vec.getDataBuffer() + (arrowBuf.nioBuffer(), arrowBuf.getReferenceManager) } - override def getArrowValidityBuf(vec: ValueVector): ByteBuffer = { - vec.getValidityBuffer.nioBuffer() + override def getArrowValidityBuf(vec: ValueVector): (ByteBuffer, ReferenceManager) = { + val arrowBuf = vec.getValidityBuffer + (arrowBuf.nioBuffer(), arrowBuf.getReferenceManager) } - override def getArrowOffsetsBuf(vec: ValueVector): ByteBuffer = { - vec.getOffsetBuffer.nioBuffer() + override def getArrowOffsetsBuf(vec: ValueVector): (ByteBuffer, ReferenceManager) = { + val arrowBuf = vec.getOffsetBuffer + (arrowBuf.nioBuffer(), arrowBuf.getReferenceManager) } /** matches SPARK-33008 fix in 3.1.1 */ diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java index 064cc437a30..e24cc05cd85 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java @@ -24,6 +24,7 @@ import ai.rapids.cudf.Schema; import ai.rapids.cudf.Table; +import org.apache.arrow.memory.ReferenceManager; import org.apache.spark.sql.catalyst.expressions.Attribute; import org.apache.spark.sql.types.*; import org.apache.spark.sql.vectorized.ColumnVector; @@ -249,6 +250,8 @@ public ColumnarBatch build(int rows) { public static final class GpuArrowColumnarBatchBuilder extends GpuColumnarBatchBuilderBase { private final ai.rapids.cudf.ArrowColumnBuilder[] builders; + private final ArrowBufReferenceHolder[] referenceHolders; + /** * A collection of builders for building up columnar data from Arrow data. * @param schema the schema of the batch. @@ -262,21 +265,19 @@ public GpuArrowColumnarBatchBuilder(StructType schema, int rows, ColumnarBatch b fields = schema.fields(); int len = fields.length; builders = new ai.rapids.cudf.ArrowColumnBuilder[len]; + referenceHolders = new ArrowBufReferenceHolder[len]; boolean success = false; try { for (int i = 0; i < len; i++) { StructField field = fields[i]; builders[i] = new ArrowColumnBuilder(convertFrom(field.dataType(), field.nullable())); + referenceHolders[i] = new ArrowBufReferenceHolder(); } success = true; } finally { if (!success) { - for (ai.rapids.cudf.ArrowColumnBuilder b: builders) { - if (b != null) { - b.close(); - } - } + close(); } } } @@ -288,12 +289,15 @@ protected int buildersLength() { protected ColumnVector buildAndPutOnDevice(int builderIndex) { ai.rapids.cudf.ColumnVector cv = builders[builderIndex].buildAndPutOnDevice(); GpuColumnVector gcv = new GpuColumnVector(fields[builderIndex].dataType(), cv); + referenceHolders[builderIndex].releaseReferences(); builders[builderIndex] = null; return gcv; } public void copyColumnar(ColumnVector cv, int colNum, boolean nullable, int rows) { - HostColumnarToGpu.arrowColumnarCopy(cv, builder(colNum), nullable, rows); + referenceHolders[colNum].addReferences( + HostColumnarToGpu.arrowColumnarCopy(cv, builder(colNum), nullable, rows) + ); } public ai.rapids.cudf.ArrowColumnBuilder builder(int i) { @@ -307,6 +311,9 @@ public void close() { b.close(); } } + for (ArrowBufReferenceHolder holder: referenceHolders) { + holder.releaseReferences(); + } } } @@ -394,6 +401,25 @@ public void close() { } } + private static final class ArrowBufReferenceHolder { + private List references = new ArrayList<>(); + + public void addReferences(List refs) { + references.addAll(refs); + refs.forEach(ReferenceManager::retain); + } + + public void releaseReferences() { + if (references.isEmpty()) { + return; + } + for (ReferenceManager ref: references) { + ref.release(); + } + references.clear(); + } + } + private static DType toRapidsOrNull(DataType type) { if (type instanceof LongType) { return DType.INT64; diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostColumnarToGpu.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostColumnarToGpu.scala index 894c8ffc7fc..e0a61082691 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostColumnarToGpu.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostColumnarToGpu.scala @@ -16,8 +16,13 @@ package com.nvidia.spark.rapids +import java.{util => ju} import java.nio.ByteBuffer +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.arrow.memory.ReferenceManager import org.apache.arrow.vector.ValueVector import org.apache.spark.TaskContext @@ -63,7 +68,7 @@ object HostColumnarToGpu extends Logging { cv: ColumnVector, ab: ai.rapids.cudf.ArrowColumnBuilder, nullable: Boolean, - rows: Int): Unit = { + rows: Int): ju.List[ReferenceManager] = { val valVector = cv match { case v: ArrowColumnVector => try { @@ -78,18 +83,28 @@ object HostColumnarToGpu extends Logging { case _ => throw new IllegalStateException(s"Illegal column vector type: ${cv.getClass}") } + + val referenceManagers = new mutable.ListBuffer[ReferenceManager] + + def getBufferAndAddReference(getter: => (ByteBuffer, ReferenceManager)): ByteBuffer = { + val (buf, ref) = getter + referenceManagers += ref + buf + } + val nullCount = valVector.getNullCount() - val dataBuf = ShimLoader.getSparkShims.getArrowDataBuf(valVector) - val validity = ShimLoader.getSparkShims.getArrowValidityBuf(valVector) + val dataBuf = getBufferAndAddReference(ShimLoader.getSparkShims.getArrowDataBuf(valVector)) + val validity = getBufferAndAddReference(ShimLoader.getSparkShims.getArrowValidityBuf(valVector)) // this is a bit ugly, not all Arrow types have the offsets buffer var offsets: ByteBuffer = null try { - offsets = ShimLoader.getSparkShims.getArrowOffsetsBuf(valVector) + offsets = getBufferAndAddReference(ShimLoader.getSparkShims.getArrowOffsetsBuf(valVector)) } catch { case e: UnsupportedOperationException => // swallow the exception and assume no offsets buffer } ab.addBatch(rows, nullCount, dataBuf, validity, offsets) + referenceManagers.result().asJava } def columnarCopy(cv: ColumnVector, b: ai.rapids.cudf.HostColumnVector.ColumnBuilder, diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala index 9991171ced1..c44f7f00ef4 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala @@ -18,6 +18,7 @@ package com.nvidia.spark.rapids import java.nio.ByteBuffer +import org.apache.arrow.memory.ReferenceManager import org.apache.arrow.vector.ValueVector import org.apache.hadoop.fs.Path @@ -186,9 +187,9 @@ trait SparkShims { def shouldIgnorePath(path: String): Boolean - def getArrowDataBuf(vec: ValueVector): ByteBuffer - def getArrowValidityBuf(vec: ValueVector): ByteBuffer - def getArrowOffsetsBuf(vec: ValueVector): ByteBuffer + def getArrowDataBuf(vec: ValueVector): (ByteBuffer, ReferenceManager) + def getArrowValidityBuf(vec: ValueVector): (ByteBuffer, ReferenceManager) + def getArrowOffsetsBuf(vec: ValueVector): (ByteBuffer, ReferenceManager) def replaceWithAlluxioPathIfNeeded( conf: RapidsConf, diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/GpuCoalesceBatchesSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/GpuCoalesceBatchesSuite.scala index a45fd32e8c9..ec5a6f7755a 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/GpuCoalesceBatchesSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/GpuCoalesceBatchesSuite.scala @@ -247,6 +247,54 @@ class GpuCoalesceBatchesSuite extends SparkQueryCompareTestSuite { isInstanceOf[GpuColumnVector.GpuColumnarBatchBuilder]) } + test("test GpuArrowColumnarBatchBuilder retains reference of ArrowBuf") { + val rootAllocator = new RootAllocator(Long.MaxValue) + val allocator = rootAllocator.newChildAllocator("int", 0, Long.MaxValue) + val vector1 = toArrowField("int", IntegerType, nullable = true, null) + .createVector(allocator).asInstanceOf[IntVector] + val vector2 = toArrowField("int", IntegerType, nullable = true, null) + .createVector(allocator).asInstanceOf[IntVector] + vector1.allocateNew(10) + vector2.allocateNew(10) + (0 until 10).foreach { i => + vector1.setSafe(i, i) + vector2.setSafe(i, i) + } + val schema = StructType(Seq(StructField("int", IntegerType))) + val batches = Seq( + new ColumnarBatch(Array(new ArrowColumnVector(vector1)), vector1.getValueCount), + new ColumnarBatch(Array(new ArrowColumnVector(vector2)), vector1.getValueCount) + ) + val hostToGpuCoalesceIterator = new HostToGpuCoalesceIterator(batches.iterator, + TargetSize(1024), + schema: StructType, + WrappedGpuMetric(new SQLMetric("t1", 0)), + WrappedGpuMetric(new SQLMetric("t2", 0)), + WrappedGpuMetric(new SQLMetric("t3", 0)), + WrappedGpuMetric(new SQLMetric("t4", 0)), + WrappedGpuMetric(new SQLMetric("t5", 0)), + WrappedGpuMetric(new SQLMetric("t6", 0)), + WrappedGpuMetric(new SQLMetric("t7", 0)), + WrappedGpuMetric(new SQLMetric("t8", 0)), + "testcoalesce", + useArrowCopyOpt = true) + + val allocatedMemory = allocator.getAllocatedMemory + hostToGpuCoalesceIterator.initNewBatch(batches.head) + hostToGpuCoalesceIterator.addBatchToConcat(batches.head) + hostToGpuCoalesceIterator.addBatchToConcat(batches(1)) + + // Close columnar batches + batches.foreach(cb => cb.close()) + + // Verify that buffers are not deallocated + assertResult(allocatedMemory)(allocator.getAllocatedMemory) + + // Verify that buffers are deallocated after concat is done + hostToGpuCoalesceIterator.cleanupConcatIsDone() + assertResult(0L)(allocator.getAllocatedMemory) + } + test("test HostToGpuCoalesceIterator with arrow config off") { val (batch, schema) = setupArrowBatch() val iter = Iterator.single(batch)