Skip to content

Commit

Permalink
Add maxSplit parameter to Java binding for strings:split (#10137)
Browse files Browse the repository at this point in the history
Currently, Java binding for `strings::split` calls the API with the default max split value. This PR adds a parameter for it, allowing to specify max split from Java.

This is non-breaking because the Java code implicitly uses the default max split value.

Authors:
  - Nghia Truong (https://github.com/ttnghia)

Approvers:
  - Jason Lowe (https://github.com/jlowe)

URL: #10137
  • Loading branch information
ttnghia authored Jan 27, 2022
1 parent b684f17 commit b290ec7
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 13 deletions.
29 changes: 22 additions & 7 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -2331,13 +2331,27 @@ public final ColumnVector stringLocate(Scalar substring, int start, int end) {
* Null string entries return corresponding null output columns.
* @param delimiter UTF-8 encoded string identifying the split points in each string.
* An empty string indicates split on whitespace.
* @param maxSplit the maximum number of splits to perform, or -1 for all possible splits.
* @return New table of strings columns.
*/
public final Table stringSplit(Scalar delimiter) {
public final Table stringSplit(Scalar delimiter, int maxSplit) {
assert type.equals(DType.STRING) : "column type must be a String";
assert delimiter != null : "delimiter may not be null";
assert delimiter.getType().equals(DType.STRING) : "delimiter must be a string scalar";
return new Table(stringSplit(this.getNativeView(), delimiter.getScalarHandle()));
return new Table(stringSplit(this.getNativeView(), delimiter.getScalarHandle(), maxSplit));
}

/**
* Returns a list of columns by splitting each string using the specified delimiter.
* The number of rows in the output columns will be the same as the input column.
* Null entries are added for a row where split results have been exhausted.
* Null string entries return corresponding null output columns.
* @param delimiter UTF-8 encoded string identifying the split points in each string.
* An empty string indicates split on whitespace.
* @return New table of strings columns.
*/
public final Table stringSplit(Scalar delimiter) {
return stringSplit(delimiter, -1);
}

/**
Expand All @@ -2349,7 +2363,7 @@ public final Table stringSplit(Scalar delimiter) {
*/
public final Table stringSplit() {
try (Scalar emptyString = Scalar.fromString("")) {
return stringSplit(emptyString);
return stringSplit(emptyString, -1);
}
}

Expand All @@ -2362,7 +2376,7 @@ public final ColumnVector stringSplitRecord() {

/**
* Returns a column of lists of strings by splitting each string using whitespace as the delimiter.
* @param maxSplit the maximum number of records to split, or -1 for all of them.
* @param maxSplit the maximum number of splits to perform, or -1 for all possible splits.
*/
public final ColumnVector stringSplitRecord(int maxSplit) {
try (Scalar emptyString = Scalar.fromString("")) {
Expand All @@ -2384,7 +2398,7 @@ public final ColumnVector stringSplitRecord(Scalar delimiter) {
* string using the specified delimiter.
* @param delimiter UTF-8 encoded string identifying the split points in each string.
* An empty string indicates split on whitespace.
* @param maxSplit the maximum number of records to split, or -1 for all of them.
* @param maxSplit the maximum number of splits to perform, or -1 for all possible splits.
* @return New table of strings columns.
*/
public final ColumnVector stringSplitRecord(Scalar delimiter, int maxSplit) {
Expand Down Expand Up @@ -3490,8 +3504,9 @@ private static native long repeatStringsWithColumnRepeatTimes(long stringsHandle
* delimiter.
* @param columnView native handle of the cudf::column_view being operated on.
* @param delimiter UTF-8 encoded string identifying the split points in each string.
* @param maxSplit the maximum number of splits to perform, or -1 for all possible splits.
*/
private static native long[] stringSplit(long columnView, long delimiter);
private static native long[] stringSplit(long columnView, long delimiter, int maxSplit);

private static native long stringSplitRecord(long nativeView, long scalarHandle, int maxSplit);

Expand Down
5 changes: 3 additions & 2 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,8 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_listSortRows(JNIEnv *env,

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_stringSplit(JNIEnv *env, jclass,
jlong column_view,
jlong delimiter) {
jlong delimiter,
jint max_split) {
JNI_NULL_CHECK(env, column_view, "column is null", 0);
JNI_NULL_CHECK(env, delimiter, "string scalar delimiter is null", 0);
try {
Expand All @@ -570,7 +571,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_stringSplit(JNIEnv *
cudf::strings_column_view scv(*cv);
cudf::string_scalar *ss_scalar = reinterpret_cast<cudf::string_scalar *>(delimiter);

std::unique_ptr<cudf::table> table_result = cudf::strings::split(scv, *ss_scalar);
std::unique_ptr<cudf::table> table_result = cudf::strings::split(scv, *ss_scalar, max_split);
return cudf::jni::convert_table_for_return(env, table_result);
}
CATCH_STD(env, 0);
Expand Down
16 changes: 12 additions & 4 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4703,13 +4703,21 @@ void testStringSplitRecord() {

@Test
void testStringSplit() {
try (ColumnVector v = ColumnVector.fromStrings("Héllo there", "thésé", null, "", "ARé some", "test strings");
Table expected = new Table.TestBuilder().column("Héllo", "thésé", null, "", "ARé", "test")
try (ColumnVector v = ColumnVector.fromStrings("Héllo there all", "thésé", null, "", "ARé some things", "test strings here");
Table expectedSplitOnce = new Table.TestBuilder()
.column("Héllo", "thésé", null, "", "ARé", "test")
.column("there all", null, null, null, "some things", "strings here")
.build();
Table expectedSplitAll = new Table.TestBuilder()
.column("Héllo", "thésé", null, "", "ARé", "test")
.column("there", null, null, null, "some", "strings")
.column("all", null, null, null, "things", "here")
.build();
Scalar pattern = Scalar.fromString(" ");
Table result = v.stringSplit(pattern)) {
assertTablesAreEqual(expected, result);
Table resultSplitOnce = v.stringSplit(pattern, 1);
Table resultSplitAll = v.stringSplit(pattern)) {
assertTablesAreEqual(expectedSplitOnce, resultSplitOnce);
assertTablesAreEqual(expectedSplitAll, resultSplitAll);
}
}

Expand Down

0 comments on commit b290ec7

Please sign in to comment.