diff --git a/java/src/main/java/ai/rapids/cudf/ColumnVector.java b/java/src/main/java/ai/rapids/cudf/ColumnVector.java index 61981b34615..cb3234bf706 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnVector.java @@ -45,7 +45,6 @@ public final class ColumnVector extends ColumnView { NativeDepsLoader.loadNativeDeps(); } - private final OffHeapState offHeap; private Optional nullCount = Optional.empty(); private int refCount; @@ -56,14 +55,23 @@ public final class ColumnVector extends ColumnView { * owned by this instance. */ public ColumnVector(long nativePointer) { - super(getColumnViewFromColumn(nativePointer)); + super(new OffHeapState(nativePointer)); assert nativePointer != 0; - offHeap = new OffHeapState(nativePointer); MemoryCleaner.register(this, offHeap); this.refCount = 0; incRefCountInternal(true); } + private static OffHeapState makeOffHeap(DType type, long rows, Optional nullCount, + DeviceMemoryBuffer dataBuffer, DeviceMemoryBuffer validityBuffer, + DeviceMemoryBuffer offsetBuffer) { + long viewHandle = initViewHandle( + type, (int)rows, nullCount.orElse(UNKNOWN_NULL_COUNT).intValue(), + dataBuffer, validityBuffer, offsetBuffer, null); + return new OffHeapState(type, (int) rows, dataBuffer, validityBuffer, + offsetBuffer, null, viewHandle); + } + /** * Create a new column vector based off of data already on the device. * @param type the type of the vector @@ -81,24 +89,29 @@ public ColumnVector(long nativePointer) { public ColumnVector(DType type, long rows, Optional nullCount, DeviceMemoryBuffer dataBuffer, DeviceMemoryBuffer validityBuffer, DeviceMemoryBuffer offsetBuffer) { - super(ColumnVector.initViewHandle( - type, (int)rows, nullCount.orElse(UNKNOWN_NULL_COUNT).intValue(), - dataBuffer, validityBuffer, offsetBuffer, null)); + super(makeOffHeap(type, rows, nullCount, dataBuffer, validityBuffer, offsetBuffer)); assert !type.equals(DType.LIST) : "This constructor should not be used for list type"; if (!type.equals(DType.STRING)) { assert offsetBuffer == null : "offsets are only supported for STRING"; } assert (nullCount.isPresent() && nullCount.get() <= Integer.MAX_VALUE) || !nullCount.isPresent(); - offHeap = new OffHeapState(type, (int) rows, dataBuffer, validityBuffer, - offsetBuffer, null, viewHandle); MemoryCleaner.register(this, offHeap); this.nullCount = nullCount; - this.refCount = 0; incRefCountInternal(true); } + private static OffHeapState makeOffHeap(DType type, long rows, Optional nullCount, + DeviceMemoryBuffer dataBuffer, DeviceMemoryBuffer validityBuffer, + DeviceMemoryBuffer offsetBuffer, List toClose, long[] childHandles) { + long viewHandle = initViewHandle(type, (int)rows, nullCount.orElse(UNKNOWN_NULL_COUNT).intValue(), + dataBuffer, validityBuffer, + offsetBuffer, childHandles); + return new OffHeapState(type, (int) rows, dataBuffer, validityBuffer, offsetBuffer, + toClose, viewHandle); + } + /** * Create a new column vector based off of data already on the device with child columns. * @param type the type of the vector, typically a nested type @@ -118,16 +131,12 @@ public ColumnVector(DType type, long rows, Optional nullCount, public ColumnVector(DType type, long rows, Optional nullCount, DeviceMemoryBuffer dataBuffer, DeviceMemoryBuffer validityBuffer, DeviceMemoryBuffer offsetBuffer, List toClose, long[] childHandles) { - super(initViewHandle(type, (int)rows, nullCount.orElse(UNKNOWN_NULL_COUNT).intValue(), - dataBuffer, validityBuffer, - offsetBuffer, childHandles)); + super(makeOffHeap(type, rows, nullCount, dataBuffer, validityBuffer, offsetBuffer, toClose, childHandles)); if (!type.equals(DType.STRING) && !type.equals(DType.LIST)) { assert offsetBuffer == null : "offsets are only supported for STRING, LISTS"; } assert (nullCount.isPresent() && nullCount.get() <= Integer.MAX_VALUE) || !nullCount.isPresent(); - offHeap = new OffHeapState(type, (int) rows, dataBuffer, validityBuffer, offsetBuffer, - toClose, viewHandle); MemoryCleaner.register(this, offHeap); this.refCount = 0; @@ -143,8 +152,7 @@ public ColumnVector(DType type, long rows, Optional nullCount, * @param contiguousBuffer the buffer that this is based off of. */ private ColumnVector(long viewAddress, DeviceMemoryBuffer contiguousBuffer) { - super(viewAddress); - offHeap = new OffHeapState(viewAddress, contiguousBuffer); + super(new OffHeapState(viewAddress, contiguousBuffer)); MemoryCleaner.register(this, offHeap); // TODO we may want to ask for the null count anyways... this.nullCount = Optional.empty(); diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index cc1fa46becb..8155fe79080 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -40,6 +40,7 @@ public class ColumnView implements AutoCloseable, BinaryOperable { protected final DType type; protected final long rows; protected final long nullCount; + protected final ColumnVector.OffHeapState offHeap; /** * Constructs a Column View given a native view address @@ -50,6 +51,22 @@ public class ColumnView implements AutoCloseable, BinaryOperable { this.type = DType.fromNative(ColumnView.getNativeTypeId(viewHandle), ColumnView.getNativeTypeScale(viewHandle)); this.rows = ColumnView.getNativeRowCount(viewHandle); this.nullCount = ColumnView.getNativeNullCount(viewHandle); + this.offHeap = null; + } + + + /** + * Intended to be called from ColumnVector when it is being constructed. Because state creates a + * cudf::column_view instance and will close it in all cases, we don't want to have to double + * close it. + * @param state the state this view is based off of. + */ + protected ColumnView(ColumnVector.OffHeapState state) { + offHeap = state; + viewHandle = state.getViewHandle(); + type = DType.fromNative(ColumnView.getNativeTypeId(viewHandle), ColumnView.getNativeTypeScale(viewHandle)); + rows = ColumnView.getNativeRowCount(viewHandle); + nullCount = ColumnView.getNativeNullCount(viewHandle); } /** @@ -265,7 +282,10 @@ public long getDeviceMemorySize() { @Override public void close() { - ColumnView.deleteColumnView(viewHandle); + // close the view handle so long as offHeap is not going to do it for us. + if (offHeap == null) { + ColumnView.deleteColumnView(viewHandle); + } viewHandle = 0; }