Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JNI support for apply_boolean_mask #10812

Merged
merged 3 commits into from
May 12, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -3498,6 +3498,10 @@ public final Scalar getScalarElement(int index) {
return new Scalar(getType(), getElement(getNativeView(), index));
}

public final ColumnVector applyBooleanMask(ColumnView booleanMaskView) {
return new ColumnVector(applyBooleanMask(getNativeView(), booleanMaskView.getNativeView()));
firestarman marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* Get the number of bytes needed to allocate a validity buffer for the given number of rows.
* According to cudf::bitmask_allocation_size_bytes, the padding boundary for null mask is 64 bytes.
Expand Down Expand Up @@ -4176,6 +4180,8 @@ static native long makeCudfColumnView(int type, int scale, long data, long dataS

static native long generateListOffsets(long handle) throws CudfException;

static native long applyBooleanMask(long arrayColumnView, long booleanMaskHandle);

/**
* A utility class to create column vector like objects without refcounts and other APIs when
* creating the device side vector from host side nested vectors. Eventually this can go away or
Expand Down
21 changes: 21 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <cudf/lists/gather.hpp>
#include <cudf/lists/lists_column_view.hpp>
#include <cudf/lists/sorting.hpp>
#include <cudf/lists/stream_compaction.hpp>
#include <cudf/null_mask.hpp>
#include <cudf/quantiles.hpp>
#include <cudf/reduction.hpp>
Expand Down Expand Up @@ -2223,4 +2224,24 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_repeatStringsSizes(
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_applyBooleanMask(
JNIEnv *env, jclass, jlong list_column_handle, jlong boolean_mask_list_column_handle) {
JNI_NULL_CHECK(env, list_column_handle, "list handle is null", 0);
JNI_NULL_CHECK(env, boolean_mask_list_column_handle, "boolean mask handle is null", 0);
try {
cudf::jni::auto_set_device(env);

cudf::column_view *list_column = reinterpret_cast<cudf::column_view *>(list_column_handle);
cudf::lists_column_view list_view = cudf::lists_column_view(*list_column);

cudf::column_view *boolean_mask_list_column =
reinterpret_cast<cudf::column_view *>(boolean_mask_list_column_handle);
cudf::lists_column_view boolean_mask_list_view =
cudf::lists_column_view(*boolean_mask_list_column);
jlowe marked this conversation as resolved.
Show resolved Hide resolved

return release_as_jlong(cudf::lists::apply_boolean_mask(list_view, boolean_mask_list_view));
}
CATCH_STD(env, 0);
}

} // extern "C"
94 changes: 94 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -6299,4 +6299,98 @@ void testGenerateListOffsets() {
assertColumnsAreEqual(expected, actual);
}
}

@Test
void testApplyBooleanMaskFromListOfInt() {
try (
ColumnVector elementCv = ColumnVector.fromBoxedInts(
11, 12, // list1
21, 22, 23, // list2
null, 32, 33, null, 35, // list3
null, 42, 43, null, 45 // list 4
// list5 (empty)
);
ColumnVector offsetsCv = ColumnVector.fromInts(0, 2, 5, 10, 15, 15);
ColumnVector listOfIntCv = elementCv.makeListFromOffsets(5, offsetsCv);

ColumnVector boolCv = ColumnVector.fromBoxedBooleans(
true, false, // list1
true, false, true, // list2
true, false, true, false, true, // list3
true, false, true, false, true // list 4
// list5 (empty)
);
ColumnVector listOfBoolCv = boolCv.makeListFromOffsets(5, offsetsCv);

// apply boolean mask
ColumnVector actualCv = listOfIntCv.applyBooleanMask(listOfBoolCv);

ColumnVector expectedElementCv = ColumnVector.fromBoxedInts(
11, // list1
21, 23, // list2
null, 33, 35, // list3
null, 43, 45 // list 4
// list5 (empty)
);
ColumnVector expectedOffsetsCv = ColumnVector.fromInts(0, 1, 3, 6, 9, 9);
ColumnVector expectedCv = expectedElementCv.makeListFromOffsets(5, expectedOffsetsCv)
) {
assertColumnsAreEqual(expectedCv, actualCv);
}
}

@Test
void testApplyBooleanMaskFromListOfStructure() {
try (
ColumnVector keyCv = ColumnVector.fromBoxedInts(
11, 12, // list1
21, 22, 23, // list2
null, 32, 33, null, 35, // list3
null, 42, 43, null, 45 // list 4
// list5 (empty)
);
ColumnVector valCv = ColumnVector.fromBoxedInts(
11, 12, // list1
21, 22, 23, // list2
31, 32, 33, 34, 35, // list3
41, 42, 43, 44, 45 // list4
// list5 (empty)
);
ColumnVector structCv = ColumnVector.makeStruct(keyCv, valCv);
ColumnVector offsetsCv = ColumnVector.fromInts(0, 2, 5, 10, 15, 15);
ColumnVector listOfStructCv = structCv.makeListFromOffsets(5, offsetsCv);

ColumnVector boolCv = ColumnVector.fromBoxedBooleans(
true, false, // list1
true, false, true, // list2
true, false, true, false, true, // list3
true, false, true, false, true // list 4
// list5 (empty)
);
ColumnVector listOfBoolCv = boolCv.makeListFromOffsets(5, offsetsCv);

// apply boolean mask
ColumnVector actualCv = listOfStructCv.applyBooleanMask(listOfBoolCv);

ColumnVector expectedKeyCv = ColumnVector.fromBoxedInts(
11, // list1
21, 23, // list2
null, 33, 35, // list3
null, 43, 45 // list 4
// list5 (empty)
);
ColumnVector expectedValCv = ColumnVector.fromBoxedInts(
11, // list1
21, 23, // list2
31, 33, 35, // list3
41, 43, 45 // list4
// list5 (empty)
);
ColumnVector expectedStructCv = ColumnVector.makeStruct(expectedKeyCv, expectedValCv);
ColumnVector expectedOffsetsCv = ColumnVector.fromInts(0, 1, 3, 6, 9, 9);
ColumnVector expectedCv = expectedStructCv.makeListFromOffsets(5, expectedOffsetsCv)
) {
assertColumnsAreEqual(expectedCv, actualCv);
}
}
}