From 436bc8da3c4035c33c63a25899682eb2e3e75caa Mon Sep 17 00:00:00 2001 From: Paul Mattione Date: Tue, 11 Jun 2024 10:25:52 -0400 Subject: [PATCH] Remove device_memory_resource initial pass --- .../rapids/cudf/RmmDeviceMemoryResource.java | 2 +- java/src/main/native/src/RmmJni.cpp | 312 ++++++++++++------ java/src/main/native/src/TableJni.cpp | 4 +- 3 files changed, 210 insertions(+), 108 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/RmmDeviceMemoryResource.java b/java/src/main/java/ai/rapids/cudf/RmmDeviceMemoryResource.java index f44631396df..ff02c952d0b 100644 --- a/java/src/main/java/ai/rapids/cudf/RmmDeviceMemoryResource.java +++ b/java/src/main/java/ai/rapids/cudf/RmmDeviceMemoryResource.java @@ -22,7 +22,7 @@ */ public interface RmmDeviceMemoryResource extends AutoCloseable { /** - * Returns a pointer to the underlying C++ class that implements rmm::mr::device_memory_resource + * Returns a pointer to the underlying C++ class that implements rmm::device_async_resource_ref */ long getHandle(); diff --git a/java/src/main/native/src/RmmJni.cpp b/java/src/main/native/src/RmmJni.cpp index fa78f6ca4e2..3d47bd326f8 100644 --- a/java/src/main/native/src/RmmJni.cpp +++ b/java/src/main/native/src/RmmJni.cpp @@ -39,7 +39,6 @@ #include #include -using rmm::mr::device_memory_resource; using rmm::mr::logging_resource_adaptor; using rmm_pinned_pool_t = rmm::mr::pool_memory_resource; @@ -51,7 +50,7 @@ constexpr char const* RMM_EXCEPTION_CLASS = "ai/rapids/cudf/RmmException"; * @brief Base class so we can template tracking_resource_adaptor but * still hold all instances of it without issues. */ -class base_tracking_resource_adaptor : public device_memory_resource { +class base_tracking_resource_adaptor { public: virtual std::size_t get_total_allocated() = 0; @@ -60,6 +59,31 @@ class base_tracking_resource_adaptor : public device_memory_resource { virtual void reset_scoped_max_total_allocated(std::size_t initial_value) = 0; virtual std::size_t get_scoped_max_total_allocated() = 0; + + /** + * Sync allocation method required to satisfy cuda::mr::resource concept + */ + void* allocate(std::size_t num_bytes, std::size_t alignment) = 0; + + /** + * Sync deallocation method required to satisfy cuda::mr::resource concept + */ + void deallocate(void* p, std::size_t num_bytes, std::size_t alignment) = 0; + + /** + * Async allocation method required to satisfy cuda::mr::async_resource concept + */ + void* allocate_async(std::size_t num_bytes, + std::size_t alignment, + rmm::cuda_stream_view stream) = 0; + + /** + * Async deallocation method required to satisfy cuda::mr::async_resource concept + */ + void deallocate_async(void* p, + std::size_t size, + std::size_t alignment, + rmm::cuda_stream_view stream) = 0; }; /** @@ -82,7 +106,7 @@ class tracking_resource_adaptor final : public base_tracking_resource_adaptor { * @param size_alignment The alignment to which the `mr` resource will * round up all memory allocation size requests. */ - tracking_resource_adaptor(Upstream* mr, std::size_t size_alignment) + tracking_resource_adaptor(Upstream mr, std::size_t size_alignment) : resource{mr}, size_align{size_alignment} { } @@ -106,8 +130,56 @@ class tracking_resource_adaptor final : public base_tracking_resource_adaptor { return scoped_max_total_allocated; } + void* allocate(std::size_t num_bytes, std::size_t) + { + auto result = resource.allocate(num_bytes, size_align, stream); + if (result) { bookkeep_allocation(num_bytes); } + return result; + } + + void deallocate(void* p, std::size_t size, std::size_t) + { + resource.deallocate(p, size, size_align, stream); + if (p) { bookkeep_deallocation(size); } + } + + /** + * Equality comparison method required to satisfy cuda::mr::resource concept + */ + friend bool operator==(const tracking_resource_adaptor& lhs, const tracking_resource_adaptor& rhs) + { + return (lhs.resource == rhs.resource); + } + + /** + * Equality comparison method required to satisfy cuda::mr::resource concept + */ + friend bool operator!=(const tracking_resource_adaptor& lhs, const tracking_resource_adaptor& rhs) + { + return !(lhs == rhs); + } + + /** + * Async allocation method required to satisfy cuda::mr::async_resource concept + */ + void* allocate_async(std::size_t num_bytes, std::size_t, rmm::cuda_stream_view stream) + { + auto result = resource.allocate_async(num_bytes, size_align, stream); + if (result) { bookkeep_allocation(num_bytes); } + return result; + } + + /** + * Async deallocation method required to satisfy cuda::mr::async_resource concept + */ + void deallocate_async(void* p, std::size_t size, std::size_t, rmm::cuda_stream_view stream) + { + resource.deallocate_async(p, size, size_align, stream); + if (p) { bookkeep_deallocation(size); } + } + private: - Upstream* const resource; + Upstream const resource; std::size_t const size_align; // sum of what is currently allocated std::atomic_size_t total_allocated{0}; @@ -125,37 +197,30 @@ class tracking_resource_adaptor final : public base_tracking_resource_adaptor { std::mutex max_total_allocated_mutex; - void* do_allocate(std::size_t num_bytes, rmm::cuda_stream_view stream) override + // adjust size of allocation based on specified size alignment + std::size_t adjust_for_alignment(std::size_t num_bytes) { - // adjust size of allocation based on specified size alignment - num_bytes = (num_bytes + size_align - 1) / size_align * size_align; - - auto result = resource->allocate(num_bytes, stream); - if (result) { - total_allocated += num_bytes; - scoped_allocated += num_bytes; - std::scoped_lock lock(max_total_allocated_mutex); - max_total_allocated = std::max(total_allocated.load(), max_total_allocated); - scoped_max_total_allocated = std::max(scoped_allocated.load(), scoped_max_total_allocated); - } - return result; + return (num_bytes + size_align - 1) / size_align * size_align; } - void do_deallocate(void* p, std::size_t size, rmm::cuda_stream_view stream) override + void bookkeep_allocation(std::size_t num_bytes) { - size = (size + size_align - 1) / size_align * size_align; - - resource->deallocate(p, size, stream); + total_allocated += num_bytes; + scoped_allocated += num_bytes; + std::scoped_lock lock(max_total_allocated_mutex); + max_total_allocated = std::max(total_allocated.load(), max_total_allocated); + scoped_max_total_allocated = std::max(scoped_allocated.load(), scoped_max_total_allocated); + } - if (p) { - total_allocated -= size; - scoped_allocated -= size; - } + void bookkeep_deallocation(std::size_t num_bytes) + { + total_allocated -= num_bytes; + scoped_allocated -= num_bytes; } }; template -tracking_resource_adaptor* make_tracking_adaptor(Upstream* upstream, +tracking_resource_adaptor* make_tracking_adaptor(Upstream upstream, std::size_t size_alignment) { return new tracking_resource_adaptor{upstream, size_alignment}; @@ -165,13 +230,13 @@ tracking_resource_adaptor* make_tracking_adaptor(Upstream* upstream, * @brief An RMM device memory resource adaptor that delegates to the wrapped resource * for most operations but will call Java to handle certain situations (e.g.: allocation failure). */ -class java_event_handler_memory_resource : public device_memory_resource { +class java_event_handler_memory_resource { public: java_event_handler_memory_resource(JNIEnv* env, jobject jhandler, jlongArray jalloc_thresholds, jlongArray jdealloc_thresholds, - device_memory_resource* resource_to_wrap, + device_async_resource_ref resource_to_wrap, base_tracking_resource_adaptor* tracker) : resource(resource_to_wrap), tracker(tracker) { @@ -216,10 +281,96 @@ class java_event_handler_memory_resource : public device_memory_resource { handler_obj = nullptr; } - device_memory_resource* get_wrapped_resource() { return resource; } + device_async_resource_ref get_wrapped_resource() { return resource; } + + void* allocate(std::size_t, std::size_t) + { + // Sync allocations not supported + return nullptr; + } + + void deallocate(void*, std::size_t, std::size_t) + { + // Sync deallocations not supported + } + + /** + * Equality comparison method required to satisfy cuda::mr::resource concept + */ + friend bool operator==(const java_event_handler_memory_resource& lhs, + const java_event_handler_memory_resource& rhs) + { + return (lhs.resource == rhs.resource); + } + + /** + * Equality comparison method required to satisfy cuda::mr::resource concept + */ + friend bool operator!=(const java_event_handler_memory_resource& lhs, + const java_event_handler_memory_resource& rhs) + { + return !(lhs == rhs); + } + + /** + * Async allocation method required to satisfy cuda::mr::async_resource concept + */ + void* allocate_async(std::size_t num_bytes, std::size_t alignment, rmm::cuda_stream_view stream) + { + std::size_t total_before; + void* result; + // a non-zero retry_count signifies that the `on_alloc_fail` + // callback is being invoked while re-attempting an allocation + // that had previously failed. + int retry_count = 0; + while (true) { + try { + total_before = tracker->get_total_allocated(); + result = resource.allocate(num_bytes, alignment, stream); + break; + } catch (rmm::out_of_memory const& e) { + if (!on_alloc_fail(num_bytes, retry_count++)) { throw; } + } + } + auto total_after = tracker->get_total_allocated(); + + try { + check_for_threshold_callback(total_before, + total_after, + alloc_thresholds, + on_alloc_threshold_method, + "onAllocThreshold", + total_after); + } catch (std::exception const& e) { + // Free the allocation as app will think the exception means the memory was not allocated. + resource.deallocate(result, num_bytes, alignment, stream); + throw; + } + + return result; + } + + /** + * Async deallocation method required to satisfy cuda::mr::async_resource concept + */ + void deallocate_async(void* p, + std::size_t size, + std::size_t alignment, + rmm::cuda_stream_view stream) + { + auto total_before = tracker->get_total_allocated(); + resource.deallocate(p, size, alignment, stream); + auto total_after = tracker->get_total_allocated(); + check_for_threshold_callback(total_after, + total_before, + dealloc_thresholds, + on_dealloc_threshold_method, + "onDeallocThreshold", + total_after); + } private: - device_memory_resource* const resource; + device_async_resource_ref const resource; base_tracking_resource_adaptor* const tracker; jmethodID on_alloc_fail_method; bool use_old_alloc_fail_interface; @@ -289,54 +440,6 @@ class java_event_handler_memory_resource : public device_memory_resource { protected: JavaVM* jvm; jobject handler_obj; - - void* do_allocate(std::size_t num_bytes, rmm::cuda_stream_view stream) override - { - std::size_t total_before; - void* result; - // a non-zero retry_count signifies that the `on_alloc_fail` - // callback is being invoked while re-attempting an allocation - // that had previously failed. - int retry_count = 0; - while (true) { - try { - total_before = tracker->get_total_allocated(); - result = resource->allocate(num_bytes, stream); - break; - } catch (rmm::out_of_memory const& e) { - if (!on_alloc_fail(num_bytes, retry_count++)) { throw; } - } - } - auto total_after = tracker->get_total_allocated(); - - try { - check_for_threshold_callback(total_before, - total_after, - alloc_thresholds, - on_alloc_threshold_method, - "onAllocThreshold", - total_after); - } catch (std::exception const& e) { - // Free the allocation as app will think the exception means the memory was not allocated. - resource->deallocate(result, num_bytes, stream); - throw; - } - - return result; - } - - void do_deallocate(void* p, std::size_t size, rmm::cuda_stream_view stream) override - { - auto total_before = tracker->get_total_allocated(); - resource->deallocate(p, size, stream); - auto total_after = tracker->get_total_allocated(); - check_for_threshold_callback(total_after, - total_before, - dealloc_thresholds, - on_dealloc_threshold_method, - "onDeallocThreshold", - total_after); - } }; class java_debug_event_handler_memory_resource final : public java_event_handler_memory_resource { @@ -345,7 +448,7 @@ class java_debug_event_handler_memory_resource final : public java_event_handler jobject jhandler, jlongArray jalloc_thresholds, jlongArray jdealloc_thresholds, - device_memory_resource* resource_to_wrap, + rmm::device_async_resource_ref resource_to_wrap, base_tracking_resource_adaptor* tracker) : java_event_handler_memory_resource( env, jhandler, jalloc_thresholds, jdealloc_thresholds, resource_to_wrap, tracker) @@ -730,9 +833,9 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Rmm_newPoolMemoryResource( JNI_NULL_CHECK(env, child, "child is null", 0); try { cudf::jni::auto_set_device(env); - auto wrapped = reinterpret_cast(child); + auto wrapped = reinterpret_cast(child); auto ret = - new rmm::mr::pool_memory_resource(wrapped, init, max); + new rmm::mr::pool_memory_resource(wrapped, init, max); return reinterpret_cast(ret); } CATCH_STD(env, 0) @@ -744,8 +847,7 @@ JNIEXPORT void JNICALL Java_ai_rapids_cudf_Rmm_releasePoolMemoryResource(JNIEnv* { try { cudf::jni::auto_set_device(env); - auto mr = - reinterpret_cast*>(ptr); + auto mr = reinterpret_cast*>(ptr); delete mr; } CATCH_STD(env, ) @@ -757,8 +859,8 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Rmm_newArenaMemoryResource( JNI_NULL_CHECK(env, child, "child is null", 0); try { cudf::jni::auto_set_device(env); - auto wrapped = reinterpret_cast(child); - auto ret = new rmm::mr::arena_memory_resource( + auto wrapped = reinterpret_cast(child); + auto ret = new rmm::mr::arena_memory_resource( wrapped, init, dump_on_oom); return reinterpret_cast(ret); } @@ -772,7 +874,7 @@ JNIEXPORT void JNICALL Java_ai_rapids_cudf_Rmm_releaseArenaMemoryResource(JNIEnv try { cudf::jni::auto_set_device(env); auto mr = - reinterpret_cast*>(ptr); + reinterpret_cast*>(ptr); delete mr; } CATCH_STD(env, ) @@ -809,9 +911,9 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Rmm_newLimitingResourceAdaptor( JNI_NULL_CHECK(env, child, "child is null", 0); try { cudf::jni::auto_set_device(env); - auto wrapped = reinterpret_cast(child); - auto ret = new rmm::mr::limiting_resource_adaptor( - wrapped, limit, align); + auto wrapped = reinterpret_cast(child); + auto ret = + new rmm::mr::limiting_resource_adaptor(wrapped, limit, align); return reinterpret_cast(ret); } CATCH_STD(env, 0) @@ -824,7 +926,7 @@ JNIEXPORT void JNICALL Java_ai_rapids_cudf_Rmm_releaseLimitingResourceAdaptor(JN try { cudf::jni::auto_set_device(env); auto mr = - reinterpret_cast*>(ptr); + reinterpret_cast*>(ptr); delete mr; } CATCH_STD(env, ) @@ -836,24 +938,24 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Rmm_newLoggingResourceAdaptor( JNI_NULL_CHECK(env, child, "child is null", 0); try { cudf::jni::auto_set_device(env); - auto wrapped = reinterpret_cast(child); + auto wrapped = reinterpret_cast(child); switch (type) { case 1: // File { cudf::jni::native_jstring path(env, jpath); - auto ret = new logging_resource_adaptor( + auto ret = new logging_resource_adaptor( wrapped, path.get(), auto_flush); return reinterpret_cast(ret); } case 2: // stdout { - auto ret = new logging_resource_adaptor( + auto ret = new logging_resource_adaptor( wrapped, std::cout, auto_flush); return reinterpret_cast(ret); } case 3: // stderr { - auto ret = new logging_resource_adaptor( + auto ret = new logging_resource_adaptor( wrapped, std::cerr, auto_flush); return reinterpret_cast(ret); } @@ -870,7 +972,7 @@ JNIEXPORT void JNICALL Java_ai_rapids_cudf_Rmm_releaseLoggingResourceAdaptor(JNI try { cudf::jni::auto_set_device(env); auto mr = - reinterpret_cast*>(ptr); + reinterpret_cast*>(ptr); delete mr; } CATCH_STD(env, ) @@ -884,8 +986,8 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Rmm_newTrackingResourceAdaptor(JNIEn JNI_NULL_CHECK(env, child, "child is null", 0); try { cudf::jni::auto_set_device(env); - auto wrapped = reinterpret_cast(child); - auto ret = new tracking_resource_adaptor(wrapped, align); + auto wrapped = reinterpret_cast(child); + auto ret = new tracking_resource_adaptor(wrapped, align); return reinterpret_cast(ret); } CATCH_STD(env, 0) @@ -897,7 +999,7 @@ JNIEXPORT void JNICALL Java_ai_rapids_cudf_Rmm_releaseTrackingResourceAdaptor(JN { try { cudf::jni::auto_set_device(env); - auto mr = reinterpret_cast*>(ptr); + auto mr = reinterpret_cast*>(ptr); delete mr; } CATCH_STD(env, ) @@ -910,7 +1012,7 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Rmm_nativeGetTotalBytesAllocated(JNI JNI_NULL_CHECK(env, ptr, "adaptor is null", 0); try { cudf::jni::auto_set_device(env); - auto mr = reinterpret_cast*>(ptr); + auto mr = reinterpret_cast*>(ptr); return mr->get_total_allocated(); } CATCH_STD(env, 0) @@ -923,7 +1025,7 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Rmm_nativeGetMaxTotalBytesAllocated( JNI_NULL_CHECK(env, ptr, "adaptor is null", 0); try { cudf::jni::auto_set_device(env); - auto mr = reinterpret_cast*>(ptr); + auto mr = reinterpret_cast*>(ptr); return mr->get_max_total_allocated(); } CATCH_STD(env, 0) @@ -937,7 +1039,7 @@ JNIEXPORT void JNICALL Java_ai_rapids_cudf_Rmm_nativeResetScopedMaxTotalBytesAll JNI_NULL_CHECK(env, ptr, "adaptor is null", ); try { cudf::jni::auto_set_device(env); - auto mr = reinterpret_cast*>(ptr); + auto mr = reinterpret_cast*>(ptr); mr->reset_scoped_max_total_allocated(init); } CATCH_STD(env, ) @@ -950,7 +1052,7 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Rmm_nativeGetScopedMaxTotalBytesAllo JNI_NULL_CHECK(env, ptr, "adaptor is null", 0); try { cudf::jni::auto_set_device(env); - auto mr = reinterpret_cast*>(ptr); + auto mr = reinterpret_cast*>(ptr); return mr->get_scoped_max_total_allocated(); } CATCH_STD(env, 0) @@ -969,8 +1071,8 @@ Java_ai_rapids_cudf_Rmm_newEventHandlerResourceAdaptor(JNIEnv* env, JNI_NULL_CHECK(env, child, "child is null", 0); JNI_NULL_CHECK(env, tracker, "tracker is null", 0); try { - auto wrapped = reinterpret_cast(child); - auto t = reinterpret_cast*>(tracker); + auto wrapped = reinterpret_cast(child); + auto t = reinterpret_cast*>(tracker); if (enable_debug) { auto ret = new java_debug_event_handler_memory_resource( env, handler_obj, jalloc_thresholds, jdealloc_thresholds, wrapped, t); @@ -1006,7 +1108,7 @@ JNIEXPORT void JNICALL Java_ai_rapids_cudf_Rmm_setCurrentDeviceResourceInternal( { try { cudf::jni::auto_set_device(env); - auto mr = reinterpret_cast(new_handle); + auto mr = reinterpret_cast(new_handle); rmm::mr::set_current_device_resource(mr); } CATCH_STD(env, ) diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index e411b1d5362..6a4b9df8287 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -49,7 +49,7 @@ #include #include -#include +#include #include @@ -4082,7 +4082,7 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_makeChunkedPack( // `temp_mr` is the memory resource that `cudf::chunked_pack` will use to create temporary // and scratch memory only. auto temp_mr = memoryResourceHandle != 0 - ? reinterpret_cast(memoryResourceHandle) + ? reinterpret_cast(memoryResourceHandle) : rmm::mr::get_current_device_resource(); auto chunked_pack = cudf::chunked_pack::create(*n_table, bounce_buffer_size, temp_mr); return reinterpret_cast(chunked_pack.release());