Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add regex_program strings splitting java APIs and tests #12713

Merged
111 changes: 97 additions & 14 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -2531,12 +2531,34 @@ public final ColumnVector stringLocate(Scalar substring, int start, int end) {
* regular expression pattern or just by a string literal delimiter.
* @return list of strings columns as a table.
*/
@Deprecated
public final Table stringSplit(String pattern, int limit, boolean splitByRegex) {
if (splitByRegex) {
return stringSplit(new RegexProgram(pattern, CaptureGroups.NON_CAPTURE), limit);
} else {
return stringSplit(pattern, limit);
}
}

/**
* Returns a list of columns by splitting each string using the specified regex program. The
* number of rows in the output columns will be the same as the input column. Null entries
* are added for a row where split results have been exhausted. Null input entries result in
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* are added for a row where split results have been exhausted. Null input entries result in
* are added for the rows where split results have been exhausted. Null input entries result in

* all nulls in the corresponding rows of the output columns.
*
* @param regexProg the regex program with UTF-8 encoded string identifying the split pattern
* for each input string.
* @param limit the maximum size of the list resulting from splitting each input string,
* or -1 for all possible splits. Note that limit = 0 (all possible splits without
* trailing empty strings) and limit = 1 (no split at all) are not supported.
* @return list of strings columns as a table.
*/
public final Table stringSplit(RegexProgram regexProg, int limit) {
assert type.equals(DType.STRING) : "column type must be a String";
assert pattern != null : "pattern is null";
assert pattern.length() > 0 : "empty pattern is not supported";
assert regexProg != null : "regex program is null";
assert limit != 0 && limit != 1 : "split limit == 0 and limit == 1 are not supported";
return new Table(stringSplit(this.getNativeView(), pattern, limit, splitByRegex));
return new Table(stringSplit(this.getNativeView(), regexProg.pattern(), regexProg.combinedFlags(),
regexProg.capture().nativeId, limit, true));
}

/**
Expand All @@ -2550,6 +2572,7 @@ public final Table stringSplit(String pattern, int limit, boolean splitByRegex)
* regular expression pattern or just by a string literal delimiter.
* @return list of strings columns as a table.
*/
@Deprecated
public final Table stringSplit(String pattern, boolean splitByRegex) {
return stringSplit(pattern, -1, splitByRegex);
}
Expand All @@ -2567,7 +2590,11 @@ public final Table stringSplit(String pattern, boolean splitByRegex) {
* @return list of strings columns as a table.
*/
public final Table stringSplit(String delimiter, int limit) {
return stringSplit(delimiter, limit, false);
assert type.equals(DType.STRING) : "column type must be a String";
assert delimiter != null : "delimiter is null";
assert limit != 0 && limit != 1 : "split limit == 0 and limit == 1 are not supported";
return new Table(stringSplit(this.getNativeView(), delimiter, RegexFlag.DEFAULT.nativeId,
CaptureGroups.NON_CAPTURE.nativeId, limit, false));
}

/**
Expand All @@ -2580,7 +2607,21 @@ public final Table stringSplit(String delimiter, int limit) {
* @return list of strings columns as a table.
*/
public final Table stringSplit(String delimiter) {
return stringSplit(delimiter, -1, false);
return stringSplit(delimiter, -1);
}

/**
* Returns a list of columns by splitting each string using the specified regex program with
* string literal delimiter. The number of rows in the output columns will be the same as the
* input column. Null entries are added for a row where split results have been exhausted.
* Null input entries result in all nulls in the corresponding rows of the output columns.
*
* @param regexProg the regex program with UTF-8 encoded string identifying the split pattern
* for each input string.
* @return list of strings columns as a table.
*/
public final Table stringSplit(RegexProgram regexProg) {
return stringSplit(regexProg, -1);
}

/**
Expand All @@ -2595,13 +2636,33 @@ public final Table stringSplit(String delimiter) {
* regular expression pattern or just by a string literal delimiter.
* @return a LIST column of string elements.
*/
@Deprecated
public final ColumnVector stringSplitRecord(String pattern, int limit, boolean splitByRegex) {
if (splitByRegex) {
return stringSplitRecord(new RegexProgram(pattern, CaptureGroups.NON_CAPTURE), limit);
} else {
return stringSplitRecord(pattern, limit);
}
}

/**
* Returns a column that are lists of strings in which each list is made by splitting the
* corresponding input string using the specified regex program pattern.
*
* @param regexProg the regex program with UTF-8 encoded string identifying the split pattern
* for each input string.
* @param limit the maximum size of the list resulting from splitting each input string,
* or -1 for all possible splits. Note that limit = 0 (all possible splits without
* trailing empty strings) and limit = 1 (no split at all) are not supported.
* @return a LIST column of string elements.
*/
public final ColumnVector stringSplitRecord(RegexProgram regexProg, int limit) {
assert type.equals(DType.STRING) : "column type must be String";
assert pattern != null : "pattern is null";
assert pattern.length() > 0 : "empty pattern is not supported";
assert regexProg != null : "regex program is null";
assert limit != 0 && limit != 1 : "split limit == 0 and limit == 1 are not supported";
return new ColumnVector(
stringSplitRecord(this.getNativeView(), pattern, limit, splitByRegex));
stringSplitRecord(this.getNativeView(), regexProg.pattern(), regexProg.combinedFlags(),
regexProg.capture().nativeId, limit, true));
}

/**
Expand All @@ -2613,6 +2674,7 @@ public final ColumnVector stringSplitRecord(String pattern, int limit, boolean s
* regular expression pattern or just by a string literal delimiter.
* @return a LIST column of string elements.
*/
@Deprecated
public final ColumnVector stringSplitRecord(String pattern, boolean splitByRegex) {
return stringSplitRecord(pattern, -1, splitByRegex);
}
Expand All @@ -2628,7 +2690,12 @@ public final ColumnVector stringSplitRecord(String pattern, boolean splitByRegex
* @return a LIST column of string elements.
*/
public final ColumnVector stringSplitRecord(String delimiter, int limit) {
return stringSplitRecord(delimiter, limit, false);
assert type.equals(DType.STRING) : "column type must be String";
assert delimiter != null : "delimiter is null";
assert limit != 0 && limit != 1 : "split limit == 0 and limit == 1 are not supported";
return new ColumnVector(
stringSplitRecord(this.getNativeView(), delimiter, RegexFlag.DEFAULT.nativeId,
CaptureGroups.NON_CAPTURE.nativeId, limit, false));
}

/**
Expand All @@ -2639,7 +2706,19 @@ public final ColumnVector stringSplitRecord(String delimiter, int limit) {
* @return a LIST column of string elements.
*/
public final ColumnVector stringSplitRecord(String delimiter) {
return stringSplitRecord(delimiter, -1, false);
return stringSplitRecord(delimiter, -1);
}

/**
* Returns a column that are lists of strings in which each list is made by splitting the
* corresponding input string using the specified regex program with string literal delimiter.
*
* @param regexProg the regex program with UTF-8 encoded string identifying the split pattern
* for each input string.
* @return a LIST column of string elements.
*/
public final ColumnVector stringSplitRecord(RegexProgram regexProg) {
return stringSplitRecord(regexProg, -1);
}

/**
Expand Down Expand Up @@ -3965,29 +4044,33 @@ private static native long repeatStringsWithColumnRepeatTimes(long stringsHandle
*
* @param nativeHandle native handle of the input strings column that being operated on.
* @param pattern UTF-8 encoded string identifying the split pattern for each input string.
* @param flags regex flags setting.
* @param capture capture groups setting.
* @param limit the maximum size of the list resulting from splitting each input string,
* or -1 for all possible splits. Note that limit = 0 (all possible splits without
* trailing empty strings) and limit = 1 (no split at all) are not supported.
* @param splitByRegex a boolean flag indicating whether the input strings will be split by a
* regular expression pattern or just by a string literal delimiter.
*/
private static native long[] stringSplit(long nativeHandle, String pattern, int limit,
boolean splitByRegex);
private static native long[] stringSplit(long nativeHandle, String pattern, int flags,
int capture, int limit, boolean splitByRegex);
jlowe marked this conversation as resolved.
Show resolved Hide resolved

/**
* Returns a column that are lists of strings in which each list is made by splitting the
* corresponding input string using the specified string literal delimiter.
*
* @param nativeHandle native handle of the input strings column that being operated on.
* @param pattern UTF-8 encoded string identifying the split pattern for each input string.
* @param flags regex flags setting.
* @param capture capture groups setting.
* @param limit the maximum size of the list resulting from splitting each input string,
* or -1 for all possible splits. Note that limit = 0 (all possible splits without
* trailing empty strings) and limit = 1 (no split at all) are not supported.
* @param splitByRegex a boolean flag indicating whether the input strings will be split by a
* regular expression pattern or just by a string literal delimiter.
*/
private static native long stringSplitRecord(long nativeHandle, String pattern, int limit,
boolean splitByRegex);
private static native long stringSplitRecord(long nativeHandle, String pattern, int flags,
int capture, int limit, boolean splitByRegex);

/**
* Native method to calculate substring from a given string column. 0 indexing.
Expand Down
68 changes: 31 additions & 37 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -679,11 +679,9 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_reverseStringsOrLists(JNI
CATCH_STD(env, 0);
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_stringSplit(JNIEnv *env, jclass,
jlong input_handle,
jstring pattern_obj,
jint limit,
jboolean split_by_regex) {
JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_stringSplit(
JNIEnv *env, jclass, jlong input_handle, jstring pattern_obj, jint regex_flags,
jint capture_groups, jint limit, jboolean split_by_regex) {
JNI_NULL_CHECK(env, input_handle, "input_handle is null", 0);

if (limit == 0 || limit == 1) {
Expand All @@ -697,31 +695,28 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_stringSplit(JNIEnv *

try {
cudf::jni::auto_set_device(env);
auto const input = reinterpret_cast<cudf::column_view *>(input_handle);
auto const strs_input = cudf::strings_column_view{*input};

auto const column_view = reinterpret_cast<cudf::column_view const *>(input_handle);
auto const strings_column = cudf::strings_column_view{*column_view};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest to avoid column_view name as it may clash with cudf::column_view.

Suggested change
auto const column_view = reinterpret_cast<cudf::column_view const *>(input_handle);
auto const strings_column = cudf::strings_column_view{*column_view};
auto const input = reinterpret_cast<cudf::column_view const *>(input_handle);
auto const strings_column = cudf::strings_column_view{*input};

auto const pattern_jstr = cudf::jni::native_jstring(env, pattern_obj);
if (pattern_jstr.is_empty()) {
// Java's split API produces different behaviors than cudf when splitting with empty
// pattern.
JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "Empty pattern is not supported", 0);
}

auto const pattern = std::string(pattern_jstr.get(), pattern_jstr.size_bytes());
auto const max_split = limit > 1 ? limit - 1 : limit;
auto result = split_by_regex ?
cudf::strings::split_re(strs_input, pattern, max_split) :
cudf::strings::split(strs_input, cudf::string_scalar{pattern}, max_split);
return cudf::jni::convert_table_for_return(env, std::move(result));
if (split_by_regex) {
auto const flags = static_cast<cudf::strings::regex_flags>(regex_flags);
auto const groups = static_cast<cudf::strings::capture_groups>(capture_groups);
auto const regex_prog = cudf::strings::regex_program::create(pattern, flags, groups);
auto result = cudf::strings::split_re(strings_column, *regex_prog, max_split);
return cudf::jni::convert_table_for_return(env, std::move(result));
} else {
auto result = cudf::strings::split(strings_column, cudf::string_scalar{pattern}, max_split);
return cudf::jni::convert_table_for_return(env, std::move(result));
}
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringSplitRecord(JNIEnv *env, jclass,
jlong input_handle,
jstring pattern_obj,
jint limit,
jboolean split_by_regex) {
JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringSplitRecord(
JNIEnv *env, jclass, jlong input_handle, jstring pattern_obj, jint regex_flags,
jint capture_groups, jint limit, jboolean split_by_regex) {
JNI_NULL_CHECK(env, input_handle, "input_handle is null", 0);

if (limit == 0 || limit == 1) {
Expand All @@ -735,23 +730,22 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringSplitRecord(JNIEnv

try {
cudf::jni::auto_set_device(env);
auto const input = reinterpret_cast<cudf::column_view *>(input_handle);
auto const strs_input = cudf::strings_column_view{*input};

auto const column_view = reinterpret_cast<cudf::column_view const *>(input_handle);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! All changes updated.

auto const strings_column = cudf::strings_column_view{*column_view};
auto const pattern_jstr = cudf::jni::native_jstring(env, pattern_obj);
if (pattern_jstr.is_empty()) {
// Java's split API produces different behaviors than cudf when splitting with empty
// pattern.
JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "Empty pattern is not supported", 0);
}

auto const pattern = std::string(pattern_jstr.get(), pattern_jstr.size_bytes());
auto const max_split = limit > 1 ? limit - 1 : limit;
auto result =
split_by_regex ?
cudf::strings::split_record_re(strs_input, pattern, max_split) :
cudf::strings::split_record(strs_input, cudf::string_scalar{pattern}, max_split);
return release_as_jlong(result);
if (split_by_regex) {
auto const flags = static_cast<cudf::strings::regex_flags>(regex_flags);
auto const groups = static_cast<cudf::strings::capture_groups>(capture_groups);
auto const regex_prog = cudf::strings::regex_program::create(pattern, flags, groups);
auto result = cudf::strings::split_record_re(strings_column, *regex_prog, max_split);
return release_as_jlong(result);
} else {
auto result =
cudf::strings::split_record(strings_column, cudf::string_scalar{pattern}, max_split);
return release_as_jlong(result);
}
}
CATCH_STD(env, 0);
}
Expand Down
Loading