Skip to content

Commit

Permalink
Clean up buffers in case AssertionError (#13262)
Browse files Browse the repository at this point in the history
Authors:
  - Raza Jafri (https://github.com/razajafri)

Approvers:
  - Alessandro Bellina (https://github.com/abellina)
  - Robert (Bobby) Evans (https://github.com/revans2)

URL: #13262
  • Loading branch information
razajafri authored May 11, 2023
1 parent 6f3f507 commit e4e65a9
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 32 deletions.
4 changes: 2 additions & 2 deletions java/src/main/java/ai/rapids/cudf/BitVectorHelper.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (c) 2019, NVIDIA CORPORATION.
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -71,7 +71,7 @@ private static void shiftSrcLeftAndWriteToDst(HostMemoryBuffer src, HostMemoryBu
/**
* This method returns the length in bytes needed to represent X number of rows
* e.g. getValidityLengthInBytes(5) => 1 byte
* getLengthInBytes(7) => 1 byte
* getValidityLengthInBytes(7) => 1 byte
* getValidityLengthInBytes(14) => 2 bytes
*/
static long getValidityLengthInBytes(long rows) {
Expand Down
7 changes: 5 additions & 2 deletions java/src/main/java/ai/rapids/cudf/ColumnVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,10 @@ public ColumnVector(DType type, long rows, Optional<Long> nullCount,
incRefCountInternal(true);
}

private static OffHeapState makeOffHeap(DType type, long rows, Optional<Long> nullCount,
/**
* This method is internal and exposed purely for testing purposes
*/
static OffHeapState makeOffHeap(DType type, long rows, Optional<Long> nullCount,
DeviceMemoryBuffer dataBuffer, DeviceMemoryBuffer validityBuffer,
DeviceMemoryBuffer offsetBuffer, List<DeviceMemoryBuffer> toClose, long[] childHandles) {
long viewHandle = initViewHandle(type, (int)rows, nullCount.orElse(UNKNOWN_NULL_COUNT).intValue(),
Expand All @@ -141,7 +144,7 @@ private static OffHeapState makeOffHeap(DType type, long rows, Optional<Long> nu
* @param offsetBuffer a host buffer required for strings and string categories. The column
* vector takes ownership of the buffer. Do not use the buffer after calling
* this.
* @param toClose List of buffers to track adn close once done, usually in case of children
* @param toClose List of buffers to track and close once done, usually in case of children
* @param childHandles array of longs for child column view handles.
*/
public ColumnVector(DType type, long rows, Optional<Long> nullCount,
Expand Down
59 changes: 51 additions & 8 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,32 +43,50 @@ public class ColumnView implements AutoCloseable, BinaryOperable {
protected final ColumnVector.OffHeapState offHeap;

/**
* Constructs a Column View given a native view address
* Constructs a Column View given a native view address. This asserts that if the ColumnView is
* of nested-type it doesn't contain non-empty nulls
* @param address the view handle
* @throws AssertionError if the address points to a nested-type view with non-empty nulls
*/
ColumnView(long address) {
this.viewHandle = address;
this.type = DType.fromNative(ColumnView.getNativeTypeId(viewHandle), ColumnView.getNativeTypeScale(viewHandle));
this.rows = ColumnView.getNativeRowCount(viewHandle);
this.nullCount = ColumnView.getNativeNullCount(viewHandle);
this.offHeap = null;
AssertEmptyNulls.assertNullsAreEmpty(this);
try {
AssertEmptyNulls.assertNullsAreEmpty(this);
} catch (AssertionError ae) {
// offHeap state is null, so there is nothing to clean in offHeap
// delete ColumnView to avoid memory leak
deleteColumnView(viewHandle);
viewHandle = 0;
throw ae;
}
}


/**
* Intended to be called from ColumnVector when it is being constructed. Because state creates a
* cudf::column_view instance and will close it in all cases, we don't want to have to double
* close it.
* close it. This asserts that if the offHeapState is of nested-type it doesn't contain non-empty nulls
* @param state the state this view is based off of.
* @throws AssertionError if offHeapState points to a nested-type view with non-empty nulls
*/
protected ColumnView(ColumnVector.OffHeapState state) {
offHeap = state;
viewHandle = state.getViewHandle();
type = DType.fromNative(ColumnView.getNativeTypeId(viewHandle), ColumnView.getNativeTypeScale(viewHandle));
rows = ColumnView.getNativeRowCount(viewHandle);
nullCount = ColumnView.getNativeNullCount(viewHandle);
AssertEmptyNulls.assertNullsAreEmpty(this);
try {
AssertEmptyNulls.assertNullsAreEmpty(this);
} catch (AssertionError ae) {
// cleanup offHeap
offHeap.clean(false);
viewHandle = 0;
throw ae;
}
}

/**
Expand Down Expand Up @@ -649,8 +667,14 @@ public final ColumnVector ifElse(Scalar trueValue, Scalar falseValue) {
public final ColumnVector[] slice(int... indices) {
long[] nativeHandles = slice(this.getNativeView(), indices);
ColumnVector[] columnVectors = new ColumnVector[nativeHandles.length];
for (int i = 0; i < nativeHandles.length; i++) {
columnVectors[i] = new ColumnVector(nativeHandles[i]);
try {
for (int i = 0; i < nativeHandles.length; i++) {
columnVectors[i] = new ColumnVector(nativeHandles[i]);
nativeHandles[i] = 0;
}
} catch (Throwable t) {
cleanupColumnViews(nativeHandles, columnVectors);
throw t;
}
return columnVectors;
}
Expand Down Expand Up @@ -788,12 +812,31 @@ public final ColumnVector[] split(int... indices) {
public ColumnView[] splitAsViews(int... indices) {
long[] nativeHandles = split(this.getNativeView(), indices);
ColumnView[] columnViews = new ColumnView[nativeHandles.length];
for (int i = 0; i < nativeHandles.length; i++) {
columnViews[i] = new ColumnView(nativeHandles[i]);
try {
for (int i = 0; i < nativeHandles.length; i++) {
columnViews[i] = new ColumnView(nativeHandles[i]);
nativeHandles[i] = 0;
}
} catch (Throwable t) {
cleanupColumnViews(nativeHandles, columnViews);
throw t;
}
return columnViews;
}

static void cleanupColumnViews(long[] nativeHandles, ColumnView[] columnViews) {
for (ColumnView columnView: columnViews) {
if (columnView != null) {
columnView.close();
}
}
for (long nativeHandle: nativeHandles) {
if (nativeHandle != 0) {
deleteColumnView(nativeHandle);
}
}
}

/**
* Create a new vector of "normalized" values, where:
* 1. All representations of NaN (and -NaN) are replaced with the normalized NaN value
Expand Down
38 changes: 19 additions & 19 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ public Table(long[] cudfColumns) {
try {
for (int i = 0; i < cudfColumns.length; i++) {
this.columns[i] = new ColumnVector(cudfColumns[i]);
cudfColumns[i] = 0;
}
long[] views = new long[columns.length];
for (int i = 0; i < columns.length; i++) {
Expand All @@ -95,13 +96,7 @@ public Table(long[] cudfColumns) {
nativeHandle = createCudfTableView(views);
this.rows = columns[0].getRowCount();
} catch (Throwable t) {
for (int i = 0; i < cudfColumns.length; i++) {
if (this.columns[i] != null) {
this.columns[i].close();
} else {
ColumnVector.deleteCudfColumn(cudfColumns[i]);
}
}
ColumnView.cleanupColumnViews(cudfColumns, this.columns);
throw t;
}
}
Expand Down Expand Up @@ -3396,8 +3391,14 @@ public static GatherMap mixedLeftAntiJoinGatherMap(Table leftKeys, Table rightKe
public ColumnVector[] convertToRows() {
long[] ptrs = convertToRows(nativeHandle);
ColumnVector[] ret = new ColumnVector[ptrs.length];
for (int i = 0; i < ptrs.length; i++) {
ret[i] = new ColumnVector(ptrs[i]);
try {
for (int i = 0; i < ptrs.length; i++) {
ret[i] = new ColumnVector(ptrs[i]);
ptrs[i] = 0;
}
} catch (Throwable t) {
ColumnView.cleanupColumnViews(ptrs, ret);
throw t;
}
return ret;
}
Expand Down Expand Up @@ -3479,8 +3480,14 @@ public ColumnVector[] convertToRows() {
public ColumnVector[] convertToRowsFixedWidthOptimized() {
long[] ptrs = convertToRowsFixedWidthOptimized(nativeHandle);
ColumnVector[] ret = new ColumnVector[ptrs.length];
for (int i = 0; i < ptrs.length; i++) {
ret[i] = new ColumnVector(ptrs[i]);
try {
for (int i = 0; i < ptrs.length; i++) {
ret[i] = new ColumnVector(ptrs[i]);
ptrs[i] = 0;
}
} catch (Throwable t) {
ColumnView.cleanupColumnViews(ptrs, ret);
throw t;
}
return ret;
}
Expand Down Expand Up @@ -3552,14 +3559,7 @@ public static Table fromPackedTable(ByteBuffer metadata, DeviceMemoryBuffer data
}
result = new Table(columns);
} catch (Throwable t) {
for (int i = 0; i < columns.length; i++) {
if (columns[i] != null) {
columns[i].close();
}
if (columnViewAddresses[i] != 0) {
ColumnView.deleteColumnView(columnViewAddresses[i]);
}
}
ColumnView.cleanupColumnViews(columnViewAddresses, columns);
throw t;
}

Expand Down
50 changes: 49 additions & 1 deletion java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -6677,6 +6678,54 @@ void testApplyBooleanMaskFromListOfStructure() {
}
}

@Test
void testColumnViewWithNonEmptyNullsIsCleared() {
List<Integer> list0 = Arrays.asList(1, 2, 3);
List<Integer> list1 = Arrays.asList(4, 5, null);
List<Integer> list2 = Arrays.asList(7, 8, 9);
List<Integer> list3 = null;
try (ColumnVector input = ColumnVectorTest.makeListsColumn(DType.INT32, list0, list1, list2, list3);
BaseDeviceMemoryBuffer baseValidityBuffer = input.getDeviceBufferFor(BufferType.VALIDITY);
BaseDeviceMemoryBuffer baseOffsetBuffer = input.getDeviceBufferFor(BufferType.OFFSET);
HostMemoryBuffer newValidity = HostMemoryBuffer.allocate(BitVectorHelper.getValidityAllocationSizeInBytes(4))) {

newValidity.copyFromDeviceBuffer(baseValidityBuffer);
// we are setting list1 with 3 elements to null. This will result in a non-empty null in the
// ColumnView at index 1
BitVectorHelper.setNullAt(newValidity, 1);
// validityBuffer will be closed by offHeapState later
DeviceMemoryBuffer validityBuffer = DeviceMemoryBuffer.allocate(BitVectorHelper.getValidityAllocationSizeInBytes(4));
try {
// offsetBuffer will be closed by offHeapState later
DeviceMemoryBuffer offsetBuffer = DeviceMemoryBuffer.allocate(baseOffsetBuffer.getLength());
try {
validityBuffer.copyFromHostBuffer(newValidity);
offsetBuffer.copyFromMemoryBuffer(0, baseOffsetBuffer, 0,
baseOffsetBuffer.length, Cuda.DEFAULT_STREAM);

// The new offHeapState will have 2 nulls, one null at index 4 from the original ColumnVector
// the other at index 1 which is non-empty
ColumnVector.OffHeapState offHeapState = ColumnVector.makeOffHeap(input.type, input.rows, Optional.of(2L),
null, validityBuffer, offsetBuffer,
null, Arrays.stream(input.getChildColumnViews()).mapToLong((c) -> c.viewHandle).toArray());
try {
new ColumnView(offHeapState);
} catch (AssertionError ae) {
assert offHeapState.isClean();
}
} catch (Exception e) {
if (!offsetBuffer.closed) {
offsetBuffer.close();
}
}
} catch (Exception e) {
if (!validityBuffer.closed) {
validityBuffer.close();
}
}
}
}

@Test
public void testEventHandlerIsCalledForEachClose() {
final AtomicInteger onClosedWasCalled = new AtomicInteger(0);
Expand All @@ -6700,5 +6749,4 @@ public void testEventHandlerIsNotCalledIfNotSet() {
}
assertEquals(0, onClosedWasCalled.get());
}

}

0 comments on commit e4e65a9

Please sign in to comment.