From 17554adf507cd43f9fbfe51820163751d220188c Mon Sep 17 00:00:00 2001 From: Cindy Jiang <47068112+cindyyuanjiang@users.noreply.github.com> Date: Fri, 3 Feb 2023 13:19:30 -0800 Subject: [PATCH] Add `regex_program` searching APIs and related java classes (#12666) --- .../java/ai/rapids/cudf/CaptureGroups.java | 36 +++++ .../main/java/ai/rapids/cudf/ColumnView.java | 108 ++++++++++--- .../main/java/ai/rapids/cudf/RegexFlag.java | 37 +++++ .../java/ai/rapids/cudf/RegexProgram.java | 134 +++++++++++++++++ java/src/main/native/src/ColumnViewJni.cpp | 58 ++++--- .../java/ai/rapids/cudf/ColumnVectorTest.java | 142 +++++++++++------- 6 files changed, 422 insertions(+), 93 deletions(-) create mode 100644 java/src/main/java/ai/rapids/cudf/CaptureGroups.java create mode 100644 java/src/main/java/ai/rapids/cudf/RegexFlag.java create mode 100644 java/src/main/java/ai/rapids/cudf/RegexProgram.java diff --git a/java/src/main/java/ai/rapids/cudf/CaptureGroups.java b/java/src/main/java/ai/rapids/cudf/CaptureGroups.java new file mode 100644 index 00000000000..2ab778dbc35 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/CaptureGroups.java @@ -0,0 +1,36 @@ +/* + * + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +/** + * Capture groups setting, closely following cudf::strings::capture_groups. + * + * For processing a regex pattern containing capture groups. These can be used + * to optimize the generated regex instructions where the capture groups do not + * require extracting the groups. + */ +public enum CaptureGroups { + EXTRACT(0), // capture groups processed normally for extract + NON_CAPTURE(1); // convert all capture groups to non-capture groups + + final int nativeId; // Native id, for use with libcudf. + private CaptureGroups(int nativeId) { // Only constant values should be used + this.nativeId = nativeId; + } +} diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 47d6b7573cd..b3111cec77b 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -3153,8 +3153,8 @@ public final ColumnVector clamp(Scalar lo, Scalar loReplace, Scalar hi, Scalar h * match the given regex pattern but only at the beginning of the string. * * ``` - * cv = ["abc","123","def456"] - * result = cv.matches_re("\\d+") + * cv = ["abc", "123", "def456"] + * result = cv.matchesRe("\\d+") * r is now [false, true, false] * ``` * Any null string entries return corresponding null output column entries. @@ -3164,11 +3164,34 @@ public final ColumnVector clamp(Scalar lo, Scalar loReplace, Scalar hi, Scalar h * @param pattern Regex pattern to match to each string. * @return New ColumnVector of boolean results for each string. */ + @Deprecated public final ColumnVector matchesRe(String pattern) { + return matchesRe(new RegexProgram(pattern, CaptureGroups.NON_CAPTURE)); + } + + /** + * Returns a boolean ColumnVector identifying rows which + * match the given regex program pattern but only at the beginning of the string. + * + * ``` + * cv = ["abc", "123", "def456"] + * p = new RegexProgram("\\d+", CaptureGroups.NON_CAPTURE) + * r = cv.matchesRe(p) + * r is now [false, true, false] + * ``` + * Any null string entries return corresponding null output column entries. + * For supported regex patterns refer to: + * @link https://docs.rapids.ai/api/libcudf/nightly/md_regex.html + * + * @param regexProg Regex program to match to each string. + * @return New ColumnVector of boolean results for each string. + */ + public final ColumnVector matchesRe(RegexProgram regexProg) { assert type.equals(DType.STRING) : "column type must be a String"; - assert pattern != null : "pattern may not be null"; - assert !pattern.isEmpty() : "pattern string may not be empty"; - return new ColumnVector(matchesRe(getNativeView(), pattern)); + assert regexProg != null : "regex program may not be null"; + assert !regexProg.pattern().isEmpty() : "pattern string may not be empty"; + return new ColumnVector(matchesRe(getNativeView(), regexProg.pattern(), + regexProg.combinedFlags(), regexProg.capture().nativeId)); } /** @@ -3176,8 +3199,8 @@ public final ColumnVector matchesRe(String pattern) { * match the given regex pattern starting at any location. * * ``` - * cv = ["abc","123","def456"] - * result = cv.matches_re("\\d+") + * cv = ["abc", "123", "def456"] + * r = cv.containsRe("\\d+") * r is now [false, true, true] * ``` * Any null string entries return corresponding null output column entries. @@ -3187,11 +3210,34 @@ public final ColumnVector matchesRe(String pattern) { * @param pattern Regex pattern to match to each string. * @return New ColumnVector of boolean results for each string. */ + @Deprecated public final ColumnVector containsRe(String pattern) { + return containsRe(new RegexProgram(pattern, CaptureGroups.NON_CAPTURE)); + } + + /** + * Returns a boolean ColumnVector identifying rows which + * match the given RegexProgram pattern starting at any location. + * + * ``` + * cv = ["abc", "123", "def456"] + * p = new RegexProgram("\\d+", CaptureGroups.NON_CAPTURE) + * r = cv.containsRe(p) + * r is now [false, true, true] + * ``` + * Any null string entries return corresponding null output column entries. + * For supported regex patterns refer to: + * @link https://docs.rapids.ai/api/libcudf/nightly/md_regex.html + * + * @param regexProg Regex program to match to each string. + * @return New ColumnVector of boolean results for each string. + */ + public final ColumnVector containsRe(RegexProgram regexProg) { assert type.equals(DType.STRING) : "column type must be a String"; - assert pattern != null : "pattern may not be null"; - assert !pattern.isEmpty() : "pattern string may not be empty"; - return new ColumnVector(containsRe(getNativeView(), pattern)); + assert regexProg != null : "regex program may not be null"; + assert !regexProg.pattern().isEmpty() : "pattern string may not be empty"; + return new ColumnVector(containsRe(getNativeView(), regexProg.pattern(), + regexProg.combinedFlags(), regexProg.capture().nativeId)); } /** @@ -3222,11 +3268,31 @@ public final Table extractRe(String pattern) throws CudfException { * @param idx The regex group index * @return A new column vector of extracted matches */ + @Deprecated public final ColumnVector extractAllRecord(String pattern, int idx) { + if (idx == 0) { + return extractAllRecord(new RegexProgram(pattern, CaptureGroups.NON_CAPTURE), idx); + } + return extractAllRecord(new RegexProgram(pattern), idx); + } + + /** + * Extracts all strings that match the given regex program pattern and corresponds to the + * regular expression group index. Any null inputs also result in null output entries. + * + * For supported regex patterns refer to: + * @link https://docs.rapids.ai/api/libcudf/nightly/md_regex.html + * @param regexProg The regex program + * @param idx The regex group index + * @return A new column vector of extracted matches + */ + public final ColumnVector extractAllRecord(RegexProgram regexProg, int idx) { assert type.equals(DType.STRING) : "column type must be a String"; assert idx >= 0 : "group index must be at least 0"; - - return new ColumnVector(extractAllRecord(this.getNativeView(), pattern, idx)); + assert regexProg != null : "regex program may not be null"; + return new ColumnVector( + extractAllRecord(this.getNativeView(), regexProg.pattern(), regexProg.combinedFlags(), + regexProg.capture().nativeId, idx)); } /** @@ -3995,21 +4061,25 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat private static native long stringStrip(long columnView, int type, long toStrip) throws CudfException; /** - * Native method for checking if strings match the passed in regex pattern from the + * Native method for checking if strings match the passed in regex program pattern from the * beginning of the string. * @param cudfViewHandle native handle of the cudf::column_view being operated on. * @param pattern string regex pattern. + * @param flags regex flags setting. + * @param capture capture groups setting. * @return native handle of the resulting cudf column containing the boolean results. */ - private static native long matchesRe(long cudfViewHandle, String pattern) throws CudfException; + private static native long matchesRe(long cudfViewHandle, String pattern, int flags, int capture) throws CudfException; /** - * Native method for checking if strings match the passed in regex pattern starting at any location. + * Native method for checking if strings match the passed in regex program pattern starting at any location. * @param cudfViewHandle native handle of the cudf::column_view being operated on. * @param pattern string regex pattern. + * @param flags regex flags setting. + * @param capture capture groups setting. * @return native handle of the resulting cudf column containing the boolean results. */ - private static native long containsRe(long cudfViewHandle, String pattern) throws CudfException; + private static native long containsRe(long cudfViewHandle, String pattern, int flags, int capture) throws CudfException; /** * Native method for checking if strings match the passed in like pattern @@ -4035,14 +4105,16 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat private static native long[] extractRe(long cudfViewHandle, String pattern) throws CudfException; /** - * Native method for extracting all results corresponding to group idx from a regular expression. + * Native method for extracting all results corresponding to group idx from a regex program pattern. * * @param nativeHandle Native handle of the cudf::column_view being operated on. * @param pattern String regex pattern. + * @param flags Regex flags setting. + * @param capture Capture groups setting. * @param idx Regex group index. A 0 value means matching the entire regex. * @return Native handle of a string column of the result. */ - private static native long extractAllRecord(long nativeHandle, String pattern, int idx); + private static native long extractAllRecord(long nativeHandle, String pattern, int flags, int capture, int idx); private static native long urlDecode(long cudfViewHandle); diff --git a/java/src/main/java/ai/rapids/cudf/RegexFlag.java b/java/src/main/java/ai/rapids/cudf/RegexFlag.java new file mode 100644 index 00000000000..7ed8e0354c9 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/RegexFlag.java @@ -0,0 +1,37 @@ +/* + * + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +/** + * Regex flags setting, closely following cudf::strings::regex_flags. + * + * These types can be or'd to combine them. The values are chosen to + * leave room for future flags and to match the Python flag values. + */ +public enum RegexFlag { + DEFAULT(0), // default + MULTILINE(8), // the '^' and '$' honor new-line characters + DOTALL(16), // the '.' matching includes new-line characters + ASCII(256); // use only ASCII when matching built-in character classes + + final int nativeId; // Native id, for use with libcudf. + private RegexFlag(int nativeId) { // Only constant values should be used + this.nativeId = nativeId; + } +} diff --git a/java/src/main/java/ai/rapids/cudf/RegexProgram.java b/java/src/main/java/ai/rapids/cudf/RegexProgram.java new file mode 100644 index 00000000000..191a0b95ff3 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/RegexProgram.java @@ -0,0 +1,134 @@ +/* + * + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +import java.util.EnumSet; + +/** + * Regex program class, closely following cudf::strings::regex_program. + */ +public class RegexProgram { + private String pattern; // regex pattern + // regex flags for interpreting special characters in the pattern + private EnumSet flags; + // controls how capture groups in the pattern are used + // default is to extract a capture group + private CaptureGroups capture; + + /** + * Constructor for RegexProgram + * + * @param pattern Regex pattern + */ + public RegexProgram(String pattern) { + this(pattern, EnumSet.of(RegexFlag.DEFAULT), CaptureGroups.EXTRACT); + } + + /** + * Constructor for RegexProgram + * + * @param pattern Regex pattern + * @param flags Regex flags setting + */ + public RegexProgram(String pattern, EnumSet flags) { + this(pattern, flags, CaptureGroups.EXTRACT); + } + + /** + * Constructor for RegexProgram + * + * @param pattern Regex pattern setting + * @param capture Capture groups setting + */ + public RegexProgram(String pattern, CaptureGroups capture) { + this(pattern, EnumSet.of(RegexFlag.DEFAULT), capture); + } + + /** + * Constructor for RegexProgram + * + * @param pattern Regex pattern + * @param flags Regex flags setting + * @param capture Capture groups setting + */ + public RegexProgram(String pattern, EnumSet flags, CaptureGroups capture) { + assert pattern != null : "pattern may not be null"; + this.pattern = pattern; + this.flags = flags; + this.capture = capture; + } + + /** + * Get the pattern used to create this instance + * + * @param return A regex pattern as a string + */ + public String pattern() { + return pattern; + } + + /** + * Get the regex flags setting used to create this instance + * + * @param return Regex flags setting + */ + public EnumSet flags() { + return flags; + } + + /** + * Reset the regex flags setting for this instance + * + * @param flags Regex flags setting + */ + public void setFlags(EnumSet flags) { + this.flags = flags; + } + + /** + * Get the capture groups setting used to create this instance + * + * @param return Capture groups setting + */ + public CaptureGroups capture() { + return capture; + } + + /** + * Reset the capture groups setting for this instance + * + * @param capture Capture groups setting + */ + public void setCapture(CaptureGroups capture) { + this.capture = capture; + } + + /** + * Combine the regex flags using 'or' + * + * @param return An integer representing the value of combined (or'ed) flags + */ + public int combinedFlags() { + int allFlags = 0; + for (RegexFlag flag : flags) { + allFlags |= flag.nativeId; + } + return allFlags; + } +} diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index b48ddae196b..ff07a6786c1 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -62,6 +62,7 @@ #include #include #include +#include #include #include #include @@ -1290,32 +1291,42 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringContains(JNIEnv *en JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_matchesRe(JNIEnv *env, jobject j_object, jlong j_view_handle, - jstring patternObj) { + jstring pattern_obj, + jint regex_flags, + jint capture_groups) { JNI_NULL_CHECK(env, j_view_handle, "column is null", false); - JNI_NULL_CHECK(env, patternObj, "pattern is null", false); + JNI_NULL_CHECK(env, pattern_obj, "pattern is null", false); try { cudf::jni::auto_set_device(env); - cudf::column_view *column_view = reinterpret_cast(j_view_handle); - cudf::strings_column_view strings_column(*column_view); - cudf::jni::native_jstring pattern(env, patternObj); - return release_as_jlong(cudf::strings::matches_re(strings_column, pattern.get())); + auto const column_view = reinterpret_cast(j_view_handle); + auto const strings_column = cudf::strings_column_view{*column_view}; + auto const pattern = cudf::jni::native_jstring(env, pattern_obj); + auto const flags = static_cast(regex_flags); + auto const groups = static_cast(capture_groups); + auto const regex_prog = cudf::strings::regex_program::create(pattern.get(), flags, groups); + return release_as_jlong(cudf::strings::matches_re(strings_column, *regex_prog)); } CATCH_STD(env, 0); } JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_containsRe(JNIEnv *env, jobject j_object, jlong j_view_handle, - jstring patternObj) { + jstring pattern_obj, + jint regex_flags, + jint capture_groups) { JNI_NULL_CHECK(env, j_view_handle, "column is null", false); - JNI_NULL_CHECK(env, patternObj, "pattern is null", false); + JNI_NULL_CHECK(env, pattern_obj, "pattern is null", false); try { cudf::jni::auto_set_device(env); - cudf::column_view *column_view = reinterpret_cast(j_view_handle); - cudf::strings_column_view strings_column(*column_view); - cudf::jni::native_jstring pattern(env, patternObj); - return release_as_jlong(cudf::strings::contains_re(strings_column, pattern.get())); + auto const column_view = reinterpret_cast(j_view_handle); + auto const strings_column = cudf::strings_column_view{*column_view}; + auto const pattern = cudf::jni::native_jstring(env, pattern_obj); + auto const flags = static_cast(regex_flags); + auto const capture = static_cast(capture_groups); + auto const regex_prog = cudf::strings::regex_program::create(pattern.get(), flags, capture); + return release_as_jlong(cudf::strings::contains_re(strings_column, *regex_prog)); } CATCH_STD(env, 0); } @@ -1679,21 +1690,22 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_extractRe(JNIEnv *en CATCH_STD(env, 0); } -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_extractAllRecord(JNIEnv *env, jclass, - jlong j_view_handle, - jstring pattern_obj, - jint idx) { +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_extractAllRecord( + JNIEnv *env, jclass, jlong j_view_handle, jstring pattern_obj, jint regex_flags, + jint capture_groups, jint idx) { JNI_NULL_CHECK(env, j_view_handle, "column is null", 0); + JNI_NULL_CHECK(env, pattern_obj, "pattern is null", 0); try { cudf::jni::auto_set_device(env); - cudf::strings_column_view const strings_column{ - *reinterpret_cast(j_view_handle)}; - cudf::jni::native_jstring pattern(env, pattern_obj); - - auto result = (idx == 0) ? cudf::strings::findall(strings_column, pattern.get()) : - cudf::strings::extract_all_record(strings_column, pattern.get()); - + auto const column_view = reinterpret_cast(j_view_handle); + auto const strings_column = cudf::strings_column_view{*column_view}; + auto const pattern = cudf::jni::native_jstring(env, pattern_obj); + auto const flags = static_cast(regex_flags); + auto const groups = static_cast(capture_groups); + auto const regex_prog = cudf::strings::regex_program::create(pattern.get(), flags, groups); + auto result = (idx == 0) ? cudf::strings::findall(strings_column, *regex_prog) : + cudf::strings::extract_all_record(strings_column, *regex_prog); return release_as_jlong(result); } CATCH_STD(env, 0); diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index fc0a542e0a7..26817281c2e 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -4053,31 +4053,38 @@ void testExtractRe() { @Test void testExtractAllRecord() { String pattern = "([ab])(\\d)"; + RegexProgram regexProg = new RegexProgram(pattern); try (ColumnVector v = ColumnVector.fromStrings("a1", "b2", "c3", null, "a1b1c3a2"); - ColumnVector expectedIdx0 = ColumnVector.fromLists( - new HostColumnVector.ListType(true, - new HostColumnVector.BasicType(true, DType.STRING)), - Arrays.asList("a1"), - Arrays.asList("b2"), - Arrays.asList(), - null, - Arrays.asList("a1", "b1", "a2")); - ColumnVector expectedIdx12 = ColumnVector.fromLists( - new HostColumnVector.ListType(true, - new HostColumnVector.BasicType(true, DType.STRING)), - Arrays.asList("a", "1"), - Arrays.asList("b", "2"), - null, - null, - Arrays.asList("a", "1", "b", "1", "a", "2")); - - ColumnVector resultIdx0 = v.extractAllRecord(pattern, 0); - ColumnVector resultIdx1 = v.extractAllRecord(pattern, 1); - ColumnVector resultIdx2 = v.extractAllRecord(pattern, 2); - ) { - assertColumnsAreEqual(expectedIdx0, resultIdx0); - assertColumnsAreEqual(expectedIdx12, resultIdx1); - assertColumnsAreEqual(expectedIdx12, resultIdx2); + ColumnVector expectedIdx0 = ColumnVector.fromLists( + new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("a1"), + Arrays.asList("b2"), + Arrays.asList(), + null, + Arrays.asList("a1", "b1", "a2")); + ColumnVector expectedIdx12 = ColumnVector.fromLists( + new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("a", "1"), + Arrays.asList("b", "2"), + null, + null, + Arrays.asList("a", "1", "b", "1", "a", "2"))) { + try (ColumnVector resultIdx0 = v.extractAllRecord(pattern, 0); + ColumnVector resultIdx1 = v.extractAllRecord(pattern, 1); + ColumnVector resultIdx2 = v.extractAllRecord(pattern, 2)) { + assertColumnsAreEqual(expectedIdx0, resultIdx0); + assertColumnsAreEqual(expectedIdx12, resultIdx1); + assertColumnsAreEqual(expectedIdx12, resultIdx2); + } + try (ColumnVector resultIdx0 = v.extractAllRecord(regexProg, 0); + ColumnVector resultIdx1 = v.extractAllRecord(regexProg, 1); + ColumnVector resultIdx2 = v.extractAllRecord(regexProg, 2)) { + assertColumnsAreEqual(expectedIdx0, resultIdx0); + assertColumnsAreEqual(expectedIdx12, resultIdx1); + assertColumnsAreEqual(expectedIdx12, resultIdx2); + } } } @@ -4087,26 +4094,39 @@ void testMatchesRe() { String patternString2 = "[A-Za-z]+\\s@[A-Za-z]+"; String patternString3 = ".*"; String patternString4 = ""; + RegexProgram regexProg1 = new RegexProgram(patternString1, CaptureGroups.NON_CAPTURE); + RegexProgram regexProg2 = new RegexProgram(patternString2, CaptureGroups.NON_CAPTURE); + RegexProgram regexProg3 = new RegexProgram(patternString3, CaptureGroups.NON_CAPTURE); + RegexProgram regexProg4 = new RegexProgram(patternString4, CaptureGroups.NON_CAPTURE); try (ColumnVector testStrings = ColumnVector.fromStrings("", null, "abCD", "ovér the", - "lazy @dog", "1234", "00:0:00"); - ColumnVector res1 = testStrings.matchesRe(patternString1); - ColumnVector res2 = testStrings.matchesRe(patternString2); - ColumnVector res3 = testStrings.matchesRe(patternString3); + "lazy @dog", "1234", "00:0:00"); ColumnVector expected1 = ColumnVector.fromBoxedBooleans(false, null, false, false, false, - true, true); + true, true); ColumnVector expected2 = ColumnVector.fromBoxedBooleans(false, null, false, false, true, - false, false); + false, false); ColumnVector expected3 = ColumnVector.fromBoxedBooleans(true, null, true, true, true, - true, true)) { - assertColumnsAreEqual(expected1, res1); - assertColumnsAreEqual(expected2, res2); - assertColumnsAreEqual(expected3, res3); + true, true)) { + try (ColumnVector res1 = testStrings.matchesRe(patternString1); + ColumnVector res2 = testStrings.matchesRe(patternString2); + ColumnVector res3 = testStrings.matchesRe(patternString3)) { + assertColumnsAreEqual(expected1, res1); + assertColumnsAreEqual(expected2, res2); + assertColumnsAreEqual(expected3, res3); + } + try (ColumnVector res1 = testStrings.matchesRe(regexProg1); + ColumnVector res2 = testStrings.matchesRe(regexProg2); + ColumnVector res3 = testStrings.matchesRe(regexProg3)) { + assertColumnsAreEqual(expected1, res1); + assertColumnsAreEqual(expected2, res2); + assertColumnsAreEqual(expected3, res3); + } + assertThrows(AssertionError.class, () -> { + try (ColumnVector res = testStrings.matchesRe(patternString4)) {} + }); + assertThrows(AssertionError.class, () -> { + try (ColumnVector res = testStrings.matchesRe(regexProg4)) {} + }); } - assertThrows(AssertionError.class, () -> { - try (ColumnVector testStrings = ColumnVector.fromStrings("", null, "abCD", "ovér the", - "lazy @dog", "1234", "00:0:00"); - ColumnVector res = testStrings.matchesRe(patternString4)) {} - }); } @Test @@ -4115,36 +4135,54 @@ void testContainsRe() { String patternString2 = "[A-Za-z]+\\s@[A-Za-z]+"; String patternString3 = ".*"; String patternString4 = ""; + RegexProgram regexProg1 = new RegexProgram(patternString1, CaptureGroups.NON_CAPTURE); + RegexProgram regexProg2 = new RegexProgram(patternString2, CaptureGroups.NON_CAPTURE); + RegexProgram regexProg3 = new RegexProgram(patternString3, CaptureGroups.NON_CAPTURE); + RegexProgram regexProg4 = new RegexProgram(patternString4, CaptureGroups.NON_CAPTURE); try (ColumnVector testStrings = ColumnVector.fromStrings(null, "abCD", "ovér the", - "lazy @dog", "1234", "00:0:00", "abc1234abc", "there @are 2 lazy @dogs"); - ColumnVector res1 = testStrings.containsRe(patternString1); - ColumnVector res2 = testStrings.containsRe(patternString2); - ColumnVector res3 = testStrings.containsRe(patternString3); + "lazy @dog", "1234", "00:0:00", "abc1234abc", "there @are 2 lazy @dogs"); ColumnVector expected1 = ColumnVector.fromBoxedBooleans(null, false, false, false, true, true, true, true); ColumnVector expected2 = ColumnVector.fromBoxedBooleans(null, false, false, true, false, false, false, true); ColumnVector expected3 = ColumnVector.fromBoxedBooleans(null, true, true, true, true, true, true, true)) { - assertColumnsAreEqual(expected1, res1); - assertColumnsAreEqual(expected2, res2); - assertColumnsAreEqual(expected3, res3); + try (ColumnVector res1 = testStrings.containsRe(patternString1); + ColumnVector res2 = testStrings.containsRe(patternString2); + ColumnVector res3 = testStrings.containsRe(patternString3)) { + assertColumnsAreEqual(expected1, res1); + assertColumnsAreEqual(expected2, res2); + assertColumnsAreEqual(expected3, res3); + } + try (ColumnVector res1 = testStrings.containsRe(regexProg1); + ColumnVector res2 = testStrings.containsRe(regexProg2); + ColumnVector res3 = testStrings.containsRe(regexProg3)) { + assertColumnsAreEqual(expected1, res1); + assertColumnsAreEqual(expected2, res2); + assertColumnsAreEqual(expected3, res3); + } + } + try (ColumnVector testStrings = ColumnVector.fromStrings("", null, "abCD", "ovér the", + "lazy @dog", "1234", "00:0:00", "abc1234abc", "there @are 2 lazy @dogs")) { + assertThrows(AssertionError.class, () -> { + try (ColumnVector res = testStrings.containsRe(patternString4)) {} + }); + assertThrows(AssertionError.class, () -> { + try (ColumnVector res = testStrings.containsRe(regexProg4)) {} + }); } - assertThrows(AssertionError.class, () -> { - try (ColumnVector testStrings = ColumnVector.fromStrings("", null, "abCD", "ovér the", - "lazy @dog", "1234", "00:0:00", "abc1234abc", "there @are 2 lazy @dogs"); - ColumnVector res = testStrings.containsRe(patternString4)) {} - }); } @Test - @Disabled("Needs fix for https://github.com/rapidsai/cudf/issues/4671") void testContainsReEmptyInput() { String patternString1 = ".*"; + RegexProgram regexProg1 = new RegexProgram(patternString1, CaptureGroups.NON_CAPTURE); try (ColumnVector testStrings = ColumnVector.fromStrings(""); ColumnVector res1 = testStrings.containsRe(patternString1); + ColumnVector resReProg1 = testStrings.containsRe(regexProg1); ColumnVector expected1 = ColumnVector.fromBoxedBooleans(true)) { assertColumnsAreEqual(expected1, res1); + assertColumnsAreEqual(expected1, resReProg1); } }