diff --git a/java/src/main/java/ai/rapids/cudf/BitVectorHelper.java b/java/src/main/java/ai/rapids/cudf/BitVectorHelper.java index cdb7e9e4418..fa96e833b90 100644 --- a/java/src/main/java/ai/rapids/cudf/BitVectorHelper.java +++ b/java/src/main/java/ai/rapids/cudf/BitVectorHelper.java @@ -1,6 +1,6 @@ /* * - * Copyright (c) 2019, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -71,7 +71,7 @@ private static void shiftSrcLeftAndWriteToDst(HostMemoryBuffer src, HostMemoryBu /** * This method returns the length in bytes needed to represent X number of rows * e.g. getValidityLengthInBytes(5) => 1 byte - * getLengthInBytes(7) => 1 byte + * getValidityLengthInBytes(7) => 1 byte * getValidityLengthInBytes(14) => 2 bytes */ static long getValidityLengthInBytes(long rows) { diff --git a/java/src/main/java/ai/rapids/cudf/ColumnVector.java b/java/src/main/java/ai/rapids/cudf/ColumnVector.java index 4d43ffcb457..fecb13e1921 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnVector.java @@ -119,7 +119,10 @@ public ColumnVector(DType type, long rows, Optional nullCount, incRefCountInternal(true); } - private static OffHeapState makeOffHeap(DType type, long rows, Optional nullCount, + /** + * This method is internal and exposed purely for testing purposes + */ + static OffHeapState makeOffHeap(DType type, long rows, Optional nullCount, DeviceMemoryBuffer dataBuffer, DeviceMemoryBuffer validityBuffer, DeviceMemoryBuffer offsetBuffer, List toClose, long[] childHandles) { long viewHandle = initViewHandle(type, (int)rows, nullCount.orElse(UNKNOWN_NULL_COUNT).intValue(), @@ -141,7 +144,7 @@ private static OffHeapState makeOffHeap(DType type, long rows, Optional nu * @param offsetBuffer a host buffer required for strings and string categories. The column * vector takes ownership of the buffer. Do not use the buffer after calling * this. - * @param toClose List of buffers to track adn close once done, usually in case of children + * @param toClose List of buffers to track and close once done, usually in case of children * @param childHandles array of longs for child column view handles. */ public ColumnVector(DType type, long rows, Optional nullCount, diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 67ad9166fe0..8b59ea68972 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -43,8 +43,10 @@ public class ColumnView implements AutoCloseable, BinaryOperable { protected final ColumnVector.OffHeapState offHeap; /** - * Constructs a Column View given a native view address + * Constructs a Column View given a native view address. This asserts that if the ColumnView is + * of nested-type it doesn't contain non-empty nulls * @param address the view handle + * @throws AssertionError if the address points to a nested-type view with non-empty nulls */ ColumnView(long address) { this.viewHandle = address; @@ -52,15 +54,24 @@ public class ColumnView implements AutoCloseable, BinaryOperable { this.rows = ColumnView.getNativeRowCount(viewHandle); this.nullCount = ColumnView.getNativeNullCount(viewHandle); this.offHeap = null; - AssertEmptyNulls.assertNullsAreEmpty(this); + try { + AssertEmptyNulls.assertNullsAreEmpty(this); + } catch (AssertionError ae) { + // offHeap state is null, so there is nothing to clean in offHeap + // delete ColumnView to avoid memory leak + deleteColumnView(viewHandle); + viewHandle = 0; + throw ae; + } } /** * Intended to be called from ColumnVector when it is being constructed. Because state creates a * cudf::column_view instance and will close it in all cases, we don't want to have to double - * close it. + * close it. This asserts that if the offHeapState is of nested-type it doesn't contain non-empty nulls * @param state the state this view is based off of. + * @throws AssertionError if offHeapState points to a nested-type view with non-empty nulls */ protected ColumnView(ColumnVector.OffHeapState state) { offHeap = state; @@ -68,7 +79,14 @@ protected ColumnView(ColumnVector.OffHeapState state) { type = DType.fromNative(ColumnView.getNativeTypeId(viewHandle), ColumnView.getNativeTypeScale(viewHandle)); rows = ColumnView.getNativeRowCount(viewHandle); nullCount = ColumnView.getNativeNullCount(viewHandle); - AssertEmptyNulls.assertNullsAreEmpty(this); + try { + AssertEmptyNulls.assertNullsAreEmpty(this); + } catch (AssertionError ae) { + // cleanup offHeap + offHeap.clean(false); + viewHandle = 0; + throw ae; + } } /** @@ -649,8 +667,14 @@ public final ColumnVector ifElse(Scalar trueValue, Scalar falseValue) { public final ColumnVector[] slice(int... indices) { long[] nativeHandles = slice(this.getNativeView(), indices); ColumnVector[] columnVectors = new ColumnVector[nativeHandles.length]; - for (int i = 0; i < nativeHandles.length; i++) { - columnVectors[i] = new ColumnVector(nativeHandles[i]); + try { + for (int i = 0; i < nativeHandles.length; i++) { + columnVectors[i] = new ColumnVector(nativeHandles[i]); + nativeHandles[i] = 0; + } + } catch (Throwable t) { + cleanupColumnViews(nativeHandles, columnVectors); + throw t; } return columnVectors; } @@ -788,12 +812,31 @@ public final ColumnVector[] split(int... indices) { public ColumnView[] splitAsViews(int... indices) { long[] nativeHandles = split(this.getNativeView(), indices); ColumnView[] columnViews = new ColumnView[nativeHandles.length]; - for (int i = 0; i < nativeHandles.length; i++) { - columnViews[i] = new ColumnView(nativeHandles[i]); + try { + for (int i = 0; i < nativeHandles.length; i++) { + columnViews[i] = new ColumnView(nativeHandles[i]); + nativeHandles[i] = 0; + } + } catch (Throwable t) { + cleanupColumnViews(nativeHandles, columnViews); + throw t; } return columnViews; } + static void cleanupColumnViews(long[] nativeHandles, ColumnView[] columnViews) { + for (ColumnView columnView: columnViews) { + if (columnView != null) { + columnView.close(); + } + } + for (long nativeHandle: nativeHandles) { + if (nativeHandle != 0) { + deleteColumnView(nativeHandle); + } + } + } + /** * Create a new vector of "normalized" values, where: * 1. All representations of NaN (and -NaN) are replaced with the normalized NaN value diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index 9abc2dbcd7c..93cb1acfae4 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -87,6 +87,7 @@ public Table(long[] cudfColumns) { try { for (int i = 0; i < cudfColumns.length; i++) { this.columns[i] = new ColumnVector(cudfColumns[i]); + cudfColumns[i] = 0; } long[] views = new long[columns.length]; for (int i = 0; i < columns.length; i++) { @@ -95,13 +96,7 @@ public Table(long[] cudfColumns) { nativeHandle = createCudfTableView(views); this.rows = columns[0].getRowCount(); } catch (Throwable t) { - for (int i = 0; i < cudfColumns.length; i++) { - if (this.columns[i] != null) { - this.columns[i].close(); - } else { - ColumnVector.deleteCudfColumn(cudfColumns[i]); - } - } + ColumnView.cleanupColumnViews(cudfColumns, this.columns); throw t; } } @@ -3396,8 +3391,14 @@ public static GatherMap mixedLeftAntiJoinGatherMap(Table leftKeys, Table rightKe public ColumnVector[] convertToRows() { long[] ptrs = convertToRows(nativeHandle); ColumnVector[] ret = new ColumnVector[ptrs.length]; - for (int i = 0; i < ptrs.length; i++) { - ret[i] = new ColumnVector(ptrs[i]); + try { + for (int i = 0; i < ptrs.length; i++) { + ret[i] = new ColumnVector(ptrs[i]); + ptrs[i] = 0; + } + } catch (Throwable t) { + ColumnView.cleanupColumnViews(ptrs, ret); + throw t; } return ret; } @@ -3479,8 +3480,14 @@ public ColumnVector[] convertToRows() { public ColumnVector[] convertToRowsFixedWidthOptimized() { long[] ptrs = convertToRowsFixedWidthOptimized(nativeHandle); ColumnVector[] ret = new ColumnVector[ptrs.length]; - for (int i = 0; i < ptrs.length; i++) { - ret[i] = new ColumnVector(ptrs[i]); + try { + for (int i = 0; i < ptrs.length; i++) { + ret[i] = new ColumnVector(ptrs[i]); + ptrs[i] = 0; + } + } catch (Throwable t) { + ColumnView.cleanupColumnViews(ptrs, ret); + throw t; } return ret; } @@ -3552,14 +3559,7 @@ public static Table fromPackedTable(ByteBuffer metadata, DeviceMemoryBuffer data } result = new Table(columns); } catch (Throwable t) { - for (int i = 0; i < columns.length; i++) { - if (columns[i] != null) { - columns[i].close(); - } - if (columnViewAddresses[i] != 0) { - ColumnView.deleteColumnView(columnViewAddresses[i]); - } - } + ColumnView.cleanupColumnViews(columnViewAddresses, columns); throw t; } diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 6e9498acdac..59b4c9f9f67 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -32,6 +32,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -6677,6 +6678,54 @@ void testApplyBooleanMaskFromListOfStructure() { } } + @Test + void testColumnViewWithNonEmptyNullsIsCleared() { + List list0 = Arrays.asList(1, 2, 3); + List list1 = Arrays.asList(4, 5, null); + List list2 = Arrays.asList(7, 8, 9); + List list3 = null; + try (ColumnVector input = ColumnVectorTest.makeListsColumn(DType.INT32, list0, list1, list2, list3); + BaseDeviceMemoryBuffer baseValidityBuffer = input.getDeviceBufferFor(BufferType.VALIDITY); + BaseDeviceMemoryBuffer baseOffsetBuffer = input.getDeviceBufferFor(BufferType.OFFSET); + HostMemoryBuffer newValidity = HostMemoryBuffer.allocate(BitVectorHelper.getValidityAllocationSizeInBytes(4))) { + + newValidity.copyFromDeviceBuffer(baseValidityBuffer); + // we are setting list1 with 3 elements to null. This will result in a non-empty null in the + // ColumnView at index 1 + BitVectorHelper.setNullAt(newValidity, 1); + // validityBuffer will be closed by offHeapState later + DeviceMemoryBuffer validityBuffer = DeviceMemoryBuffer.allocate(BitVectorHelper.getValidityAllocationSizeInBytes(4)); + try { + // offsetBuffer will be closed by offHeapState later + DeviceMemoryBuffer offsetBuffer = DeviceMemoryBuffer.allocate(baseOffsetBuffer.getLength()); + try { + validityBuffer.copyFromHostBuffer(newValidity); + offsetBuffer.copyFromMemoryBuffer(0, baseOffsetBuffer, 0, + baseOffsetBuffer.length, Cuda.DEFAULT_STREAM); + + // The new offHeapState will have 2 nulls, one null at index 4 from the original ColumnVector + // the other at index 1 which is non-empty + ColumnVector.OffHeapState offHeapState = ColumnVector.makeOffHeap(input.type, input.rows, Optional.of(2L), + null, validityBuffer, offsetBuffer, + null, Arrays.stream(input.getChildColumnViews()).mapToLong((c) -> c.viewHandle).toArray()); + try { + new ColumnView(offHeapState); + } catch (AssertionError ae) { + assert offHeapState.isClean(); + } + } catch (Exception e) { + if (!offsetBuffer.closed) { + offsetBuffer.close(); + } + } + } catch (Exception e) { + if (!validityBuffer.closed) { + validityBuffer.close(); + } + } + } + } + @Test public void testEventHandlerIsCalledForEachClose() { final AtomicInteger onClosedWasCalled = new AtomicInteger(0); @@ -6700,5 +6749,4 @@ public void testEventHandlerIsNotCalledIfNotSet() { } assertEquals(0, onClosedWasCalled.get()); } - }