Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JNI] Adds retryCount to RmmEventHandler.onAllocFailure #11940

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions java/src/main/java/ai/rapids/cudf/RmmEventHandler.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -22,9 +22,28 @@ public interface RmmEventHandler {
/**
* Invoked on a memory allocation failure.
* @param sizeRequested number of bytes that failed to allocate
* @deprecated deprecated in favor of onAllocFailure(long, boolean)
* @return true if the memory allocation should be retried or false if it should fail
*/
boolean onAllocFailure(long sizeRequested);
default boolean onAllocFailure(long sizeRequested) {
// this should not be called since it was the previous interface,
// and it was abstract before, throwing by default for good measure.
throw new UnsupportedOperationException(
"Unexpected invocation of deprecated onAllocFailure without retry count.");
}

/**
* Invoked on a memory allocation failure.
* @param sizeRequested number of bytes that failed to allocate
* @param retryCount number of times this allocation has been retried after failure
* @return true if the memory allocation should be retried or false if it should fail
*/
default boolean onAllocFailure(long sizeRequested, int retryCount) {
// newer code should override this implementation of `onAllocFailure` to handle
// `retryCount`. Otherwise, we call the prior implementation to not
// break existing code.
return onAllocFailure(sizeRequested);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to just call the function above (which returns false all the time)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have code that implemented onAllocFailure (all existing code), no it just calls that impl. So this should be clearer given an exception in the deprecated code perhaps.

}

/**
* Get the memory thresholds that will trigger {@link #onAllocThreshold(long)}
Expand Down
31 changes: 25 additions & 6 deletions java/src/main/native/src/RmmJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,15 @@ class java_event_handler_memory_resource final : public device_memory_resource {
if (cls == nullptr) {
throw cudf::jni::jni_exception("class not found");
}
on_alloc_fail_method = env->GetMethodID(cls, "onAllocFailure", "(J)Z");
on_alloc_fail_method = env->GetMethodID(cls, "onAllocFailure", "(JI)Z");
if (on_alloc_fail_method == nullptr) {
throw cudf::jni::jni_exception("onAllocFailure method");
use_old_alloc_fail_interface = true;
on_alloc_fail_method = env->GetMethodID(cls, "onAllocFailure", "(J)Z");
if (on_alloc_fail_method == nullptr) {
throw cudf::jni::jni_exception("onAllocFailure method");
}
} else {
use_old_alloc_fail_interface = false;
}
on_alloc_threshold_method = env->GetMethodID(cls, "onAllocThreshold", "(J)V");
if (on_alloc_threshold_method == nullptr) {
Expand Down Expand Up @@ -190,6 +196,7 @@ class java_event_handler_memory_resource final : public device_memory_resource {
JavaVM *jvm;
jobject handler_obj;
jmethodID on_alloc_fail_method;
bool use_old_alloc_fail_interface;
jmethodID on_alloc_threshold_method;
jmethodID on_dealloc_threshold_method;

Expand All @@ -209,10 +216,18 @@ class java_event_handler_memory_resource final : public device_memory_resource {
}
}

bool on_alloc_fail(std::size_t num_bytes) {
bool on_alloc_fail(std::size_t num_bytes, int retry_count) {
JNIEnv *env = cudf::jni::get_jni_env(jvm);
jboolean result =
env->CallBooleanMethod(handler_obj, on_alloc_fail_method, static_cast<jlong>(num_bytes));
jboolean result = false;
if (!use_old_alloc_fail_interface) {
result =
env->CallBooleanMethod(handler_obj, on_alloc_fail_method, static_cast<jlong>(num_bytes),
static_cast<jint>(retry_count));

} else {
result =
env->CallBooleanMethod(handler_obj, on_alloc_fail_method, static_cast<jlong>(num_bytes));
}
if (env->ExceptionCheck()) {
throw std::runtime_error("onAllocFailure handler threw an exception");
}
Expand Down Expand Up @@ -240,13 +255,17 @@ class java_event_handler_memory_resource final : public device_memory_resource {
void *do_allocate(std::size_t num_bytes, rmm::cuda_stream_view stream) override {
std::size_t total_before;
void *result;
// a positive retry_count signifies that the `on_alloc_fail`
abellina marked this conversation as resolved.
Show resolved Hide resolved
// callback is being invoked while re-attempting an allocation
// that had previously failed.
int retry_count = 0;
while (true) {
try {
total_before = get_total_bytes_allocated();
result = resource->allocate(num_bytes, stream);
break;
} catch (rmm::out_of_memory const &e) {
if (!on_alloc_fail(num_bytes)) {
if (!on_alloc_fail(num_bytes, retry_count++)) {
throw;
}
}
Expand Down
17 changes: 10 additions & 7 deletions java/src/test/java/ai/rapids/cudf/RmmTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,13 @@ public void testTotalAllocated(int rmmAllocMode) {
public void testEventHandler(int rmmAllocMode) {
AtomicInteger invokedCount = new AtomicInteger();
AtomicLong amountRequested = new AtomicLong();
AtomicInteger timesRetried = new AtomicInteger();

RmmEventHandler handler = new BaseRmmEventHandler() {
@Override
public boolean onAllocFailure(long sizeRequested) {
public boolean onAllocFailure(long sizeRequested, int retryCount) {
int count = invokedCount.incrementAndGet();
timesRetried.set(retryCount);
amountRequested.set(sizeRequested);
return count != 3;
}
Expand All @@ -100,6 +102,7 @@ public boolean onAllocFailure(long sizeRequested) {
}

assertEquals(3, invokedCount.get());
assertEquals(2, timesRetried.get());
assertEquals(requested, amountRequested.get());

// verify after a failure we can still allocate something more reasonable
Expand All @@ -114,15 +117,15 @@ public void testSetEventHandlerTwice() {
// installing an event handler the first time should not be an error
Rmm.setEventHandler(new BaseRmmEventHandler() {
@Override
public boolean onAllocFailure(long sizeRequested) {
public boolean onAllocFailure(long sizeRequested, int retryCount) {
return false;
}
});

// installing a second event handler is an error
RmmEventHandler otherHandler = new BaseRmmEventHandler() {
@Override
public boolean onAllocFailure(long sizeRequested) {
public boolean onAllocFailure(long sizeRequested, int retryCount) {
return true;
}
};
Expand All @@ -138,7 +141,7 @@ public void testClearEventHandler() {
// create an event handler that will always retry
RmmEventHandler retryHandler = new BaseRmmEventHandler() {
@Override
public boolean onAllocFailure(long sizeRequested) {
public boolean onAllocFailure(long sizeRequested, int retryCount) {
return true;
}
};
Expand All @@ -165,7 +168,7 @@ public void testAllocOnlyThresholds() {

RmmEventHandler handler = new RmmEventHandler() {
@Override
public boolean onAllocFailure(long sizeRequested) {
public boolean onAllocFailure(long sizeRequested, int retryCount) {
return false;
}

Expand Down Expand Up @@ -228,7 +231,7 @@ public void testThresholds() {

RmmEventHandler handler = new RmmEventHandler() {
@Override
public boolean onAllocFailure(long sizeRequested) {
public boolean onAllocFailure(long sizeRequested, int retryCount) {
return false;
}

Expand Down Expand Up @@ -308,7 +311,7 @@ public void testExceptionHandling() {

RmmEventHandler handler = new RmmEventHandler() {
@Override
public boolean onAllocFailure(long sizeRequested) {
public boolean onAllocFailure(long sizeRequested, int retryCount) {
throw new AllocFailException();
}

Expand Down