diff --git a/java/src/main/java/ai/rapids/cudf/MemoryBuffer.java b/java/src/main/java/ai/rapids/cudf/MemoryBuffer.java index 75c7a2dc22c..e6b3994235d 100644 --- a/java/src/main/java/ai/rapids/cudf/MemoryBuffer.java +++ b/java/src/main/java/ai/rapids/cudf/MemoryBuffer.java @@ -36,6 +36,9 @@ public interface EventHandler { /** * `onClosed` is invoked with the updated `refCount` during `close`. * The last invocation of `onClosed` will be with `refCount=0`. + * + * @note the callback is invoked with this `MemoryBuffer`'s lock held. + * * @param refCount - the updated ref count for this MemoryBuffer at the time * of invocation */ @@ -212,12 +215,14 @@ public final void copyFromMemoryBufferAsync( /** * Set an event handler for this buffer. This method can be invoked with null * to unset the handler. + * + * @param newHandler - the EventHandler to use from this point forward + * @return the prior event handler, or null if not set. */ - public synchronized void setEventHandler(EventHandler handler) { - if (this.eventHandler != null && handler != null) { - throw new IllegalStateException("EventHandler is already set for this buffer"); - } - this.eventHandler = handler; + public synchronized EventHandler setEventHandler(EventHandler newHandler) { + EventHandler prev = this.eventHandler; + this.eventHandler = newHandler; + return prev; } /** diff --git a/java/src/test/java/ai/rapids/cudf/MemoryBufferTest.java b/java/src/test/java/ai/rapids/cudf/MemoryBufferTest.java index f7bd70a0c8f..c332ce660d1 100644 --- a/java/src/test/java/ai/rapids/cudf/MemoryBufferTest.java +++ b/java/src/test/java/ai/rapids/cudf/MemoryBufferTest.java @@ -203,14 +203,16 @@ public void testEventHandlerIsNotCalledIfNotSet() { } @Test - public void testEventHandlerDisallowsResetting() { + public void testEventHandlerReturnsPreviousHandlerOnReset() { try (DeviceMemoryBuffer b = DeviceMemoryBuffer.allocate(256)) { - b.setEventHandler(refCount -> {}); - b.setEventHandler(null); // ok - unsets it + MemoryBuffer.EventHandler handler = refCount -> {}; + MemoryBuffer.EventHandler handler2 = refCount -> {}; - b.setEventHandler(refCount -> {}); // ok - resets it because it was null before - // we cannot reset the handler without having set it to null first - assertThrows(IllegalStateException.class, () -> b.setEventHandler(refCount -> {})); + assertNull(b.setEventHandler(handler)); + assertEquals(handler, b.setEventHandler(null)); + + assertNull(b.setEventHandler(handler2)); + assertEquals(handler2, b.setEventHandler(handler)); } } }