From ef529dea09aad6d16a9bc0f2486468eab3091d09 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Tue, 18 Oct 2022 10:50:37 -0500 Subject: [PATCH 1/4] Adds isRetry to RmmEventHandler.onAllocFailure --- .../java/ai/rapids/cudf/RmmEventHandler.java | 23 ++++++++++++-- java/src/main/native/src/RmmJni.cpp | 30 +++++++++++++++---- .../src/test/java/ai/rapids/cudf/RmmTest.java | 7 ++++- 3 files changed, 51 insertions(+), 9 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/RmmEventHandler.java b/java/src/main/java/ai/rapids/cudf/RmmEventHandler.java index 85442402403..9abc0a52653 100644 --- a/java/src/main/java/ai/rapids/cudf/RmmEventHandler.java +++ b/java/src/main/java/ai/rapids/cudf/RmmEventHandler.java @@ -1,6 +1,6 @@ /* * - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,9 +22,28 @@ public interface RmmEventHandler { /** * Invoked on a memory allocation failure. * @param sizeRequested number of bytes that failed to allocate + * @deprecated deprecated in favor of onAllocFailure(long, boolean) * @return true if the memory allocation should be retried or false if it should fail */ - boolean onAllocFailure(long sizeRequested); + default boolean onAllocFailure(long sizeRequested) { + // this should not be called since it was the previous interface, + // and it was abstract before. + return false; + } + + /** + * Invoked on a memory allocation failure. + * @param sizeRequested number of bytes that failed to allocate + * @param isRetry whether this failure happened while retrying an allocation + * that had previously failed + * @return true if the memory allocation should be retried or false if it should fail + */ + default boolean onAllocFailure(long sizeRequested, boolean isRetry) { + // newer code should override this implementation of `onAllocFailure` to handle + // the `isRetry` flag. Otherwise, we call the prior implementation to not + // break existing code. + return onAllocFailure(sizeRequested); + } /** * Get the memory thresholds that will trigger {@link #onAllocThreshold(long)} diff --git a/java/src/main/native/src/RmmJni.cpp b/java/src/main/native/src/RmmJni.cpp index ce3e6ffb285..e2020833c20 100644 --- a/java/src/main/native/src/RmmJni.cpp +++ b/java/src/main/native/src/RmmJni.cpp @@ -150,9 +150,15 @@ class java_event_handler_memory_resource final : public device_memory_resource { if (cls == nullptr) { throw cudf::jni::jni_exception("class not found"); } - on_alloc_fail_method = env->GetMethodID(cls, "onAllocFailure", "(J)Z"); + on_alloc_fail_method = env->GetMethodID(cls, "onAllocFailure", "(JZ)Z"); if (on_alloc_fail_method == nullptr) { - throw cudf::jni::jni_exception("onAllocFailure method"); + use_old_alloc_fail_interface = true; + on_alloc_fail_method = env->GetMethodID(cls, "onAllocFailure", "(J)Z"); + if (on_alloc_fail_method == nullptr) { + throw cudf::jni::jni_exception("onAllocFailure method"); + } + } else { + use_old_alloc_fail_interface = false; } on_alloc_threshold_method = env->GetMethodID(cls, "onAllocThreshold", "(J)V"); if (on_alloc_threshold_method == nullptr) { @@ -190,6 +196,7 @@ class java_event_handler_memory_resource final : public device_memory_resource { JavaVM *jvm; jobject handler_obj; jmethodID on_alloc_fail_method; + bool use_old_alloc_fail_interface; jmethodID on_alloc_threshold_method; jmethodID on_dealloc_threshold_method; @@ -209,10 +216,17 @@ class java_event_handler_memory_resource final : public device_memory_resource { } } - bool on_alloc_fail(std::size_t num_bytes) { + bool on_alloc_fail(std::size_t num_bytes, bool is_retry) { JNIEnv *env = cudf::jni::get_jni_env(jvm); - jboolean result = - env->CallBooleanMethod(handler_obj, on_alloc_fail_method, static_cast(num_bytes)); + jboolean result = false; + if (!use_old_alloc_fail_interface) { + result = env->CallBooleanMethod(handler_obj, on_alloc_fail_method, + static_cast(num_bytes), static_cast(is_retry)); + + } else { + result = env->CallBooleanMethod(handler_obj, on_alloc_fail_method, + static_cast(num_bytes)); + } if (env->ExceptionCheck()) { throw std::runtime_error("onAllocFailure handler threw an exception"); } @@ -240,15 +254,19 @@ class java_event_handler_memory_resource final : public device_memory_resource { void *do_allocate(std::size_t num_bytes, rmm::cuda_stream_view stream) override { std::size_t total_before; void *result; + bool is_retry = false; while (true) { try { total_before = get_total_bytes_allocated(); result = resource->allocate(num_bytes, stream); break; } catch (rmm::out_of_memory const &e) { - if (!on_alloc_fail(num_bytes)) { + if (!on_alloc_fail(num_bytes, is_retry)) { throw; } + // tells the handling code that this failure was on a + // retry + is_retry = true; } } auto total_after = get_total_bytes_allocated(); diff --git a/java/src/test/java/ai/rapids/cudf/RmmTest.java b/java/src/test/java/ai/rapids/cudf/RmmTest.java index c56b131de86..6258f746af7 100644 --- a/java/src/test/java/ai/rapids/cudf/RmmTest.java +++ b/java/src/test/java/ai/rapids/cudf/RmmTest.java @@ -73,11 +73,15 @@ public void testTotalAllocated(int rmmAllocMode) { public void testEventHandler(int rmmAllocMode) { AtomicInteger invokedCount = new AtomicInteger(); AtomicLong amountRequested = new AtomicLong(); + AtomicInteger amountRetried = new AtomicInteger(); RmmEventHandler handler = new BaseRmmEventHandler() { @Override - public boolean onAllocFailure(long sizeRequested) { + public boolean onAllocFailure(long sizeRequested, boolean isRetry) { int count = invokedCount.incrementAndGet(); + if (isRetry) { + amountRetried.getAndIncrement(); + } amountRequested.set(sizeRequested); return count != 3; } @@ -100,6 +104,7 @@ public boolean onAllocFailure(long sizeRequested) { } assertEquals(3, invokedCount.get()); + assertEquals(2, amountRetried.get()); assertEquals(requested, amountRequested.get()); // verify after a failure we can still allocate something more reasonable From 9407a67af5e56daf2dca9c9e6acddf5e960d368c Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Tue, 18 Oct 2022 13:24:45 -0500 Subject: [PATCH 2/4] Address review comments --- .../java/ai/rapids/cudf/RmmEventHandler.java | 12 +++++----- java/src/main/native/src/RmmJni.cpp | 16 +++++++------- .../src/test/java/ai/rapids/cudf/RmmTest.java | 22 +++++++++---------- 3 files changed, 24 insertions(+), 26 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/RmmEventHandler.java b/java/src/main/java/ai/rapids/cudf/RmmEventHandler.java index 9abc0a52653..19707b85bcb 100644 --- a/java/src/main/java/ai/rapids/cudf/RmmEventHandler.java +++ b/java/src/main/java/ai/rapids/cudf/RmmEventHandler.java @@ -27,20 +27,20 @@ public interface RmmEventHandler { */ default boolean onAllocFailure(long sizeRequested) { // this should not be called since it was the previous interface, - // and it was abstract before. - return false; + // and it was abstract before, throwing by default for good measure. + throw new UnsupportedOperationException( + "Unexpected invocation of deprecated onAllocFailure without retry count."); } /** * Invoked on a memory allocation failure. * @param sizeRequested number of bytes that failed to allocate - * @param isRetry whether this failure happened while retrying an allocation - * that had previously failed + * @param retryCount number of times this allocation has been retried after failure * @return true if the memory allocation should be retried or false if it should fail */ - default boolean onAllocFailure(long sizeRequested, boolean isRetry) { + default boolean onAllocFailure(long sizeRequested, int retryCount) { // newer code should override this implementation of `onAllocFailure` to handle - // the `isRetry` flag. Otherwise, we call the prior implementation to not + // `retryCount`. Otherwise, we call the prior implementation to not // break existing code. return onAllocFailure(sizeRequested); } diff --git a/java/src/main/native/src/RmmJni.cpp b/java/src/main/native/src/RmmJni.cpp index e2020833c20..1d6590c8316 100644 --- a/java/src/main/native/src/RmmJni.cpp +++ b/java/src/main/native/src/RmmJni.cpp @@ -150,7 +150,7 @@ class java_event_handler_memory_resource final : public device_memory_resource { if (cls == nullptr) { throw cudf::jni::jni_exception("class not found"); } - on_alloc_fail_method = env->GetMethodID(cls, "onAllocFailure", "(JZ)Z"); + on_alloc_fail_method = env->GetMethodID(cls, "onAllocFailure", "(JI)Z"); if (on_alloc_fail_method == nullptr) { use_old_alloc_fail_interface = true; on_alloc_fail_method = env->GetMethodID(cls, "onAllocFailure", "(J)Z"); @@ -216,12 +216,12 @@ class java_event_handler_memory_resource final : public device_memory_resource { } } - bool on_alloc_fail(std::size_t num_bytes, bool is_retry) { + bool on_alloc_fail(std::size_t num_bytes, int retry_count) { JNIEnv *env = cudf::jni::get_jni_env(jvm); jboolean result = false; if (!use_old_alloc_fail_interface) { result = env->CallBooleanMethod(handler_obj, on_alloc_fail_method, - static_cast(num_bytes), static_cast(is_retry)); + static_cast(num_bytes), static_cast(retry_count)); } else { result = env->CallBooleanMethod(handler_obj, on_alloc_fail_method, @@ -254,19 +254,19 @@ class java_event_handler_memory_resource final : public device_memory_resource { void *do_allocate(std::size_t num_bytes, rmm::cuda_stream_view stream) override { std::size_t total_before; void *result; - bool is_retry = false; + // a positive retry_count signifies that the `on_alloc_fail` + // callback is being invoked while re-attempting an allocation + // that had previously failed. + int retry_count = 0; while (true) { try { total_before = get_total_bytes_allocated(); result = resource->allocate(num_bytes, stream); break; } catch (rmm::out_of_memory const &e) { - if (!on_alloc_fail(num_bytes, is_retry)) { + if (!on_alloc_fail(num_bytes, retry_count++)) { throw; } - // tells the handling code that this failure was on a - // retry - is_retry = true; } } auto total_after = get_total_bytes_allocated(); diff --git a/java/src/test/java/ai/rapids/cudf/RmmTest.java b/java/src/test/java/ai/rapids/cudf/RmmTest.java index 6258f746af7..09fbedd8a1c 100644 --- a/java/src/test/java/ai/rapids/cudf/RmmTest.java +++ b/java/src/test/java/ai/rapids/cudf/RmmTest.java @@ -73,15 +73,13 @@ public void testTotalAllocated(int rmmAllocMode) { public void testEventHandler(int rmmAllocMode) { AtomicInteger invokedCount = new AtomicInteger(); AtomicLong amountRequested = new AtomicLong(); - AtomicInteger amountRetried = new AtomicInteger(); + AtomicInteger timesRetried = new AtomicInteger(); RmmEventHandler handler = new BaseRmmEventHandler() { @Override - public boolean onAllocFailure(long sizeRequested, boolean isRetry) { + public boolean onAllocFailure(long sizeRequested, int retryCount) { int count = invokedCount.incrementAndGet(); - if (isRetry) { - amountRetried.getAndIncrement(); - } + timesRetried.set(retryCount); amountRequested.set(sizeRequested); return count != 3; } @@ -104,7 +102,7 @@ public boolean onAllocFailure(long sizeRequested, boolean isRetry) { } assertEquals(3, invokedCount.get()); - assertEquals(2, amountRetried.get()); + assertEquals(2, timesRetried.get()); assertEquals(requested, amountRequested.get()); // verify after a failure we can still allocate something more reasonable @@ -119,7 +117,7 @@ public void testSetEventHandlerTwice() { // installing an event handler the first time should not be an error Rmm.setEventHandler(new BaseRmmEventHandler() { @Override - public boolean onAllocFailure(long sizeRequested) { + public boolean onAllocFailure(long sizeRequested, int retryCount) { return false; } }); @@ -127,7 +125,7 @@ public boolean onAllocFailure(long sizeRequested) { // installing a second event handler is an error RmmEventHandler otherHandler = new BaseRmmEventHandler() { @Override - public boolean onAllocFailure(long sizeRequested) { + public boolean onAllocFailure(long sizeRequested, int retryCount) { return true; } }; @@ -143,7 +141,7 @@ public void testClearEventHandler() { // create an event handler that will always retry RmmEventHandler retryHandler = new BaseRmmEventHandler() { @Override - public boolean onAllocFailure(long sizeRequested) { + public boolean onAllocFailure(long sizeRequested, int retryCount) { return true; } }; @@ -170,7 +168,7 @@ public void testAllocOnlyThresholds() { RmmEventHandler handler = new RmmEventHandler() { @Override - public boolean onAllocFailure(long sizeRequested) { + public boolean onAllocFailure(long sizeRequested, int retryCount) { return false; } @@ -233,7 +231,7 @@ public void testThresholds() { RmmEventHandler handler = new RmmEventHandler() { @Override - public boolean onAllocFailure(long sizeRequested) { + public boolean onAllocFailure(long sizeRequested, int retryCount) { return false; } @@ -313,7 +311,7 @@ public void testExceptionHandling() { RmmEventHandler handler = new RmmEventHandler() { @Override - public boolean onAllocFailure(long sizeRequested) { + public boolean onAllocFailure(long sizeRequested, int retryCount) { throw new AllocFailException(); } From 777d04629f002b7068e237a713bebe1cbcbcbdcc Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Tue, 18 Oct 2022 13:30:23 -0500 Subject: [PATCH 3/4] Fix code style --- java/src/main/native/src/RmmJni.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/java/src/main/native/src/RmmJni.cpp b/java/src/main/native/src/RmmJni.cpp index 1d6590c8316..361c17db92c 100644 --- a/java/src/main/native/src/RmmJni.cpp +++ b/java/src/main/native/src/RmmJni.cpp @@ -220,12 +220,13 @@ class java_event_handler_memory_resource final : public device_memory_resource { JNIEnv *env = cudf::jni::get_jni_env(jvm); jboolean result = false; if (!use_old_alloc_fail_interface) { - result = env->CallBooleanMethod(handler_obj, on_alloc_fail_method, - static_cast(num_bytes), static_cast(retry_count)); + result = + env->CallBooleanMethod(handler_obj, on_alloc_fail_method, static_cast(num_bytes), + static_cast(retry_count)); } else { - result = env->CallBooleanMethod(handler_obj, on_alloc_fail_method, - static_cast(num_bytes)); + result = + env->CallBooleanMethod(handler_obj, on_alloc_fail_method, static_cast(num_bytes)); } if (env->ExceptionCheck()) { throw std::runtime_error("onAllocFailure handler threw an exception"); From 40866b7c429f36dc1726e969e7a074a095c2047a Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Tue, 18 Oct 2022 13:52:49 -0500 Subject: [PATCH 4/4] Fix nit: positive->non-zero Co-authored-by: Jason Lowe --- java/src/main/native/src/RmmJni.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/java/src/main/native/src/RmmJni.cpp b/java/src/main/native/src/RmmJni.cpp index 361c17db92c..2b4c5ae59f5 100644 --- a/java/src/main/native/src/RmmJni.cpp +++ b/java/src/main/native/src/RmmJni.cpp @@ -255,7 +255,7 @@ class java_event_handler_memory_resource final : public device_memory_resource { void *do_allocate(std::size_t num_bytes, rmm::cuda_stream_view stream) override { std::size_t total_before; void *result; - // a positive retry_count signifies that the `on_alloc_fail` + // a non-zero retry_count signifies that the `on_alloc_fail` // callback is being invoked while re-attempting an allocation // that had previously failed. int retry_count = 0;