Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JNI support for apply_boolean_mask #10812

Merged
merged 3 commits into from
May 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -3498,6 +3498,37 @@ public final Scalar getScalarElement(int index) {
return new Scalar(getType(), getElement(getNativeView(), index));
}

/**
* Filters elements in each row of this LIST column using `booleanMaskView`
* LIST of booleans as a mask.
* <p>
* Given a list-of-bools column, the function produces
* a new `LIST` column of the same type as this column, where each element is copied
* from the row *only* if the corresponding `boolean_mask` is non-null and `true`.
* <p>
* E.g.
* column = { {0,1,2}, {3,4}, {5,6,7}, {8,9} };
* boolean_mask = { {0,1,1}, {1,0}, {1,1,1}, {0,0} };
* results = { {1,2}, {3}, {5,6,7}, {} };
* <p>
* This column and `boolean_mask` must have the same number of rows.
* The output column has the same number of rows as this column.
* An element is copied to an output row *only*
* if the corresponding boolean_mask element is `true`.
* An output row is invalid only if the row is invalid.
*
* @param booleanMaskView A nullable list of bools column used to filter elements in this column
* @return List column of the same type as this column, containing filtered list rows
* @throws CudfException if `boolean_mask` is not a "lists of bools" column
* @throws CudfException if this column and `boolean_mask` have different number of rows
*/
public final ColumnVector applyBooleanMask(ColumnView booleanMaskView) {
assert (getType().equals(DType.LIST));
assert (booleanMaskView.getType().equals(DType.LIST));
assert (getRowCount() == booleanMaskView.getRowCount());
return new ColumnVector(applyBooleanMask(getNativeView(), booleanMaskView.getNativeView()));
firestarman marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* Get the number of bytes needed to allocate a validity buffer for the given number of rows.
* According to cudf::bitmask_allocation_size_bytes, the padding boundary for null mask is 64 bytes.
Expand Down Expand Up @@ -4176,6 +4207,8 @@ static native long makeCudfColumnView(int type, int scale, long data, long dataS

static native long generateListOffsets(long handle) throws CudfException;

static native long applyBooleanMask(long arrayColumnView, long booleanMaskHandle) throws CudfException;

/**
* A utility class to create column vector like objects without refcounts and other APIs when
* creating the device side vector from host side nested vectors. Eventually this can go away or
Expand Down
22 changes: 22 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <cudf/lists/gather.hpp>
#include <cudf/lists/lists_column_view.hpp>
#include <cudf/lists/sorting.hpp>
#include <cudf/lists/stream_compaction.hpp>
#include <cudf/null_mask.hpp>
#include <cudf/quantiles.hpp>
#include <cudf/reduction.hpp>
Expand Down Expand Up @@ -2223,4 +2224,25 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_repeatStringsSizes(
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_applyBooleanMask(
JNIEnv *env, jclass, jlong list_column_handle, jlong boolean_mask_list_column_handle) {
JNI_NULL_CHECK(env, list_column_handle, "list handle is null", 0);
JNI_NULL_CHECK(env, boolean_mask_list_column_handle, "boolean mask handle is null", 0);
try {
cudf::jni::auto_set_device(env);

cudf::column_view const *list_column =
reinterpret_cast<cudf::column_view const *>(list_column_handle);
cudf::lists_column_view const list_view = cudf::lists_column_view(*list_column);

cudf::column_view const *boolean_mask_list_column =
reinterpret_cast<cudf::column_view const *>(boolean_mask_list_column_handle);
cudf::lists_column_view const boolean_mask_list_view =
cudf::lists_column_view(*boolean_mask_list_column);

return release_as_jlong(cudf::lists::apply_boolean_mask(list_view, boolean_mask_list_view));
}
CATCH_STD(env, 0);
}

} // extern "C"
94 changes: 94 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -6299,4 +6299,98 @@ void testGenerateListOffsets() {
assertColumnsAreEqual(expected, actual);
}
}

@Test
void testApplyBooleanMaskFromListOfInt() {
try (
ColumnVector elementCv = ColumnVector.fromBoxedInts(
11, 12, // list1
21, 22, 23, // list2
null, 32, 33, null, 35, // list3
null, 42, 43, null, 45 // list 4
// list5 (empty)
);
ColumnVector offsetsCv = ColumnVector.fromInts(0, 2, 5, 10, 15, 15);
ColumnVector listOfIntCv = elementCv.makeListFromOffsets(5, offsetsCv);

ColumnVector boolCv = ColumnVector.fromBoxedBooleans(
true, false, // list1
true, false, true, // list2
true, false, true, false, true, // list3
true, false, true, false, true // list 4
// list5 (empty)
);
ColumnVector listOfBoolCv = boolCv.makeListFromOffsets(5, offsetsCv);

// apply boolean mask
ColumnVector actualCv = listOfIntCv.applyBooleanMask(listOfBoolCv);

ColumnVector expectedElementCv = ColumnVector.fromBoxedInts(
11, // list1
21, 23, // list2
null, 33, 35, // list3
null, 43, 45 // list 4
// list5 (empty)
);
ColumnVector expectedOffsetsCv = ColumnVector.fromInts(0, 1, 3, 6, 9, 9);
ColumnVector expectedCv = expectedElementCv.makeListFromOffsets(5, expectedOffsetsCv)
) {
assertColumnsAreEqual(expectedCv, actualCv);
}
}

@Test
void testApplyBooleanMaskFromListOfStructure() {
try (
ColumnVector keyCv = ColumnVector.fromBoxedInts(
11, 12, // list1
21, 22, 23, // list2
null, 32, 33, null, 35, // list3
null, 42, 43, null, 45 // list 4
// list5 (empty)
);
ColumnVector valCv = ColumnVector.fromBoxedInts(
11, 12, // list1
21, 22, 23, // list2
31, 32, 33, 34, 35, // list3
41, 42, 43, 44, 45 // list4
// list5 (empty)
);
ColumnVector structCv = ColumnVector.makeStruct(keyCv, valCv);
ColumnVector offsetsCv = ColumnVector.fromInts(0, 2, 5, 10, 15, 15);
ColumnVector listOfStructCv = structCv.makeListFromOffsets(5, offsetsCv);

ColumnVector boolCv = ColumnVector.fromBoxedBooleans(
true, false, // list1
true, false, true, // list2
true, false, true, false, true, // list3
true, false, true, false, true // list 4
// list5 (empty)
);
ColumnVector listOfBoolCv = boolCv.makeListFromOffsets(5, offsetsCv);

// apply boolean mask
ColumnVector actualCv = listOfStructCv.applyBooleanMask(listOfBoolCv);

ColumnVector expectedKeyCv = ColumnVector.fromBoxedInts(
11, // list1
21, 23, // list2
null, 33, 35, // list3
null, 43, 45 // list 4
// list5 (empty)
);
ColumnVector expectedValCv = ColumnVector.fromBoxedInts(
11, // list1
21, 23, // list2
31, 33, 35, // list3
41, 43, 45 // list4
// list5 (empty)
);
ColumnVector expectedStructCv = ColumnVector.makeStruct(expectedKeyCv, expectedValCv);
ColumnVector expectedOffsetsCv = ColumnVector.fromInts(0, 1, 3, 6, 9, 9);
ColumnVector expectedCv = expectedStructCv.makeListFromOffsets(5, expectedOffsetsCv)
) {
assertColumnsAreEqual(expectedCv, actualCv);
}
}
}