diff --git a/java/src/main/java/ai/rapids/cudf/CaptureGroups.java b/java/src/main/java/ai/rapids/cudf/CaptureGroups.java new file mode 100644 index 00000000000..2ab778dbc35 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/CaptureGroups.java @@ -0,0 +1,36 @@ +/* + * + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +/** + * Capture groups setting, closely following cudf::strings::capture_groups. + * + * For processing a regex pattern containing capture groups. These can be used + * to optimize the generated regex instructions where the capture groups do not + * require extracting the groups. + */ +public enum CaptureGroups { + EXTRACT(0), // capture groups processed normally for extract + NON_CAPTURE(1); // convert all capture groups to non-capture groups + + final int nativeId; // Native id, for use with libcudf. + private CaptureGroups(int nativeId) { // Only constant values should be used + this.nativeId = nativeId; + } +} diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 47d6b7573cd..8ffe5b4aa09 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -2531,12 +2531,35 @@ public final ColumnVector stringLocate(Scalar substring, int start, int end) { * regular expression pattern or just by a string literal delimiter. * @return list of strings columns as a table. */ + @Deprecated public final Table stringSplit(String pattern, int limit, boolean splitByRegex) { + if (splitByRegex) { + RegexProgram regexProg = new RegexProgram(pattern, CaptureGroups.NON_CAPTURE); + return stringSplit(regexProg, limit); + } else { + return stringSplit(pattern, limit); + } + } + + /** + * Returns a list of columns by splitting each string using the specified regex program. The + * number of rows in the output columns will be the same as the input column. Null entries + * are added for a row where split results have been exhausted. Null input entries result in + * all nulls in the corresponding rows of the output columns. + * + * @param regexProg the regex program with UTF-8 encoded string identifying the split pattern + * for each input string. + * @param limit the maximum size of the list resulting from splitting each input string, + * or -1 for all possible splits. Note that limit = 0 (all possible splits without + * trailing empty strings) and limit = 1 (no split at all) are not supported. + * @return list of strings columns as a table. + */ + public final Table stringSplit(RegexProgram regexProg, int limit) { assert type.equals(DType.STRING) : "column type must be a String"; - assert pattern != null : "pattern is null"; - assert pattern.length() > 0 : "empty pattern is not supported"; + assert regexProg != null : "regex program is null"; assert limit != 0 && limit != 1 : "split limit == 0 and limit == 1 are not supported"; - return new Table(stringSplit(this.getNativeView(), pattern, limit, splitByRegex)); + return new Table(stringSplit(this.getNativeView(), regexProg.pattern(), regexProg.combinedFlags(), + regexProg.capture().nativeId, limit, true)); } /** @@ -2550,6 +2573,7 @@ public final Table stringSplit(String pattern, int limit, boolean splitByRegex) * regular expression pattern or just by a string literal delimiter. * @return list of strings columns as a table. */ + @Deprecated public final Table stringSplit(String pattern, boolean splitByRegex) { return stringSplit(pattern, -1, splitByRegex); } @@ -2567,7 +2591,11 @@ public final Table stringSplit(String pattern, boolean splitByRegex) { * @return list of strings columns as a table. */ public final Table stringSplit(String delimiter, int limit) { - return stringSplit(delimiter, limit, false); + assert type.equals(DType.STRING) : "column type must be a String"; + assert delimiter != null : "delimiter is null"; + assert limit != 0 && limit != 1 : "split limit == 0 and limit == 1 are not supported"; + return new Table(stringSplit(this.getNativeView(), delimiter, RegexFlag.DEFAULT.nativeId, + CaptureGroups.NON_CAPTURE.nativeId, limit, false)); } /** @@ -2580,7 +2608,21 @@ public final Table stringSplit(String delimiter, int limit) { * @return list of strings columns as a table. */ public final Table stringSplit(String delimiter) { - return stringSplit(delimiter, -1, false); + return stringSplit(delimiter, -1); + } + + /** + * Returns a list of columns by splitting each string using the specified regex program with + * string literal delimiter. The number of rows in the output columns will be the same as the + * input column. Null entries are added for a row where split results have been exhausted. + * Null input entries result in all nulls in the corresponding rows of the output columns. + * + * @param regexProg the regex program with UTF-8 encoded string identifying the split pattern + * for each input string. + * @return list of strings columns as a table. + */ + public final Table stringSplit(RegexProgram regexProg) { + return stringSplit(regexProg, -1); } /** @@ -2595,13 +2637,34 @@ public final Table stringSplit(String delimiter) { * regular expression pattern or just by a string literal delimiter. * @return a LIST column of string elements. */ + @Deprecated public final ColumnVector stringSplitRecord(String pattern, int limit, boolean splitByRegex) { + if (splitByRegex) { + RegexProgram regexProg = new RegexProgram(pattern, CaptureGroups.NON_CAPTURE); + return stringSplitRecord(regexProg, limit); + } else { + return stringSplitRecord(pattern, limit); + } + } + + /** + * Returns a column that are lists of strings in which each list is made by splitting the + * corresponding input string using the specified regex program pattern. + * + * @param regexProg the regex program with UTF-8 encoded string identifying the split pattern + * for each input string. + * @param limit the maximum size of the list resulting from splitting each input string, + * or -1 for all possible splits. Note that limit = 0 (all possible splits without + * trailing empty strings) and limit = 1 (no split at all) are not supported. + * @return a LIST column of string elements. + */ + public final ColumnVector stringSplitRecord(RegexProgram regexProg, int limit) { assert type.equals(DType.STRING) : "column type must be String"; - assert pattern != null : "pattern is null"; - assert pattern.length() > 0 : "empty pattern is not supported"; + assert regexProg != null : "regex program is null"; assert limit != 0 && limit != 1 : "split limit == 0 and limit == 1 are not supported"; return new ColumnVector( - stringSplitRecord(this.getNativeView(), pattern, limit, splitByRegex)); + stringSplitRecord(this.getNativeView(), regexProg.pattern(), regexProg.combinedFlags(), + regexProg.capture().nativeId, limit, true)); } /** @@ -2613,6 +2676,7 @@ public final ColumnVector stringSplitRecord(String pattern, int limit, boolean s * regular expression pattern or just by a string literal delimiter. * @return a LIST column of string elements. */ + @Deprecated public final ColumnVector stringSplitRecord(String pattern, boolean splitByRegex) { return stringSplitRecord(pattern, -1, splitByRegex); } @@ -2628,7 +2692,12 @@ public final ColumnVector stringSplitRecord(String pattern, boolean splitByRegex * @return a LIST column of string elements. */ public final ColumnVector stringSplitRecord(String delimiter, int limit) { - return stringSplitRecord(delimiter, limit, false); + assert type.equals(DType.STRING) : "column type must be String"; + assert delimiter != null : "delimiter is null"; + assert limit != 0 && limit != 1 : "split limit == 0 and limit == 1 are not supported"; + return new ColumnVector( + stringSplitRecord(this.getNativeView(), delimiter, RegexFlag.DEFAULT.nativeId, + CaptureGroups.NON_CAPTURE.nativeId, limit, false)); } /** @@ -2639,7 +2708,19 @@ public final ColumnVector stringSplitRecord(String delimiter, int limit) { * @return a LIST column of string elements. */ public final ColumnVector stringSplitRecord(String delimiter) { - return stringSplitRecord(delimiter, -1, false); + return stringSplitRecord(delimiter, -1); + } + + /** + * Returns a column that are lists of strings in which each list is made by splitting the + * corresponding input string using the specified regex program with string literal delimiter. + * + * @param regexProg the regex program with UTF-8 encoded string identifying the split pattern + * for each input string. + * @return a LIST column of string elements. + */ + public final ColumnVector stringSplitRecord(RegexProgram regexProg) { + return stringSplitRecord(regexProg, -1); } /** @@ -2846,10 +2927,23 @@ public final ColumnVector stringReplace(Scalar target, Scalar replace) { * @param repl The string scalar to replace for each pattern match. * @return A new column vector containing the string results. */ + @Deprecated public final ColumnVector replaceRegex(String pattern, Scalar repl) { return replaceRegex(pattern, repl, -1); } + /** + * For each string, replaces any character sequence matching the given regex program pattern + * using the replacement string scalar. + * + * @param regexProg The regex program with pattern to search within each string. + * @param repl The string scalar to replace for each pattern match. + * @return A new column vector containing the string results. + */ + public final ColumnVector replaceRegex(RegexProgram regexProg, Scalar repl) { + return replaceRegex(regexProg, repl, -1); + } + /** * For each string, replaces any character sequence matching the given pattern using the * replacement string scalar. @@ -2859,12 +2953,27 @@ public final ColumnVector replaceRegex(String pattern, Scalar repl) { * @param maxRepl The maximum number of times a replacement should occur within each string. * @return A new column vector containing the string results. */ + @Deprecated public final ColumnVector replaceRegex(String pattern, Scalar repl, int maxRepl) { + return replaceRegex(new RegexProgram(pattern, CaptureGroups.NON_CAPTURE), repl, maxRepl); + } + + /** + * For each string, replaces any character sequence matching the given regex program pattern + * using the replacement string scalar. + * + * @param regexProg The regex program with pattern to search within each string. + * @param repl The string scalar to replace for each pattern match. + * @param maxRepl The maximum number of times a replacement should occur within each string. + * @return A new column vector containing the string results. + */ + public final ColumnVector replaceRegex(RegexProgram regexProg, Scalar repl, int maxRepl) { if (!repl.getType().equals(DType.STRING)) { throw new IllegalArgumentException("Replacement must be a string scalar"); } - return new ColumnVector(replaceRegex(getNativeView(), pattern, repl.getScalarHandle(), - maxRepl)); + assert regexProg != null : "regex program may not be null"; + return new ColumnVector(replaceRegex(getNativeView(), regexProg.pattern(), regexProg.combinedFlags(), + regexProg.capture().nativeId, repl.getScalarHandle(), maxRepl)); } /** @@ -2890,9 +2999,25 @@ public final ColumnVector replaceMultiRegex(String[] patterns, ColumnView repls) * @param replace The replacement template for creating the output string. * @return A new java column vector containing the string results. */ + @Deprecated public final ColumnVector stringReplaceWithBackrefs(String pattern, String replace) { - return new ColumnVector(stringReplaceWithBackrefs(getNativeView(), pattern, - replace)); + return stringReplaceWithBackrefs(new RegexProgram(pattern), replace); + } + + /** + * For each string, replaces any character sequence matching the given regex program + * pattern using the replace template for back-references. + * + * Any null string entries return corresponding null output column entries. + * + * @param regexProg The regex program with pattern to search within each string. + * @param replace The replacement template for creating the output string. + * @return A new java column vector containing the string results. + */ + public final ColumnVector stringReplaceWithBackrefs(RegexProgram regexProg, String replace) { + assert regexProg != null : "regex program may not be null"; + return new ColumnVector(stringReplaceWithBackrefs(getNativeView(), regexProg.pattern(), + regexProg.combinedFlags(), regexProg.capture().nativeId, replace)); } /** @@ -3164,11 +3289,32 @@ public final ColumnVector clamp(Scalar lo, Scalar loReplace, Scalar hi, Scalar h * @param pattern Regex pattern to match to each string. * @return New ColumnVector of boolean results for each string. */ + @Deprecated public final ColumnVector matchesRe(String pattern) { + return matchesRe(new RegexProgram(pattern, CaptureGroups.NON_CAPTURE)); + } + + /** + * Returns a boolean ColumnVector identifying rows which + * match the given regex program but only at the beginning of the string. + * + * ``` + * cv = ["abc","123","def456"] + * result = cv.matches_re("\\d+") + * r is now [false, true, false] + * ``` + * Any null string entries return corresponding null output column entries. + * For supported regex patterns refer to: + * @link https://docs.rapids.ai/api/libcudf/nightly/md_regex.html + * + * @param regexProg Regex program to match to each string. + * @return New ColumnVector of boolean results for each string. + */ + public final ColumnVector matchesRe(RegexProgram regexProg) { assert type.equals(DType.STRING) : "column type must be a String"; - assert pattern != null : "pattern may not be null"; - assert !pattern.isEmpty() : "pattern string may not be empty"; - return new ColumnVector(matchesRe(getNativeView(), pattern)); + assert regexProg != null : "regex program may not be null"; + assert !regexProg.pattern().isEmpty() : "pattern string may not be empty"; + return new ColumnVector(matchesRe(getNativeView(), regexProg.pattern(), regexProg.combinedFlags(), regexProg.capture().nativeId)); } /** @@ -3177,7 +3323,7 @@ public final ColumnVector matchesRe(String pattern) { * * ``` * cv = ["abc","123","def456"] - * result = cv.matches_re("\\d+") + * result = cv.contains_re("\\d+") * r is now [false, true, true] * ``` * Any null string entries return corresponding null output column entries. @@ -3187,11 +3333,32 @@ public final ColumnVector matchesRe(String pattern) { * @param pattern Regex pattern to match to each string. * @return New ColumnVector of boolean results for each string. */ + @Deprecated public final ColumnVector containsRe(String pattern) { + return containsRe(new RegexProgram(pattern, CaptureGroups.NON_CAPTURE)); + } + + /** + * Returns a boolean ColumnVector identifying rows which + * match the given RegexProgram object starting at any location. + * + * ``` + * cv = ["abc","123","def456"] + * result = cv.contains_re("\\d+") + * r is now [false, true, true] + * ``` + * Any null string entries return corresponding null output column entries. + * For supported regex patterns refer to: + * @link https://docs.rapids.ai/api/libcudf/nightly/md_regex.html + * + * @param regexProg Regex program to match to each string. + * @return New ColumnVector of boolean results for each string. + */ + public final ColumnVector containsRe(RegexProgram regexProg) { assert type.equals(DType.STRING) : "column type must be a String"; - assert pattern != null : "pattern may not be null"; - assert !pattern.isEmpty() : "pattern string may not be empty"; - return new ColumnVector(containsRe(getNativeView(), pattern)); + assert regexProg != null : "regex program may not be null"; + assert !regexProg.pattern().isEmpty() : "pattern string may not be empty"; + return new ColumnVector(containsRe(getNativeView(), regexProg.pattern(), regexProg.combinedFlags(), regexProg.capture().nativeId)); } /** @@ -3206,10 +3373,27 @@ public final ColumnVector containsRe(String pattern) { * @throws CudfException if any error happens including if the RE does * not contain any capture groups. */ + @Deprecated public final Table extractRe(String pattern) throws CudfException { + return extractRe(new RegexProgram(pattern)); + } + + /** + * For each captured group specified in the given regex program + * return a column in the table. Null entries are added if the string + * does not match. Any null inputs also result in null output entries. + * + * For supported regex patterns refer to: + * @link https://docs.rapids.ai/api/libcudf/nightly/md_regex.html + * @param regexProg the regex program to use + * @return the table of extracted matches + * @throws CudfException if any error happens including if the RE does + * not contain any capture groups. + */ + public final Table extractRe(RegexProgram regexProg) throws CudfException { assert type.equals(DType.STRING) : "column type must be a String"; - assert pattern != null : "pattern may not be null"; - return new Table(extractRe(this.getNativeView(), pattern)); + assert regexProg != null : "regex program may not be null"; + return new Table(extractRe(this.getNativeView(), regexProg.pattern(), regexProg.combinedFlags(), regexProg.capture().nativeId)); } /** @@ -3222,11 +3406,28 @@ public final Table extractRe(String pattern) throws CudfException { * @param idx The regex group index * @return A new column vector of extracted matches */ + @Deprecated public final ColumnVector extractAllRecord(String pattern, int idx) { + return extractAllRecord(new RegexProgram(pattern), idx); + } + + /** + * Extracts all strings that match the given regex program and corresponds to the + * regular expression group index. Any null inputs also result in null output entries. + * + * For supported regex patterns refer to: + * @link https://docs.rapids.ai/api/libcudf/nightly/md_regex.html + * @param regexProg The regex program + * @param idx The regex group index + * @return A new column vector of extracted matches + */ + public final ColumnVector extractAllRecord(RegexProgram regexProg, int idx) { assert type.equals(DType.STRING) : "column type must be a String"; assert idx >= 0 : "group index must be at least 0"; - - return new ColumnVector(extractAllRecord(this.getNativeView(), pattern, idx)); + assert regexProg != null : "regex program may not be null"; + return new ColumnVector( + extractAllRecord(this.getNativeView(), regexProg.pattern(), regexProg.combinedFlags(), + regexProg.capture().nativeId, idx)); } /** @@ -3881,14 +4082,16 @@ private static native long repeatStringsWithColumnRepeatTimes(long stringsHandle * * @param nativeHandle native handle of the input strings column that being operated on. * @param pattern UTF-8 encoded string identifying the split pattern for each input string. + * @param flags regex flags setting. + * @param capture capture groups setting. * @param limit the maximum size of the list resulting from splitting each input string, * or -1 for all possible splits. Note that limit = 0 (all possible splits without * trailing empty strings) and limit = 1 (no split at all) are not supported. * @param splitByRegex a boolean flag indicating whether the input strings will be split by a * regular expression pattern or just by a string literal delimiter. */ - private static native long[] stringSplit(long nativeHandle, String pattern, int limit, - boolean splitByRegex); + private static native long[] stringSplit(long nativeHandle, String pattern, int flags, + int capture, int limit, boolean splitByRegex); /** * Returns a column that are lists of strings in which each list is made by splitting the @@ -3896,14 +4099,16 @@ private static native long[] stringSplit(long nativeHandle, String pattern, int * * @param nativeHandle native handle of the input strings column that being operated on. * @param pattern UTF-8 encoded string identifying the split pattern for each input string. + * @param flags regex flags setting. + * @param capture capture groups setting. * @param limit the maximum size of the list resulting from splitting each input string, * or -1 for all possible splits. Note that limit = 0 (all possible splits without * trailing empty strings) and limit = 1 (no split at all) are not supported. * @param splitByRegex a boolean flag indicating whether the input strings will be split by a * regular expression pattern or just by a string literal delimiter. */ - private static native long stringSplitRecord(long nativeHandle, String pattern, int limit, - boolean splitByRegex); + private static native long stringSplitRecord(long nativeHandle, String pattern, int flags, + int capture, int limit, boolean splitByRegex); /** * Native method to calculate substring from a given string column. 0 indexing. @@ -3941,12 +4146,14 @@ private static native long substringColumn(long columnView, long startColumn, lo * Native method for replacing each regular expression pattern match with the specified * replacement string. * @param columnView native handle of the cudf::column_view being operated on. - * @param pattern The regular expression pattern to search within each string. + * @param pattern regular expression pattern to search within each string. + * @param flags regex flags setting. + * @param capture capture groups setting. * @param repl native handle of the cudf::scalar containing the replacement string. * @param maxRepl maximum number of times to replace the pattern within a string * @return native handle of the resulting cudf column containing the string results. */ - private static native long replaceRegex(long columnView, String pattern, + private static native long replaceRegex(long columnView, String pattern, int flags, int capture, long repl, long maxRepl) throws CudfException; /** @@ -3960,15 +4167,17 @@ private static native long replaceMultiRegex(long columnView, String[] patterns, long repls) throws CudfException; /** - * Native method for replacing any character sequence matching the given pattern - * using the replace template for back-references. + * Native method for replacing any character sequence matching the given regex program + * pattern using the replace template for back-references. * @param columnView native handle of the cudf::column_view being operated on. - * @param pattern The regular expression patterns to search within each string. + * @param pattern regular expression pattern to search within each string. + * @param flags regex flags setting. + * @param capture capture groups setting. * @param replace The replacement template for creating the output string. * @return native handle of the resulting cudf column containing the string results. */ - private static native long stringReplaceWithBackrefs(long columnView, String pattern, - String replace) throws CudfException; + private static native long stringReplaceWithBackrefs(long columnView, String pattern, int flags, + int capture, String replace) throws CudfException; /** * Native method for checking if strings in a column starts with a specified comparison string. @@ -3995,21 +4204,25 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat private static native long stringStrip(long columnView, int type, long toStrip) throws CudfException; /** - * Native method for checking if strings match the passed in regex pattern from the + * Native method for checking if strings match the passed in regex program from the * beginning of the string. * @param cudfViewHandle native handle of the cudf::column_view being operated on. * @param pattern string regex pattern. + * @param flags regex flags setting. + * @param capture capture groups setting. * @return native handle of the resulting cudf column containing the boolean results. */ - private static native long matchesRe(long cudfViewHandle, String pattern) throws CudfException; + private static native long matchesRe(long cudfViewHandle, String pattern, int flags, int capture) throws CudfException; /** - * Native method for checking if strings match the passed in regex pattern starting at any location. + * Native method for checking if strings match the passed in regex program starting at any location. * @param cudfViewHandle native handle of the cudf::column_view being operated on. * @param pattern string regex pattern. + * @param flags regex flags setting. + * @param capture capture groups setting. * @return native handle of the resulting cudf column containing the boolean results. */ - private static native long containsRe(long cudfViewHandle, String pattern) throws CudfException; + private static native long containsRe(long cudfViewHandle, String pattern, int flags, int capture) throws CudfException; /** * Native method for checking if strings match the passed in like pattern @@ -4030,19 +4243,21 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat private static native long stringContains(long cudfViewHandle, long compString) throws CudfException; /** - * Native method for extracting results from an regular expressions. Returns a table handle. + * Native method for extracting results from a regex program. Returns a table handle. */ - private static native long[] extractRe(long cudfViewHandle, String pattern) throws CudfException; + private static native long[] extractRe(long cudfViewHandle, String pattern, int flags, int capture) throws CudfException; /** - * Native method for extracting all results corresponding to group idx from a regular expression. + * Native method for extracting all results corresponding to group idx from a regex program. * * @param nativeHandle Native handle of the cudf::column_view being operated on. - * @param pattern String regex pattern. + * @param pattern string regex pattern. + * @param flags regex flags setting. + * @param capture capture groups setting. * @param idx Regex group index. A 0 value means matching the entire regex. * @return Native handle of a string column of the result. */ - private static native long extractAllRecord(long nativeHandle, String pattern, int idx); + private static native long extractAllRecord(long nativeHandle, String pattern, int flags, int capture, int idx); private static native long urlDecode(long cudfViewHandle); diff --git a/java/src/main/java/ai/rapids/cudf/RegexFlag.java b/java/src/main/java/ai/rapids/cudf/RegexFlag.java new file mode 100644 index 00000000000..7ed8e0354c9 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/RegexFlag.java @@ -0,0 +1,37 @@ +/* + * + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package ai.rapids.cudf; + +/** + * Regex flags setting, closely following cudf::strings::regex_flags. + * + * These types can be or'd to combine them. The values are chosen to + * leave room for future flags and to match the Python flag values. + */ +public enum RegexFlag { + DEFAULT(0), // default + MULTILINE(8), // the '^' and '$' honor new-line characters + DOTALL(16), // the '.' matching includes new-line characters + ASCII(256); // use only ASCII when matching built-in character classes + + final int nativeId; // Native id, for use with libcudf. + private RegexFlag(int nativeId) { // Only constant values should be used + this.nativeId = nativeId; + } +} diff --git a/java/src/main/java/ai/rapids/cudf/RegexProgram.java b/java/src/main/java/ai/rapids/cudf/RegexProgram.java new file mode 100644 index 00000000000..358eea8ba43 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/RegexProgram.java @@ -0,0 +1,130 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.rapids.cudf; + +import java.util.EnumSet; + +/** + * Regex program class, closely following cudf::strings::regex_program. + */ +public class RegexProgram { + private String pattern; // regex pattern + private EnumSet flags; // regex flags for interpreting special characters in the pattern + // controls how capture groups in the pattern are used + // default is to extract a capture group + private CaptureGroups capture; + + /** + * Constructor for RegexProgram + * + * @param pattern Regex pattern + */ + public RegexProgram(String pattern) { + this(pattern, EnumSet.of(RegexFlag.DEFAULT), CaptureGroups.EXTRACT); + } + + /** + * Constructor for RegexProgram + * + * @param pattern Regex pattern + * @param flags Regex flags setting + */ + public RegexProgram(String pattern, EnumSet flags) { + this(pattern, flags, CaptureGroups.EXTRACT); + } + + /** + * Constructor for RegexProgram + * + * @param pattern Regex pattern setting + * @param capture Capture groups setting + */ + public RegexProgram(String pattern, CaptureGroups capture) { + this(pattern, EnumSet.of(RegexFlag.DEFAULT), capture); + } + + /** + * Constructor for RegexProgram + * + * @param pattern Regex pattern + * @param flags Regex flags setting + * @param capture Capture groups setting + */ + public RegexProgram(String pattern, EnumSet flags, CaptureGroups capture) { + assert pattern != null : "pattern may not be null"; + this.pattern = pattern; + this.flags = flags; + this.capture = capture; + } + + /** + * Get the pattern used to create this instance + * + * @param return A regex pattern as a string + */ + public String pattern() { + return pattern; + } + + /** + * Get the regex flags setting used to create this instance + * + * @param return Regex flags setting + */ + public EnumSet flags() { + return flags; + } + + /** + * Reset the regex flags setting for this instance + * + * @param flags Regex flags setting + */ + public void setFlags(EnumSet flags) { + this.flags = flags; + } + + /** + * Get the capture groups setting used to create this instance + * + * @param return Capture groups setting + */ + public CaptureGroups capture() { + return capture; + } + + /** + * Reset the capture groups setting for this instance + * + * @param capture Capture groups setting + */ + public void setCapture(CaptureGroups capture) { + this.capture = capture; + } + + /** + * Combine the regex flags using 'or' + * + * @param return An integer representing the value of combined (or'ed) flags + */ + public int combinedFlags() { + int allFlags = 0; + for (RegexFlag flag : flags) { + allFlags = allFlags | flag.nativeId; + } + return allFlags; + } +} diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index b48ddae196b..c17e16bce73 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -62,6 +62,7 @@ #include #include #include +#include #include #include #include @@ -678,11 +679,9 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_reverseStringsOrLists(JNI CATCH_STD(env, 0); } -JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_stringSplit(JNIEnv *env, jclass, - jlong input_handle, - jstring pattern_obj, - jint limit, - jboolean split_by_regex) { +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_stringSplit( + JNIEnv *env, jclass, jlong input_handle, jstring pattern_obj, jint regex_flags, + jint capture_groups, jint limit, jboolean split_by_regex) { JNI_NULL_CHECK(env, input_handle, "input_handle is null", 0); if (limit == 0 || limit == 1) { @@ -696,31 +695,25 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_stringSplit(JNIEnv * try { cudf::jni::auto_set_device(env); - auto const input = reinterpret_cast(input_handle); - auto const strs_input = cudf::strings_column_view{*input}; - + auto const column_view = reinterpret_cast(input_handle); + auto const strings_column = cudf::strings_column_view{*column_view}; auto const pattern_jstr = cudf::jni::native_jstring(env, pattern_obj); - if (pattern_jstr.is_empty()) { - // Java's split API produces different behaviors than cudf when splitting with empty - // pattern. - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "Empty pattern is not supported", 0); - } - auto const pattern = std::string(pattern_jstr.get(), pattern_jstr.size_bytes()); auto const max_split = limit > 1 ? limit - 1 : limit; + auto const flags = static_cast(regex_flags); + auto const groups = static_cast(capture_groups); + auto const regex_prog = cudf::strings::regex_program::create(pattern, flags, groups); auto result = split_by_regex ? - cudf::strings::split_re(strs_input, pattern, max_split) : - cudf::strings::split(strs_input, cudf::string_scalar{pattern}, max_split); + cudf::strings::split_re(strings_column, *regex_prog, max_split) : + cudf::strings::split(strings_column, cudf::string_scalar{pattern}, max_split); return cudf::jni::convert_table_for_return(env, std::move(result)); } CATCH_STD(env, 0); } -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringSplitRecord(JNIEnv *env, jclass, - jlong input_handle, - jstring pattern_obj, - jint limit, - jboolean split_by_regex) { +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringSplitRecord( + JNIEnv *env, jclass, jlong input_handle, jstring pattern_obj, jint regex_flags, + jint capture_groups, jint limit, jboolean split_by_regex) { JNI_NULL_CHECK(env, input_handle, "input_handle is null", 0); if (limit == 0 || limit == 1) { @@ -734,22 +727,18 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringSplitRecord(JNIEnv try { cudf::jni::auto_set_device(env); - auto const input = reinterpret_cast(input_handle); - auto const strs_input = cudf::strings_column_view{*input}; - + auto const column_view = reinterpret_cast(input_handle); + auto const strings_column = cudf::strings_column_view{*column_view}; auto const pattern_jstr = cudf::jni::native_jstring(env, pattern_obj); - if (pattern_jstr.is_empty()) { - // Java's split API produces different behaviors than cudf when splitting with empty - // pattern. - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "Empty pattern is not supported", 0); - } - auto const pattern = std::string(pattern_jstr.get(), pattern_jstr.size_bytes()); + auto const flags = static_cast(regex_flags); + auto const groups = static_cast(capture_groups); + auto const regex_prog = cudf::strings::regex_program::create(pattern, flags, groups); auto const max_split = limit > 1 ? limit - 1 : limit; auto result = split_by_regex ? - cudf::strings::split_record_re(strs_input, pattern, max_split) : - cudf::strings::split_record(strs_input, cudf::string_scalar{pattern}, max_split); + cudf::strings::split_record_re(strings_column, *regex_prog, max_split) : + cudf::strings::split_record(strings_column, cudf::string_scalar{pattern}, max_split); return release_as_jlong(result); } CATCH_STD(env, 0); @@ -1290,32 +1279,42 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringContains(JNIEnv *en JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_matchesRe(JNIEnv *env, jobject j_object, jlong j_view_handle, - jstring patternObj) { + jstring pattern_obj, + jint regex_flags, + jint capture_groups) { JNI_NULL_CHECK(env, j_view_handle, "column is null", false); - JNI_NULL_CHECK(env, patternObj, "pattern is null", false); + JNI_NULL_CHECK(env, pattern_obj, "pattern is null", false); try { cudf::jni::auto_set_device(env); - cudf::column_view *column_view = reinterpret_cast(j_view_handle); - cudf::strings_column_view strings_column(*column_view); - cudf::jni::native_jstring pattern(env, patternObj); - return release_as_jlong(cudf::strings::matches_re(strings_column, pattern.get())); + auto const column_view = reinterpret_cast(j_view_handle); + auto const strings_column = cudf::strings_column_view{*column_view}; + auto const pattern = cudf::jni::native_jstring(env, pattern_obj); + auto const flags = static_cast(regex_flags); + auto const groups = static_cast(capture_groups); + auto const regex_prog = cudf::strings::regex_program::create(pattern.get(), flags, groups); + return release_as_jlong(cudf::strings::matches_re(strings_column, *regex_prog)); } CATCH_STD(env, 0); } JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_containsRe(JNIEnv *env, jobject j_object, jlong j_view_handle, - jstring patternObj) { + jstring pattern_obj, + jint regex_flags, + jint capture_groups) { JNI_NULL_CHECK(env, j_view_handle, "column is null", false); - JNI_NULL_CHECK(env, patternObj, "pattern is null", false); + JNI_NULL_CHECK(env, pattern_obj, "pattern is null", false); try { cudf::jni::auto_set_device(env); - cudf::column_view *column_view = reinterpret_cast(j_view_handle); - cudf::strings_column_view strings_column(*column_view); - cudf::jni::native_jstring pattern(env, patternObj); - return release_as_jlong(cudf::strings::contains_re(strings_column, pattern.get())); + auto const column_view = reinterpret_cast(j_view_handle); + auto const strings_column = cudf::strings_column_view{*column_view}; + auto const pattern = cudf::jni::native_jstring(env, pattern_obj); + auto const flags = static_cast(regex_flags); + auto const capture = static_cast(capture_groups); + auto const regex_prog = cudf::strings::regex_program::create(pattern.get(), flags, capture); + return release_as_jlong(cudf::strings::contains_re(strings_column, *regex_prog)); } CATCH_STD(env, 0); } @@ -1555,21 +1554,24 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_mapContains(JNIEnv *env, CATCH_STD(env, 0); } -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceRegex(JNIEnv *env, jclass, - jlong j_column_view, - jstring j_pattern, jlong j_repl, - jlong j_maxrepl) { +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceRegex( + JNIEnv *env, jclass, jlong j_column_view, jstring j_pattern, jint regex_flags, + jint capture_groups, jlong j_repl, jlong j_maxrepl) { JNI_NULL_CHECK(env, j_column_view, "column is null", 0); JNI_NULL_CHECK(env, j_pattern, "pattern string is null", 0); JNI_NULL_CHECK(env, j_repl, "replace scalar is null", 0); try { cudf::jni::auto_set_device(env); - auto cv = reinterpret_cast(j_column_view); - cudf::strings_column_view scv(*cv); - cudf::jni::native_jstring pattern(env, j_pattern); - auto repl = reinterpret_cast(j_repl); - return release_as_jlong(cudf::strings::replace_re(scv, pattern.get(), *repl, j_maxrepl)); + auto const column_view = reinterpret_cast(j_column_view); + auto const strings_column = cudf::strings_column_view{*column_view}; + auto const pattern = cudf::jni::native_jstring(env, j_pattern); + auto const flags = static_cast(regex_flags); + auto const groups = static_cast(capture_groups); + auto const regex_prog = cudf::strings::regex_program::create(pattern.get(), flags, groups); + auto const repl = reinterpret_cast(j_repl); + return release_as_jlong( + cudf::strings::replace_re(strings_column, *regex_prog, *repl, j_maxrepl)); } CATCH_STD(env, 0); } @@ -1595,19 +1597,23 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceMultiRegex(JNIEnv } JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringReplaceWithBackrefs( - JNIEnv *env, jclass, jlong column_view, jstring patternObj, jstring replaceObj) { + JNIEnv *env, jclass, jlong column_view, jstring pattern_obj, jint regex_flags, + jint capture_groups, jstring replaceObj) { JNI_NULL_CHECK(env, column_view, "column is null", 0); - JNI_NULL_CHECK(env, patternObj, "pattern string is null", 0); + JNI_NULL_CHECK(env, pattern_obj, "pattern string is null", 0); JNI_NULL_CHECK(env, replaceObj, "replace string is null", 0); try { cudf::jni::auto_set_device(env); - cudf::column_view *cv = reinterpret_cast(column_view); - cudf::strings_column_view scv(*cv); - cudf::jni::native_jstring ss_pattern(env, patternObj); + auto const cv = reinterpret_cast(column_view); + auto const strings_column = cudf::strings_column_view{*cv}; + auto const pattern = cudf::jni::native_jstring(env, pattern_obj); + auto const flags = static_cast(regex_flags); + auto const groups = static_cast(capture_groups); + auto const regex_prog = cudf::strings::regex_program::create(pattern.get(), flags, groups); cudf::jni::native_jstring ss_replace(env, replaceObj); return release_as_jlong( - cudf::strings::replace_with_backrefs(scv, ss_pattern.get(), ss_replace.get())); + cudf::strings::replace_with_backrefs(strings_column, *regex_prog, ss_replace.get())); } CATCH_STD(env, 0); } @@ -1663,37 +1669,42 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringStrip(JNIEnv *env, JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_extractRe(JNIEnv *env, jclass, jlong j_view_handle, - jstring patternObj) { + jstring pattern_obj, + jint regex_flags, + jint capture_groups) { JNI_NULL_CHECK(env, j_view_handle, "column is null", nullptr); - JNI_NULL_CHECK(env, patternObj, "pattern is null", nullptr); + JNI_NULL_CHECK(env, pattern_obj, "pattern is null", nullptr); try { cudf::jni::auto_set_device(env); - cudf::strings_column_view const strings_column{ - *reinterpret_cast(j_view_handle)}; - cudf::jni::native_jstring pattern(env, patternObj); - - return cudf::jni::convert_table_for_return( - env, cudf::strings::extract(strings_column, pattern.get())); + auto const column_view = reinterpret_cast(j_view_handle); + auto const strings_column = cudf::strings_column_view{*column_view}; + auto const pattern = cudf::jni::native_jstring(env, pattern_obj); + auto const flags = static_cast(regex_flags); + auto const groups = static_cast(capture_groups); + auto const regex_prog = cudf::strings::regex_program::create(pattern.get(), flags, groups); + return cudf::jni::convert_table_for_return(env, + cudf::strings::extract(strings_column, *regex_prog)); } CATCH_STD(env, 0); } -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_extractAllRecord(JNIEnv *env, jclass, - jlong j_view_handle, - jstring pattern_obj, - jint idx) { +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_extractAllRecord( + JNIEnv *env, jclass, jlong j_view_handle, jstring pattern_obj, jint regex_flags, + jint capture_groups, jint idx) { JNI_NULL_CHECK(env, j_view_handle, "column is null", 0); + JNI_NULL_CHECK(env, pattern_obj, "pattern is null", 0); try { cudf::jni::auto_set_device(env); - cudf::strings_column_view const strings_column{ - *reinterpret_cast(j_view_handle)}; - cudf::jni::native_jstring pattern(env, pattern_obj); - - auto result = (idx == 0) ? cudf::strings::findall(strings_column, pattern.get()) : - cudf::strings::extract_all_record(strings_column, pattern.get()); - + auto const column_view = reinterpret_cast(j_view_handle); + auto const strings_column = cudf::strings_column_view{*column_view}; + auto const pattern = cudf::jni::native_jstring(env, pattern_obj); + auto const flags = static_cast(regex_flags); + auto const groups = static_cast(capture_groups); + auto const regex_prog = cudf::strings::regex_program::create(pattern.get(), flags, groups); + auto result = (idx == 0) ? cudf::strings::findall(strings_column, *regex_prog) : + cudf::strings::extract_all_record(strings_column, *regex_prog); return release_as_jlong(result); } CATCH_STD(env, 0); diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index fc0a542e0a7..5b846545906 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -4040,41 +4040,50 @@ void testStringFindOperations() { @Test void testExtractRe() { - try (ColumnVector input = ColumnVector.fromStrings("a1", "b2", "c3", null); - Table expected = new Table.TestBuilder() - .column("a", "b", null, null) - .column("1", "2", null, null) - .build(); - Table found = input.extractRe("([ab])(\\d)")) { - assertTablesAreEqual(expected, found); - } + ColumnVector input = ColumnVector.fromStrings("a1", "b2", "c3", null); + Table expected = new Table.TestBuilder() + .column("a", "b", null, null) + .column("1", "2", null, null) + .build(); + try (Table found = input.extractRe("([ab])(\\d)")) { + assertTablesAreEqual(expected, found); + } + try (Table found = input.extractRe(new RegexProgram("([ab])(\\d)"))) { + assertTablesAreEqual(expected, found); + } } @Test void testExtractAllRecord() { String pattern = "([ab])(\\d)"; - try (ColumnVector v = ColumnVector.fromStrings("a1", "b2", "c3", null, "a1b1c3a2"); - ColumnVector expectedIdx0 = ColumnVector.fromLists( - new HostColumnVector.ListType(true, - new HostColumnVector.BasicType(true, DType.STRING)), - Arrays.asList("a1"), - Arrays.asList("b2"), - Arrays.asList(), - null, - Arrays.asList("a1", "b1", "a2")); - ColumnVector expectedIdx12 = ColumnVector.fromLists( - new HostColumnVector.ListType(true, - new HostColumnVector.BasicType(true, DType.STRING)), - Arrays.asList("a", "1"), - Arrays.asList("b", "2"), - null, - null, - Arrays.asList("a", "1", "b", "1", "a", "2")); - - ColumnVector resultIdx0 = v.extractAllRecord(pattern, 0); - ColumnVector resultIdx1 = v.extractAllRecord(pattern, 1); - ColumnVector resultIdx2 = v.extractAllRecord(pattern, 2); - ) { + RegexProgram regexProg = new RegexProgram(pattern); + ColumnVector v = ColumnVector.fromStrings("a1", "b2", "c3", null, "a1b1c3a2"); + ColumnVector expectedIdx0 = ColumnVector.fromLists( + new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("a1"), + Arrays.asList("b2"), + Arrays.asList(), + null, + Arrays.asList("a1", "b1", "a2")); + ColumnVector expectedIdx12 = ColumnVector.fromLists( + new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("a", "1"), + Arrays.asList("b", "2"), + null, + null, + Arrays.asList("a", "1", "b", "1", "a", "2")); + try (ColumnVector resultIdx0 = v.extractAllRecord(pattern, 0); + ColumnVector resultIdx1 = v.extractAllRecord(pattern, 1); + ColumnVector resultIdx2 = v.extractAllRecord(pattern, 2)) { + assertColumnsAreEqual(expectedIdx0, resultIdx0); + assertColumnsAreEqual(expectedIdx12, resultIdx1); + assertColumnsAreEqual(expectedIdx12, resultIdx2); + } + try (ColumnVector resultIdx0 = v.extractAllRecord(regexProg, 0); + ColumnVector resultIdx1 = v.extractAllRecord(regexProg, 1); + ColumnVector resultIdx2 = v.extractAllRecord(regexProg, 2)) { assertColumnsAreEqual(expectedIdx0, resultIdx0); assertColumnsAreEqual(expectedIdx12, resultIdx1); assertColumnsAreEqual(expectedIdx12, resultIdx2); @@ -4087,25 +4096,37 @@ void testMatchesRe() { String patternString2 = "[A-Za-z]+\\s@[A-Za-z]+"; String patternString3 = ".*"; String patternString4 = ""; - try (ColumnVector testStrings = ColumnVector.fromStrings("", null, "abCD", "ovér the", - "lazy @dog", "1234", "00:0:00"); - ColumnVector res1 = testStrings.matchesRe(patternString1); + RegexProgram regexProg1 = new RegexProgram(patternString1, CaptureGroups.NON_CAPTURE); + RegexProgram regexProg2 = new RegexProgram(patternString2, CaptureGroups.NON_CAPTURE); + RegexProgram regexProg3 = new RegexProgram(patternString3, CaptureGroups.NON_CAPTURE); + RegexProgram regexProg4 = new RegexProgram(patternString4, CaptureGroups.NON_CAPTURE); + ColumnVector testStrings = ColumnVector.fromStrings("", null, "abCD", "ovér the", + "lazy @dog", "1234", "00:0:00"); + ColumnVector expected1 = ColumnVector.fromBoxedBooleans(false, null, false, false, false, + true, true); + ColumnVector expected2 = ColumnVector.fromBoxedBooleans(false, null, false, false, true, + false, false); + ColumnVector expected3 = ColumnVector.fromBoxedBooleans(true, null, true, true, true, + true, true); + try (ColumnVector res1 = testStrings.matchesRe(patternString1); ColumnVector res2 = testStrings.matchesRe(patternString2); - ColumnVector res3 = testStrings.matchesRe(patternString3); - ColumnVector expected1 = ColumnVector.fromBoxedBooleans(false, null, false, false, false, - true, true); - ColumnVector expected2 = ColumnVector.fromBoxedBooleans(false, null, false, false, true, - false, false); - ColumnVector expected3 = ColumnVector.fromBoxedBooleans(true, null, true, true, true, - true, true)) { + ColumnVector res3 = testStrings.matchesRe(patternString3)) { + assertColumnsAreEqual(expected1, res1); + assertColumnsAreEqual(expected2, res2); + assertColumnsAreEqual(expected3, res3); + } + try (ColumnVector res1 = testStrings.matchesRe(regexProg1); + ColumnVector res2 = testStrings.matchesRe(regexProg2); + ColumnVector res3 = testStrings.matchesRe(regexProg3)) { assertColumnsAreEqual(expected1, res1); assertColumnsAreEqual(expected2, res2); assertColumnsAreEqual(expected3, res3); } assertThrows(AssertionError.class, () -> { - try (ColumnVector testStrings = ColumnVector.fromStrings("", null, "abCD", "ovér the", - "lazy @dog", "1234", "00:0:00"); - ColumnVector res = testStrings.matchesRe(patternString4)) {} + try (ColumnVector res = testStrings.matchesRe(patternString4)) {} + }); + assertThrows(AssertionError.class, () -> { + try (ColumnVector res = testStrings.matchesRe(regexProg4)) {} }); } @@ -4115,36 +4136,51 @@ void testContainsRe() { String patternString2 = "[A-Za-z]+\\s@[A-Za-z]+"; String patternString3 = ".*"; String patternString4 = ""; - try (ColumnVector testStrings = ColumnVector.fromStrings(null, "abCD", "ovér the", + RegexProgram regexProg1 = new RegexProgram(patternString1, CaptureGroups.NON_CAPTURE); + RegexProgram regexProg2 = new RegexProgram(patternString2, CaptureGroups.NON_CAPTURE); + RegexProgram regexProg3 = new RegexProgram(patternString3, CaptureGroups.NON_CAPTURE); + RegexProgram regexProg4 = new RegexProgram(patternString4, CaptureGroups.NON_CAPTURE); + ColumnVector testStrings = ColumnVector.fromStrings(null, "abCD", "ovér the", "lazy @dog", "1234", "00:0:00", "abc1234abc", "there @are 2 lazy @dogs"); - ColumnVector res1 = testStrings.containsRe(patternString1); + ColumnVector expected1 = ColumnVector.fromBoxedBooleans(null, false, false, false, + true, true, true, true); + ColumnVector expected2 = ColumnVector.fromBoxedBooleans(null, false, false, true, + false, false, false, true); + ColumnVector expected3 = ColumnVector.fromBoxedBooleans(null, true, true, true, + true, true, true, true); + try (ColumnVector res1 = testStrings.containsRe(patternString1); ColumnVector res2 = testStrings.containsRe(patternString2); - ColumnVector res3 = testStrings.containsRe(patternString3); - ColumnVector expected1 = ColumnVector.fromBoxedBooleans(null, false, false, false, - true, true, true, true); - ColumnVector expected2 = ColumnVector.fromBoxedBooleans(null, false, false, true, - false, false, false, true); - ColumnVector expected3 = ColumnVector.fromBoxedBooleans(null, true, true, true, - true, true, true, true)) { + ColumnVector res3 = testStrings.containsRe(patternString3)) { assertColumnsAreEqual(expected1, res1); assertColumnsAreEqual(expected2, res2); assertColumnsAreEqual(expected3, res3); } + try (ColumnVector res1 = testStrings.containsRe(regexProg1); + ColumnVector res2 = testStrings.containsRe(regexProg2); + ColumnVector res3 = testStrings.containsRe(regexProg3)) { + assertColumnsAreEqual(expected1, res1); + assertColumnsAreEqual(expected2, res2); + assertColumnsAreEqual(expected3, res3); + } + ColumnVector testStringsError = ColumnVector.fromStrings("", null, "abCD", "ovér the", + "lazy @dog", "1234", "00:0:00", "abc1234abc", "there @are 2 lazy @dogs"); + assertThrows(AssertionError.class, () -> { + try (ColumnVector res = testStringsError.containsRe(patternString4)) {}}); assertThrows(AssertionError.class, () -> { - try (ColumnVector testStrings = ColumnVector.fromStrings("", null, "abCD", "ovér the", - "lazy @dog", "1234", "00:0:00", "abc1234abc", "there @are 2 lazy @dogs"); - ColumnVector res = testStrings.containsRe(patternString4)) {} + try (ColumnVector res = testStringsError.containsRe(regexProg4)) {} }); } @Test - @Disabled("Needs fix for https://github.com/rapidsai/cudf/issues/4671") void testContainsReEmptyInput() { String patternString1 = ".*"; + RegexProgram regexProg1 = new RegexProgram(patternString1, CaptureGroups.NON_CAPTURE); try (ColumnVector testStrings = ColumnVector.fromStrings(""); ColumnVector res1 = testStrings.containsRe(patternString1); + ColumnVector resRe1 = testStrings.containsRe(regexProg1); ColumnVector expected1 = ColumnVector.fromBoxedBooleans(true)) { assertColumnsAreEqual(expected1, res1); + assertColumnsAreEqual(expected1, resRe1); } } @@ -4405,9 +4441,13 @@ void testsubstring() { @Test void testExtractListElements() { - try (ColumnVector v = ColumnVector.fromStrings("Héllo there", "thésé", null, "", "ARé some", "test strings"); - ColumnVector expected = ColumnVector.fromStrings("Héllo", "thésé", null, "", "ARé", "test"); - ColumnVector list = v.stringSplitRecord(" "); + ColumnVector v = ColumnVector.fromStrings("Héllo there", "thésé", null, "", "ARé some", "test strings"); + ColumnVector expected = ColumnVector.fromStrings("Héllo", "thésé", null, "", "ARé", "test"); + try (ColumnVector list = v.stringSplitRecord(" "); + ColumnVector result = list.extractListElement(0)) { + assertColumnsAreEqual(expected, result); + } + try (ColumnVector list = v.stringSplitRecord(new RegexProgram(" ", CaptureGroups.NON_CAPTURE)); ColumnVector result = list.extractListElement(0)) { assertColumnsAreEqual(expected, result); } @@ -4415,10 +4455,14 @@ void testExtractListElements() { @Test void testExtractListElementsV() { - try (ColumnVector v = ColumnVector.fromStrings("Héllo there", "thésé", null, "", "ARé some", "test strings"); - ColumnVector indices = ColumnVector.fromInts(0, 2, 0, 0, 1, -1); - ColumnVector expected = ColumnVector.fromStrings("Héllo", null, null, "", "some", "strings"); - ColumnVector list = v.stringSplitRecord(" "); + ColumnVector v = ColumnVector.fromStrings("Héllo there", "thésé", null, "", "ARé some", "test strings"); + ColumnVector indices = ColumnVector.fromInts(0, 2, 0, 0, 1, -1); + ColumnVector expected = ColumnVector.fromStrings("Héllo", null, null, "", "some", "strings"); + try (ColumnVector list = v.stringSplitRecord(" "); + ColumnVector result = list.extractListElement(indices)) { + assertColumnsAreEqual(expected, result); + } + try (ColumnVector list = v.stringSplitRecord(new RegexProgram(" ", CaptureGroups.NON_CAPTURE)); ColumnVector result = list.extractListElement(indices)) { assertColumnsAreEqual(expected, result); } @@ -4947,103 +4991,127 @@ void testReverseList() { @Test void testStringSplit() { String pattern = " "; - try (ColumnVector v = ColumnVector.fromStrings("Héllo there all", "thésé", null, "", + RegexProgram regexProg = new RegexProgram(pattern, CaptureGroups.NON_CAPTURE); + ColumnVector v = ColumnVector.fromStrings("Héllo there all", "thésé", null, "", "ARé some things", "test strings here"); - Table expectedSplitLimit2 = new Table.TestBuilder() - .column("Héllo", "thésé", null, "", "ARé", "test") - .column("there all", null, null, null, "some things", "strings here") - .build(); - Table expectedSplitAll = new Table.TestBuilder() - .column("Héllo", "thésé", null, "", "ARé", "test") - .column("there", null, null, null, "some", "strings") - .column("all", null, null, null, "things", "here") - .build(); - Table resultSplitLimit2 = v.stringSplit(pattern, 2); + Table expectedSplitLimit2 = new Table.TestBuilder() + .column("Héllo", "thésé", null, "", "ARé", "test") + .column("there all", null, null, null, "some things", "strings here") + .build(); + Table expectedSplitAll = new Table.TestBuilder() + .column("Héllo", "thésé", null, "", "ARé", "test") + .column("there", null, null, null, "some", "strings") + .column("all", null, null, null, "things", "here") + .build(); + try (Table resultSplitLimit2 = v.stringSplit(pattern, 2); Table resultSplitAll = v.stringSplit(pattern)) { - assertTablesAreEqual(expectedSplitLimit2, resultSplitLimit2); - assertTablesAreEqual(expectedSplitAll, resultSplitAll); + assertTablesAreEqual(expectedSplitLimit2, resultSplitLimit2); + assertTablesAreEqual(expectedSplitAll, resultSplitAll); + } + try (Table resultSplitLimit2 = v.stringSplit(regexProg, 2); + Table resultSplitAll = v.stringSplit(regexProg)) { + assertTablesAreEqual(expectedSplitLimit2, resultSplitLimit2); + assertTablesAreEqual(expectedSplitAll, resultSplitAll); } } @Test void testStringSplitByRegularExpression() { String pattern = "[_ ]"; - try (ColumnVector v = ColumnVector.fromStrings("Héllo_there all", "thésé", null, "", + RegexProgram regexProg = new RegexProgram(pattern, CaptureGroups.NON_CAPTURE); + ColumnVector v = ColumnVector.fromStrings("Héllo_there all", "thésé", null, "", "ARé some_things", "test_strings_here"); - Table expectedSplitLimit2 = new Table.TestBuilder() - .column("Héllo", "thésé", null, "", "ARé", "test") - .column("there all", null, null, null, "some_things", "strings_here") - .build(); - Table expectedSplitAll = new Table.TestBuilder() - .column("Héllo", "thésé", null, "", "ARé", "test") - .column("there", null, null, null, "some", "strings") - .column("all", null, null, null, "things", "here") - .build(); - Table resultSplitLimit2 = v.stringSplit(pattern, 2, true); + Table expectedSplitLimit2 = new Table.TestBuilder() + .column("Héllo", "thésé", null, "", "ARé", "test") + .column("there all", null, null, null, "some_things", "strings_here") + .build(); + Table expectedSplitAll = new Table.TestBuilder() + .column("Héllo", "thésé", null, "", "ARé", "test") + .column("there", null, null, null, "some", "strings") + .column("all", null, null, null, "things", "here") + .build(); + try (Table resultSplitLimit2 = v.stringSplit(pattern, 2, true); Table resultSplitAll = v.stringSplit(pattern, true)) { assertTablesAreEqual(expectedSplitLimit2, resultSplitLimit2); assertTablesAreEqual(expectedSplitAll, resultSplitAll); } + try (Table resultSplitLimit2 = v.stringSplit(regexProg, 2); + Table resultSplitAll = v.stringSplit(regexProg)) { + assertTablesAreEqual(expectedSplitLimit2, resultSplitLimit2); + assertTablesAreEqual(expectedSplitAll, resultSplitAll); + } } @Test void testStringSplitRecord() { String pattern = " "; - try (ColumnVector v = ColumnVector.fromStrings("Héllo there all", "thésé", null, "", + RegexProgram regexProg = new RegexProgram(pattern, CaptureGroups.NON_CAPTURE); + ColumnVector v = ColumnVector.fromStrings("Héllo there all", "thésé", null, "", "ARé some things", "test strings here"); - ColumnVector expectedSplitLimit2 = ColumnVector.fromLists( - new HostColumnVector.ListType(true, - new HostColumnVector.BasicType(true, DType.STRING)), - Arrays.asList("Héllo", "there all"), - Arrays.asList("thésé"), - null, - Arrays.asList(""), - Arrays.asList("ARé", "some things"), - Arrays.asList("test", "strings here")); - ColumnVector expectedSplitAll = ColumnVector.fromLists( - new HostColumnVector.ListType(true, - new HostColumnVector.BasicType(true, DType.STRING)), - Arrays.asList("Héllo", "there", "all"), - Arrays.asList("thésé"), - null, - Arrays.asList(""), - Arrays.asList("ARé", "some", "things"), - Arrays.asList("test", "strings", "here")); - ColumnVector resultSplitLimit2 = v.stringSplitRecord(pattern, 2); + ColumnVector expectedSplitLimit2 = ColumnVector.fromLists( + new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("Héllo", "there all"), + Arrays.asList("thésé"), + null, + Arrays.asList(""), + Arrays.asList("ARé", "some things"), + Arrays.asList("test", "strings here")); + ColumnVector expectedSplitAll = ColumnVector.fromLists( + new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("Héllo", "there", "all"), + Arrays.asList("thésé"), + null, + Arrays.asList(""), + Arrays.asList("ARé", "some", "things"), + Arrays.asList("test", "strings", "here")); + try (ColumnVector resultSplitLimit2 = v.stringSplitRecord(pattern, 2); ColumnVector resultSplitAll = v.stringSplitRecord(pattern)) { assertColumnsAreEqual(expectedSplitLimit2, resultSplitLimit2); assertColumnsAreEqual(expectedSplitAll, resultSplitAll); } + try (ColumnVector resultSplitLimit2 = v.stringSplitRecord(regexProg, 2); + ColumnVector resultSplitAll = v.stringSplitRecord(regexProg)) { + assertColumnsAreEqual(expectedSplitLimit2, resultSplitLimit2); + assertColumnsAreEqual(expectedSplitAll, resultSplitAll); + } } @Test void testStringSplitRecordByRegularExpression() { String pattern = "[_ ]"; - try (ColumnVector v = ColumnVector.fromStrings("Héllo_there all", "thésé", null, "", + RegexProgram regexProg = new RegexProgram(pattern, CaptureGroups.NON_CAPTURE); + ColumnVector v = ColumnVector.fromStrings("Héllo_there all", "thésé", null, "", "ARé some_things", "test_strings_here"); - ColumnVector expectedSplitLimit2 = ColumnVector.fromLists( - new HostColumnVector.ListType(true, - new HostColumnVector.BasicType(true, DType.STRING)), - Arrays.asList("Héllo", "there all"), - Arrays.asList("thésé"), - null, - Arrays.asList(""), - Arrays.asList("ARé", "some_things"), - Arrays.asList("test", "strings_here")); - ColumnVector expectedSplitAll = ColumnVector.fromLists( - new HostColumnVector.ListType(true, - new HostColumnVector.BasicType(true, DType.STRING)), - Arrays.asList("Héllo", "there", "all"), - Arrays.asList("thésé"), - null, - Arrays.asList(""), - Arrays.asList("ARé", "some", "things"), - Arrays.asList("test", "strings", "here")); - ColumnVector resultSplitLimit2 = v.stringSplitRecord(pattern, 2, true); + ColumnVector expectedSplitLimit2 = ColumnVector.fromLists( + new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("Héllo", "there all"), + Arrays.asList("thésé"), + null, + Arrays.asList(""), + Arrays.asList("ARé", "some_things"), + Arrays.asList("test", "strings_here")); + ColumnVector expectedSplitAll = ColumnVector.fromLists( + new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("Héllo", "there", "all"), + Arrays.asList("thésé"), + null, + Arrays.asList(""), + Arrays.asList("ARé", "some", "things"), + Arrays.asList("test", "strings", "here")); + try (ColumnVector resultSplitLimit2 = v.stringSplitRecord(pattern, 2, true); ColumnVector resultSplitAll = v.stringSplitRecord(pattern, true)) { assertColumnsAreEqual(expectedSplitLimit2, resultSplitLimit2); assertColumnsAreEqual(expectedSplitAll, resultSplitAll); } + try (ColumnVector resultSplitLimit2 = v.stringSplitRecord(regexProg, 2); + ColumnVector resultSplitAll = v.stringSplitRecord(regexProg)) { + assertColumnsAreEqual(expectedSplitLimit2, resultSplitLimit2); + assertColumnsAreEqual(expectedSplitAll, resultSplitAll); + } } @Test @@ -5091,26 +5159,37 @@ void teststringReplaceThrowsException() { @Test void testReplaceRegex() { - try (ColumnVector v = - ColumnVector.fromStrings("title and Title with title", "nothing", null, "Title"); - Scalar repl = Scalar.fromString("Repl"); - ColumnVector actual = v.replaceRegex("[tT]itle", repl); + ColumnVector v = ColumnVector.fromStrings("title and Title with title", "nothing", null, "Title"); + Scalar repl = Scalar.fromString("Repl"); + String pattern = "[tT]itle"; + RegexProgram regexProg = new RegexProgram(pattern, CaptureGroups.NON_CAPTURE); + try (ColumnVector actual = v.replaceRegex(pattern, repl); ColumnVector expected = ColumnVector.fromStrings("Repl and Repl with Repl", "nothing", null, "Repl")) { assertColumnsAreEqual(expected, actual); } - try (ColumnVector v = - ColumnVector.fromStrings("title and Title with title", "nothing", null, "Title"); - Scalar repl = Scalar.fromString("Repl"); - ColumnVector actual = v.replaceRegex("[tT]itle", repl, 0)) { + try (ColumnVector actual = v.replaceRegex(pattern, repl, 0)) { assertColumnsAreEqual(v, actual); } - try (ColumnVector v = - ColumnVector.fromStrings("title and Title with title", "nothing", null, "Title"); - Scalar repl = Scalar.fromString("Repl"); - ColumnVector actual = v.replaceRegex("[tT]itle", repl, 1); + try (ColumnVector actual = v.replaceRegex(pattern, repl, 1); + ColumnVector expected = + ColumnVector.fromStrings("Repl and Title with title", "nothing", null, "Repl")) { + assertColumnsAreEqual(expected, actual); + } + + try (ColumnVector actual = v.replaceRegex(regexProg, repl); + ColumnVector expected = + ColumnVector.fromStrings("Repl and Repl with Repl", "nothing", null, "Repl")) { + assertColumnsAreEqual(expected, actual); + } + + try (ColumnVector actual = v.replaceRegex(regexProg, repl, 0)) { + assertColumnsAreEqual(v, actual); + } + + try (ColumnVector actual = v.replaceRegex(regexProg, repl, 1); ColumnVector expected = ColumnVector.fromStrings("Repl and Title with title", "nothing", null, "Repl")) { assertColumnsAreEqual(expected, actual); @@ -5132,45 +5211,56 @@ void testReplaceMultiRegex() { @Test void testStringReplaceWithBackrefs() { - try (ColumnVector v = ColumnVector.fromStrings("

title

", "

another title

", - null); + try (ColumnVector v = ColumnVector.fromStrings("

title

", "

another title

", null); ColumnVector expected = ColumnVector.fromStrings("

title

", "

another title

", null); - ColumnVector actual = v.stringReplaceWithBackrefs("

(.*)

", "

\\1

")) { + ColumnVector actual = v.stringReplaceWithBackrefs("

(.*)

", "

\\1

"); + ColumnVector actualRe = + v.stringReplaceWithBackrefs(new RegexProgram("

(.*)

"), "

\\1

")) { assertColumnsAreEqual(expected, actual); + assertColumnsAreEqual(expected, actualRe); } try (ColumnVector v = ColumnVector.fromStrings("2020-1-01", "2020-2-02", null); ColumnVector expected = ColumnVector.fromStrings("2020-01-01", "2020-02-02", null); - ColumnVector actual = v.stringReplaceWithBackrefs("-([0-9])-", "-0\\1-")) { + ColumnVector actual = v.stringReplaceWithBackrefs("-([0-9])-", "-0\\1-"); + ColumnVector actualRe = + v.stringReplaceWithBackrefs(new RegexProgram("-([0-9])-"), "-0\\1-")) { assertColumnsAreEqual(expected, actual); + assertColumnsAreEqual(expected, actualRe); } try (ColumnVector v = ColumnVector.fromStrings("2020-01-1", "2020-02-2", - "2020-03-3invalid", null); + "2020-03-3invalid", null); ColumnVector expected = ColumnVector.fromStrings("2020-01-01", "2020-02-02", "2020-03-3invalid", null); - ColumnVector actual = v.stringReplaceWithBackrefs( - "-([0-9])$", "-0\\1")) { + ColumnVector actual = v.stringReplaceWithBackrefs("-([0-9])$", "-0\\1"); + ColumnVector actualRe = + v.stringReplaceWithBackrefs(new RegexProgram("-([0-9])$"), "-0\\1")) { assertColumnsAreEqual(expected, actual); + assertColumnsAreEqual(expected, actualRe); } try (ColumnVector v = ColumnVector.fromStrings("2020-01-1 random_text", "2020-02-2T12:34:56", - "2020-03-3invalid", null); + "2020-03-3invalid", null); ColumnVector expected = ColumnVector.fromStrings("2020-01-01 random_text", "2020-02-02T12:34:56", "2020-03-3invalid", null); - ColumnVector actual = v.stringReplaceWithBackrefs( - "-([0-9])([ T])", "-0\\1\\2")) { + ColumnVector actual = v.stringReplaceWithBackrefs("-([0-9])([ T])", "-0\\1\\2"); + ColumnVector actualRe = + v.stringReplaceWithBackrefs(new RegexProgram("-([0-9])([ T])"), "-0\\1\\2")) { assertColumnsAreEqual(expected, actual); + assertColumnsAreEqual(expected, actualRe); } // test zero as group index try (ColumnVector v = ColumnVector.fromStrings("aa-11 b2b-345", "aa-11a 1c-2b2 b2-c3", "11-aa", null); ColumnVector expected = ColumnVector.fromStrings("aa-11:aa:11; b2b-345:b:345;", "aa-11:aa:11;a 1c-2:c:2;b2 b2-c3", "11-aa", null); - ColumnVector actual = v.stringReplaceWithBackrefs( - "([a-z]+)-([0-9]+)", "${0}:${1}:${2};")) { + ColumnVector actual = v.stringReplaceWithBackrefs("([a-z]+)-([0-9]+)", "${0}:${1}:${2};"); + ColumnVector actualRe = + v.stringReplaceWithBackrefs(new RegexProgram("([a-z]+)-([0-9]+)"), "${0}:${1}:${2};")) { assertColumnsAreEqual(expected, actual); + assertColumnsAreEqual(expected, actualRe); } // group index exceeds group count @@ -5180,6 +5270,12 @@ void testStringReplaceWithBackrefs() { } }); + assertThrows(CudfException.class, () -> { + try (ColumnVector v = ColumnVector.fromStrings("ABC123defgh"); + ColumnVector r = v.stringReplaceWithBackrefs( + new RegexProgram("([A-Z]+)([0-9]+)([a-z]+)"), "\\4")) { + } + }); } @Test