diff --git a/java/src/main/java/ai/rapids/cudf/ColumnVector.java b/java/src/main/java/ai/rapids/cudf/ColumnVector.java index 52731dcf081..4d43ffcb457 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnVector.java @@ -1,6 +1,6 @@ /* * - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -39,6 +39,23 @@ * to increment the reference count. */ public final class ColumnVector extends ColumnView { + /** + * Interface to handle events for this ColumnVector. Only invoked during + * close, hence `onClosed` is the only event. + */ + public interface EventHandler { + /** + * `onClosed` is invoked with the updated `refCount` during `close`. + * The last invocation of `onClosed` will be with `refCount=0`. + * + * @note the callback is invoked with this `ColumnVector`'s lock held. + * + * @param refCount - the updated ref count for this ColumnVector at the time + * of invocation + */ + void onClosed(int refCount); + } + private static final Logger log = LoggerFactory.getLogger(ColumnVector.class); static { @@ -47,6 +64,7 @@ public final class ColumnVector extends ColumnView { private Optional nullCount = Optional.empty(); private int refCount; + private EventHandler eventHandler; /** * Wrap an existing on device cudf::column with the corresponding ColumnVector. The new @@ -200,6 +218,27 @@ static ColumnVector fromViewWithContiguousAllocation(long columnViewAddress, Dev return new ColumnVector(columnViewAddress, buffer); } + /** + * Set an event handler for this vector. This method can be invoked with null + * to unset the handler. + * + * @param newHandler - the EventHandler to use from this point forward + * @return the prior event handler, or null if not set. + */ + public synchronized EventHandler setEventHandler(EventHandler newHandler) { + EventHandler prev = this.eventHandler; + this.eventHandler = newHandler; + return prev; + } + + /** + * Returns the current event handler for this ColumnVector or null if no handler + * is associated. + */ + public synchronized EventHandler getEventHandler() { + return this.eventHandler; + } + /** * This is a really ugly API, but it is possible that the lifecycle of a column of * data may not have a clear lifecycle thanks to java and GC. This API informs the leak @@ -217,6 +256,9 @@ public void noWarnLeakExpected() { public synchronized void close() { refCount--; offHeap.delRef(); + if (eventHandler != null) { + eventHandler.onClosed(refCount); + } if (refCount == 0) { offHeap.clean(false); } else if (refCount < 0) { @@ -272,7 +314,7 @@ public long getNullCount() { /** * Returns this column's current refcount */ - synchronized int getRefCount() { + public synchronized int getRefCount() { return refCount; } diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 7cdb4538e32..6e9498acdac 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -32,7 +32,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; -import java.util.Optional; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -6676,4 +6676,29 @@ void testApplyBooleanMaskFromListOfStructure() { assertColumnsAreEqual(expectedCv, actualCv); } } + + @Test + public void testEventHandlerIsCalledForEachClose() { + final AtomicInteger onClosedWasCalled = new AtomicInteger(0); + try (ColumnVector cv = ColumnVector.fromInts(1,2,3,4)) { + cv.setEventHandler(refCount -> onClosedWasCalled.incrementAndGet()); + } + assertEquals(1, onClosedWasCalled.get()); + } + + @Test + public void testEventHandlerIsNotCalledIfNotSet() { + final AtomicInteger onClosedWasCalled = new AtomicInteger(0); + try (ColumnVector cv = ColumnVector.fromInts(1,2,3,4)) { + assertNull(cv.getEventHandler()); + } + assertEquals(0, onClosedWasCalled.get()); + + try (ColumnVector cv = ColumnVector.fromInts(1,2,3,4)) { + cv.setEventHandler(refCount -> onClosedWasCalled.incrementAndGet()); + cv.setEventHandler(null); + } + assertEquals(0, onClosedWasCalled.get()); + } + }