Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add regex_program strings replacing java APIs and tests #12701

Merged
69 changes: 59 additions & 10 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -2846,10 +2846,23 @@ public final ColumnVector stringReplace(Scalar target, Scalar replace) {
* @param repl The string scalar to replace for each pattern match.
* @return A new column vector containing the string results.
*/
@Deprecated
public final ColumnVector replaceRegex(String pattern, Scalar repl) {
return replaceRegex(pattern, repl, -1);
jlowe marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* For each string, replaces any character sequence matching the given regex program pattern
* using the replacement string scalar.
*
* @param regexProg The regex program with pattern to search within each string.
* @param repl The string scalar to replace for each pattern match.
* @return A new column vector containing the string results.
*/
public final ColumnVector replaceRegex(RegexProgram regexProg, Scalar repl) {
return replaceRegex(regexProg, repl, -1);
}

/**
* For each string, replaces any character sequence matching the given pattern using the
* replacement string scalar.
Expand All @@ -2859,12 +2872,27 @@ public final ColumnVector replaceRegex(String pattern, Scalar repl) {
* @param maxRepl The maximum number of times a replacement should occur within each string.
* @return A new column vector containing the string results.
*/
@Deprecated
public final ColumnVector replaceRegex(String pattern, Scalar repl, int maxRepl) {
return replaceRegex(new RegexProgram(pattern, CaptureGroups.NON_CAPTURE), repl, maxRepl);
}

/**
* For each string, replaces any character sequence matching the given regex program pattern
* using the replacement string scalar.
*
* @param regexProg The regex program with pattern to search within each string.
* @param repl The string scalar to replace for each pattern match.
* @param maxRepl The maximum number of times a replacement should occur within each string.
* @return A new column vector containing the string results.
*/
public final ColumnVector replaceRegex(RegexProgram regexProg, Scalar repl, int maxRepl) {
if (!repl.getType().equals(DType.STRING)) {
throw new IllegalArgumentException("Replacement must be a string scalar");
}
return new ColumnVector(replaceRegex(getNativeView(), pattern, repl.getScalarHandle(),
maxRepl));
assert regexProg != null : "regex program may not be null";
return new ColumnVector(replaceRegex(getNativeView(), regexProg.pattern(), regexProg.combinedFlags(),
regexProg.capture().nativeId, repl.getScalarHandle(), maxRepl));
}

/**
Expand All @@ -2890,9 +2918,26 @@ public final ColumnVector replaceMultiRegex(String[] patterns, ColumnView repls)
* @param replace The replacement template for creating the output string.
* @return A new java column vector containing the string results.
*/
@Deprecated
public final ColumnVector stringReplaceWithBackrefs(String pattern, String replace) {
return new ColumnVector(stringReplaceWithBackrefs(getNativeView(), pattern,
replace));
return stringReplaceWithBackrefs(new RegexProgram(pattern), replace);
}

/**
* For each string, replaces any character sequence matching the given regex program
* pattern using the replace template for back-references.
*
* Any null string entries return corresponding null output column entries.
*
* @param regexProg The regex program with pattern to search within each string.
* @param replace The replacement template for creating the output string.
* @return A new java column vector containing the string results.
*/
public final ColumnVector stringReplaceWithBackrefs(RegexProgram regexProg, String replace) {
assert regexProg != null : "regex program may not be null";
return new ColumnVector(
stringReplaceWithBackrefs(getNativeView(), regexProg.pattern(), regexProg.combinedFlags(),
regexProg.capture().nativeId, replace));
}

/**
Expand Down Expand Up @@ -4025,12 +4070,14 @@ private static native long substringColumn(long columnView, long startColumn, lo
* Native method for replacing each regular expression pattern match with the specified
* replacement string.
* @param columnView native handle of the cudf::column_view being operated on.
* @param pattern The regular expression pattern to search within each string.
* @param pattern regular expression pattern to search within each string.
* @param flags regex flags setting.
* @param capture capture groups setting.
* @param repl native handle of the cudf::scalar containing the replacement string.
* @param maxRepl maximum number of times to replace the pattern within a string
* @return native handle of the resulting cudf column containing the string results.
*/
private static native long replaceRegex(long columnView, String pattern,
private static native long replaceRegex(long columnView, String pattern, int flags, int capture,
long repl, long maxRepl) throws CudfException;

/**
Expand All @@ -4044,15 +4091,17 @@ private static native long replaceMultiRegex(long columnView, String[] patterns,
long repls) throws CudfException;

/**
* Native method for replacing any character sequence matching the given pattern
* using the replace template for back-references.
* Native method for replacing any character sequence matching the given regex program
* pattern using the replace template for back-references.
* @param columnView native handle of the cudf::column_view being operated on.
* @param pattern The regular expression patterns to search within each string.
* @param flags Regex flags setting.
* @param capture Capture groups setting.
* @param replace The replacement template for creating the output string.
* @return native handle of the resulting cudf column containing the string results.
*/
private static native long stringReplaceWithBackrefs(long columnView, String pattern,
String replace) throws CudfException;
private static native long stringReplaceWithBackrefs(long columnView, String pattern, int flags,
int capture, String replace) throws CudfException;

/**
* Native method for checking if strings in a column starts with a specified comparison string.
Expand Down
41 changes: 24 additions & 17 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1566,21 +1566,24 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_mapContains(JNIEnv *env,
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceRegex(JNIEnv *env, jclass,
jlong j_column_view,
jstring j_pattern, jlong j_repl,
jlong j_maxrepl) {
JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceRegex(
JNIEnv *env, jclass, jlong j_column_view, jstring j_pattern, jint regex_flags,
jint capture_groups, jlong j_repl, jlong j_maxrepl) {

JNI_NULL_CHECK(env, j_column_view, "column is null", 0);
JNI_NULL_CHECK(env, j_pattern, "pattern string is null", 0);
JNI_NULL_CHECK(env, j_repl, "replace scalar is null", 0);
try {
cudf::jni::auto_set_device(env);
auto cv = reinterpret_cast<cudf::column_view const *>(j_column_view);
cudf::strings_column_view scv(*cv);
cudf::jni::native_jstring pattern(env, j_pattern);
auto repl = reinterpret_cast<cudf::string_scalar const *>(j_repl);
return release_as_jlong(cudf::strings::replace_re(scv, pattern.get(), *repl, j_maxrepl));
auto const column_view = reinterpret_cast<cudf::column_view const *>(j_column_view);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to my comment on other PR, avoid to use column_view name. This also applies to below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated, thank you!

auto const strings_column = cudf::strings_column_view{*column_view};
auto const pattern = cudf::jni::native_jstring(env, j_pattern);
auto const flags = static_cast<cudf::strings::regex_flags>(regex_flags);
auto const groups = static_cast<cudf::strings::capture_groups>(capture_groups);
auto const regex_prog = cudf::strings::regex_program::create(pattern.get(), flags, groups);
auto const repl = reinterpret_cast<cudf::string_scalar const *>(j_repl);
return release_as_jlong(
cudf::strings::replace_re(strings_column, *regex_prog, *repl, j_maxrepl));
}
CATCH_STD(env, 0);
}
Expand All @@ -1606,19 +1609,23 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceMultiRegex(JNIEnv
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringReplaceWithBackrefs(
JNIEnv *env, jclass, jlong column_view, jstring patternObj, jstring replaceObj) {
JNIEnv *env, jclass, jlong column_view, jstring pattern_obj, jint regex_flags,
jint capture_groups, jstring replace_obj) {

JNI_NULL_CHECK(env, column_view, "column is null", 0);
JNI_NULL_CHECK(env, patternObj, "pattern string is null", 0);
JNI_NULL_CHECK(env, replaceObj, "replace string is null", 0);
JNI_NULL_CHECK(env, pattern_obj, "pattern string is null", 0);
JNI_NULL_CHECK(env, replace_obj, "replace string is null", 0);
try {
cudf::jni::auto_set_device(env);
cudf::column_view *cv = reinterpret_cast<cudf::column_view *>(column_view);
cudf::strings_column_view scv(*cv);
cudf::jni::native_jstring ss_pattern(env, patternObj);
cudf::jni::native_jstring ss_replace(env, replaceObj);
auto const cv = reinterpret_cast<cudf::column_view const *>(column_view);
auto const strings_column = cudf::strings_column_view{*cv};
auto const pattern = cudf::jni::native_jstring(env, pattern_obj);
auto const flags = static_cast<cudf::strings::regex_flags>(regex_flags);
auto const groups = static_cast<cudf::strings::capture_groups>(capture_groups);
auto const regex_prog = cudf::strings::regex_program::create(pattern.get(), flags, groups);
cudf::jni::native_jstring ss_replace(env, replace_obj);
return release_as_jlong(
cudf::strings::replace_with_backrefs(scv, ss_pattern.get(), ss_replace.get()));
cudf::strings::replace_with_backrefs(strings_column, *regex_prog, ss_replace.get()));
}
CATCH_STD(env, 0);
}
Expand Down
98 changes: 64 additions & 34 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -5133,29 +5133,42 @@ void teststringReplaceThrowsException() {

@Test
void testReplaceRegex() {
try (ColumnVector v =
ColumnVector.fromStrings("title and Title with title", "nothing", null, "Title");
Scalar repl = Scalar.fromString("Repl");
ColumnVector actual = v.replaceRegex("[tT]itle", repl);
ColumnVector expected =
ColumnVector.fromStrings("Repl and Repl with Repl", "nothing", null, "Repl")) {
assertColumnsAreEqual(expected, actual);
}
try (ColumnVector v = ColumnVector.fromStrings("title and Title with title", "nothing", null, "Title");
Scalar repl = Scalar.fromString("Repl")) {
String pattern = "[tT]itle";
RegexProgram regexProg = new RegexProgram(pattern, CaptureGroups.NON_CAPTURE);

try (ColumnVector actual = v.replaceRegex(pattern, repl);
ColumnVector expected =
ColumnVector.fromStrings("Repl and Repl with Repl", "nothing", null, "Repl")) {
assertColumnsAreEqual(expected, actual);
}

try (ColumnVector v =
ColumnVector.fromStrings("title and Title with title", "nothing", null, "Title");
Scalar repl = Scalar.fromString("Repl");
ColumnVector actual = v.replaceRegex("[tT]itle", repl, 0)) {
assertColumnsAreEqual(v, actual);
}
try (ColumnVector actual = v.replaceRegex(pattern, repl, 0)) {
assertColumnsAreEqual(v, actual);
}

try (ColumnVector v =
ColumnVector.fromStrings("title and Title with title", "nothing", null, "Title");
Scalar repl = Scalar.fromString("Repl");
ColumnVector actual = v.replaceRegex("[tT]itle", repl, 1);
ColumnVector expected =
ColumnVector.fromStrings("Repl and Title with title", "nothing", null, "Repl")) {
assertColumnsAreEqual(expected, actual);
try (ColumnVector actual = v.replaceRegex(pattern, repl, 1);
ColumnVector expected =
ColumnVector.fromStrings("Repl and Title with title", "nothing", null, "Repl")) {
assertColumnsAreEqual(expected, actual);
}

try (ColumnVector actual = v.replaceRegex(regexProg, repl);
ColumnVector expected =
ColumnVector.fromStrings("Repl and Repl with Repl", "nothing", null, "Repl")) {
assertColumnsAreEqual(expected, actual);
}

try (ColumnVector actual = v.replaceRegex(regexProg, repl, 0)) {
assertColumnsAreEqual(v, actual);
}

try (ColumnVector actual = v.replaceRegex(regexProg, repl, 1);
ColumnVector expected =
ColumnVector.fromStrings("Repl and Title with title", "nothing", null, "Repl")) {
assertColumnsAreEqual(expected, actual);
}
}
}

Expand All @@ -5174,45 +5187,55 @@ void testReplaceMultiRegex() {
@Test
void testStringReplaceWithBackrefs() {

try (ColumnVector v = ColumnVector.fromStrings("<h1>title</h1>", "<h1>another title</h1>",
null);
try (ColumnVector v = ColumnVector.fromStrings("<h1>title</h1>", "<h1>another title</h1>", null);
ColumnVector expected = ColumnVector.fromStrings("<h2>title</h2>",
"<h2>another title</h2>", null);
ColumnVector actual = v.stringReplaceWithBackrefs("<h1>(.*)</h1>", "<h2>\\1</h2>")) {
ColumnVector actual = v.stringReplaceWithBackrefs("<h1>(.*)</h1>", "<h2>\\1</h2>");
ColumnVector actualRe =
v.stringReplaceWithBackrefs(new RegexProgram("<h1>(.*)</h1>"), "<h2>\\1</h2>")) {
assertColumnsAreEqual(expected, actual);
assertColumnsAreEqual(expected, actualRe);
}

try (ColumnVector v = ColumnVector.fromStrings("2020-1-01", "2020-2-02", null);
ColumnVector expected = ColumnVector.fromStrings("2020-01-01", "2020-02-02", null);
ColumnVector actual = v.stringReplaceWithBackrefs("-([0-9])-", "-0\\1-")) {
ColumnVector actual = v.stringReplaceWithBackrefs("-([0-9])-", "-0\\1-");
ColumnVector actualRe =
v.stringReplaceWithBackrefs(new RegexProgram("-([0-9])-"), "-0\\1-")) {
assertColumnsAreEqual(expected, actual);
assertColumnsAreEqual(expected, actualRe);
}

try (ColumnVector v = ColumnVector.fromStrings("2020-01-1", "2020-02-2",
"2020-03-3invalid", null);
try (ColumnVector v = ColumnVector.fromStrings("2020-01-1", "2020-02-2", "2020-03-3invalid", null);
ColumnVector expected = ColumnVector.fromStrings("2020-01-01", "2020-02-02",
"2020-03-3invalid", null);
ColumnVector actual = v.stringReplaceWithBackrefs(
"-([0-9])$", "-0\\1")) {
ColumnVector actual = v.stringReplaceWithBackrefs("-([0-9])$", "-0\\1");
ColumnVector actualRe =
v.stringReplaceWithBackrefs(new RegexProgram("-([0-9])$"), "-0\\1")) {
assertColumnsAreEqual(expected, actual);
assertColumnsAreEqual(expected, actualRe);
}

try (ColumnVector v = ColumnVector.fromStrings("2020-01-1 random_text", "2020-02-2T12:34:56",
"2020-03-3invalid", null);
"2020-03-3invalid", null);
ColumnVector expected = ColumnVector.fromStrings("2020-01-01 random_text",
"2020-02-02T12:34:56", "2020-03-3invalid", null);
ColumnVector actual = v.stringReplaceWithBackrefs(
"-([0-9])([ T])", "-0\\1\\2")) {
ColumnVector actual = v.stringReplaceWithBackrefs("-([0-9])([ T])", "-0\\1\\2");
ColumnVector actualRe =
v.stringReplaceWithBackrefs(new RegexProgram("-([0-9])([ T])"), "-0\\1\\2")) {
assertColumnsAreEqual(expected, actual);
assertColumnsAreEqual(expected, actualRe);
}

// test zero as group index
try (ColumnVector v = ColumnVector.fromStrings("aa-11 b2b-345", "aa-11a 1c-2b2 b2-c3", "11-aa", null);
ColumnVector expected = ColumnVector.fromStrings("aa-11:aa:11; b2b-345:b:345;",
"aa-11:aa:11;a 1c-2:c:2;b2 b2-c3", "11-aa", null);
ColumnVector actual = v.stringReplaceWithBackrefs(
"([a-z]+)-([0-9]+)", "${0}:${1}:${2};")) {
ColumnVector actual = v.stringReplaceWithBackrefs("([a-z]+)-([0-9]+)", "${0}:${1}:${2};");
ColumnVector actualRe =
v.stringReplaceWithBackrefs(new RegexProgram("([a-z]+)-([0-9]+)"), "${0}:${1}:${2};")) {
assertColumnsAreEqual(expected, actual);
assertColumnsAreEqual(expected, actualRe);
}

// group index exceeds group count
Expand All @@ -5222,6 +5245,13 @@ void testStringReplaceWithBackrefs() {
}
});

// group index exceeds group count
assertThrows(CudfException.class, () -> {
try (ColumnVector v = ColumnVector.fromStrings("ABC123defgh");
ColumnVector r =
v.stringReplaceWithBackrefs(new RegexProgram("([A-Z]+)([0-9]+)([a-z]+)"), "\\4")) {
}
});
}

@Test
Expand Down