Skip to content

Commit

Permalink
Java bindings for regex replace (#8847)
Browse files Browse the repository at this point in the history
This adds Java bindings for `cudf::strings::replace_re`.

Authors:
  - Jason Lowe (https://github.com/jlowe)

Approvers:
  - Robert (Bobby) Evans (https://github.com/revans2)

URL: #8847
  • Loading branch information
jlowe authored Jul 27, 2021
1 parent 5d5bb2c commit d7ba345
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 0 deletions.
64 changes: 64 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -2465,6 +2465,48 @@ public final ColumnVector stringReplace(Scalar target, Scalar replace) {
replace.getScalarHandle()));
}

/**
* For each string, replaces any character sequence matching the given pattern using the
* replacement string scalar.
*
* @param pattern The regular expression pattern to search within each string.
* @param repl The string scalar to replace for each pattern match.
* @return A new column vector containing the string results.
*/
public final ColumnVector replaceRegex(String pattern, Scalar repl) {
return replaceRegex(pattern, repl, -1);
}

/**
* For each string, replaces any character sequence matching the given pattern using the
* replacement string scalar.
*
* @param pattern The regular expression pattern to search within each string.
* @param repl The string scalar to replace for each pattern match.
* @param maxRepl The maximum number of times a replacement should occur within each string.
* @return A new column vector containing the string results.
*/
public final ColumnVector replaceRegex(String pattern, Scalar repl, int maxRepl) {
if (!repl.getType().equals(DType.STRING)) {
throw new IllegalArgumentException("Replacement must be a string scalar");
}
return new ColumnVector(replaceRegex(getNativeView(), pattern, repl.getScalarHandle(),
maxRepl));
}

/**
* For each string, replaces any character sequence matching any of the regular expression
* patterns with the corresponding replacement strings.
*
* @param patterns The regular expression patterns to search within each string.
* @param repls The string scalars to replace for each corresponding pattern match.
* @return A new column vector containing the string results.
*/
public final ColumnVector replaceMultiRegex(String[] patterns, ColumnView repls) {
return new ColumnVector(replaceMultiRegex(getNativeView(), patterns,
repls.getNativeView()));
}

/**
* For each string, replaces any character sequence matching the given pattern
* using the replace template for back-references.
Expand Down Expand Up @@ -3241,6 +3283,28 @@ private static native long substringColumn(long columnView, long startColumn, lo
*/
private static native long stringReplace(long columnView, long target, long repl) throws CudfException;

/**
* Native method for replacing each regular expression pattern match with the specified
* replacement string.
* @param columnView native handle of the cudf::column_view being operated on.
* @param pattern The regular expression pattern to search within each string.
* @param repl native handle of the cudf::scalar containing the replacement string.
* @param maxRepl maximum number of times to replace the pattern within a string
* @return native handle of the resulting cudf column containing the string results.
*/
private static native long replaceRegex(long columnView, String pattern,
long repl, long maxRepl) throws CudfException;

/**
* Native method for multiple instance regular expression replacement.
* @param columnView native handle of the cudf::column_view being operated on.
* @param patterns native handle of the cudf::column_view containing the regex patterns.
* @param repls The replacement template for creating the output string.
* @return native handle of the resulting cudf column containing the string results.
*/
private static native long replaceMultiRegex(long columnView, String[] patterns,
long repls) throws CudfException;

/**
* Native method for replacing any character sequence matching the given pattern
* using the replace template for back-references.
Expand Down
45 changes: 45 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,51 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_mapContains(JNIEnv *env,
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceRegex(JNIEnv *env, jclass,
jlong j_column_view,
jstring j_pattern, jlong j_repl,
jlong j_maxrepl) {

JNI_NULL_CHECK(env, j_column_view, "column is null", 0);
JNI_NULL_CHECK(env, j_pattern, "pattern string is null", 0);
JNI_NULL_CHECK(env, j_repl, "replace scalar is null", 0);
try {
cudf::jni::auto_set_device(env);
auto cv = reinterpret_cast<cudf::column_view const *>(j_column_view);
cudf::strings_column_view scv(*cv);
cudf::jni::native_jstring pattern(env, j_pattern);
auto repl = reinterpret_cast<cudf::string_scalar const *>(j_repl);

std::unique_ptr<cudf::column> result =
cudf::strings::replace_re(scv, pattern.get(), *repl, j_maxrepl);
return reinterpret_cast<jlong>(result.release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceMultiRegex(JNIEnv *env, jclass,
jlong j_column_view,
jobjectArray j_patterns,
jlong j_repls) {

JNI_NULL_CHECK(env, j_column_view, "column is null", 0);
JNI_NULL_CHECK(env, j_patterns, "patterns is null", 0);
JNI_NULL_CHECK(env, j_repls, "repls is null", 0);
try {
cudf::jni::auto_set_device(env);
auto cv = reinterpret_cast<cudf::column_view const *>(j_column_view);
cudf::strings_column_view scv(*cv);
cudf::jni::native_jstringArray patterns(env, j_patterns);
auto repl_cv = reinterpret_cast<cudf::column_view const *>(j_repls);
cudf::strings_column_view repl_scv(*repl_cv);

std::unique_ptr<cudf::column> result =
cudf::strings::replace_re(scv, patterns.as_cpp_vector(), repl_scv);
return reinterpret_cast<jlong>(result.release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringReplaceWithBackrefs(
JNIEnv *env, jclass, jlong column_view, jstring patternObj, jstring replaceObj) {

Expand Down
40 changes: 40 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4479,6 +4479,46 @@ void teststringReplaceThrowsException() {
});
}

@Test
void testReplaceRegex() {
try (ColumnVector v =
ColumnVector.fromStrings("title and Title with title", "nothing", null, "Title");
Scalar repl = Scalar.fromString("Repl");
ColumnVector actual = v.replaceRegex("[tT]itle", repl);
ColumnVector expected =
ColumnVector.fromStrings("Repl and Repl with Repl", "nothing", null, "Repl")) {
assertColumnsAreEqual(expected, actual);
}

try (ColumnVector v =
ColumnVector.fromStrings("title and Title with title", "nothing", null, "Title");
Scalar repl = Scalar.fromString("Repl");
ColumnVector actual = v.replaceRegex("[tT]itle", repl, 0)) {
assertColumnsAreEqual(v, actual);
}

try (ColumnVector v =
ColumnVector.fromStrings("title and Title with title", "nothing", null, "Title");
Scalar repl = Scalar.fromString("Repl");
ColumnVector actual = v.replaceRegex("[tT]itle", repl, 1);
ColumnVector expected =
ColumnVector.fromStrings("Repl and Title with title", "nothing", null, "Repl")) {
assertColumnsAreEqual(expected, actual);
}
}

@Test
void testReplaceMultiRegex() {
try (ColumnVector v =
ColumnVector.fromStrings("title and Title with title", "nothing", null, "Title");
ColumnVector repls = ColumnVector.fromStrings("Repl", "**");
ColumnVector actual = v.replaceMultiRegex(new String[] { "[tT]itle", "and|th" }, repls);
ColumnVector expected =
ColumnVector.fromStrings("Repl ** Repl wi** Repl", "no**ing", null, "Repl")) {
assertColumnsAreEqual(expected, actual);
}
}

@Test
void testStringReplaceWithBackrefs() {

Expand Down

0 comments on commit d7ba345

Please sign in to comment.