Skip to content

Commit

Permalink
Support serializing tables directly for shuffle write (#37)
Browse files Browse the repository at this point in the history
* Support serializing packed tables directly
---------

Signed-off-by: Firestarman <[email protected]>
  • Loading branch information
firestarman authored and nvliyuan committed May 7, 2024
1 parent 6cafecf commit f103c4e
Show file tree
Hide file tree
Showing 10 changed files with 630 additions and 126 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* This file was derived from OptimizeWriteExchange.scala
* in the Delta Lake project at https://github.com/delta-io/delta
Expand Down Expand Up @@ -97,8 +97,12 @@ case class GpuOptimizeWriteExchangeExec(
) ++ additionalMetrics
}

private lazy val serializer: Serializer =
new GpuColumnarBatchSerializer(gpuLongMetric("dataSize"))
private lazy val sparkTypes: Array[DataType] = child.output.map(_.dataType).toArray

private lazy val serializer: Serializer = new GpuColumnarBatchSerializer(
gpuLongMetric("dataSize"), allMetrics("rapidsShuffleSerializationTime"),
allMetrics("rapidsShuffleDeserializationTime"),
partitioning.serializingOnGPU, sparkTypes)

@transient lazy val inputRDD: RDD[ColumnarBatch] = child.executeColumnar()

Expand All @@ -116,7 +120,7 @@ case class GpuOptimizeWriteExchangeExec(
inputRDD,
child.output,
partitioning,
child.output.map(_.dataType).toArray,
sparkTypes,
serializer,
useGPUShuffle=partitioning.usesGPUShuffle,
useMultiThreadedShuffle=partitioning.usesMultiThreadedShuffle,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids;

import ai.rapids.cudf.ContiguousTable;
import ai.rapids.cudf.DeviceMemoryBuffer;
import ai.rapids.cudf.HostMemoryBuffer;
import com.nvidia.spark.rapids.format.TableMeta;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.vectorized.ColumnVector;
import org.apache.spark.sql.vectorized.ColumnarArray;
import org.apache.spark.sql.vectorized.ColumnarBatch;
import org.apache.spark.sql.vectorized.ColumnarMap;
import org.apache.spark.unsafe.types.UTF8String;

/**
* A column vector that tracks a packed (or compressed) table on host. Unlike a normal
* host column vector, the columnar data within cannot be accessed directly.
* This is intended to only be used during shuffle after the data is partitioned and
* before it is serialized.
*/
public final class PackedTableHostColumnVector extends ColumnVector {

private static final String BAD_ACCESS_MSG = "Column is packed";

private final TableMeta tableMeta;
private final HostMemoryBuffer tableBuffer;

PackedTableHostColumnVector(TableMeta tableMeta, HostMemoryBuffer tableBuffer) {
super(DataTypes.NullType);
long rows = tableMeta.rowCount();
int batchRows = (int) rows;
if (rows != batchRows) {
throw new IllegalStateException("Cannot support a batch larger that MAX INT rows");
}
this.tableMeta = tableMeta;
this.tableBuffer = tableBuffer;
}

private static ColumnarBatch from(TableMeta meta, DeviceMemoryBuffer devBuf) {
HostMemoryBuffer tableBuf;
try(HostMemoryBuffer buf = HostMemoryBuffer.allocate(devBuf.getLength())) {
buf.copyFromDeviceBuffer(devBuf);
buf.incRefCount();
tableBuf = buf;
}
ColumnVector column = new PackedTableHostColumnVector(meta, tableBuf);
return new ColumnarBatch(new ColumnVector[] { column }, (int) meta.rowCount());
}

/** Both the input table and output batch should be closed. */
public static ColumnarBatch from(CompressedTable table) {
return from(table.meta(), table.buffer());
}

/** Both the input table and output batch should be closed. */
public static ColumnarBatch from(ContiguousTable table) {
return from(MetaUtils.buildTableMeta(0, table), table.getBuffer());
}

/** Returns true if this columnar batch uses a packed table on host */
public static boolean isBatchPackedOnHost(ColumnarBatch batch) {
return batch.numCols() == 1 && batch.column(0) instanceof PackedTableHostColumnVector;
}

public TableMeta getTableMeta() {
return tableMeta;
}

public HostMemoryBuffer getTableBuffer() {
return tableBuffer;
}

@Override
public void close() {
tableBuffer.close();
}

@Override
public boolean hasNull() {
throw new IllegalStateException(BAD_ACCESS_MSG);
}

@Override
public int numNulls() {
throw new IllegalStateException(BAD_ACCESS_MSG);
}

@Override
public boolean isNullAt(int rowId) {
throw new IllegalStateException(BAD_ACCESS_MSG);
}

@Override
public boolean getBoolean(int rowId) {
throw new IllegalStateException(BAD_ACCESS_MSG);
}

@Override
public byte getByte(int rowId) {
throw new IllegalStateException(BAD_ACCESS_MSG);
}

@Override
public short getShort(int rowId) {
throw new IllegalStateException(BAD_ACCESS_MSG);
}

@Override
public int getInt(int rowId) {
throw new IllegalStateException(BAD_ACCESS_MSG);
}

@Override
public long getLong(int rowId) {
throw new IllegalStateException(BAD_ACCESS_MSG);
}

@Override
public float getFloat(int rowId) {
throw new IllegalStateException(BAD_ACCESS_MSG);
}

@Override
public double getDouble(int rowId) {
throw new IllegalStateException(BAD_ACCESS_MSG);
}

@Override
public ColumnarArray getArray(int rowId) {
throw new IllegalStateException(BAD_ACCESS_MSG);
}

@Override
public ColumnarMap getMap(int rowId) {
throw new IllegalStateException(BAD_ACCESS_MSG);
}

@Override
public Decimal getDecimal(int rowId, int precision, int scale) {
throw new IllegalStateException(BAD_ACCESS_MSG);
}

@Override
public UTF8String getUTF8String(int rowId) {
throw new IllegalStateException(BAD_ACCESS_MSG);
}

@Override
public byte[] getBinary(int rowId) {
throw new IllegalStateException(BAD_ACCESS_MSG);
}

@Override
public ColumnVector getChild(int ordinal) {
throw new IllegalStateException(BAD_ACCESS_MSG);
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -462,7 +462,7 @@ abstract class AbstractGpuCoalesceIterator(
// If we have reached the cuDF limit once, proactively filter batches
// after that first limit is reached.
GpuFilter.filterAndClose(cbFromIter, inputFilterTier.get,
NoopMetric, NoopMetric, opTime)
NoopMetric, NoopMetric, NoopMetric)
} else {
Iterator(cbFromIter)
}
Expand Down Expand Up @@ -499,7 +499,7 @@ abstract class AbstractGpuCoalesceIterator(
var filteredBytes = 0L
if (hasAnyToConcat) {
val filteredDowIter = GpuFilter.filterAndClose(concatAllAndPutOnGPU(),
filterTier, NoopMetric, NoopMetric, opTime)
filterTier, NoopMetric, NoopMetric, NoopMetric)
while (filteredDowIter.hasNext) {
closeOnExcept(filteredDowIter.next()) { filteredDownCb =>
filteredNumRows += filteredDownCb.numRows()
Expand All @@ -512,7 +512,7 @@ abstract class AbstractGpuCoalesceIterator(
// filterAndClose takes ownership of CB so we should not close it on a failure
// anymore...
val filteredCbIter = GpuFilter.filterAndClose(cb.release, filterTier,
NoopMetric, NoopMetric, opTime)
NoopMetric, NoopMetric, NoopMetric)
while (filteredCbIter.hasNext) {
closeOnExcept(filteredCbIter.next()) { filteredCb =>
val filteredWouldBeRows = filteredNumRows + filteredCb.numRows()
Expand Down
Loading

0 comments on commit f103c4e

Please sign in to comment.