Skip to content

Commit

Permalink
Host Memory OOM handling for RowToColumnarIterator (#10617)
Browse files Browse the repository at this point in the history
Signed-off-by: Jim Brennan <[email protected]>
  • Loading branch information
jbrennan333 authored Apr 1, 2024
1 parent 0747506 commit c28c7fa
Show file tree
Hide file tree
Showing 7 changed files with 350 additions and 22 deletions.
11 changes: 8 additions & 3 deletions integration_tests/src/main/python/row_conversion_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2022, NVIDIA CORPORATION.
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,7 +28,12 @@
# to be brought back to the CPU (rows) to be returned.
# So we just need a very simple operation in the middle that
# can be done on the GPU.
def test_row_conversions():
@pytest.mark.parametrize('override_batch_size_bytes', [None, '4mb', '1kb'], ids=idfn)
def test_row_conversions(override_batch_size_bytes):
conf = {}
if override_batch_size_bytes is not None:
conf["spark.rapids.sql.batchSizeBytes"] = override_batch_size_bytes

gens = [["a", byte_gen], ["b", short_gen], ["c", int_gen], ["d", long_gen],
["e", float_gen], ["f", double_gen], ["g", string_gen], ["h", boolean_gen],
["i", timestamp_gen], ["j", date_gen], ["k", ArrayGen(byte_gen)],
Expand All @@ -40,7 +45,7 @@ def test_row_conversions():
["s", null_gen], ["t", decimal_gen_64bit], ["u", decimal_gen_32bit],
["v", decimal_gen_128bit]]
assert_gpu_and_cpu_are_equal_collect(
lambda spark : gen_df(spark, gens).selectExpr("*", "a as a_again"))
lambda spark : gen_df(spark, gens).selectExpr("*", "a as a_again"), conf=conf)

def test_row_conversions_fixed_width():
gens = [["a", byte_gen], ["b", short_gen], ["c", int_gen], ["d", long_gen],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.nvidia.spark.rapids;

import ai.rapids.cudf.*;
import com.nvidia.spark.Retryable;
import com.nvidia.spark.rapids.shims.GpuTypeShims;
import org.apache.arrow.memory.ReferenceManager;

Expand Down Expand Up @@ -234,7 +235,8 @@ public void close() {
}
}

public static final class GpuColumnarBatchBuilder extends GpuColumnarBatchBuilderBase {
public static final class GpuColumnarBatchBuilder extends GpuColumnarBatchBuilderBase
implements Retryable {
private final RapidsHostColumnBuilder[] builders;
private ai.rapids.cudf.HostColumnVector[] hostColumns;

Expand Down Expand Up @@ -266,6 +268,45 @@ public GpuColumnarBatchBuilder(StructType schema, int rows) {
}
}

/**
* A collection of builders for building up columnar data.
* @param schema the schema of the batch.
* @param rows the maximum number of rows in this batch.
* @param spillableHostBuf single spillable host buffer to slice up among columns
* @param bufferSizes an array of sizes for each column
*/
public GpuColumnarBatchBuilder(StructType schema, int rows,
SpillableHostBuffer spillableHostBuf, long[] bufferSizes) {
fields = schema.fields();
int len = fields.length;
builders = new RapidsHostColumnBuilder[len];
boolean success = false;
try (SpillableHostBuffer sBuf = spillableHostBuf;
HostMemoryBuffer hBuf =
RmmRapidsRetryIterator.withRetryNoSplit(() -> sBuf.getHostBuffer());) {
long offset = 0;
for (int i = 0; i < len; i++) {
StructField field = fields[i];
try (HostMemoryBuffer columnBuffer = hBuf.slice(offset, bufferSizes[i]);) {
offset += bufferSizes[i];
builders[i] =
new RapidsHostColumnBuilder(convertFrom(field.dataType(), field.nullable()), rows)
.preAllocateBuffers(columnBuffer, 0);
}
}
success = true;
} finally {
if (!success) {
for (RapidsHostColumnBuilder b: builders) {
if (b != null) {
b.close();
}
}
}
}
}


@Override
public void copyColumnar(ColumnVector cv, int colNum, int rows) {
if (builders.length > 0) {
Expand Down Expand Up @@ -337,6 +378,32 @@ public void close() {
}
}
}

@Override
public void checkpoint() {
for (RapidsHostColumnBuilder b: builders) {
if (b != null) {
b.checkpoint();
}
}
}

@Override
public void restore() {
for (RapidsHostColumnBuilder b: builders) {
if (b != null) {
b.restore();
}
}
}

public void setAllowGrowth(boolean enable) {
for (RapidsHostColumnBuilder b: builders) {
if (b != null) {
b.setAllowGrowth(enable);
}
}
}
}

private static final class ArrowBufReferenceHolder {
Expand Down
Loading

0 comments on commit c28c7fa

Please sign in to comment.