Skip to content

Commit

Permalink
Make broadcast tables spillable (NVIDIA#6604)
Browse files Browse the repository at this point in the history
Fixes NVIDIA#836. This is an MVP utilizing SpillableColumnarBatch wrapper.

Signed-off-by: Gera Shegalov <[email protected]>
  • Loading branch information
gerashegalov authored Oct 3, 2022
1 parent 67fea32 commit 99e5ca9
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,11 @@ object RapidsBufferCatalog extends Logging with Arm {
}
}

// For testing
def setDeviceStorage(rdms: RapidsDeviceMemoryStore): Unit = {
deviceStorage = rdms
}

def init(rapidsConf: RapidsConf): Unit = {
// We are going to re-initialize so make sure all of the old things were closed...
closeImpl()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import com.nvidia.spark.rapids.shims.{ShimBroadcastExchangeLike, ShimUnaryExecNo

import org.apache.spark.SparkException
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.launcher.SparkLauncher
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -87,18 +88,21 @@ object SerializedHostTableUtils extends Arm {
}
}

// scalastyle:off no.finalize
@SerialVersionUID(100L)
class SerializeConcatHostBuffersDeserializeBatch(
data: Array[SerializeBatchDeserializeHostBuffer],
output: Seq[Attribute])
extends Serializable with Arm with AutoCloseable {
extends Serializable with Arm with AutoCloseable with Logging {
@transient private var dataTypes = output.map(_.dataType).toArray
@transient private var headers = data.map(_.header)
@transient private var buffers = data.map(_.buffer)
@transient private var batchInternal: ColumnarBatch = null

def batch: ColumnarBatch = this.synchronized {
if (batchInternal == null) {
// used for memoization of deserialization to GPU on Executor
@transient private var batchInternal: SpillableColumnarBatch = null

def batch: SpillableColumnarBatch = this.synchronized {
Option(batchInternal).getOrElse {
if (headers.length > 1) {
// This should only happen if the driver is trying to access the batch. That should not be
// a common occurrence, so for simplicity just round-trip this through the serialization.
Expand All @@ -111,30 +115,34 @@ class SerializeConcatHostBuffersDeserializeBatch(
}
assert(headers.length <= 1 && buffers.length <= 1)
withResource(new NvtxRange("broadcast manifest batch", NvtxColor.PURPLE)) { _ =>
if (headers.isEmpty) {
batchInternal = GpuColumnVector.emptyBatchFromTypes(dataTypes)
GpuColumnVector.extractBases(batchInternal).foreach(_.noWarnLeakExpected())
} else {
withResource(JCudfSerialization.readTableFrom(headers.head, buffers.head)) { tableInfo =>
val table = tableInfo.getContiguousTable
if (table == null) {
val numRows = tableInfo.getNumRows
batchInternal = new ColumnarBatch(new Array[ColumnVector](0), numRows)
} else {
batchInternal = GpuColumnVectorFromBuffer.from(table, dataTypes)
GpuColumnVector.extractBases(batchInternal).foreach(_.noWarnLeakExpected())
table.getBuffer.noWarnLeakExpected()
try {
val res = if (headers.isEmpty) {
SpillableColumnarBatch(GpuColumnVector.emptyBatchFromTypes(dataTypes),
SpillPriorities.ACTIVE_BATCHING_PRIORITY, RapidsBuffer.defaultSpillCallback)
} else {
withResource(JCudfSerialization.readTableFrom(headers.head, buffers.head)) {
tableInfo =>
val table = tableInfo.getContiguousTable
if (table == null) {
val numRows = tableInfo.getNumRows
SpillableColumnarBatch(new ColumnarBatch(Array.empty[ColumnVector], numRows),
SpillPriorities.ACTIVE_BATCHING_PRIORITY, RapidsBuffer.defaultSpillCallback)
} else {
SpillableColumnarBatch(table, dataTypes,
SpillPriorities.ACTIVE_BATCHING_PRIORITY, RapidsBuffer.defaultSpillCallback)
}
}
}
batchInternal = res
res
} finally {
// At this point we no longer need the host data and should not need to touch it again.
buffers.safeClose()
headers = null
buffers = null
}

// At this point we no longer need the host data and should not need to touch it again.
buffers.safeClose()
headers = null
buffers = null
}
}
batchInternal
}

/**
Expand All @@ -145,32 +153,35 @@ class SerializeConcatHostBuffersDeserializeBatch(
* NOTE: The caller is responsible to release these host columnar batches.
*/
def hostBatches: Array[ColumnarBatch] = this.synchronized {
batchInternal match {
case batch if batch == null =>
withResource(new NvtxRange("broadcast manifest batch", NvtxColor.PURPLE)) { _ =>
val columnBatches = new mutable.ArrayBuffer[ColumnarBatch]()
closeOnExcept(columnBatches) { cBatches =>
headers.zip(buffers).foreach { case (header, buffer) =>
val hostColumns = SerializedHostTableUtils.buildHostColumns(
header, buffer, dataTypes)
val rowCount = header.getNumRows
cBatches += new ColumnarBatch(hostColumns.toArray, rowCount)
}
Option(batchInternal).map { spillable =>
withResource(spillable.getColumnarBatch()) { batch =>
val hostColumns: Array[ColumnVector] = GpuColumnVector
.extractColumns(batch)
.safeMap(_.copyToHost())
Array(new ColumnarBatch(hostColumns, numRows))
}
}.getOrElse {
withResource(new NvtxRange("broadcast manifest batch", NvtxColor.PURPLE)) { _ =>
val columnBatches = new mutable.ArrayBuffer[ColumnarBatch]()
closeOnExcept(columnBatches) { cBatches =>
headers.zip(buffers).foreach { case (header, buffer) =>
val hostColumns = SerializedHostTableUtils.buildHostColumns(
header, buffer, dataTypes)
val rowCount = header.getNumRows
cBatches += new ColumnarBatch(hostColumns.toArray, rowCount)
}
columnBatches.toArray
}
case batch =>
val hostColumns = GpuColumnVector.extractColumns(batch).map(_.copyToHost())
Array(new ColumnarBatch(hostColumns.toArray, batch.numRows()))
columnBatches.toArray
}
}
}

private def writeObject(out: ObjectOutputStream): Unit = {
if (batchInternal != null) {
val table = GpuColumnVector.from(batchInternal)
Option(batchInternal).map { spillable =>
val table = withResource(spillable.getColumnarBatch())(GpuColumnVector.from)
JCudfSerialization.writeToStream(table, out, 0, table.getRowCount)
out.writeObject(dataTypes)
} else {
}.getOrElse {
if (headers.length == 0) {
// We didn't get any data back, but we need to write out an empty table that matches
withResource(GpuColumnVector.emptyHostColumns(dataTypes)) { hostVectors =>
Expand Down Expand Up @@ -201,35 +212,25 @@ class SerializeConcatHostBuffersDeserializeBatch(
}
}

def numRows: Int = {
if (batchInternal != null) {
batchInternal.numRows()
} else {
headers.map(_.getNumRows).sum
}
}
def numRows: Int = Option(batchInternal)
.map(_.numRows())
.getOrElse(headers.map(_.getNumRows).sum)

def dataSize: Long = {
if (batchInternal != null) {
val bases = GpuColumnVector.extractBases(batchInternal).map(_.copyToHost())
try {
JCudfSerialization.getSerializedSizeInBytes(bases, 0, batchInternal.numRows())
} finally {
bases.safeClose()
}
} else {
buffers.map(_.getLength).sum
}
}
def dataSize: Long = Option(batchInternal)
.map(_.sizeInBytes)
.getOrElse(buffers.map(_.getLength).sum)

override def close(): Unit = this.synchronized {
buffers.safeClose()
if (batchInternal != null) {
batchInternal.close()
batchInternal = null
}
Option(batchInternal).foreach(_.close())
}

override def finalize(): Unit = {
super.finalize()
close()
}
}
// scalastyle:on no.finalize

@SerialVersionUID(100L)
class SerializeBatchDeserializeHostBuffer(batch: ColumnarBatch)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ object GpuBroadcastHelper {
broadcastSchema: StructType): ColumnarBatch = {
broadcastRelation.value match {
case broadcastBatch: SerializeConcatHostBuffersDeserializeBatch =>
val builtBatch = broadcastBatch.batch
GpuColumnVector.incRefCounts(builtBatch)
builtBatch
broadcastBatch.batch.getColumnarBatch()
case v if SparkShimImpl.isEmptyRelation(v) =>
GpuColumnVector.emptyBatch(broadcastSchema)
case t =>
Expand All @@ -67,7 +65,7 @@ object GpuBroadcastHelper {
def getBroadcastBatchNumRows(broadcastRelation: Broadcast[Any]): Int = {
broadcastRelation.value match {
case broadcastBatch: SerializeConcatHostBuffersDeserializeBatch =>
broadcastBatch.batch.numRows()
broadcastBatch.numRows
case v if SparkShimImpl.isEmptyRelation(v) => 0
case t =>
throw new IllegalStateException(s"Invalid broadcast batch received $t")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,14 +18,20 @@ package com.nvidia.spark.rapids

import ai.rapids.cudf.Table
import org.apache.commons.lang3.SerializationUtils
import org.scalatest.FunSuite
import org.scalatest.{BeforeAndAfterAll, FunSuite}

import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.rapids.execution.{SerializeBatchDeserializeHostBuffer, SerializeConcatHostBuffersDeserializeBatch}
import org.apache.spark.sql.types.{DoubleType, IntegerType, StringType}
import org.apache.spark.sql.vectorized.ColumnarBatch

class SerializationSuite extends FunSuite with Arm {
class SerializationSuite extends FunSuite
with BeforeAndAfterAll with Arm {

override def beforeAll(): Unit = {
RapidsBufferCatalog.setDeviceStorage(new RapidsDeviceMemoryStore())
}

private def buildBatch(): ColumnarBatch = {
withResource(new Table.TestBuilder()
.column(5, null.asInstanceOf[java.lang.Integer], 3, 1, 1, 1, 1, 1, 1, 1)
Expand Down Expand Up @@ -68,12 +74,14 @@ class SerializationSuite extends FunSuite with Arm {
val buffer = createDeserializedHostBuffer(gpuExpected)
val hostBatch = new SerializeConcatHostBuffersDeserializeBatch(Array(buffer), attrs)
withResource(hostBatch) { _ =>
val gpuBatch = hostBatch.batch
TestUtils.compareBatches(gpuExpected, gpuBatch)
withResource(hostBatch.batch.getColumnarBatch()) { gpuBatch =>
TestUtils.compareBatches(gpuExpected, gpuBatch)
}
// clone via serialization after manifesting the GPU batch
withResource(SerializationUtils.clone(hostBatch)) { clonedObj =>
val gpuClonedBatch = clonedObj.batch
TestUtils.compareBatches(gpuExpected, gpuClonedBatch)
withResource(clonedObj.batch.getColumnarBatch()) { gpuClonedBatch =>
TestUtils.compareBatches(gpuExpected, gpuClonedBatch)
}
// try to clone it again from the cloned object
SerializationUtils.clone(clonedObj).close()
}
Expand Down

0 comments on commit 99e5ca9

Please sign in to comment.