Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JNI] Adds an EventHandler to Java MemoryBuffer to be invoked on close #12125

Merged
merged 2 commits into from
Nov 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions java/src/main/java/ai/rapids/cudf/MemoryBuffer.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,6 +28,23 @@
* subclassing beyond what is included in CUDF is not recommended and not supported.
*/
abstract public class MemoryBuffer implements AutoCloseable {
/**
* Interface to handle events for this MemoryBuffer. Only invoked during
* close, hence `onClosed` is the only event.
*/
public interface EventHandler {
/**
* `onClosed` is invoked with the updated `refCount` during `close`.
* The last invocation of `onClosed` will be with `refCount=0`.
*
* @note the callback is invoked with this `MemoryBuffer`'s lock held.
*
* @param refCount - the updated ref count for this MemoryBuffer at the time
* of invocation
*/
void onClosed(int refCount);
}

private static final Logger log = LoggerFactory.getLogger(MemoryBuffer.class);
protected final long address;
protected final long length;
Expand All @@ -36,6 +53,8 @@ abstract public class MemoryBuffer implements AutoCloseable {
protected final MemoryBufferCleaner cleaner;
protected final long id;

private EventHandler eventHandler;

public static abstract class MemoryBufferCleaner extends MemoryCleaner.Cleaner{}

private static final class SlicedBufferCleaner extends MemoryBufferCleaner {
Expand Down Expand Up @@ -193,13 +212,37 @@ public final void copyFromMemoryBufferAsync(
*/
public abstract MemoryBuffer slice(long offset, long len);

/**
* Set an event handler for this buffer. This method can be invoked with null
* to unset the handler.
*
* @param newHandler - the EventHandler to use from this point forward
* @return the prior event handler, or null if not set.
*/
public synchronized EventHandler setEventHandler(EventHandler newHandler) {
EventHandler prev = this.eventHandler;
this.eventHandler = newHandler;
return prev;
}

/**
* Returns the current event handler for this buffer or null if no handler
* is associated or this buffer is closed.
*/
public synchronized EventHandler getEventHandler() {
return this.eventHandler;
}

/**
* Close this buffer and free memory
*/
public synchronized void close() {
if (cleaner != null) {
refCount--;
cleaner.delRef();
if (eventHandler != null) {
eventHandler.onClosed(refCount);
}
if (refCount == 0) {
cleaner.clean(false);
closed = true;
Expand Down Expand Up @@ -232,8 +275,10 @@ public synchronized void incRefCount() {
cleaner.addRef();
}

// visible for testing
synchronized int getRefCount() {
/**
* Get the current reference count for this buffer.
*/
public synchronized int getRefCount() {
return refCount;
}
}
49 changes: 48 additions & 1 deletion java/src/test/java/ai/rapids/cudf/MemoryBufferTest.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,6 +20,8 @@

import org.junit.jupiter.api.Test;

import java.util.concurrent.atomic.AtomicInteger;

import static org.junit.jupiter.api.Assertions.*;

public class MemoryBufferTest extends CudfTestBase {
Expand Down Expand Up @@ -168,4 +170,49 @@ private void verifyOutput(HostMemoryBuffer out) {
out.getBytes(bytes, 0, 0, 16);
assertArrayEquals(EXPECTED, bytes);
}

@Test
public void testEventHandlerIsCalledForEachClose() {
final AtomicInteger onClosedWasCalled = new AtomicInteger(0);
try (DeviceMemoryBuffer b = DeviceMemoryBuffer.allocate(256)) {
b.setEventHandler(refCount -> onClosedWasCalled.incrementAndGet());
}
assertEquals(1, onClosedWasCalled.get());
onClosedWasCalled.set(0);

try (DeviceMemoryBuffer b = DeviceMemoryBuffer.allocate(256)) {
b.setEventHandler(refCount -> onClosedWasCalled.incrementAndGet());
DeviceMemoryBuffer sliced = b.slice(0, b.getLength());
sliced.close();
}
assertEquals(2, onClosedWasCalled.get());
}

@Test
public void testEventHandlerIsNotCalledIfNotSet() {
final AtomicInteger onClosedWasCalled = new AtomicInteger(0);
try (DeviceMemoryBuffer b = DeviceMemoryBuffer.allocate(256)) {
assertNull(b.getEventHandler());
}
assertEquals(0, onClosedWasCalled.get());
try (DeviceMemoryBuffer b = DeviceMemoryBuffer.allocate(256)) {
b.setEventHandler(refCount -> onClosedWasCalled.incrementAndGet());
b.setEventHandler(null);
}
assertEquals(0, onClosedWasCalled.get());
}

@Test
public void testEventHandlerReturnsPreviousHandlerOnReset() {
try (DeviceMemoryBuffer b = DeviceMemoryBuffer.allocate(256)) {
MemoryBuffer.EventHandler handler = refCount -> {};
MemoryBuffer.EventHandler handler2 = refCount -> {};

assertNull(b.setEventHandler(handler));
assertEquals(handler, b.setEventHandler(null));

assertNull(b.setEventHandler(handler2));
assertEquals(handler2, b.setEventHandler(handler));
}
}
}