diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 6bd4e06c47e..098c68f0596 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -3332,6 +3332,36 @@ public final ColumnVector stringContains(Scalar compString) { return new ColumnVector(stringContains(getNativeView(), compString.getScalarHandle())); } + /** + * @brief Searches for the given target strings within each string in the provided column + * + * Each column in the result table corresponds to the result for the target string at the same + * ordinal. i.e. 0th column is the BOOL8 column result for the 0th target string, 1th for 1th, + * etc. + * + * If the target is not found for a string, false is returned for that entry in the output column. + * If the target is an empty string, true is returned for all non-null entries in the output column. + * + * Any null input strings return corresponding null entries in the output columns. + * + * input = ["a", "b", "c"] + * targets = ["a", "c"] + * output is a table with two boolean columns: + * column 0: [true, false, false] + * column 1: [false, false, true] + * + * @param targets UTF-8 encoded strings to search for in each string in `input` + * @return BOOL8 columns + */ + public final ColumnVector[] stringContains(ColumnView targets) { + assert type.equals(DType.STRING) : "column type must be a String"; + assert targets.getType().equals(DType.STRING) : "targets type must be a string"; + assert targets.getNullCount() == 0 : "targets must not contain nulls"; + assert targets.getRowCount() > 0 : "targets must not be empty"; + long[] resultPointers = stringContainsMulti(getNativeView(), targets.getNativeView()); + return Arrays.stream(resultPointers).mapToObj(ColumnVector::new).toArray(ColumnVector[]::new); + } + /** * Replaces values less than `lo` in `input` with `lo`, * and values greater than `hi` with `hi`. @@ -4437,6 +4467,13 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat */ private static native long stringContains(long cudfViewHandle, long compString) throws CudfException; + /** + * Native method for searching for the given target strings within each string in the provided column. + * @param cudfViewHandle native handle of the cudf::column_view being operated on. + * @param targetViewHandle handle of the column view containing the strings being searched for. + */ + private static native long[] stringContainsMulti(long cudfViewHandle, long targetViewHandle) throws CudfException; + /** * Native method for extracting results from a regex program pattern. Returns a table handle. * diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index 72f0ad19912..90902a24bbe 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -64,6 +64,7 @@ #include #include #include +#include #include #include #include @@ -2827,4 +2828,23 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_toHex(JNIEnv* env, jclass } CATCH_STD(env, 0); } + +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_stringContainsMulti( + JNIEnv* env, jobject j_object, jlong j_view_handle, jlong j_target_view_handle) +{ + JNI_NULL_CHECK(env, j_view_handle, "column is null", 0); + JNI_NULL_CHECK(env, j_target_view_handle, "targets is null", 0); + + try { + cudf::jni::auto_set_device(env); + auto* column_view = reinterpret_cast(j_view_handle); + auto* targets_view = reinterpret_cast(j_target_view_handle); + auto const strings_column = cudf::strings_column_view(*column_view); + auto const targets_column = cudf::strings_column_view(*targets_view); + auto contains_results = cudf::strings::contains_multiple(strings_column, targets_column); + return cudf::jni::convert_table_for_return(env, std::move(contains_results)); + } + CATCH_STD(env, 0); +} + } // extern "C" diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 14c290b300a..d1a1ff2c95c 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -3828,6 +3828,30 @@ void testStringOpsEmpty() { } } + @Test + void testStringContainsMulti() { + ColumnVector[] results = null; + try (ColumnVector haystack = ColumnVector.fromStrings("tést strings", + "Héllo cd", + "1 43 42 7", + "scala spark 42 other", + null, + ""); + ColumnVector targets = ColumnVector.fromStrings("é", "42"); + ColumnVector expected0 = ColumnVector.fromBoxedBooleans(true, true, false, false, null, false); + ColumnVector expected1 = ColumnVector.fromBoxedBooleans(false, false, true, true, null, false)) { + results = haystack.stringContains(targets); + assertColumnsAreEqual(results[0], expected0); + assertColumnsAreEqual(results[1], expected1); + } finally { + if (results != null) { + for (ColumnVector c : results) { + c.close(); + } + } + } + } + @Test void testStringFindOperations() { try (ColumnVector testStrings = ColumnVector.fromStrings("", null, "abCD", "1a\"\u0100B1", "a\"\u0100B1", "1a\"\u0100B",