From 7b9f4a17579befd902d1c30af38daa5fe493e335 Mon Sep 17 00:00:00 2001 From: Gera Shegalov Date: Tue, 29 Aug 2023 20:59:04 -0700 Subject: [PATCH] Use HostMemoryAllocator in jni::allocate_host_buffer (#13975) Fixes #13940 Contributes to NVIDIA/spark-rapids#8889 - Pass an explicit host memory allocator to `jni::allocate_host_buffer` - Consistently check for errors from NewGlobalRef - Consistently guard against DelteteGlobalRef on a null Authors: - Gera Shegalov (https://github.com/gerashegalov) Approvers: - https://github.com/nvdbaranec - Jason Lowe (https://github.com/jlowe) URL: https://github.com/rapidsai/cudf/pull/13975 --- java/src/main/java/ai/rapids/cudf/Table.java | 84 +++++++++++++++---- java/src/main/native/include/jni_utils.hpp | 15 ++++ .../main/native/src/ContiguousTableJni.cpp | 10 +-- java/src/main/native/src/CudfJni.cpp | 21 ++--- java/src/main/native/src/RmmJni.cpp | 7 +- java/src/main/native/src/TableJni.cpp | 59 ++++++------- java/src/main/native/src/cudf_jni_apis.hpp | 3 +- .../main/native/src/jni_writer_data_sink.hpp | 29 +++---- 8 files changed, 133 insertions(+), 95 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index 57189b052b6..b2eb33d47dc 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -336,7 +336,9 @@ private static native long writeParquetBufferBegin(String[] columnNames, boolean[] isBinaryValues, boolean[] hasParquetFieldIds, int[] parquetFieldIds, - HostBufferConsumer consumer) throws CudfException; + HostBufferConsumer consumer, + HostMemoryAllocator hostMemoryAllocator + ) throws CudfException; /** * Write out a table to an open handle. @@ -419,7 +421,9 @@ private static native long writeORCBufferBegin(String[] columnNames, int compression, int[] precisions, boolean[] isMapValues, - HostBufferConsumer consumer) throws CudfException; + HostBufferConsumer consumer, + HostMemoryAllocator hostMemoryAllocator + ) throws CudfException; /** * Write out a table to an open handle. @@ -447,10 +451,12 @@ private static native long writeORCBufferBegin(String[] columnNames, * Setup everything to write Arrow IPC formatted data to a buffer. * @param columnNames names that correspond to the table columns * @param consumer consumer of host buffers produced. + * @param hostMemoryAllocator allocator for host memory buffers. * @return a handle that is used in later calls to writeArrowIPCChunk and writeArrowIPCEnd. */ private static native long writeArrowIPCBufferBegin(String[] columnNames, - HostBufferConsumer consumer); + HostBufferConsumer consumer, + HostMemoryAllocator hostMemoryAllocator); /** * Convert a cudf table to an arrow table handle. @@ -906,7 +912,9 @@ private static native long startWriteCSVToBuffer(String[] columnNames, String trueValue, String falseValue, int quoteStyle, - HostBufferConsumer buffer) throws CudfException; + HostBufferConsumer buffer, + HostMemoryAllocator hostMemoryAllocator + ) throws CudfException; private static native void writeCSVChunkToBuffer(long writerHandle, long tableHandle); @@ -915,7 +923,8 @@ private static native long startWriteCSVToBuffer(String[] columnNames, private static class CSVTableWriter extends TableWriter { private HostBufferConsumer consumer; - private CSVTableWriter(CSVWriterOptions options, HostBufferConsumer consumer) { + private CSVTableWriter(CSVWriterOptions options, HostBufferConsumer consumer, + HostMemoryAllocator hostMemoryAllocator) { super(startWriteCSVToBuffer(options.getColumnNames(), options.getIncludeHeader(), options.getRowDelimiter(), @@ -924,7 +933,7 @@ private CSVTableWriter(CSVWriterOptions options, HostBufferConsumer consumer) { options.getTrueValue(), options.getFalseValue(), options.getQuoteStyle().nativeId, - consumer)); + consumer, hostMemoryAllocator)); this.consumer = consumer; } @@ -949,8 +958,14 @@ public void close() throws CudfException { } } - public static TableWriter getCSVBufferWriter(CSVWriterOptions options, HostBufferConsumer bufferConsumer) { - return new CSVTableWriter(options, bufferConsumer); + public static TableWriter getCSVBufferWriter(CSVWriterOptions options, + HostBufferConsumer bufferConsumer, HostMemoryAllocator hostMemoryAllocator) { + return new CSVTableWriter(options, bufferConsumer, hostMemoryAllocator); + } + + public static TableWriter getCSVBufferWriter(CSVWriterOptions options, + HostBufferConsumer bufferConsumer) { + return getCSVBufferWriter(options, bufferConsumer, DefaultHostMemoryAllocator.get()); } /** @@ -1393,7 +1408,8 @@ private ParquetTableWriter(ParquetWriterOptions options, File outputFile) { this.consumer = null; } - private ParquetTableWriter(ParquetWriterOptions options, HostBufferConsumer consumer) { + private ParquetTableWriter(ParquetWriterOptions options, HostBufferConsumer consumer, + HostMemoryAllocator hostMemoryAllocator) { super(writeParquetBufferBegin(options.getFlatColumnNames(), options.getTopLevelChildren(), options.getFlatNumChildren(), @@ -1408,7 +1424,7 @@ private ParquetTableWriter(ParquetWriterOptions options, HostBufferConsumer cons options.getFlatIsBinary(), options.getFlatHasParquetFieldId(), options.getFlatParquetFieldId(), - consumer)); + consumer, hostMemoryAllocator)); this.consumer = consumer; } @@ -1448,11 +1464,18 @@ public static TableWriter writeParquetChunked(ParquetWriterOptions options, File * @param options the parquet writer options. * @param consumer a class that will be called when host buffers are ready with parquet * formatted data in them. + * @param hostMemoryAllocator allocator for host memory buffers * @return a table writer to use for writing out multiple tables. */ + public static TableWriter writeParquetChunked(ParquetWriterOptions options, + HostBufferConsumer consumer, + HostMemoryAllocator hostMemoryAllocator) { + return new ParquetTableWriter(options, consumer, hostMemoryAllocator); + } + public static TableWriter writeParquetChunked(ParquetWriterOptions options, HostBufferConsumer consumer) { - return new ParquetTableWriter(options, consumer); + return writeParquetChunked(options, consumer, DefaultHostMemoryAllocator.get()); } /** @@ -1461,10 +1484,12 @@ public static TableWriter writeParquetChunked(ParquetWriterOptions options, * @param options the Parquet writer options. * @param consumer a class that will be called when host buffers are ready with Parquet * formatted data in them. + * @param hostMemoryAllocator allocator for host memory buffers * @param columnViews ColumnViews to write to Parquet */ public static void writeColumnViewsToParquet(ParquetWriterOptions options, HostBufferConsumer consumer, + HostMemoryAllocator hostMemoryAllocator, ColumnView... columnViews) { assert columnViews != null && columnViews.length > 0 : "ColumnViews can't be null or empty"; long rows = columnViews[0].getRowCount(); @@ -1483,7 +1508,9 @@ public static void writeColumnViewsToParquet(ParquetWriterOptions options, long nativeHandle = createCudfTableView(viewPointers); try { - try (ParquetTableWriter writer = new ParquetTableWriter(options, consumer)) { + try ( + ParquetTableWriter writer = new ParquetTableWriter(options, consumer, hostMemoryAllocator) + ) { long total = 0; for (ColumnView cv : columnViews) { total += cv.getDeviceMemorySize(); @@ -1495,6 +1522,12 @@ public static void writeColumnViewsToParquet(ParquetWriterOptions options, } } + public static void writeColumnViewsToParquet(ParquetWriterOptions options, + HostBufferConsumer consumer, + ColumnView... columnViews) { + writeColumnViewsToParquet(options, consumer, DefaultHostMemoryAllocator.get(), columnViews); + } + private static class ORCTableWriter extends TableWriter { HostBufferConsumer consumer; @@ -1512,7 +1545,8 @@ private ORCTableWriter(ORCWriterOptions options, File outputFile) { this.consumer = null; } - private ORCTableWriter(ORCWriterOptions options, HostBufferConsumer consumer) { + private ORCTableWriter(ORCWriterOptions options, HostBufferConsumer consumer, + HostMemoryAllocator hostMemoryAllocator) { super(writeORCBufferBegin(options.getFlatColumnNames(), options.getTopLevelChildren(), options.getFlatNumChildren(), @@ -1522,7 +1556,7 @@ private ORCTableWriter(ORCWriterOptions options, HostBufferConsumer consumer) { options.getCompressionType().nativeId, options.getFlatPrecision(), options.getFlatIsMap(), - consumer)); + consumer, hostMemoryAllocator)); this.consumer = consumer; } @@ -1562,10 +1596,16 @@ public static TableWriter writeORCChunked(ORCWriterOptions options, File outputF * @param options the ORC writer options. * @param consumer a class that will be called when host buffers are ready with ORC * formatted data in them. + * @param hostMemoryAllocator allocator for host memory buffers * @return a table writer to use for writing out multiple tables. */ + public static TableWriter writeORCChunked(ORCWriterOptions options, HostBufferConsumer consumer, + HostMemoryAllocator hostMemoryAllocator) { + return new ORCTableWriter(options, consumer, hostMemoryAllocator); + } + public static TableWriter writeORCChunked(ORCWriterOptions options, HostBufferConsumer consumer) { - return new ORCTableWriter(options, consumer); + return writeORCChunked(options, consumer, DefaultHostMemoryAllocator.get()); } private static class ArrowIPCTableWriter extends TableWriter { @@ -1580,8 +1620,9 @@ private ArrowIPCTableWriter(ArrowIPCWriterOptions options, File outputFile) { this.maxChunkSize = options.getMaxChunkSize(); } - private ArrowIPCTableWriter(ArrowIPCWriterOptions options, HostBufferConsumer consumer) { - super(writeArrowIPCBufferBegin(options.getColumnNames(), consumer)); + private ArrowIPCTableWriter(ArrowIPCWriterOptions options, HostBufferConsumer consumer, + HostMemoryAllocator hostMemoryAllocator) { + super(writeArrowIPCBufferBegin(options.getColumnNames(), consumer, hostMemoryAllocator)); this.callback = options.getCallback(); this.consumer = consumer; this.maxChunkSize = options.getMaxChunkSize(); @@ -1629,11 +1670,18 @@ public static TableWriter writeArrowIPCChunked(ArrowIPCWriterOptions options, Fi * @param options the arrow IPC writer options. * @param consumer a class that will be called when host buffers are ready with arrow IPC * formatted data in them. + * @param hostMemoryAllocator allocator for host memory buffers * @return a table writer to use for writing out multiple tables. */ + public static TableWriter writeArrowIPCChunked(ArrowIPCWriterOptions options, + HostBufferConsumer consumer, + HostMemoryAllocator hostMemoryAllocator) { + return new ArrowIPCTableWriter(options, consumer, hostMemoryAllocator); + } + public static TableWriter writeArrowIPCChunked(ArrowIPCWriterOptions options, HostBufferConsumer consumer) { - return new ArrowIPCTableWriter(options, consumer); + return writeArrowIPCChunked(options, consumer, DefaultHostMemoryAllocator.get()); } private static class ArrowReaderWrapper implements AutoCloseable { diff --git a/java/src/main/native/include/jni_utils.hpp b/java/src/main/native/include/jni_utils.hpp index ff4da893329..f342fca8933 100644 --- a/java/src/main/native/include/jni_utils.hpp +++ b/java/src/main/native/include/jni_utils.hpp @@ -786,6 +786,21 @@ inline void jni_cuda_check(JNIEnv *const env, cudaError_t cuda_status) { } } +inline auto add_global_ref(JNIEnv *env, jobject jobj) { + auto new_global_ref = env->NewGlobalRef(jobj); + if (new_global_ref == nullptr) { + throw cudf::jni::jni_exception("global ref"); + } + return new_global_ref; +} + +inline nullptr_t del_global_ref(JNIEnv *env, jobject jobj) { + if (jobj != nullptr) { + env->DeleteGlobalRef(jobj); + } + return nullptr; +} + } // namespace jni } // namespace cudf diff --git a/java/src/main/native/src/ContiguousTableJni.cpp b/java/src/main/native/src/ContiguousTableJni.cpp index 7eddea2a895..8c99c77ca1f 100644 --- a/java/src/main/native/src/ContiguousTableJni.cpp +++ b/java/src/main/native/src/ContiguousTableJni.cpp @@ -55,10 +55,7 @@ bool cache_contiguous_table_jni(JNIEnv *env) { } void release_contiguous_table_jni(JNIEnv *env) { - if (Contiguous_table_jclass != nullptr) { - env->DeleteGlobalRef(Contiguous_table_jclass); - Contiguous_table_jclass = nullptr; - } + Contiguous_table_jclass = cudf::jni::del_global_ref(env, Contiguous_table_jclass); } bool cache_contig_split_group_by_result_jni(JNIEnv *env) { @@ -87,10 +84,7 @@ bool cache_contig_split_group_by_result_jni(JNIEnv *env) { } void release_contig_split_group_by_result_jni(JNIEnv *env) { - if (Contig_split_group_by_result_jclass != nullptr) { - env->DeleteGlobalRef(Contig_split_group_by_result_jclass); - Contig_split_group_by_result_jclass = nullptr; - } + Contig_split_group_by_result_jclass = del_global_ref(env, Contig_split_group_by_result_jclass); } jobject contig_split_group_by_result_from(JNIEnv *env, jobjectArray &groups) { diff --git a/java/src/main/native/src/CudfJni.cpp b/java/src/main/native/src/CudfJni.cpp index acbf309b4b7..0f143086451 100644 --- a/java/src/main/native/src/CudfJni.cpp +++ b/java/src/main/native/src/CudfJni.cpp @@ -46,7 +46,6 @@ constexpr bool is_ptds_enabled{false}; #endif static jclass Host_memory_buffer_jclass; -static jmethodID Host_buffer_allocate; static jfieldID Host_buffer_address; static jfieldID Host_buffer_length; @@ -59,11 +58,6 @@ static bool cache_host_memory_buffer_jni(JNIEnv *env) { return false; } - Host_buffer_allocate = env->GetStaticMethodID(cls, "allocate", HOST_MEMORY_BUFFER_SIG("JZ")); - if (Host_buffer_allocate == nullptr) { - return false; - } - Host_buffer_address = env->GetFieldID(cls, "address", "J"); if (Host_buffer_address == nullptr) { return false; @@ -83,15 +77,16 @@ static bool cache_host_memory_buffer_jni(JNIEnv *env) { } static void release_host_memory_buffer_jni(JNIEnv *env) { - if (Host_memory_buffer_jclass != nullptr) { - env->DeleteGlobalRef(Host_memory_buffer_jclass); - Host_memory_buffer_jclass = nullptr; - } + Host_memory_buffer_jclass = del_global_ref(env, Host_memory_buffer_jclass); } -jobject allocate_host_buffer(JNIEnv *env, jlong amount, jboolean prefer_pinned) { - jobject ret = env->CallStaticObjectMethod(Host_memory_buffer_jclass, Host_buffer_allocate, amount, - prefer_pinned); +jobject allocate_host_buffer(JNIEnv *env, jlong amount, jboolean prefer_pinned, + jobject host_memory_allocator) { + auto const host_memory_allocator_class = env->GetObjectClass(host_memory_allocator); + auto const allocateMethodId = + env->GetMethodID(host_memory_allocator_class, "allocate", HOST_MEMORY_BUFFER_SIG("JZ")); + jobject ret = + env->CallObjectMethod(host_memory_allocator, allocateMethodId, amount, prefer_pinned); if (env->ExceptionCheck()) { throw std::runtime_error("allocateHostBuffer threw an exception"); diff --git a/java/src/main/native/src/RmmJni.cpp b/java/src/main/native/src/RmmJni.cpp index 5bbb5383d93..3c49d153cb6 100644 --- a/java/src/main/native/src/RmmJni.cpp +++ b/java/src/main/native/src/RmmJni.cpp @@ -197,10 +197,7 @@ class java_event_handler_memory_resource : public device_memory_resource { update_thresholds(env, alloc_thresholds, jalloc_thresholds); update_thresholds(env, dealloc_thresholds, jdealloc_thresholds); - handler_obj = env->NewGlobalRef(jhandler); - if (handler_obj == nullptr) { - throw cudf::jni::jni_exception("global ref"); - } + handler_obj = cudf::jni::add_global_ref(env, jhandler); } virtual ~java_event_handler_memory_resource() { @@ -209,7 +206,7 @@ class java_event_handler_memory_resource : public device_memory_resource { // already be destroyed and this thread should not try to attach to get an environment. JNIEnv *env = nullptr; if (jvm->GetEnv(reinterpret_cast(&env), cudf::jni::MINIMUM_JNI_VERSION) == JNI_OK) { - env->DeleteGlobalRef(handler_obj); + handler_obj = cudf::jni::del_global_ref(env, handler_obj); } handler_obj = nullptr; } diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index d6ef2a1e26c..f7ada4305db 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -224,7 +224,7 @@ class native_arrow_ipc_writer_handle final { class jni_arrow_output_stream final : public arrow::io::OutputStream { public: - explicit jni_arrow_output_stream(JNIEnv *env, jobject callback) { + explicit jni_arrow_output_stream(JNIEnv *env, jobject callback, jobject host_memory_allocator) { if (env->GetJavaVM(&jvm) < 0) { throw std::runtime_error("GetJavaVM failed"); } @@ -239,11 +239,8 @@ class jni_arrow_output_stream final : public arrow::io::OutputStream { if (handle_buffer_method == nullptr) { throw cudf::jni::jni_exception("handleBuffer method"); } - - this->callback = env->NewGlobalRef(callback); - if (this->callback == nullptr) { - throw cudf::jni::jni_exception("global ref"); - } + this->callback = add_global_ref(env, callback); + this->host_memory_allocator = add_global_ref(env, host_memory_allocator); } virtual ~jni_arrow_output_stream() { @@ -252,13 +249,13 @@ class jni_arrow_output_stream final : public arrow::io::OutputStream { // already be destroyed and this thread should not try to attach to get an environment. JNIEnv *env = nullptr; if (jvm->GetEnv(reinterpret_cast(&env), cudf::jni::MINIMUM_JNI_VERSION) == JNI_OK) { - env->DeleteGlobalRef(callback); - if (current_buffer != nullptr) { - env->DeleteGlobalRef(current_buffer); - } + callback = del_global_ref(env, callback); + current_buffer = del_global_ref(env, current_buffer); + host_memory_allocator = del_global_ref(env, host_memory_allocator); } callback = nullptr; current_buffer = nullptr; + host_memory_allocator = nullptr; } arrow::Status Write(const std::shared_ptr &data) override { @@ -293,10 +290,7 @@ class jni_arrow_output_stream final : public arrow::io::OutputStream { if (current_buffer_written > 0) { JNIEnv *env = cudf::jni::get_jni_env(jvm); handle_buffer(env, current_buffer, current_buffer_written); - if (current_buffer != nullptr) { - env->DeleteGlobalRef(current_buffer); - } - current_buffer = nullptr; + current_buffer = del_global_ref(env, current_buffer); current_buffer_len = 0; current_buffer_data = nullptr; current_buffer_written = 0; @@ -323,11 +317,10 @@ class jni_arrow_output_stream final : public arrow::io::OutputStream { void rotate_buffer(JNIEnv *env) { if (current_buffer != nullptr) { handle_buffer(env, current_buffer, current_buffer_written); - env->DeleteGlobalRef(current_buffer); - current_buffer = nullptr; } - jobject tmp_buffer = allocate_host_buffer(env, alloc_size, true); - current_buffer = env->NewGlobalRef(tmp_buffer); + current_buffer = del_global_ref(env, current_buffer); + jobject tmp_buffer = allocate_host_buffer(env, alloc_size, true, host_memory_allocator); + current_buffer = add_global_ref(env, tmp_buffer); current_buffer_len = get_host_buffer_length(env, current_buffer); current_buffer_data = reinterpret_cast(get_host_buffer_address(env, current_buffer)); current_buffer_written = 0; @@ -350,6 +343,7 @@ class jni_arrow_output_stream final : public arrow::io::OutputStream { int64_t total_written = 0; long alloc_size = MINIMUM_WRITE_BUFFER_SIZE; bool is_closed = false; + jobject host_memory_allocator; }; class jni_arrow_input_stream final : public arrow::io::InputStream { @@ -370,10 +364,7 @@ class jni_arrow_input_stream final : public arrow::io::InputStream { throw cudf::jni::jni_exception("readInto method"); } - this->callback = env->NewGlobalRef(callback); - if (this->callback == nullptr) { - throw cudf::jni::jni_exception("global ref"); - } + this->callback = add_global_ref(env, callback); } virtual ~jni_arrow_input_stream() { @@ -382,7 +373,7 @@ class jni_arrow_input_stream final : public arrow::io::InputStream { // already be destroyed and this thread should not try to attach to get an environment. JNIEnv *env = nullptr; if (jvm->GetEnv(reinterpret_cast(&env), cudf::jni::MINIMUM_JNI_VERSION) == JNI_OK) { - env->DeleteGlobalRef(callback); + callback = del_global_ref(env, callback); } callback = nullptr; } @@ -1269,7 +1260,7 @@ JNIEXPORT void JNICALL Java_ai_rapids_cudf_Table_writeCSVToFile( JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_startWriteCSVToBuffer( JNIEnv *env, jclass, jobjectArray j_column_names, jboolean include_header, jstring j_row_delimiter, jbyte j_field_delimiter, jstring j_null_value, jstring j_true_value, - jstring j_false_value, jint j_quote_style, jobject j_buffer) { + jstring j_false_value, jint j_quote_style, jobject j_buffer, jobject host_memory_allocator) { JNI_NULL_CHECK(env, j_column_names, "column name array cannot be null", 0); JNI_NULL_CHECK(env, j_row_delimiter, "row delimiter cannot be null", 0); JNI_NULL_CHECK(env, j_field_delimiter, "field delimiter cannot be null", 0); @@ -1279,7 +1270,8 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_startWriteCSVToBuffer( try { cudf::jni::auto_set_device(env); - auto data_sink = std::make_unique(env, j_buffer); + auto data_sink = + std::make_unique(env, j_buffer, host_memory_allocator); auto const n_column_names = cudf::jni::native_jstringArray{env, j_column_names}; auto const column_names = n_column_names.as_cpp_vector(); @@ -1576,7 +1568,7 @@ JNIEXPORT long JNICALL Java_ai_rapids_cudf_Table_writeParquetBufferBegin( jbooleanArray j_col_nullability, jobjectArray j_metadata_keys, jobjectArray j_metadata_values, jint j_compression, jint j_stats_freq, jbooleanArray j_isInt96, jintArray j_precisions, jbooleanArray j_is_map, jbooleanArray j_is_binary, jbooleanArray j_hasParquetFieldIds, - jintArray j_parquetFieldIds, jobject consumer) { + jintArray j_parquetFieldIds, jobject consumer, jobject host_memory_allocator) { JNI_NULL_CHECK(env, j_col_names, "null columns", 0); JNI_NULL_CHECK(env, j_col_nullability, "null nullability", 0); JNI_NULL_CHECK(env, j_metadata_keys, "null metadata keys", 0); @@ -1584,7 +1576,7 @@ JNIEXPORT long JNICALL Java_ai_rapids_cudf_Table_writeParquetBufferBegin( JNI_NULL_CHECK(env, consumer, "null consumer", 0); try { std::unique_ptr data_sink( - new cudf::jni::jni_writer_data_sink(env, consumer)); + new cudf::jni::jni_writer_data_sink(env, consumer, host_memory_allocator)); using namespace cudf::io; using namespace cudf::jni; @@ -1755,7 +1747,8 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readORC( JNIEXPORT long JNICALL Java_ai_rapids_cudf_Table_writeORCBufferBegin( JNIEnv *env, jclass, jobjectArray j_col_names, jint j_num_children, jintArray j_children, jbooleanArray j_col_nullability, jobjectArray j_metadata_keys, jobjectArray j_metadata_values, - jint j_compression, jintArray j_precisions, jbooleanArray j_is_map, jobject consumer) { + jint j_compression, jintArray j_precisions, jbooleanArray j_is_map, jobject consumer, + jobject host_memory_allocator) { JNI_NULL_CHECK(env, j_col_names, "null columns", 0); JNI_NULL_CHECK(env, j_col_nullability, "null nullability", 0); JNI_NULL_CHECK(env, j_metadata_keys, "null metadata keys", 0); @@ -1787,7 +1780,7 @@ JNIEXPORT long JNICALL Java_ai_rapids_cudf_Table_writeORCBufferBegin( [](const std::string &k, const std::string &v) { return std::make_pair(k, v); }); std::unique_ptr data_sink( - new cudf::jni::jni_writer_data_sink(env, consumer)); + new cudf::jni::jni_writer_data_sink(env, consumer, host_memory_allocator)); sink_info sink{data_sink.get()}; auto stats = std::make_shared(); @@ -1918,9 +1911,9 @@ JNIEXPORT jdoubleArray JNICALL Java_ai_rapids_cudf_TableWriter_getWriteStatistic CATCH_STD(env, nullptr) } -JNIEXPORT long JNICALL Java_ai_rapids_cudf_Table_writeArrowIPCBufferBegin(JNIEnv *env, jclass, - jobjectArray j_col_names, - jobject consumer) { +JNIEXPORT long JNICALL Java_ai_rapids_cudf_Table_writeArrowIPCBufferBegin( + JNIEnv *env, jclass, jobjectArray j_col_names, jobject consumer, + jobject host_memory_allocator) { JNI_NULL_CHECK(env, j_col_names, "null columns", 0); JNI_NULL_CHECK(env, consumer, "null consumer", 0); try { @@ -1928,7 +1921,7 @@ JNIEXPORT long JNICALL Java_ai_rapids_cudf_Table_writeArrowIPCBufferBegin(JNIEnv cudf::jni::native_jstringArray col_names(env, j_col_names); std::shared_ptr data_sink( - new cudf::jni::jni_arrow_output_stream(env, consumer)); + new cudf::jni::jni_arrow_output_stream(env, consumer, host_memory_allocator)); cudf::jni::native_arrow_ipc_writer_handle *ret = new cudf::jni::native_arrow_ipc_writer_handle(col_names.as_cpp_vector(), data_sink); diff --git a/java/src/main/native/src/cudf_jni_apis.hpp b/java/src/main/native/src/cudf_jni_apis.hpp index 18993aea294..867df80b722 100644 --- a/java/src/main/native/src/cudf_jni_apis.hpp +++ b/java/src/main/native/src/cudf_jni_apis.hpp @@ -100,7 +100,8 @@ jobject contig_split_group_by_result_from(JNIEnv *env, jobjectArray &groups, /** * Allocate a HostMemoryBuffer */ -jobject allocate_host_buffer(JNIEnv *env, jlong amount, jboolean prefer_pinned); +jobject allocate_host_buffer(JNIEnv *env, jlong amount, jboolean prefer_pinned, + jobject host_memory_allocator); /** * Get the address of a HostMemoryBuffer diff --git a/java/src/main/native/src/jni_writer_data_sink.hpp b/java/src/main/native/src/jni_writer_data_sink.hpp index 05fe594fcd5..efac6112c25 100644 --- a/java/src/main/native/src/jni_writer_data_sink.hpp +++ b/java/src/main/native/src/jni_writer_data_sink.hpp @@ -26,7 +26,7 @@ constexpr long MINIMUM_WRITE_BUFFER_SIZE = 10 * 1024 * 1024; // 10 MB class jni_writer_data_sink final : public cudf::io::data_sink { public: - explicit jni_writer_data_sink(JNIEnv *env, jobject callback) { + explicit jni_writer_data_sink(JNIEnv *env, jobject callback, jobject host_memory_allocator) { if (env->GetJavaVM(&jvm) < 0) { throw std::runtime_error("GetJavaVM failed"); } @@ -42,10 +42,8 @@ class jni_writer_data_sink final : public cudf::io::data_sink { throw cudf::jni::jni_exception("handleBuffer method"); } - this->callback = env->NewGlobalRef(callback); - if (this->callback == nullptr) { - throw cudf::jni::jni_exception("global ref"); - } + this->callback = add_global_ref(env, callback); + this->host_memory_allocator = add_global_ref(env, host_memory_allocator); } virtual ~jni_writer_data_sink() { @@ -54,13 +52,13 @@ class jni_writer_data_sink final : public cudf::io::data_sink { // already be destroyed and this thread should not try to attach to get an environment. JNIEnv *env = nullptr; if (jvm->GetEnv(reinterpret_cast(&env), cudf::jni::MINIMUM_JNI_VERSION) == JNI_OK) { - env->DeleteGlobalRef(callback); - if (current_buffer != nullptr) { - env->DeleteGlobalRef(current_buffer); - } + callback = del_global_ref(env, callback); + current_buffer = del_global_ref(env, current_buffer); + host_memory_allocator = del_global_ref(env, host_memory_allocator); } callback = nullptr; current_buffer = nullptr; + host_memory_allocator = nullptr; } void host_write(void const *data, size_t size) override { @@ -126,10 +124,7 @@ class jni_writer_data_sink final : public cudf::io::data_sink { if (current_buffer_written > 0) { JNIEnv *env = cudf::jni::get_jni_env(jvm); handle_buffer(env, current_buffer, current_buffer_written); - if (current_buffer != nullptr) { - env->DeleteGlobalRef(current_buffer); - } - current_buffer = nullptr; + current_buffer = del_global_ref(env, current_buffer); current_buffer_len = 0; current_buffer_data = nullptr; current_buffer_written = 0; @@ -144,11 +139,10 @@ class jni_writer_data_sink final : public cudf::io::data_sink { void rotate_buffer(JNIEnv *env) { if (current_buffer != nullptr) { handle_buffer(env, current_buffer, current_buffer_written); - env->DeleteGlobalRef(current_buffer); - current_buffer = nullptr; } - jobject tmp_buffer = allocate_host_buffer(env, alloc_size, true); - current_buffer = env->NewGlobalRef(tmp_buffer); + current_buffer = del_global_ref(env, current_buffer); + jobject tmp_buffer = allocate_host_buffer(env, alloc_size, true, host_memory_allocator); + current_buffer = add_global_ref(env, tmp_buffer); current_buffer_len = get_host_buffer_length(env, current_buffer); current_buffer_data = reinterpret_cast(get_host_buffer_address(env, current_buffer)); current_buffer_written = 0; @@ -170,6 +164,7 @@ class jni_writer_data_sink final : public cudf::io::data_sink { long current_buffer_written = 0; size_t total_written = 0; long alloc_size = MINIMUM_WRITE_BUFFER_SIZE; + jobject host_memory_allocator; }; } // namespace cudf::jni