diff --git a/java/src/main/java/ai/rapids/cudf/MemoryBuffer.java b/java/src/main/java/ai/rapids/cudf/MemoryBuffer.java index 9f0d9a451c0..e6b3994235d 100644 --- a/java/src/main/java/ai/rapids/cudf/MemoryBuffer.java +++ b/java/src/main/java/ai/rapids/cudf/MemoryBuffer.java @@ -1,6 +1,6 @@ /* * - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,6 +28,23 @@ * subclassing beyond what is included in CUDF is not recommended and not supported. */ abstract public class MemoryBuffer implements AutoCloseable { + /** + * Interface to handle events for this MemoryBuffer. Only invoked during + * close, hence `onClosed` is the only event. + */ + public interface EventHandler { + /** + * `onClosed` is invoked with the updated `refCount` during `close`. + * The last invocation of `onClosed` will be with `refCount=0`. + * + * @note the callback is invoked with this `MemoryBuffer`'s lock held. + * + * @param refCount - the updated ref count for this MemoryBuffer at the time + * of invocation + */ + void onClosed(int refCount); + } + private static final Logger log = LoggerFactory.getLogger(MemoryBuffer.class); protected final long address; protected final long length; @@ -36,6 +53,8 @@ abstract public class MemoryBuffer implements AutoCloseable { protected final MemoryBufferCleaner cleaner; protected final long id; + private EventHandler eventHandler; + public static abstract class MemoryBufferCleaner extends MemoryCleaner.Cleaner{} private static final class SlicedBufferCleaner extends MemoryBufferCleaner { @@ -193,6 +212,27 @@ public final void copyFromMemoryBufferAsync( */ public abstract MemoryBuffer slice(long offset, long len); + /** + * Set an event handler for this buffer. This method can be invoked with null + * to unset the handler. + * + * @param newHandler - the EventHandler to use from this point forward + * @return the prior event handler, or null if not set. + */ + public synchronized EventHandler setEventHandler(EventHandler newHandler) { + EventHandler prev = this.eventHandler; + this.eventHandler = newHandler; + return prev; + } + + /** + * Returns the current event handler for this buffer or null if no handler + * is associated or this buffer is closed. + */ + public synchronized EventHandler getEventHandler() { + return this.eventHandler; + } + /** * Close this buffer and free memory */ @@ -200,6 +240,9 @@ public synchronized void close() { if (cleaner != null) { refCount--; cleaner.delRef(); + if (eventHandler != null) { + eventHandler.onClosed(refCount); + } if (refCount == 0) { cleaner.clean(false); closed = true; @@ -232,8 +275,10 @@ public synchronized void incRefCount() { cleaner.addRef(); } - // visible for testing - synchronized int getRefCount() { + /** + * Get the current reference count for this buffer. + */ + public synchronized int getRefCount() { return refCount; } } diff --git a/java/src/test/java/ai/rapids/cudf/MemoryBufferTest.java b/java/src/test/java/ai/rapids/cudf/MemoryBufferTest.java index df710c71f63..c332ce660d1 100644 --- a/java/src/test/java/ai/rapids/cudf/MemoryBufferTest.java +++ b/java/src/test/java/ai/rapids/cudf/MemoryBufferTest.java @@ -1,6 +1,6 @@ /* * - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,8 @@ import org.junit.jupiter.api.Test; +import java.util.concurrent.atomic.AtomicInteger; + import static org.junit.jupiter.api.Assertions.*; public class MemoryBufferTest extends CudfTestBase { @@ -168,4 +170,49 @@ private void verifyOutput(HostMemoryBuffer out) { out.getBytes(bytes, 0, 0, 16); assertArrayEquals(EXPECTED, bytes); } + + @Test + public void testEventHandlerIsCalledForEachClose() { + final AtomicInteger onClosedWasCalled = new AtomicInteger(0); + try (DeviceMemoryBuffer b = DeviceMemoryBuffer.allocate(256)) { + b.setEventHandler(refCount -> onClosedWasCalled.incrementAndGet()); + } + assertEquals(1, onClosedWasCalled.get()); + onClosedWasCalled.set(0); + + try (DeviceMemoryBuffer b = DeviceMemoryBuffer.allocate(256)) { + b.setEventHandler(refCount -> onClosedWasCalled.incrementAndGet()); + DeviceMemoryBuffer sliced = b.slice(0, b.getLength()); + sliced.close(); + } + assertEquals(2, onClosedWasCalled.get()); + } + + @Test + public void testEventHandlerIsNotCalledIfNotSet() { + final AtomicInteger onClosedWasCalled = new AtomicInteger(0); + try (DeviceMemoryBuffer b = DeviceMemoryBuffer.allocate(256)) { + assertNull(b.getEventHandler()); + } + assertEquals(0, onClosedWasCalled.get()); + try (DeviceMemoryBuffer b = DeviceMemoryBuffer.allocate(256)) { + b.setEventHandler(refCount -> onClosedWasCalled.incrementAndGet()); + b.setEventHandler(null); + } + assertEquals(0, onClosedWasCalled.get()); + } + + @Test + public void testEventHandlerReturnsPreviousHandlerOnReset() { + try (DeviceMemoryBuffer b = DeviceMemoryBuffer.allocate(256)) { + MemoryBuffer.EventHandler handler = refCount -> {}; + MemoryBuffer.EventHandler handler2 = refCount -> {}; + + assertNull(b.setEventHandler(handler)); + assertEquals(handler, b.setEventHandler(null)); + + assertNull(b.setEventHandler(handler2)); + assertEquals(handler2, b.setEventHandler(handler)); + } + } }