From dd5252b892adb43dc362cceaed4fffd9c6329269 Mon Sep 17 00:00:00 2001 From: Navin Kumar <97137715+NVnavkumar@users.noreply.github.com> Date: Thu, 23 Mar 2023 15:25:02 -0700 Subject: [PATCH] Add JNI method for strings::replace multi variety (#12979) Adds the JNI API for `stringReplace` using column vector arguments for `targets` and `repls` (to make this consistent with the C++ API). Also adds unit tests for the new API. Part of the work for https://github.com/NVIDIA/spark-rapids/issues/7907. Authors: - Navin Kumar (https://github.com/NVnavkumar) Approvers: - Nghia Truong (https://github.com/ttnghia) URL: https://github.com/rapidsai/cudf/pull/12979 --- .../main/java/ai/rapids/cudf/ColumnView.java | 43 +++++++++++++++++++ java/src/main/native/src/ColumnViewJni.cpp | 20 +++++++++ .../java/ai/rapids/cudf/ColumnVectorTest.java | 21 +++++++++ 3 files changed, 84 insertions(+) diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 84183819854..7d93438d72e 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -2914,6 +2914,41 @@ public final ColumnVector stringReplace(Scalar target, Scalar replace) { replace.getScalarHandle())); } + /** + * Returns a new strings column where target strings with each string are replaced with + * corresponding replacement strings. For each string in the column, the list of targets + * is searched within that string. If a target string is found, it is replaced by the + * corresponding entry in the repls column. All occurrences found in each string are replaced. + * The repls argument can optionally contain a single string. In this case, all matching + * target substrings will be replaced by that single string. + * + * Example: + * cv = ["hello", "goodbye"] + * targets = ["e","o"] + * repls = ["EE","OO"] + * r1 = cv.stringReplace(targets, repls) + * r1 is now ["hEEllO", "gOOOOdbyEE"] + * + * targets = ["e", "o"] + * repls = ["_"] + * r2 = cv.stringReplace(targets, repls) + * r2 is now ["h_ll_", "g__dby_"] + * + * @param targets Strings to search for in each string. + * @param repls Corresponding replacement strings for target strings. + * @return A new java column vector containing the replaced strings. + */ + public final ColumnVector stringReplace(ColumnView targets, ColumnView repls) { + assert type.equals(DType.STRING) : "column type must be a String"; + assert targets != null : "target list may not be null"; + assert targets.getType().equals(DType.STRING) : "target list must be a string column"; + assert repls != null : "replacement list may not be null"; + assert repls.getType().equals(DType.STRING) : "replacement list must be a string column"; + + return new ColumnVector(stringReplaceMulti(getNativeView(), targets.getNativeView(), + repls.getNativeView())); + } + /** * For each string, replaces any character sequence matching the given pattern using the * replacement string scalar. @@ -4170,6 +4205,14 @@ private static native long substringColumn(long columnView, long startColumn, lo */ private static native long stringReplace(long columnView, long target, long repl) throws CudfException; + /** + * Native method to replace target strings by corresponding repl strings. + * @param inputCV native handle of the cudf::column_view being operated on. + * @param targetsCV handle of column containing the strings being searched. + * @param replsCV handle of column containing the strings to replace (can optionally contain a single string). + */ + private static native long stringReplaceMulti(long inputCV, long targetsCV, long replsCV) throws CudfException; + /** * Native method for replacing each regular expression pattern match with the specified * replacement string. diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index f2c361c5e8c..1213ab305fe 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -1546,6 +1546,26 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringReplace(JNIEnv *env CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringReplaceMulti(JNIEnv *env, jclass, + jlong inputs_cv, + jlong targets_cv, + jlong repls_cv) { + JNI_NULL_CHECK(env, inputs_cv, "column is null", 0); + JNI_NULL_CHECK(env, targets_cv, "targets string column view is null", 0); + JNI_NULL_CHECK(env, repls_cv, "repls string column view is null", 0); + try { + cudf::jni::auto_set_device(env); + cudf::column_view *cv = reinterpret_cast(inputs_cv); + cudf::strings_column_view scv(*cv); + cudf::column_view *cvtargets = reinterpret_cast(targets_cv); + cudf::strings_column_view scvtargets(*cvtargets); + cudf::column_view *cvrepls = reinterpret_cast(repls_cv); + cudf::strings_column_view scvrepls(*cvrepls); + return release_as_jlong(cudf::strings::replace(scv, scvtargets, scvrepls)); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_mapLookupForKeys(JNIEnv *env, jclass, jlong map_column_view, jlong lookup_keys) { diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 7848807dab8..8e19c543ee5 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -5146,6 +5146,27 @@ void teststringReplaceThrowsException() { }); } + @Test + void teststringReplaceMulti() { + try (ColumnVector v = ColumnVector.fromStrings("Héllo", "thésssé", null, "", "ARé", "sssstrings"); + ColumnVector e_allParameters = ColumnVector.fromStrings("Hello", "theSse", null, "", "ARe", "SStrings"); + ColumnVector targets = ColumnVector.fromStrings("ss", "é"); + ColumnVector repls = ColumnVector.fromStrings("S", "e"); + ColumnVector replace_allParameters = v.stringReplace(targets, repls)) { + assertColumnsAreEqual(e_allParameters, replace_allParameters); + } + } + + @Test + void teststringReplaceMultiThrowsException() { + assertThrows(AssertionError.class, () -> { + try (ColumnVector testStrings = ColumnVector.fromStrings("Héllo", "thésé", null, "", "ARé", "strings"); + ColumnVector targets = ColumnVector.fromInts(0, 1); + ColumnVector repls = null; + ColumnVector result = testStrings.stringReplace(targets,repls)){} + }); + } + @Test void testReplaceRegex() { try (ColumnVector v = ColumnVector.fromStrings("title and Title with title", "nothing", null, "Title");