diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastToCpuExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastToCpuExec.scala index 1e6818b0c4a..91f56271b7a 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastToCpuExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastToCpuExec.scala @@ -15,17 +15,14 @@ */ package org.apache.spark.sql.rapids.execution -import java.util -import java.util.Optional import java.util.concurrent.{Callable, Future} import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal -import ai.rapids.cudf.{DType, HostColumnVector, HostColumnVectorCore, HostMemoryBuffer, NvtxColor} -import ai.rapids.cudf.JCudfSerialization.{SerializedColumnHeader, SerializedTableHeader} -import com.nvidia.spark.rapids.{Arm, GpuMetric, MetricRange, NvtxWithMetrics, RapidsHostColumnVector, ShimLoader} +import ai.rapids.cudf.{HostMemoryBuffer, JCudfSerialization, NvtxColor} +import ai.rapids.cudf.JCudfSerialization.SerializedTableHeader +import com.nvidia.spark.rapids.{Arm, GpuMetric, NvtxWithMetrics, RapidsHostColumnVector, ShimLoader} import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableProducingArray import org.apache.spark.SparkException @@ -168,136 +165,16 @@ case class GpuBroadcastToCpuExec(mode: BroadcastMode, child: SparkPlan) } object GpuBroadcastToCpuExec extends Arm { - case class ColumnOffsets(validity: Long, offsets: Long, data: Long, dataLen: Long) - private def buildHostColumns( header: SerializedTableHeader, buffer: HostMemoryBuffer, dataTypes: Array[DataType]): Array[RapidsHostColumnVector] = { assert(dataTypes.length == header.getNumColumns) - val columnOffsets = buildColumnOffsets(header, buffer) - closeOnExcept(new ArrayBuffer[RapidsHostColumnVector](header.getNumColumns)) { hostColumns => - (0 until header.getNumColumns).foreach { i => - val columnHeader = header.getColumnHeader(i) - val hcv = buildHostColumn(columnHeader, columnOffsets, buffer) - hostColumns += new RapidsHostColumnVector(dataTypes(i), hcv) + closeOnExcept(JCudfSerialization.unpackHostColumnVectors(header, buffer)) { hostColumns => + assert(hostColumns.length == dataTypes.length) + dataTypes.zip(hostColumns).safeMap { case (dataType, hostColumn) => + new RapidsHostColumnVector(dataType, hostColumn) } - assert(columnOffsets.isEmpty) - hostColumns.toArray } } - - // TODO: The methods below either replicate private functionality in cudf - // or should be moved to cudf. - - private def buildHostColumn( - columnHeader: SerializedColumnHeader, - columnOffsets: util.ArrayDeque[ColumnOffsets], - buffer: HostMemoryBuffer): HostColumnVector = { - val offsetsInfo = columnOffsets.remove() - closeOnExcept(new ArrayBuffer[HostColumnVectorCore](columnHeader.getNumChildren)) { children => - val childHeaders = columnHeader.getChildren - if (childHeaders != null) { - childHeaders.foreach { childHeader => - children += buildHostColumn(childHeader, columnOffsets, buffer) - } - } - val dtype = columnHeader.getType - val numRows = columnHeader.getRowCount - val nullCount = columnHeader.getNullCount - - // Slice up the host buffer for this column vector's buffers. - val dataBuffer = if (dtype.isNestedType) { - null - } else { - buffer.slice(offsetsInfo.data, offsetsInfo.dataLen) - } - val validityBuffer = if (nullCount > 0) { - // one bit per row - val validitySize = (numRows + 7) / 8 - buffer.slice(offsetsInfo.validity, validitySize) - } else { - null - } - val offsetsBuffer = if (dtype.hasOffsets) { - // one 32-bit integer offset per row plus one additional offset at the end - val offsetsSize = if (numRows > 0) (numRows + 1) * Integer.BYTES else 0 - buffer.slice(offsetsInfo.offsets, offsetsSize) - } else { - null - } - - new HostColumnVector(dtype, numRows, Optional.of(nullCount), dataBuffer, validityBuffer, - offsetsBuffer, children.asJava) - } - } - - /** Build a list of column offset descriptors using a pre-order traversal of the columns. */ - private def buildColumnOffsets( - header: SerializedTableHeader, - buffer: HostMemoryBuffer): util.ArrayDeque[ColumnOffsets] = { - val numTopColumns = header.getNumColumns - val offsets = new util.ArrayDeque[ColumnOffsets] - var bufferOffset = 0L - (0 until numTopColumns).foreach { i => - val columnHeader = header.getColumnHeader(i) - bufferOffset = buildColumnOffsetsForColumn(columnHeader, buffer, offsets, bufferOffset) - } - offsets - } - - /** Append a list of column offset descriptors using a pre-order traversal of the column. */ - private def buildColumnOffsetsForColumn( - columnHeader: SerializedColumnHeader, - buffer: HostMemoryBuffer, - offsetsList: util.ArrayDeque[ColumnOffsets], - startBufferOffset: Long): Long = { - var bufferOffset = startBufferOffset - val rowCount = columnHeader.getRowCount - var validity = 0L - var offsets = 0L - var data = 0L - var dataLen = 0L - if (columnHeader.getNullCount > 0) { - val validityLen = padFor64ByteAlignment((rowCount + 7) / 8) - validity = bufferOffset - bufferOffset += validityLen - } - - val dtype = columnHeader.getType - if (dtype.hasOffsets) { - if (rowCount > 0) { - val offsetsLen = (rowCount + 1) * Integer.BYTES - offsets = bufferOffset - val startOffset = buffer.getInt(bufferOffset) - val endOffset = buffer.getInt(bufferOffset + (rowCount * Integer.BYTES)) - bufferOffset += padFor64ByteAlignment(offsetsLen) - if (dtype.equals(DType.STRING)) { - dataLen = endOffset - startOffset - data = bufferOffset - bufferOffset += padFor64ByteAlignment(dataLen) - } - } - } else if (dtype.getSizeInBytes > 0) { - dataLen = dtype.getSizeInBytes * rowCount - data = bufferOffset - bufferOffset += padFor64ByteAlignment(dataLen) - } - offsetsList.add(ColumnOffsets( - validity = validity, - offsets = offsets, - data = data, - dataLen = dataLen)) - - val children = columnHeader.getChildren - if (children != null) { - children.foreach { child => - bufferOffset = buildColumnOffsetsForColumn(child, buffer, offsetsList, bufferOffset) - } - } - - bufferOffset - } - - private def padFor64ByteAlignment(addr: Long): Long = ((addr + 63) / 64) * 64 }