Skip to content

Commit

Permalink
Merge branch 'add-tz-convert' of github.com:shwina/cudf into add-tz-c…
Browse files Browse the repository at this point in the history
…onvert
  • Loading branch information
shwina committed May 11, 2023
2 parents 5e745fb + 904f3d3 commit 40fad81
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 46 deletions.
4 changes: 2 additions & 2 deletions java/src/main/java/ai/rapids/cudf/BitVectorHelper.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (c) 2019, NVIDIA CORPORATION.
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -71,7 +71,7 @@ private static void shiftSrcLeftAndWriteToDst(HostMemoryBuffer src, HostMemoryBu
/**
* This method returns the length in bytes needed to represent X number of rows
* e.g. getValidityLengthInBytes(5) => 1 byte
* getLengthInBytes(7) => 1 byte
* getValidityLengthInBytes(7) => 1 byte
* getValidityLengthInBytes(14) => 2 bytes
*/
static long getValidityLengthInBytes(long rows) {
Expand Down
7 changes: 5 additions & 2 deletions java/src/main/java/ai/rapids/cudf/ColumnVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,10 @@ public ColumnVector(DType type, long rows, Optional<Long> nullCount,
incRefCountInternal(true);
}

private static OffHeapState makeOffHeap(DType type, long rows, Optional<Long> nullCount,
/**
* This method is internal and exposed purely for testing purposes
*/
static OffHeapState makeOffHeap(DType type, long rows, Optional<Long> nullCount,
DeviceMemoryBuffer dataBuffer, DeviceMemoryBuffer validityBuffer,
DeviceMemoryBuffer offsetBuffer, List<DeviceMemoryBuffer> toClose, long[] childHandles) {
long viewHandle = initViewHandle(type, (int)rows, nullCount.orElse(UNKNOWN_NULL_COUNT).intValue(),
Expand All @@ -141,7 +144,7 @@ private static OffHeapState makeOffHeap(DType type, long rows, Optional<Long> nu
* @param offsetBuffer a host buffer required for strings and string categories. The column
* vector takes ownership of the buffer. Do not use the buffer after calling
* this.
* @param toClose List of buffers to track adn close once done, usually in case of children
* @param toClose List of buffers to track and close once done, usually in case of children
* @param childHandles array of longs for child column view handles.
*/
public ColumnVector(DType type, long rows, Optional<Long> nullCount,
Expand Down
59 changes: 51 additions & 8 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,32 +43,50 @@ public class ColumnView implements AutoCloseable, BinaryOperable {
protected final ColumnVector.OffHeapState offHeap;

/**
* Constructs a Column View given a native view address
* Constructs a Column View given a native view address. This asserts that if the ColumnView is
* of nested-type it doesn't contain non-empty nulls
* @param address the view handle
* @throws AssertionError if the address points to a nested-type view with non-empty nulls
*/
ColumnView(long address) {
this.viewHandle = address;
this.type = DType.fromNative(ColumnView.getNativeTypeId(viewHandle), ColumnView.getNativeTypeScale(viewHandle));
this.rows = ColumnView.getNativeRowCount(viewHandle);
this.nullCount = ColumnView.getNativeNullCount(viewHandle);
this.offHeap = null;
AssertEmptyNulls.assertNullsAreEmpty(this);
try {
AssertEmptyNulls.assertNullsAreEmpty(this);
} catch (AssertionError ae) {
// offHeap state is null, so there is nothing to clean in offHeap
// delete ColumnView to avoid memory leak
deleteColumnView(viewHandle);
viewHandle = 0;
throw ae;
}
}


/**
* Intended to be called from ColumnVector when it is being constructed. Because state creates a
* cudf::column_view instance and will close it in all cases, we don't want to have to double
* close it.
* close it. This asserts that if the offHeapState is of nested-type it doesn't contain non-empty nulls
* @param state the state this view is based off of.
* @throws AssertionError if offHeapState points to a nested-type view with non-empty nulls
*/
protected ColumnView(ColumnVector.OffHeapState state) {
offHeap = state;
viewHandle = state.getViewHandle();
type = DType.fromNative(ColumnView.getNativeTypeId(viewHandle), ColumnView.getNativeTypeScale(viewHandle));
rows = ColumnView.getNativeRowCount(viewHandle);
nullCount = ColumnView.getNativeNullCount(viewHandle);
AssertEmptyNulls.assertNullsAreEmpty(this);
try {
AssertEmptyNulls.assertNullsAreEmpty(this);
} catch (AssertionError ae) {
// cleanup offHeap
offHeap.clean(false);
viewHandle = 0;
throw ae;
}
}

/**
Expand Down Expand Up @@ -649,8 +667,14 @@ public final ColumnVector ifElse(Scalar trueValue, Scalar falseValue) {
public final ColumnVector[] slice(int... indices) {
long[] nativeHandles = slice(this.getNativeView(), indices);
ColumnVector[] columnVectors = new ColumnVector[nativeHandles.length];
for (int i = 0; i < nativeHandles.length; i++) {
columnVectors[i] = new ColumnVector(nativeHandles[i]);
try {
for (int i = 0; i < nativeHandles.length; i++) {
columnVectors[i] = new ColumnVector(nativeHandles[i]);
nativeHandles[i] = 0;
}
} catch (Throwable t) {
cleanupColumnViews(nativeHandles, columnVectors);
throw t;
}
return columnVectors;
}
Expand Down Expand Up @@ -788,12 +812,31 @@ public final ColumnVector[] split(int... indices) {
public ColumnView[] splitAsViews(int... indices) {
long[] nativeHandles = split(this.getNativeView(), indices);
ColumnView[] columnViews = new ColumnView[nativeHandles.length];
for (int i = 0; i < nativeHandles.length; i++) {
columnViews[i] = new ColumnView(nativeHandles[i]);
try {
for (int i = 0; i < nativeHandles.length; i++) {
columnViews[i] = new ColumnView(nativeHandles[i]);
nativeHandles[i] = 0;
}
} catch (Throwable t) {
cleanupColumnViews(nativeHandles, columnViews);
throw t;
}
return columnViews;
}

static void cleanupColumnViews(long[] nativeHandles, ColumnView[] columnViews) {
for (ColumnView columnView: columnViews) {
if (columnView != null) {
columnView.close();
}
}
for (long nativeHandle: nativeHandles) {
if (nativeHandle != 0) {
deleteColumnView(nativeHandle);
}
}
}

/**
* Create a new vector of "normalized" values, where:
* 1. All representations of NaN (and -NaN) are replaced with the normalized NaN value
Expand Down
38 changes: 19 additions & 19 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ public Table(long[] cudfColumns) {
try {
for (int i = 0; i < cudfColumns.length; i++) {
this.columns[i] = new ColumnVector(cudfColumns[i]);
cudfColumns[i] = 0;
}
long[] views = new long[columns.length];
for (int i = 0; i < columns.length; i++) {
Expand All @@ -95,13 +96,7 @@ public Table(long[] cudfColumns) {
nativeHandle = createCudfTableView(views);
this.rows = columns[0].getRowCount();
} catch (Throwable t) {
for (int i = 0; i < cudfColumns.length; i++) {
if (this.columns[i] != null) {
this.columns[i].close();
} else {
ColumnVector.deleteCudfColumn(cudfColumns[i]);
}
}
ColumnView.cleanupColumnViews(cudfColumns, this.columns);
throw t;
}
}
Expand Down Expand Up @@ -3396,8 +3391,14 @@ public static GatherMap mixedLeftAntiJoinGatherMap(Table leftKeys, Table rightKe
public ColumnVector[] convertToRows() {
long[] ptrs = convertToRows(nativeHandle);
ColumnVector[] ret = new ColumnVector[ptrs.length];
for (int i = 0; i < ptrs.length; i++) {
ret[i] = new ColumnVector(ptrs[i]);
try {
for (int i = 0; i < ptrs.length; i++) {
ret[i] = new ColumnVector(ptrs[i]);
ptrs[i] = 0;
}
} catch (Throwable t) {
ColumnView.cleanupColumnViews(ptrs, ret);
throw t;
}
return ret;
}
Expand Down Expand Up @@ -3479,8 +3480,14 @@ public ColumnVector[] convertToRows() {
public ColumnVector[] convertToRowsFixedWidthOptimized() {
long[] ptrs = convertToRowsFixedWidthOptimized(nativeHandle);
ColumnVector[] ret = new ColumnVector[ptrs.length];
for (int i = 0; i < ptrs.length; i++) {
ret[i] = new ColumnVector(ptrs[i]);
try {
for (int i = 0; i < ptrs.length; i++) {
ret[i] = new ColumnVector(ptrs[i]);
ptrs[i] = 0;
}
} catch (Throwable t) {
ColumnView.cleanupColumnViews(ptrs, ret);
throw t;
}
return ret;
}
Expand Down Expand Up @@ -3552,14 +3559,7 @@ public static Table fromPackedTable(ByteBuffer metadata, DeviceMemoryBuffer data
}
result = new Table(columns);
} catch (Throwable t) {
for (int i = 0; i < columns.length; i++) {
if (columns[i] != null) {
columns[i].close();
}
if (columnViewAddresses[i] != 0) {
ColumnView.deleteColumnView(columnViewAddresses[i]);
}
}
ColumnView.cleanupColumnViews(columnViewAddresses, columns);
throw t;
}

Expand Down
50 changes: 49 additions & 1 deletion java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -6677,6 +6678,54 @@ void testApplyBooleanMaskFromListOfStructure() {
}
}

@Test
void testColumnViewWithNonEmptyNullsIsCleared() {
List<Integer> list0 = Arrays.asList(1, 2, 3);
List<Integer> list1 = Arrays.asList(4, 5, null);
List<Integer> list2 = Arrays.asList(7, 8, 9);
List<Integer> list3 = null;
try (ColumnVector input = ColumnVectorTest.makeListsColumn(DType.INT32, list0, list1, list2, list3);
BaseDeviceMemoryBuffer baseValidityBuffer = input.getDeviceBufferFor(BufferType.VALIDITY);
BaseDeviceMemoryBuffer baseOffsetBuffer = input.getDeviceBufferFor(BufferType.OFFSET);
HostMemoryBuffer newValidity = HostMemoryBuffer.allocate(BitVectorHelper.getValidityAllocationSizeInBytes(4))) {

newValidity.copyFromDeviceBuffer(baseValidityBuffer);
// we are setting list1 with 3 elements to null. This will result in a non-empty null in the
// ColumnView at index 1
BitVectorHelper.setNullAt(newValidity, 1);
// validityBuffer will be closed by offHeapState later
DeviceMemoryBuffer validityBuffer = DeviceMemoryBuffer.allocate(BitVectorHelper.getValidityAllocationSizeInBytes(4));
try {
// offsetBuffer will be closed by offHeapState later
DeviceMemoryBuffer offsetBuffer = DeviceMemoryBuffer.allocate(baseOffsetBuffer.getLength());
try {
validityBuffer.copyFromHostBuffer(newValidity);
offsetBuffer.copyFromMemoryBuffer(0, baseOffsetBuffer, 0,
baseOffsetBuffer.length, Cuda.DEFAULT_STREAM);

// The new offHeapState will have 2 nulls, one null at index 4 from the original ColumnVector
// the other at index 1 which is non-empty
ColumnVector.OffHeapState offHeapState = ColumnVector.makeOffHeap(input.type, input.rows, Optional.of(2L),
null, validityBuffer, offsetBuffer,
null, Arrays.stream(input.getChildColumnViews()).mapToLong((c) -> c.viewHandle).toArray());
try {
new ColumnView(offHeapState);
} catch (AssertionError ae) {
assert offHeapState.isClean();
}
} catch (Exception e) {
if (!offsetBuffer.closed) {
offsetBuffer.close();
}
}
} catch (Exception e) {
if (!validityBuffer.closed) {
validityBuffer.close();
}
}
}
}

@Test
public void testEventHandlerIsCalledForEachClose() {
final AtomicInteger onClosedWasCalled = new AtomicInteger(0);
Expand All @@ -6700,5 +6749,4 @@ public void testEventHandlerIsNotCalledIfNotSet() {
}
assertEquals(0, onClosedWasCalled.get());
}

}
35 changes: 21 additions & 14 deletions python/cudf/cudf/core/column/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ def contains(
4 False
dtype: bool
The ``pat`` may also be a list of strings in which case
The ``pat`` may also be a sequence of strings in which case
the individual strings are searched in corresponding rows.
>>> s2 = cudf.Series(['house', 'dog', 'and', '', ''])
Expand All @@ -756,8 +756,6 @@ def contains(
4 <NA>
dtype: bool
""" # noqa W605
if case is not True:
raise NotImplementedError("`case` parameter is not yet supported")
if na is not np.nan:
raise NotImplementedError("`na` parameter is not yet supported")
if regex and isinstance(pat, re.Pattern):
Expand All @@ -767,22 +765,31 @@ def contains(
raise NotImplementedError(
"unsupported value for `flags` parameter"
)

if pat is None:
result_col = column.column_empty(
len(self._column), dtype="bool", masked=True
if regex and not case:
raise NotImplementedError(
"`case=False` only supported when `regex=False`"
)
elif is_scalar(pat):

if is_scalar(pat):
if regex:
result_col = libstrings.contains_re(self._column, pat, flags)
else:
result_col = libstrings.contains(
self._column, cudf.Scalar(pat, "str")
)
if case is False:
input_column = libstrings.to_lower(self._column)
pat = cudf.Scalar(pat.lower(), dtype="str") # type: ignore
else:
input_column = self._column
pat = cudf.Scalar(pat, dtype="str") # type: ignore
result_col = libstrings.contains(input_column, pat)
else:
result_col = libstrings.contains_multiple(
self._column, column.as_column(pat, dtype="str")
)
# TODO: we silently ignore the `regex=` flag here
if case is False:
input_column = libstrings.to_lower(self._column)
pat = libstrings.to_lower(column.as_column(pat, dtype="str"))
else:
input_column = self._column
pat = column.as_column(pat, dtype="str")
result_col = libstrings.contains_multiple(input_column, pat)
return self._return_or_inplace(result_col)

def like(self, pat: str, esc: str = None) -> SeriesOrIndex:
Expand Down
11 changes: 11 additions & 0 deletions python/cudf/cudf/tests/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,17 @@ def test_string_contains(ps_gs, pat, regex, flags, flags_raise, na, na_raise):
assert_eq(expect, got)


def test_string_contains_case(ps_gs):
ps, gs = ps_gs
with pytest.raises(NotImplementedError):
gs.str.contains("A", case=False)
expected = ps.str.contains("A", regex=False, case=False)
got = gs.str.contains("A", regex=False, case=False)
assert_eq(expected, got)
got = gs.str.contains("a", regex=False, case=False)
assert_eq(expected, got)


@pytest.mark.parametrize(
"pat,esc,expect",
[
Expand Down

0 comments on commit 40fad81

Please sign in to comment.