Skip to content

Commit

Permalink
Add JNI method for strings::replace multi variety (#12979)
Browse files Browse the repository at this point in the history
Adds the JNI API for `stringReplace` using column vector arguments for `targets` and `repls` (to make this consistent with the C++ API). Also adds unit tests for the new API.
Part of the work for NVIDIA/spark-rapids#7907.

Authors:
  - Navin Kumar (https://github.com/NVnavkumar)

Approvers:
  - Nghia Truong (https://github.com/ttnghia)

URL: #12979
  • Loading branch information
NVnavkumar authored Mar 23, 2023
1 parent 4ab227d commit dd5252b
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 0 deletions.
43 changes: 43 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -2914,6 +2914,41 @@ public final ColumnVector stringReplace(Scalar target, Scalar replace) {
replace.getScalarHandle()));
}

/**
* Returns a new strings column where target strings with each string are replaced with
* corresponding replacement strings. For each string in the column, the list of targets
* is searched within that string. If a target string is found, it is replaced by the
* corresponding entry in the repls column. All occurrences found in each string are replaced.
* The repls argument can optionally contain a single string. In this case, all matching
* target substrings will be replaced by that single string.
*
* Example:
* cv = ["hello", "goodbye"]
* targets = ["e","o"]
* repls = ["EE","OO"]
* r1 = cv.stringReplace(targets, repls)
* r1 is now ["hEEllO", "gOOOOdbyEE"]
*
* targets = ["e", "o"]
* repls = ["_"]
* r2 = cv.stringReplace(targets, repls)
* r2 is now ["h_ll_", "g__dby_"]
*
* @param targets Strings to search for in each string.
* @param repls Corresponding replacement strings for target strings.
* @return A new java column vector containing the replaced strings.
*/
public final ColumnVector stringReplace(ColumnView targets, ColumnView repls) {
assert type.equals(DType.STRING) : "column type must be a String";
assert targets != null : "target list may not be null";
assert targets.getType().equals(DType.STRING) : "target list must be a string column";
assert repls != null : "replacement list may not be null";
assert repls.getType().equals(DType.STRING) : "replacement list must be a string column";

return new ColumnVector(stringReplaceMulti(getNativeView(), targets.getNativeView(),
repls.getNativeView()));
}

/**
* For each string, replaces any character sequence matching the given pattern using the
* replacement string scalar.
Expand Down Expand Up @@ -4170,6 +4205,14 @@ private static native long substringColumn(long columnView, long startColumn, lo
*/
private static native long stringReplace(long columnView, long target, long repl) throws CudfException;

/**
* Native method to replace target strings by corresponding repl strings.
* @param inputCV native handle of the cudf::column_view being operated on.
* @param targetsCV handle of column containing the strings being searched.
* @param replsCV handle of column containing the strings to replace (can optionally contain a single string).
*/
private static native long stringReplaceMulti(long inputCV, long targetsCV, long replsCV) throws CudfException;

/**
* Native method for replacing each regular expression pattern match with the specified
* replacement string.
Expand Down
20 changes: 20 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1546,6 +1546,26 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringReplace(JNIEnv *env
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringReplaceMulti(JNIEnv *env, jclass,
jlong inputs_cv,
jlong targets_cv,
jlong repls_cv) {
JNI_NULL_CHECK(env, inputs_cv, "column is null", 0);
JNI_NULL_CHECK(env, targets_cv, "targets string column view is null", 0);
JNI_NULL_CHECK(env, repls_cv, "repls string column view is null", 0);
try {
cudf::jni::auto_set_device(env);
cudf::column_view *cv = reinterpret_cast<cudf::column_view *>(inputs_cv);
cudf::strings_column_view scv(*cv);
cudf::column_view *cvtargets = reinterpret_cast<cudf::column_view *>(targets_cv);
cudf::strings_column_view scvtargets(*cvtargets);
cudf::column_view *cvrepls = reinterpret_cast<cudf::column_view *>(repls_cv);
cudf::strings_column_view scvrepls(*cvrepls);
return release_as_jlong(cudf::strings::replace(scv, scvtargets, scvrepls));
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_mapLookupForKeys(JNIEnv *env, jclass,
jlong map_column_view,
jlong lookup_keys) {
Expand Down
21 changes: 21 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -5146,6 +5146,27 @@ void teststringReplaceThrowsException() {
});
}

@Test
void teststringReplaceMulti() {
try (ColumnVector v = ColumnVector.fromStrings("Héllo", "thésssé", null, "", "ARé", "sssstrings");
ColumnVector e_allParameters = ColumnVector.fromStrings("Hello", "theSse", null, "", "ARe", "SStrings");
ColumnVector targets = ColumnVector.fromStrings("ss", "é");
ColumnVector repls = ColumnVector.fromStrings("S", "e");
ColumnVector replace_allParameters = v.stringReplace(targets, repls)) {
assertColumnsAreEqual(e_allParameters, replace_allParameters);
}
}

@Test
void teststringReplaceMultiThrowsException() {
assertThrows(AssertionError.class, () -> {
try (ColumnVector testStrings = ColumnVector.fromStrings("Héllo", "thésé", null, "", "ARé", "strings");
ColumnVector targets = ColumnVector.fromInts(0, 1);
ColumnVector repls = null;
ColumnVector result = testStrings.stringReplace(targets,repls)){}
});
}

@Test
void testReplaceRegex() {
try (ColumnVector v = ColumnVector.fromStrings("title and Title with title", "nothing", null, "Title");
Expand Down

0 comments on commit dd5252b

Please sign in to comment.