diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index e871da18966..0a6f0a03bc8 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -3498,6 +3498,37 @@ public final Scalar getScalarElement(int index) { return new Scalar(getType(), getElement(getNativeView(), index)); } + /** + * Filters elements in each row of this LIST column using `booleanMaskView` + * LIST of booleans as a mask. + *

+ * Given a list-of-bools column, the function produces + * a new `LIST` column of the same type as this column, where each element is copied + * from the row *only* if the corresponding `boolean_mask` is non-null and `true`. + *

+ * E.g. + * column = { {0,1,2}, {3,4}, {5,6,7}, {8,9} }; + * boolean_mask = { {0,1,1}, {1,0}, {1,1,1}, {0,0} }; + * results = { {1,2}, {3}, {5,6,7}, {} }; + *

+ * This column and `boolean_mask` must have the same number of rows. + * The output column has the same number of rows as this column. + * An element is copied to an output row *only* + * if the corresponding boolean_mask element is `true`. + * An output row is invalid only if the row is invalid. + * + * @param booleanMaskView A nullable list of bools column used to filter elements in this column + * @return List column of the same type as this column, containing filtered list rows + * @throws CudfException if `boolean_mask` is not a "lists of bools" column + * @throws CudfException if this column and `boolean_mask` have different number of rows + */ + public final ColumnVector applyBooleanMask(ColumnView booleanMaskView) { + assert (getType().equals(DType.LIST)); + assert (booleanMaskView.getType().equals(DType.LIST)); + assert (getRowCount() == booleanMaskView.getRowCount()); + return new ColumnVector(applyBooleanMask(getNativeView(), booleanMaskView.getNativeView())); + } + /** * Get the number of bytes needed to allocate a validity buffer for the given number of rows. * According to cudf::bitmask_allocation_size_bytes, the padding boundary for null mask is 64 bytes. @@ -4176,6 +4207,8 @@ static native long makeCudfColumnView(int type, int scale, long data, long dataS static native long generateListOffsets(long handle) throws CudfException; + static native long applyBooleanMask(long arrayColumnView, long booleanMaskHandle) throws CudfException; + /** * A utility class to create column vector like objects without refcounts and other APIs when * creating the device side vector from host side nested vectors. Eventually this can go away or diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index e074180c312..ab2c7006be5 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -34,6 +34,7 @@ #include #include #include +#include #include #include #include @@ -2223,4 +2224,25 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_repeatStringsSizes( CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_applyBooleanMask( + JNIEnv *env, jclass, jlong list_column_handle, jlong boolean_mask_list_column_handle) { + JNI_NULL_CHECK(env, list_column_handle, "list handle is null", 0); + JNI_NULL_CHECK(env, boolean_mask_list_column_handle, "boolean mask handle is null", 0); + try { + cudf::jni::auto_set_device(env); + + cudf::column_view const *list_column = + reinterpret_cast(list_column_handle); + cudf::lists_column_view const list_view = cudf::lists_column_view(*list_column); + + cudf::column_view const *boolean_mask_list_column = + reinterpret_cast(boolean_mask_list_column_handle); + cudf::lists_column_view const boolean_mask_list_view = + cudf::lists_column_view(*boolean_mask_list_column); + + return release_as_jlong(cudf::lists::apply_boolean_mask(list_view, boolean_mask_list_view)); + } + CATCH_STD(env, 0); +} + } // extern "C" diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index a42846aac05..492560f7b7f 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -6299,4 +6299,98 @@ void testGenerateListOffsets() { assertColumnsAreEqual(expected, actual); } } + + @Test + void testApplyBooleanMaskFromListOfInt() { + try ( + ColumnVector elementCv = ColumnVector.fromBoxedInts( + 11, 12, // list1 + 21, 22, 23, // list2 + null, 32, 33, null, 35, // list3 + null, 42, 43, null, 45 // list 4 + // list5 (empty) + ); + ColumnVector offsetsCv = ColumnVector.fromInts(0, 2, 5, 10, 15, 15); + ColumnVector listOfIntCv = elementCv.makeListFromOffsets(5, offsetsCv); + + ColumnVector boolCv = ColumnVector.fromBoxedBooleans( + true, false, // list1 + true, false, true, // list2 + true, false, true, false, true, // list3 + true, false, true, false, true // list 4 + // list5 (empty) + ); + ColumnVector listOfBoolCv = boolCv.makeListFromOffsets(5, offsetsCv); + + // apply boolean mask + ColumnVector actualCv = listOfIntCv.applyBooleanMask(listOfBoolCv); + + ColumnVector expectedElementCv = ColumnVector.fromBoxedInts( + 11, // list1 + 21, 23, // list2 + null, 33, 35, // list3 + null, 43, 45 // list 4 + // list5 (empty) + ); + ColumnVector expectedOffsetsCv = ColumnVector.fromInts(0, 1, 3, 6, 9, 9); + ColumnVector expectedCv = expectedElementCv.makeListFromOffsets(5, expectedOffsetsCv) + ) { + assertColumnsAreEqual(expectedCv, actualCv); + } + } + + @Test + void testApplyBooleanMaskFromListOfStructure() { + try ( + ColumnVector keyCv = ColumnVector.fromBoxedInts( + 11, 12, // list1 + 21, 22, 23, // list2 + null, 32, 33, null, 35, // list3 + null, 42, 43, null, 45 // list 4 + // list5 (empty) + ); + ColumnVector valCv = ColumnVector.fromBoxedInts( + 11, 12, // list1 + 21, 22, 23, // list2 + 31, 32, 33, 34, 35, // list3 + 41, 42, 43, 44, 45 // list4 + // list5 (empty) + ); + ColumnVector structCv = ColumnVector.makeStruct(keyCv, valCv); + ColumnVector offsetsCv = ColumnVector.fromInts(0, 2, 5, 10, 15, 15); + ColumnVector listOfStructCv = structCv.makeListFromOffsets(5, offsetsCv); + + ColumnVector boolCv = ColumnVector.fromBoxedBooleans( + true, false, // list1 + true, false, true, // list2 + true, false, true, false, true, // list3 + true, false, true, false, true // list 4 + // list5 (empty) + ); + ColumnVector listOfBoolCv = boolCv.makeListFromOffsets(5, offsetsCv); + + // apply boolean mask + ColumnVector actualCv = listOfStructCv.applyBooleanMask(listOfBoolCv); + + ColumnVector expectedKeyCv = ColumnVector.fromBoxedInts( + 11, // list1 + 21, 23, // list2 + null, 33, 35, // list3 + null, 43, 45 // list 4 + // list5 (empty) + ); + ColumnVector expectedValCv = ColumnVector.fromBoxedInts( + 11, // list1 + 21, 23, // list2 + 31, 33, 35, // list3 + 41, 43, 45 // list4 + // list5 (empty) + ); + ColumnVector expectedStructCv = ColumnVector.makeStruct(expectedKeyCv, expectedValCv); + ColumnVector expectedOffsetsCv = ColumnVector.fromInts(0, 1, 3, 6, 9, 9); + ColumnVector expectedCv = expectedStructCv.makeListFromOffsets(5, expectedOffsetsCv) + ) { + assertColumnsAreEqual(expectedCv, actualCv); + } + } }