diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java
index e871da18966..0a6f0a03bc8 100644
--- a/java/src/main/java/ai/rapids/cudf/ColumnView.java
+++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java
@@ -3498,6 +3498,37 @@ public final Scalar getScalarElement(int index) {
return new Scalar(getType(), getElement(getNativeView(), index));
}
+ /**
+ * Filters elements in each row of this LIST column using `booleanMaskView`
+ * LIST of booleans as a mask.
+ *
+ * Given a list-of-bools column, the function produces
+ * a new `LIST` column of the same type as this column, where each element is copied
+ * from the row *only* if the corresponding `boolean_mask` is non-null and `true`.
+ *
+ * E.g.
+ * column = { {0,1,2}, {3,4}, {5,6,7}, {8,9} };
+ * boolean_mask = { {0,1,1}, {1,0}, {1,1,1}, {0,0} };
+ * results = { {1,2}, {3}, {5,6,7}, {} };
+ *
+ * This column and `boolean_mask` must have the same number of rows.
+ * The output column has the same number of rows as this column.
+ * An element is copied to an output row *only*
+ * if the corresponding boolean_mask element is `true`.
+ * An output row is invalid only if the row is invalid.
+ *
+ * @param booleanMaskView A nullable list of bools column used to filter elements in this column
+ * @return List column of the same type as this column, containing filtered list rows
+ * @throws CudfException if `boolean_mask` is not a "lists of bools" column
+ * @throws CudfException if this column and `boolean_mask` have different number of rows
+ */
+ public final ColumnVector applyBooleanMask(ColumnView booleanMaskView) {
+ assert (getType().equals(DType.LIST));
+ assert (booleanMaskView.getType().equals(DType.LIST));
+ assert (getRowCount() == booleanMaskView.getRowCount());
+ return new ColumnVector(applyBooleanMask(getNativeView(), booleanMaskView.getNativeView()));
+ }
+
/**
* Get the number of bytes needed to allocate a validity buffer for the given number of rows.
* According to cudf::bitmask_allocation_size_bytes, the padding boundary for null mask is 64 bytes.
@@ -4176,6 +4207,8 @@ static native long makeCudfColumnView(int type, int scale, long data, long dataS
static native long generateListOffsets(long handle) throws CudfException;
+ static native long applyBooleanMask(long arrayColumnView, long booleanMaskHandle) throws CudfException;
+
/**
* A utility class to create column vector like objects without refcounts and other APIs when
* creating the device side vector from host side nested vectors. Eventually this can go away or
diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp
index e074180c312..ab2c7006be5 100644
--- a/java/src/main/native/src/ColumnViewJni.cpp
+++ b/java/src/main/native/src/ColumnViewJni.cpp
@@ -34,6 +34,7 @@
#include
#include
#include
+#include
#include
#include
#include
@@ -2223,4 +2224,25 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_repeatStringsSizes(
CATCH_STD(env, 0);
}
+JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_applyBooleanMask(
+ JNIEnv *env, jclass, jlong list_column_handle, jlong boolean_mask_list_column_handle) {
+ JNI_NULL_CHECK(env, list_column_handle, "list handle is null", 0);
+ JNI_NULL_CHECK(env, boolean_mask_list_column_handle, "boolean mask handle is null", 0);
+ try {
+ cudf::jni::auto_set_device(env);
+
+ cudf::column_view const *list_column =
+ reinterpret_cast(list_column_handle);
+ cudf::lists_column_view const list_view = cudf::lists_column_view(*list_column);
+
+ cudf::column_view const *boolean_mask_list_column =
+ reinterpret_cast(boolean_mask_list_column_handle);
+ cudf::lists_column_view const boolean_mask_list_view =
+ cudf::lists_column_view(*boolean_mask_list_column);
+
+ return release_as_jlong(cudf::lists::apply_boolean_mask(list_view, boolean_mask_list_view));
+ }
+ CATCH_STD(env, 0);
+}
+
} // extern "C"
diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
index a42846aac05..492560f7b7f 100644
--- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
+++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
@@ -6299,4 +6299,98 @@ void testGenerateListOffsets() {
assertColumnsAreEqual(expected, actual);
}
}
+
+ @Test
+ void testApplyBooleanMaskFromListOfInt() {
+ try (
+ ColumnVector elementCv = ColumnVector.fromBoxedInts(
+ 11, 12, // list1
+ 21, 22, 23, // list2
+ null, 32, 33, null, 35, // list3
+ null, 42, 43, null, 45 // list 4
+ // list5 (empty)
+ );
+ ColumnVector offsetsCv = ColumnVector.fromInts(0, 2, 5, 10, 15, 15);
+ ColumnVector listOfIntCv = elementCv.makeListFromOffsets(5, offsetsCv);
+
+ ColumnVector boolCv = ColumnVector.fromBoxedBooleans(
+ true, false, // list1
+ true, false, true, // list2
+ true, false, true, false, true, // list3
+ true, false, true, false, true // list 4
+ // list5 (empty)
+ );
+ ColumnVector listOfBoolCv = boolCv.makeListFromOffsets(5, offsetsCv);
+
+ // apply boolean mask
+ ColumnVector actualCv = listOfIntCv.applyBooleanMask(listOfBoolCv);
+
+ ColumnVector expectedElementCv = ColumnVector.fromBoxedInts(
+ 11, // list1
+ 21, 23, // list2
+ null, 33, 35, // list3
+ null, 43, 45 // list 4
+ // list5 (empty)
+ );
+ ColumnVector expectedOffsetsCv = ColumnVector.fromInts(0, 1, 3, 6, 9, 9);
+ ColumnVector expectedCv = expectedElementCv.makeListFromOffsets(5, expectedOffsetsCv)
+ ) {
+ assertColumnsAreEqual(expectedCv, actualCv);
+ }
+ }
+
+ @Test
+ void testApplyBooleanMaskFromListOfStructure() {
+ try (
+ ColumnVector keyCv = ColumnVector.fromBoxedInts(
+ 11, 12, // list1
+ 21, 22, 23, // list2
+ null, 32, 33, null, 35, // list3
+ null, 42, 43, null, 45 // list 4
+ // list5 (empty)
+ );
+ ColumnVector valCv = ColumnVector.fromBoxedInts(
+ 11, 12, // list1
+ 21, 22, 23, // list2
+ 31, 32, 33, 34, 35, // list3
+ 41, 42, 43, 44, 45 // list4
+ // list5 (empty)
+ );
+ ColumnVector structCv = ColumnVector.makeStruct(keyCv, valCv);
+ ColumnVector offsetsCv = ColumnVector.fromInts(0, 2, 5, 10, 15, 15);
+ ColumnVector listOfStructCv = structCv.makeListFromOffsets(5, offsetsCv);
+
+ ColumnVector boolCv = ColumnVector.fromBoxedBooleans(
+ true, false, // list1
+ true, false, true, // list2
+ true, false, true, false, true, // list3
+ true, false, true, false, true // list 4
+ // list5 (empty)
+ );
+ ColumnVector listOfBoolCv = boolCv.makeListFromOffsets(5, offsetsCv);
+
+ // apply boolean mask
+ ColumnVector actualCv = listOfStructCv.applyBooleanMask(listOfBoolCv);
+
+ ColumnVector expectedKeyCv = ColumnVector.fromBoxedInts(
+ 11, // list1
+ 21, 23, // list2
+ null, 33, 35, // list3
+ null, 43, 45 // list 4
+ // list5 (empty)
+ );
+ ColumnVector expectedValCv = ColumnVector.fromBoxedInts(
+ 11, // list1
+ 21, 23, // list2
+ 31, 33, 35, // list3
+ 41, 43, 45 // list4
+ // list5 (empty)
+ );
+ ColumnVector expectedStructCv = ColumnVector.makeStruct(expectedKeyCv, expectedValCv);
+ ColumnVector expectedOffsetsCv = ColumnVector.fromInts(0, 1, 3, 6, 9, 9);
+ ColumnVector expectedCv = expectedStructCv.makeListFromOffsets(5, expectedOffsetsCv)
+ ) {
+ assertColumnsAreEqual(expectedCv, actualCv);
+ }
+ }
}