Skip to content

Commit

Permalink
Add regex_program strings replacing java APIs and tests (#12701)
Browse files Browse the repository at this point in the history
This PR adds [replace_re, replace_with_backrefs](https://docs.rapids.ai/api/libcudf/nightly/replace__re_8hpp.html) related `regex_program` java APIs and unit tests.
Part of work for NVIDIA/spark-rapids#7295.

Authors:
  - Cindy Jiang (https://github.com/cindyyuanjiang)

Approvers:
  - Jason Lowe (https://github.com/jlowe)
  - Nghia Truong (https://github.com/ttnghia)

URL: #12701
  • Loading branch information
cindyyuanjiang authored Feb 8, 2023
1 parent d3f9daf commit 0161ba8
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 63 deletions.
71 changes: 60 additions & 11 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -2922,8 +2922,21 @@ public final ColumnVector stringReplace(Scalar target, Scalar replace) {
* @param repl The string scalar to replace for each pattern match.
* @return A new column vector containing the string results.
*/
@Deprecated
public final ColumnVector replaceRegex(String pattern, Scalar repl) {
return replaceRegex(pattern, repl, -1);
return replaceRegex(new RegexProgram(pattern, CaptureGroups.NON_CAPTURE), repl);
}

/**
* For each string, replaces any character sequence matching the given regex program pattern
* using the replacement string scalar.
*
* @param regexProg The regex program with pattern to search within each string.
* @param repl The string scalar to replace for each pattern match.
* @return A new column vector containing the string results.
*/
public final ColumnVector replaceRegex(RegexProgram regexProg, Scalar repl) {
return replaceRegex(regexProg, repl, -1);
}

/**
Expand All @@ -2935,12 +2948,27 @@ public final ColumnVector replaceRegex(String pattern, Scalar repl) {
* @param maxRepl The maximum number of times a replacement should occur within each string.
* @return A new column vector containing the string results.
*/
@Deprecated
public final ColumnVector replaceRegex(String pattern, Scalar repl, int maxRepl) {
return replaceRegex(new RegexProgram(pattern, CaptureGroups.NON_CAPTURE), repl, maxRepl);
}

/**
* For each string, replaces any character sequence matching the given regex program pattern
* using the replacement string scalar.
*
* @param regexProg The regex program with pattern to search within each string.
* @param repl The string scalar to replace for each pattern match.
* @param maxRepl The maximum number of times a replacement should occur within each string.
* @return A new column vector containing the string results.
*/
public final ColumnVector replaceRegex(RegexProgram regexProg, Scalar repl, int maxRepl) {
if (!repl.getType().equals(DType.STRING)) {
throw new IllegalArgumentException("Replacement must be a string scalar");
}
return new ColumnVector(replaceRegex(getNativeView(), pattern, repl.getScalarHandle(),
maxRepl));
assert regexProg != null : "regex program may not be null";
return new ColumnVector(replaceRegex(getNativeView(), regexProg.pattern(), regexProg.combinedFlags(),
regexProg.capture().nativeId, repl.getScalarHandle(), maxRepl));
}

/**
Expand All @@ -2966,9 +2994,26 @@ public final ColumnVector replaceMultiRegex(String[] patterns, ColumnView repls)
* @param replace The replacement template for creating the output string.
* @return A new java column vector containing the string results.
*/
@Deprecated
public final ColumnVector stringReplaceWithBackrefs(String pattern, String replace) {
return new ColumnVector(stringReplaceWithBackrefs(getNativeView(), pattern,
replace));
return stringReplaceWithBackrefs(new RegexProgram(pattern), replace);
}

/**
* For each string, replaces any character sequence matching the given regex program
* pattern using the replace template for back-references.
*
* Any null string entries return corresponding null output column entries.
*
* @param regexProg The regex program with pattern to search within each string.
* @param replace The replacement template for creating the output string.
* @return A new java column vector containing the string results.
*/
public final ColumnVector stringReplaceWithBackrefs(RegexProgram regexProg, String replace) {
assert regexProg != null : "regex program may not be null";
return new ColumnVector(
stringReplaceWithBackrefs(getNativeView(), regexProg.pattern(), regexProg.combinedFlags(),
regexProg.capture().nativeId, replace));
}

/**
Expand Down Expand Up @@ -4129,12 +4174,14 @@ private static native long substringColumn(long columnView, long startColumn, lo
* Native method for replacing each regular expression pattern match with the specified
* replacement string.
* @param columnView native handle of the cudf::column_view being operated on.
* @param pattern The regular expression pattern to search within each string.
* @param pattern regular expression pattern to search within each string.
* @param flags regex flags setting.
* @param capture capture groups setting.
* @param repl native handle of the cudf::scalar containing the replacement string.
* @param maxRepl maximum number of times to replace the pattern within a string
* @return native handle of the resulting cudf column containing the string results.
*/
private static native long replaceRegex(long columnView, String pattern,
private static native long replaceRegex(long columnView, String pattern, int flags, int capture,
long repl, long maxRepl) throws CudfException;

/**
Expand All @@ -4148,15 +4195,17 @@ private static native long replaceMultiRegex(long columnView, String[] patterns,
long repls) throws CudfException;

/**
* Native method for replacing any character sequence matching the given pattern
* using the replace template for back-references.
* Native method for replacing any character sequence matching the given regex program
* pattern using the replace template for back-references.
* @param columnView native handle of the cudf::column_view being operated on.
* @param pattern The regular expression patterns to search within each string.
* @param flags Regex flags setting.
* @param capture Capture groups setting.
* @param replace The replacement template for creating the output string.
* @return native handle of the resulting cudf column containing the string results.
*/
private static native long stringReplaceWithBackrefs(long columnView, String pattern,
String replace) throws CudfException;
private static native long stringReplaceWithBackrefs(long columnView, String pattern, int flags,
int capture, String replace) throws CudfException;

/**
* Native method for checking if strings in a column starts with a specified comparison string.
Expand Down
43 changes: 25 additions & 18 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1606,21 +1606,24 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_mapContains(JNIEnv *env,
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceRegex(JNIEnv *env, jclass,
jlong j_column_view,
jstring j_pattern, jlong j_repl,
jlong j_maxrepl) {
JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceRegex(
JNIEnv *env, jclass, jlong j_column_view, jstring j_pattern, jint regex_flags,
jint capture_groups, jlong j_repl, jlong j_maxrepl) {

JNI_NULL_CHECK(env, j_column_view, "column is null", 0);
JNI_NULL_CHECK(env, j_pattern, "pattern string is null", 0);
JNI_NULL_CHECK(env, j_repl, "replace scalar is null", 0);
try {
cudf::jni::auto_set_device(env);
auto cv = reinterpret_cast<cudf::column_view const *>(j_column_view);
cudf::strings_column_view scv(*cv);
cudf::jni::native_jstring pattern(env, j_pattern);
auto repl = reinterpret_cast<cudf::string_scalar const *>(j_repl);
return release_as_jlong(cudf::strings::replace_re(scv, pattern.get(), *repl, j_maxrepl));
auto const cv = reinterpret_cast<cudf::column_view const *>(j_column_view);
auto const strings_column = cudf::strings_column_view{*cv};
auto const pattern = cudf::jni::native_jstring(env, j_pattern);
auto const flags = static_cast<cudf::strings::regex_flags>(regex_flags);
auto const groups = static_cast<cudf::strings::capture_groups>(capture_groups);
auto const regex_prog = cudf::strings::regex_program::create(pattern.get(), flags, groups);
auto const repl = reinterpret_cast<cudf::string_scalar const *>(j_repl);
return release_as_jlong(
cudf::strings::replace_re(strings_column, *regex_prog, *repl, j_maxrepl));
}
CATCH_STD(env, 0);
}
Expand All @@ -1646,19 +1649,23 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceMultiRegex(JNIEnv
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringReplaceWithBackrefs(
JNIEnv *env, jclass, jlong column_view, jstring patternObj, jstring replaceObj) {
JNIEnv *env, jclass, jlong j_column_view, jstring pattern_obj, jint regex_flags,
jint capture_groups, jstring replace_obj) {

JNI_NULL_CHECK(env, column_view, "column is null", 0);
JNI_NULL_CHECK(env, patternObj, "pattern string is null", 0);
JNI_NULL_CHECK(env, replaceObj, "replace string is null", 0);
JNI_NULL_CHECK(env, j_column_view, "column is null", 0);
JNI_NULL_CHECK(env, pattern_obj, "pattern string is null", 0);
JNI_NULL_CHECK(env, replace_obj, "replace string is null", 0);
try {
cudf::jni::auto_set_device(env);
cudf::column_view *cv = reinterpret_cast<cudf::column_view *>(column_view);
cudf::strings_column_view scv(*cv);
cudf::jni::native_jstring ss_pattern(env, patternObj);
cudf::jni::native_jstring ss_replace(env, replaceObj);
auto const cv = reinterpret_cast<cudf::column_view const *>(j_column_view);
auto const strings_column = cudf::strings_column_view{*cv};
auto const pattern = cudf::jni::native_jstring(env, pattern_obj);
auto const flags = static_cast<cudf::strings::regex_flags>(regex_flags);
auto const groups = static_cast<cudf::strings::capture_groups>(capture_groups);
auto const regex_prog = cudf::strings::regex_program::create(pattern.get(), flags, groups);
cudf::jni::native_jstring ss_replace(env, replace_obj);
return release_as_jlong(
cudf::strings::replace_with_backrefs(scv, ss_pattern.get(), ss_replace.get()));
cudf::strings::replace_with_backrefs(strings_column, *regex_prog, ss_replace.get()));
}
CATCH_STD(env, 0);
}
Expand Down
98 changes: 64 additions & 34 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -5147,29 +5147,42 @@ void teststringReplaceThrowsException() {

@Test
void testReplaceRegex() {
try (ColumnVector v =
ColumnVector.fromStrings("title and Title with title", "nothing", null, "Title");
Scalar repl = Scalar.fromString("Repl");
ColumnVector actual = v.replaceRegex("[tT]itle", repl);
ColumnVector expected =
ColumnVector.fromStrings("Repl and Repl with Repl", "nothing", null, "Repl")) {
assertColumnsAreEqual(expected, actual);
}
try (ColumnVector v = ColumnVector.fromStrings("title and Title with title", "nothing", null, "Title");
Scalar repl = Scalar.fromString("Repl")) {
String pattern = "[tT]itle";
RegexProgram regexProg = new RegexProgram(pattern, CaptureGroups.NON_CAPTURE);

try (ColumnVector v =
ColumnVector.fromStrings("title and Title with title", "nothing", null, "Title");
Scalar repl = Scalar.fromString("Repl");
ColumnVector actual = v.replaceRegex("[tT]itle", repl, 0)) {
assertColumnsAreEqual(v, actual);
}
try (ColumnVector actual = v.replaceRegex(pattern, repl);
ColumnVector expected =
ColumnVector.fromStrings("Repl and Repl with Repl", "nothing", null, "Repl")) {
assertColumnsAreEqual(expected, actual);
}

try (ColumnVector v =
ColumnVector.fromStrings("title and Title with title", "nothing", null, "Title");
Scalar repl = Scalar.fromString("Repl");
ColumnVector actual = v.replaceRegex("[tT]itle", repl, 1);
ColumnVector expected =
ColumnVector.fromStrings("Repl and Title with title", "nothing", null, "Repl")) {
assertColumnsAreEqual(expected, actual);
try (ColumnVector actual = v.replaceRegex(pattern, repl, 0)) {
assertColumnsAreEqual(v, actual);
}

try (ColumnVector actual = v.replaceRegex(pattern, repl, 1);
ColumnVector expected =
ColumnVector.fromStrings("Repl and Title with title", "nothing", null, "Repl")) {
assertColumnsAreEqual(expected, actual);
}

try (ColumnVector actual = v.replaceRegex(regexProg, repl);
ColumnVector expected =
ColumnVector.fromStrings("Repl and Repl with Repl", "nothing", null, "Repl")) {
assertColumnsAreEqual(expected, actual);
}

try (ColumnVector actual = v.replaceRegex(regexProg, repl, 0)) {
assertColumnsAreEqual(v, actual);
}

try (ColumnVector actual = v.replaceRegex(regexProg, repl, 1);
ColumnVector expected =
ColumnVector.fromStrings("Repl and Title with title", "nothing", null, "Repl")) {
assertColumnsAreEqual(expected, actual);
}
}
}

Expand All @@ -5188,45 +5201,55 @@ void testReplaceMultiRegex() {
@Test
void testStringReplaceWithBackrefs() {

try (ColumnVector v = ColumnVector.fromStrings("<h1>title</h1>", "<h1>another title</h1>",
null);
try (ColumnVector v = ColumnVector.fromStrings("<h1>title</h1>", "<h1>another title</h1>", null);
ColumnVector expected = ColumnVector.fromStrings("<h2>title</h2>",
"<h2>another title</h2>", null);
ColumnVector actual = v.stringReplaceWithBackrefs("<h1>(.*)</h1>", "<h2>\\1</h2>")) {
ColumnVector actual = v.stringReplaceWithBackrefs("<h1>(.*)</h1>", "<h2>\\1</h2>");
ColumnVector actualRe =
v.stringReplaceWithBackrefs(new RegexProgram("<h1>(.*)</h1>"), "<h2>\\1</h2>")) {
assertColumnsAreEqual(expected, actual);
assertColumnsAreEqual(expected, actualRe);
}

try (ColumnVector v = ColumnVector.fromStrings("2020-1-01", "2020-2-02", null);
ColumnVector expected = ColumnVector.fromStrings("2020-01-01", "2020-02-02", null);
ColumnVector actual = v.stringReplaceWithBackrefs("-([0-9])-", "-0\\1-")) {
ColumnVector actual = v.stringReplaceWithBackrefs("-([0-9])-", "-0\\1-");
ColumnVector actualRe =
v.stringReplaceWithBackrefs(new RegexProgram("-([0-9])-"), "-0\\1-")) {
assertColumnsAreEqual(expected, actual);
assertColumnsAreEqual(expected, actualRe);
}

try (ColumnVector v = ColumnVector.fromStrings("2020-01-1", "2020-02-2",
"2020-03-3invalid", null);
try (ColumnVector v = ColumnVector.fromStrings("2020-01-1", "2020-02-2", "2020-03-3invalid", null);
ColumnVector expected = ColumnVector.fromStrings("2020-01-01", "2020-02-02",
"2020-03-3invalid", null);
ColumnVector actual = v.stringReplaceWithBackrefs(
"-([0-9])$", "-0\\1")) {
ColumnVector actual = v.stringReplaceWithBackrefs("-([0-9])$", "-0\\1");
ColumnVector actualRe =
v.stringReplaceWithBackrefs(new RegexProgram("-([0-9])$"), "-0\\1")) {
assertColumnsAreEqual(expected, actual);
assertColumnsAreEqual(expected, actualRe);
}

try (ColumnVector v = ColumnVector.fromStrings("2020-01-1 random_text", "2020-02-2T12:34:56",
"2020-03-3invalid", null);
"2020-03-3invalid", null);
ColumnVector expected = ColumnVector.fromStrings("2020-01-01 random_text",
"2020-02-02T12:34:56", "2020-03-3invalid", null);
ColumnVector actual = v.stringReplaceWithBackrefs(
"-([0-9])([ T])", "-0\\1\\2")) {
ColumnVector actual = v.stringReplaceWithBackrefs("-([0-9])([ T])", "-0\\1\\2");
ColumnVector actualRe =
v.stringReplaceWithBackrefs(new RegexProgram("-([0-9])([ T])"), "-0\\1\\2")) {
assertColumnsAreEqual(expected, actual);
assertColumnsAreEqual(expected, actualRe);
}

// test zero as group index
try (ColumnVector v = ColumnVector.fromStrings("aa-11 b2b-345", "aa-11a 1c-2b2 b2-c3", "11-aa", null);
ColumnVector expected = ColumnVector.fromStrings("aa-11:aa:11; b2b-345:b:345;",
"aa-11:aa:11;a 1c-2:c:2;b2 b2-c3", "11-aa", null);
ColumnVector actual = v.stringReplaceWithBackrefs(
"([a-z]+)-([0-9]+)", "${0}:${1}:${2};")) {
ColumnVector actual = v.stringReplaceWithBackrefs("([a-z]+)-([0-9]+)", "${0}:${1}:${2};");
ColumnVector actualRe =
v.stringReplaceWithBackrefs(new RegexProgram("([a-z]+)-([0-9]+)"), "${0}:${1}:${2};")) {
assertColumnsAreEqual(expected, actual);
assertColumnsAreEqual(expected, actualRe);
}

// group index exceeds group count
Expand All @@ -5236,6 +5259,13 @@ void testStringReplaceWithBackrefs() {
}
});

// group index exceeds group count
assertThrows(CudfException.class, () -> {
try (ColumnVector v = ColumnVector.fromStrings("ABC123defgh");
ColumnVector r =
v.stringReplaceWithBackrefs(new RegexProgram("([A-Z]+)([0-9]+)([a-z]+)"), "\\4")) {
}
});
}

@Test
Expand Down

0 comments on commit 0161ba8

Please sign in to comment.