Skip to content

Commit

Permalink
Adds an EventHandler to Java MemoryBuffer to be invoked on close (#12125
Browse files Browse the repository at this point in the history
)

This PR adds an EventHandler to `MemoryBuffer` with a single method `onClosed`. This is invoked during the `close` call, but after the `refCount` has been updated.

I am also making `getRefCount` public in this PR. Spill code in the RAPIDS Accelerator for Spark could likely assert/require that refCount==1 when taking in a new buffer to be spillable. This last change is a nice to have.

Authors:
  - Alessandro Bellina (https://github.com/abellina)

Approvers:
  - Robert (Bobby) Evans (https://github.com/revans2)
  - Jim Brennan (https://github.com/jbrennan333)
  - Jason Lowe (https://github.com/jlowe)

URL: #12125
  • Loading branch information
abellina authored Nov 11, 2022
1 parent d335aa3 commit 8668752
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 4 deletions.
51 changes: 48 additions & 3 deletions java/src/main/java/ai/rapids/cudf/MemoryBuffer.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,6 +28,23 @@
* subclassing beyond what is included in CUDF is not recommended and not supported.
*/
abstract public class MemoryBuffer implements AutoCloseable {
/**
* Interface to handle events for this MemoryBuffer. Only invoked during
* close, hence `onClosed` is the only event.
*/
public interface EventHandler {
/**
* `onClosed` is invoked with the updated `refCount` during `close`.
* The last invocation of `onClosed` will be with `refCount=0`.
*
* @note the callback is invoked with this `MemoryBuffer`'s lock held.
*
* @param refCount - the updated ref count for this MemoryBuffer at the time
* of invocation
*/
void onClosed(int refCount);
}

private static final Logger log = LoggerFactory.getLogger(MemoryBuffer.class);
protected final long address;
protected final long length;
Expand All @@ -36,6 +53,8 @@ abstract public class MemoryBuffer implements AutoCloseable {
protected final MemoryBufferCleaner cleaner;
protected final long id;

private EventHandler eventHandler;

public static abstract class MemoryBufferCleaner extends MemoryCleaner.Cleaner{}

private static final class SlicedBufferCleaner extends MemoryBufferCleaner {
Expand Down Expand Up @@ -193,13 +212,37 @@ public final void copyFromMemoryBufferAsync(
*/
public abstract MemoryBuffer slice(long offset, long len);

/**
* Set an event handler for this buffer. This method can be invoked with null
* to unset the handler.
*
* @param newHandler - the EventHandler to use from this point forward
* @return the prior event handler, or null if not set.
*/
public synchronized EventHandler setEventHandler(EventHandler newHandler) {
EventHandler prev = this.eventHandler;
this.eventHandler = newHandler;
return prev;
}

/**
* Returns the current event handler for this buffer or null if no handler
* is associated or this buffer is closed.
*/
public synchronized EventHandler getEventHandler() {
return this.eventHandler;
}

/**
* Close this buffer and free memory
*/
public synchronized void close() {
if (cleaner != null) {
refCount--;
cleaner.delRef();
if (eventHandler != null) {
eventHandler.onClosed(refCount);
}
if (refCount == 0) {
cleaner.clean(false);
closed = true;
Expand Down Expand Up @@ -232,8 +275,10 @@ public synchronized void incRefCount() {
cleaner.addRef();
}

// visible for testing
synchronized int getRefCount() {
/**
* Get the current reference count for this buffer.
*/
public synchronized int getRefCount() {
return refCount;
}
}
49 changes: 48 additions & 1 deletion java/src/test/java/ai/rapids/cudf/MemoryBufferTest.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,6 +20,8 @@

import org.junit.jupiter.api.Test;

import java.util.concurrent.atomic.AtomicInteger;

import static org.junit.jupiter.api.Assertions.*;

public class MemoryBufferTest extends CudfTestBase {
Expand Down Expand Up @@ -168,4 +170,49 @@ private void verifyOutput(HostMemoryBuffer out) {
out.getBytes(bytes, 0, 0, 16);
assertArrayEquals(EXPECTED, bytes);
}

@Test
public void testEventHandlerIsCalledForEachClose() {
final AtomicInteger onClosedWasCalled = new AtomicInteger(0);
try (DeviceMemoryBuffer b = DeviceMemoryBuffer.allocate(256)) {
b.setEventHandler(refCount -> onClosedWasCalled.incrementAndGet());
}
assertEquals(1, onClosedWasCalled.get());
onClosedWasCalled.set(0);

try (DeviceMemoryBuffer b = DeviceMemoryBuffer.allocate(256)) {
b.setEventHandler(refCount -> onClosedWasCalled.incrementAndGet());
DeviceMemoryBuffer sliced = b.slice(0, b.getLength());
sliced.close();
}
assertEquals(2, onClosedWasCalled.get());
}

@Test
public void testEventHandlerIsNotCalledIfNotSet() {
final AtomicInteger onClosedWasCalled = new AtomicInteger(0);
try (DeviceMemoryBuffer b = DeviceMemoryBuffer.allocate(256)) {
assertNull(b.getEventHandler());
}
assertEquals(0, onClosedWasCalled.get());
try (DeviceMemoryBuffer b = DeviceMemoryBuffer.allocate(256)) {
b.setEventHandler(refCount -> onClosedWasCalled.incrementAndGet());
b.setEventHandler(null);
}
assertEquals(0, onClosedWasCalled.get());
}

@Test
public void testEventHandlerReturnsPreviousHandlerOnReset() {
try (DeviceMemoryBuffer b = DeviceMemoryBuffer.allocate(256)) {
MemoryBuffer.EventHandler handler = refCount -> {};
MemoryBuffer.EventHandler handler2 = refCount -> {};

assertNull(b.setEventHandler(handler));
assertEquals(handler, b.setEventHandler(null));

assertNull(b.setEventHandler(handler2));
assertEquals(handler2, b.setEventHandler(handler));
}
}
}

0 comments on commit 8668752

Please sign in to comment.