Skip to content

Commit

Permalink
Use HostMemoryAllocator in jni::allocate_host_buffer (#13975)
Browse files Browse the repository at this point in the history
Fixes #13940 
Contributes to NVIDIA/spark-rapids#8889

- Pass an explicit host memory allocator to  `jni::allocate_host_buffer`
- Consistently check for errors from NewGlobalRef
- Consistently guard against DelteteGlobalRef on a null

Authors:
  - Gera Shegalov (https://github.com/gerashegalov)

Approvers:
  - https://github.com/nvdbaranec
  - Jason Lowe (https://github.com/jlowe)

URL: #13975
  • Loading branch information
gerashegalov authored Aug 30, 2023
1 parent 1452200 commit 7b9f4a1
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 95 deletions.
84 changes: 66 additions & 18 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,9 @@ private static native long writeParquetBufferBegin(String[] columnNames,
boolean[] isBinaryValues,
boolean[] hasParquetFieldIds,
int[] parquetFieldIds,
HostBufferConsumer consumer) throws CudfException;
HostBufferConsumer consumer,
HostMemoryAllocator hostMemoryAllocator
) throws CudfException;

/**
* Write out a table to an open handle.
Expand Down Expand Up @@ -419,7 +421,9 @@ private static native long writeORCBufferBegin(String[] columnNames,
int compression,
int[] precisions,
boolean[] isMapValues,
HostBufferConsumer consumer) throws CudfException;
HostBufferConsumer consumer,
HostMemoryAllocator hostMemoryAllocator
) throws CudfException;

/**
* Write out a table to an open handle.
Expand Down Expand Up @@ -447,10 +451,12 @@ private static native long writeORCBufferBegin(String[] columnNames,
* Setup everything to write Arrow IPC formatted data to a buffer.
* @param columnNames names that correspond to the table columns
* @param consumer consumer of host buffers produced.
* @param hostMemoryAllocator allocator for host memory buffers.
* @return a handle that is used in later calls to writeArrowIPCChunk and writeArrowIPCEnd.
*/
private static native long writeArrowIPCBufferBegin(String[] columnNames,
HostBufferConsumer consumer);
HostBufferConsumer consumer,
HostMemoryAllocator hostMemoryAllocator);

/**
* Convert a cudf table to an arrow table handle.
Expand Down Expand Up @@ -906,7 +912,9 @@ private static native long startWriteCSVToBuffer(String[] columnNames,
String trueValue,
String falseValue,
int quoteStyle,
HostBufferConsumer buffer) throws CudfException;
HostBufferConsumer buffer,
HostMemoryAllocator hostMemoryAllocator
) throws CudfException;

private static native void writeCSVChunkToBuffer(long writerHandle, long tableHandle);

Expand All @@ -915,7 +923,8 @@ private static native long startWriteCSVToBuffer(String[] columnNames,
private static class CSVTableWriter extends TableWriter {
private HostBufferConsumer consumer;

private CSVTableWriter(CSVWriterOptions options, HostBufferConsumer consumer) {
private CSVTableWriter(CSVWriterOptions options, HostBufferConsumer consumer,
HostMemoryAllocator hostMemoryAllocator) {
super(startWriteCSVToBuffer(options.getColumnNames(),
options.getIncludeHeader(),
options.getRowDelimiter(),
Expand All @@ -924,7 +933,7 @@ private CSVTableWriter(CSVWriterOptions options, HostBufferConsumer consumer) {
options.getTrueValue(),
options.getFalseValue(),
options.getQuoteStyle().nativeId,
consumer));
consumer, hostMemoryAllocator));
this.consumer = consumer;
}

Expand All @@ -949,8 +958,14 @@ public void close() throws CudfException {
}
}

public static TableWriter getCSVBufferWriter(CSVWriterOptions options, HostBufferConsumer bufferConsumer) {
return new CSVTableWriter(options, bufferConsumer);
public static TableWriter getCSVBufferWriter(CSVWriterOptions options,
HostBufferConsumer bufferConsumer, HostMemoryAllocator hostMemoryAllocator) {
return new CSVTableWriter(options, bufferConsumer, hostMemoryAllocator);
}

public static TableWriter getCSVBufferWriter(CSVWriterOptions options,
HostBufferConsumer bufferConsumer) {
return getCSVBufferWriter(options, bufferConsumer, DefaultHostMemoryAllocator.get());
}

/**
Expand Down Expand Up @@ -1393,7 +1408,8 @@ private ParquetTableWriter(ParquetWriterOptions options, File outputFile) {
this.consumer = null;
}

private ParquetTableWriter(ParquetWriterOptions options, HostBufferConsumer consumer) {
private ParquetTableWriter(ParquetWriterOptions options, HostBufferConsumer consumer,
HostMemoryAllocator hostMemoryAllocator) {
super(writeParquetBufferBegin(options.getFlatColumnNames(),
options.getTopLevelChildren(),
options.getFlatNumChildren(),
Expand All @@ -1408,7 +1424,7 @@ private ParquetTableWriter(ParquetWriterOptions options, HostBufferConsumer cons
options.getFlatIsBinary(),
options.getFlatHasParquetFieldId(),
options.getFlatParquetFieldId(),
consumer));
consumer, hostMemoryAllocator));
this.consumer = consumer;
}

Expand Down Expand Up @@ -1448,11 +1464,18 @@ public static TableWriter writeParquetChunked(ParquetWriterOptions options, File
* @param options the parquet writer options.
* @param consumer a class that will be called when host buffers are ready with parquet
* formatted data in them.
* @param hostMemoryAllocator allocator for host memory buffers
* @return a table writer to use for writing out multiple tables.
*/
public static TableWriter writeParquetChunked(ParquetWriterOptions options,
HostBufferConsumer consumer,
HostMemoryAllocator hostMemoryAllocator) {
return new ParquetTableWriter(options, consumer, hostMemoryAllocator);
}

public static TableWriter writeParquetChunked(ParquetWriterOptions options,
HostBufferConsumer consumer) {
return new ParquetTableWriter(options, consumer);
return writeParquetChunked(options, consumer, DefaultHostMemoryAllocator.get());
}

/**
Expand All @@ -1461,10 +1484,12 @@ public static TableWriter writeParquetChunked(ParquetWriterOptions options,
* @param options the Parquet writer options.
* @param consumer a class that will be called when host buffers are ready with Parquet
* formatted data in them.
* @param hostMemoryAllocator allocator for host memory buffers
* @param columnViews ColumnViews to write to Parquet
*/
public static void writeColumnViewsToParquet(ParquetWriterOptions options,
HostBufferConsumer consumer,
HostMemoryAllocator hostMemoryAllocator,
ColumnView... columnViews) {
assert columnViews != null && columnViews.length > 0 : "ColumnViews can't be null or empty";
long rows = columnViews[0].getRowCount();
Expand All @@ -1483,7 +1508,9 @@ public static void writeColumnViewsToParquet(ParquetWriterOptions options,

long nativeHandle = createCudfTableView(viewPointers);
try {
try (ParquetTableWriter writer = new ParquetTableWriter(options, consumer)) {
try (
ParquetTableWriter writer = new ParquetTableWriter(options, consumer, hostMemoryAllocator)
) {
long total = 0;
for (ColumnView cv : columnViews) {
total += cv.getDeviceMemorySize();
Expand All @@ -1495,6 +1522,12 @@ public static void writeColumnViewsToParquet(ParquetWriterOptions options,
}
}

public static void writeColumnViewsToParquet(ParquetWriterOptions options,
HostBufferConsumer consumer,
ColumnView... columnViews) {
writeColumnViewsToParquet(options, consumer, DefaultHostMemoryAllocator.get(), columnViews);
}

private static class ORCTableWriter extends TableWriter {
HostBufferConsumer consumer;

Expand All @@ -1512,7 +1545,8 @@ private ORCTableWriter(ORCWriterOptions options, File outputFile) {
this.consumer = null;
}

private ORCTableWriter(ORCWriterOptions options, HostBufferConsumer consumer) {
private ORCTableWriter(ORCWriterOptions options, HostBufferConsumer consumer,
HostMemoryAllocator hostMemoryAllocator) {
super(writeORCBufferBegin(options.getFlatColumnNames(),
options.getTopLevelChildren(),
options.getFlatNumChildren(),
Expand All @@ -1522,7 +1556,7 @@ private ORCTableWriter(ORCWriterOptions options, HostBufferConsumer consumer) {
options.getCompressionType().nativeId,
options.getFlatPrecision(),
options.getFlatIsMap(),
consumer));
consumer, hostMemoryAllocator));
this.consumer = consumer;
}

Expand Down Expand Up @@ -1562,10 +1596,16 @@ public static TableWriter writeORCChunked(ORCWriterOptions options, File outputF
* @param options the ORC writer options.
* @param consumer a class that will be called when host buffers are ready with ORC
* formatted data in them.
* @param hostMemoryAllocator allocator for host memory buffers
* @return a table writer to use for writing out multiple tables.
*/
public static TableWriter writeORCChunked(ORCWriterOptions options, HostBufferConsumer consumer,
HostMemoryAllocator hostMemoryAllocator) {
return new ORCTableWriter(options, consumer, hostMemoryAllocator);
}

public static TableWriter writeORCChunked(ORCWriterOptions options, HostBufferConsumer consumer) {
return new ORCTableWriter(options, consumer);
return writeORCChunked(options, consumer, DefaultHostMemoryAllocator.get());
}

private static class ArrowIPCTableWriter extends TableWriter {
Expand All @@ -1580,8 +1620,9 @@ private ArrowIPCTableWriter(ArrowIPCWriterOptions options, File outputFile) {
this.maxChunkSize = options.getMaxChunkSize();
}

private ArrowIPCTableWriter(ArrowIPCWriterOptions options, HostBufferConsumer consumer) {
super(writeArrowIPCBufferBegin(options.getColumnNames(), consumer));
private ArrowIPCTableWriter(ArrowIPCWriterOptions options, HostBufferConsumer consumer,
HostMemoryAllocator hostMemoryAllocator) {
super(writeArrowIPCBufferBegin(options.getColumnNames(), consumer, hostMemoryAllocator));
this.callback = options.getCallback();
this.consumer = consumer;
this.maxChunkSize = options.getMaxChunkSize();
Expand Down Expand Up @@ -1629,11 +1670,18 @@ public static TableWriter writeArrowIPCChunked(ArrowIPCWriterOptions options, Fi
* @param options the arrow IPC writer options.
* @param consumer a class that will be called when host buffers are ready with arrow IPC
* formatted data in them.
* @param hostMemoryAllocator allocator for host memory buffers
* @return a table writer to use for writing out multiple tables.
*/
public static TableWriter writeArrowIPCChunked(ArrowIPCWriterOptions options,
HostBufferConsumer consumer,
HostMemoryAllocator hostMemoryAllocator) {
return new ArrowIPCTableWriter(options, consumer, hostMemoryAllocator);
}

public static TableWriter writeArrowIPCChunked(ArrowIPCWriterOptions options,
HostBufferConsumer consumer) {
return new ArrowIPCTableWriter(options, consumer);
return writeArrowIPCChunked(options, consumer, DefaultHostMemoryAllocator.get());
}

private static class ArrowReaderWrapper implements AutoCloseable {
Expand Down
15 changes: 15 additions & 0 deletions java/src/main/native/include/jni_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,21 @@ inline void jni_cuda_check(JNIEnv *const env, cudaError_t cuda_status) {
}
}

inline auto add_global_ref(JNIEnv *env, jobject jobj) {
auto new_global_ref = env->NewGlobalRef(jobj);
if (new_global_ref == nullptr) {
throw cudf::jni::jni_exception("global ref");
}
return new_global_ref;
}

inline nullptr_t del_global_ref(JNIEnv *env, jobject jobj) {
if (jobj != nullptr) {
env->DeleteGlobalRef(jobj);
}
return nullptr;
}

} // namespace jni
} // namespace cudf

Expand Down
10 changes: 2 additions & 8 deletions java/src/main/native/src/ContiguousTableJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,7 @@ bool cache_contiguous_table_jni(JNIEnv *env) {
}

void release_contiguous_table_jni(JNIEnv *env) {
if (Contiguous_table_jclass != nullptr) {
env->DeleteGlobalRef(Contiguous_table_jclass);
Contiguous_table_jclass = nullptr;
}
Contiguous_table_jclass = cudf::jni::del_global_ref(env, Contiguous_table_jclass);
}

bool cache_contig_split_group_by_result_jni(JNIEnv *env) {
Expand Down Expand Up @@ -87,10 +84,7 @@ bool cache_contig_split_group_by_result_jni(JNIEnv *env) {
}

void release_contig_split_group_by_result_jni(JNIEnv *env) {
if (Contig_split_group_by_result_jclass != nullptr) {
env->DeleteGlobalRef(Contig_split_group_by_result_jclass);
Contig_split_group_by_result_jclass = nullptr;
}
Contig_split_group_by_result_jclass = del_global_ref(env, Contig_split_group_by_result_jclass);
}

jobject contig_split_group_by_result_from(JNIEnv *env, jobjectArray &groups) {
Expand Down
21 changes: 8 additions & 13 deletions java/src/main/native/src/CudfJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ constexpr bool is_ptds_enabled{false};
#endif

static jclass Host_memory_buffer_jclass;
static jmethodID Host_buffer_allocate;
static jfieldID Host_buffer_address;
static jfieldID Host_buffer_length;

Expand All @@ -59,11 +58,6 @@ static bool cache_host_memory_buffer_jni(JNIEnv *env) {
return false;
}

Host_buffer_allocate = env->GetStaticMethodID(cls, "allocate", HOST_MEMORY_BUFFER_SIG("JZ"));
if (Host_buffer_allocate == nullptr) {
return false;
}

Host_buffer_address = env->GetFieldID(cls, "address", "J");
if (Host_buffer_address == nullptr) {
return false;
Expand All @@ -83,15 +77,16 @@ static bool cache_host_memory_buffer_jni(JNIEnv *env) {
}

static void release_host_memory_buffer_jni(JNIEnv *env) {
if (Host_memory_buffer_jclass != nullptr) {
env->DeleteGlobalRef(Host_memory_buffer_jclass);
Host_memory_buffer_jclass = nullptr;
}
Host_memory_buffer_jclass = del_global_ref(env, Host_memory_buffer_jclass);
}

jobject allocate_host_buffer(JNIEnv *env, jlong amount, jboolean prefer_pinned) {
jobject ret = env->CallStaticObjectMethod(Host_memory_buffer_jclass, Host_buffer_allocate, amount,
prefer_pinned);
jobject allocate_host_buffer(JNIEnv *env, jlong amount, jboolean prefer_pinned,
jobject host_memory_allocator) {
auto const host_memory_allocator_class = env->GetObjectClass(host_memory_allocator);
auto const allocateMethodId =
env->GetMethodID(host_memory_allocator_class, "allocate", HOST_MEMORY_BUFFER_SIG("JZ"));
jobject ret =
env->CallObjectMethod(host_memory_allocator, allocateMethodId, amount, prefer_pinned);

if (env->ExceptionCheck()) {
throw std::runtime_error("allocateHostBuffer threw an exception");
Expand Down
7 changes: 2 additions & 5 deletions java/src/main/native/src/RmmJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,7 @@ class java_event_handler_memory_resource : public device_memory_resource {
update_thresholds(env, alloc_thresholds, jalloc_thresholds);
update_thresholds(env, dealloc_thresholds, jdealloc_thresholds);

handler_obj = env->NewGlobalRef(jhandler);
if (handler_obj == nullptr) {
throw cudf::jni::jni_exception("global ref");
}
handler_obj = cudf::jni::add_global_ref(env, jhandler);
}

virtual ~java_event_handler_memory_resource() {
Expand All @@ -209,7 +206,7 @@ class java_event_handler_memory_resource : public device_memory_resource {
// already be destroyed and this thread should not try to attach to get an environment.
JNIEnv *env = nullptr;
if (jvm->GetEnv(reinterpret_cast<void **>(&env), cudf::jni::MINIMUM_JNI_VERSION) == JNI_OK) {
env->DeleteGlobalRef(handler_obj);
handler_obj = cudf::jni::del_global_ref(env, handler_obj);
}
handler_obj = nullptr;
}
Expand Down
Loading

0 comments on commit 7b9f4a1

Please sign in to comment.