Skip to content

Commit

Permalink
Add JNI support for apply_boolean_mask (#10812)
Browse files Browse the repository at this point in the history
Contributes to #10650

Add JNI support for `apply_boolean_mask`

Refer to the descriptions of PR #10773

Signed-off-by: Chong Gao <[email protected]>

Authors:
  - Chong Gao (https://github.com/res-life)

Approvers:
  - Liangcai Li (https://github.com/firestarman)
  - Jason Lowe (https://github.com/jlowe)

URL: #10812
  • Loading branch information
res-life authored May 12, 2022
1 parent 3e1a345 commit e0d94f3
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 0 deletions.
33 changes: 33 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -3499,6 +3499,37 @@ public final Scalar getScalarElement(int index) {
return new Scalar(getType(), getElement(getNativeView(), index));
}

/**
* Filters elements in each row of this LIST column using `booleanMaskView`
* LIST of booleans as a mask.
* <p>
* Given a list-of-bools column, the function produces
* a new `LIST` column of the same type as this column, where each element is copied
* from the row *only* if the corresponding `boolean_mask` is non-null and `true`.
* <p>
* E.g.
* column = { {0,1,2}, {3,4}, {5,6,7}, {8,9} };
* boolean_mask = { {0,1,1}, {1,0}, {1,1,1}, {0,0} };
* results = { {1,2}, {3}, {5,6,7}, {} };
* <p>
* This column and `boolean_mask` must have the same number of rows.
* The output column has the same number of rows as this column.
* An element is copied to an output row *only*
* if the corresponding boolean_mask element is `true`.
* An output row is invalid only if the row is invalid.
*
* @param booleanMaskView A nullable list of bools column used to filter elements in this column
* @return List column of the same type as this column, containing filtered list rows
* @throws CudfException if `boolean_mask` is not a "lists of bools" column
* @throws CudfException if this column and `boolean_mask` have different number of rows
*/
public final ColumnVector applyBooleanMask(ColumnView booleanMaskView) {
assert (getType().equals(DType.LIST));
assert (booleanMaskView.getType().equals(DType.LIST));
assert (getRowCount() == booleanMaskView.getRowCount());
return new ColumnVector(applyBooleanMask(getNativeView(), booleanMaskView.getNativeView()));
}

/**
* Get the number of bytes needed to allocate a validity buffer for the given number of rows.
* According to cudf::bitmask_allocation_size_bytes, the padding boundary for null mask is 64 bytes.
Expand Down Expand Up @@ -4177,6 +4208,8 @@ static native long makeCudfColumnView(int type, int scale, long data, long dataS

static native long generateListOffsets(long handle) throws CudfException;

static native long applyBooleanMask(long arrayColumnView, long booleanMaskHandle) throws CudfException;

/**
* A utility class to create column vector like objects without refcounts and other APIs when
* creating the device side vector from host side nested vectors. Eventually this can go away or
Expand Down
22 changes: 22 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <cudf/lists/gather.hpp>
#include <cudf/lists/lists_column_view.hpp>
#include <cudf/lists/sorting.hpp>
#include <cudf/lists/stream_compaction.hpp>
#include <cudf/null_mask.hpp>
#include <cudf/quantiles.hpp>
#include <cudf/reduction.hpp>
Expand Down Expand Up @@ -2226,4 +2227,25 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_repeatStringsSizes(
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_applyBooleanMask(
JNIEnv *env, jclass, jlong list_column_handle, jlong boolean_mask_list_column_handle) {
JNI_NULL_CHECK(env, list_column_handle, "list handle is null", 0);
JNI_NULL_CHECK(env, boolean_mask_list_column_handle, "boolean mask handle is null", 0);
try {
cudf::jni::auto_set_device(env);

cudf::column_view const *list_column =
reinterpret_cast<cudf::column_view const *>(list_column_handle);
cudf::lists_column_view const list_view = cudf::lists_column_view(*list_column);

cudf::column_view const *boolean_mask_list_column =
reinterpret_cast<cudf::column_view const *>(boolean_mask_list_column_handle);
cudf::lists_column_view const boolean_mask_list_view =
cudf::lists_column_view(*boolean_mask_list_column);

return release_as_jlong(cudf::lists::apply_boolean_mask(list_view, boolean_mask_list_view));
}
CATCH_STD(env, 0);
}

} // extern "C"
94 changes: 94 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -6299,4 +6299,98 @@ void testGenerateListOffsets() {
assertColumnsAreEqual(expected, actual);
}
}

@Test
void testApplyBooleanMaskFromListOfInt() {
try (
ColumnVector elementCv = ColumnVector.fromBoxedInts(
11, 12, // list1
21, 22, 23, // list2
null, 32, 33, null, 35, // list3
null, 42, 43, null, 45 // list 4
// list5 (empty)
);
ColumnVector offsetsCv = ColumnVector.fromInts(0, 2, 5, 10, 15, 15);
ColumnVector listOfIntCv = elementCv.makeListFromOffsets(5, offsetsCv);

ColumnVector boolCv = ColumnVector.fromBoxedBooleans(
true, false, // list1
true, false, true, // list2
true, false, true, false, true, // list3
true, false, true, false, true // list 4
// list5 (empty)
);
ColumnVector listOfBoolCv = boolCv.makeListFromOffsets(5, offsetsCv);

// apply boolean mask
ColumnVector actualCv = listOfIntCv.applyBooleanMask(listOfBoolCv);

ColumnVector expectedElementCv = ColumnVector.fromBoxedInts(
11, // list1
21, 23, // list2
null, 33, 35, // list3
null, 43, 45 // list 4
// list5 (empty)
);
ColumnVector expectedOffsetsCv = ColumnVector.fromInts(0, 1, 3, 6, 9, 9);
ColumnVector expectedCv = expectedElementCv.makeListFromOffsets(5, expectedOffsetsCv)
) {
assertColumnsAreEqual(expectedCv, actualCv);
}
}

@Test
void testApplyBooleanMaskFromListOfStructure() {
try (
ColumnVector keyCv = ColumnVector.fromBoxedInts(
11, 12, // list1
21, 22, 23, // list2
null, 32, 33, null, 35, // list3
null, 42, 43, null, 45 // list 4
// list5 (empty)
);
ColumnVector valCv = ColumnVector.fromBoxedInts(
11, 12, // list1
21, 22, 23, // list2
31, 32, 33, 34, 35, // list3
41, 42, 43, 44, 45 // list4
// list5 (empty)
);
ColumnVector structCv = ColumnVector.makeStruct(keyCv, valCv);
ColumnVector offsetsCv = ColumnVector.fromInts(0, 2, 5, 10, 15, 15);
ColumnVector listOfStructCv = structCv.makeListFromOffsets(5, offsetsCv);

ColumnVector boolCv = ColumnVector.fromBoxedBooleans(
true, false, // list1
true, false, true, // list2
true, false, true, false, true, // list3
true, false, true, false, true // list 4
// list5 (empty)
);
ColumnVector listOfBoolCv = boolCv.makeListFromOffsets(5, offsetsCv);

// apply boolean mask
ColumnVector actualCv = listOfStructCv.applyBooleanMask(listOfBoolCv);

ColumnVector expectedKeyCv = ColumnVector.fromBoxedInts(
11, // list1
21, 23, // list2
null, 33, 35, // list3
null, 43, 45 // list 4
// list5 (empty)
);
ColumnVector expectedValCv = ColumnVector.fromBoxedInts(
11, // list1
21, 23, // list2
31, 33, 35, // list3
41, 43, 45 // list4
// list5 (empty)
);
ColumnVector expectedStructCv = ColumnVector.makeStruct(expectedKeyCv, expectedValCv);
ColumnVector expectedOffsetsCv = ColumnVector.fromInts(0, 1, 3, 6, 9, 9);
ColumnVector expectedCv = expectedStructCv.makeListFromOffsets(5, expectedOffsetsCv)
) {
assertColumnsAreEqual(expectedCv, actualCv);
}
}
}

0 comments on commit e0d94f3

Please sign in to comment.