diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index a2e080e02f6..422311fc8e0 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -1,6 +1,6 @@ /* * - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -2331,13 +2331,27 @@ public final ColumnVector stringLocate(Scalar substring, int start, int end) { * Null string entries return corresponding null output columns. * @param delimiter UTF-8 encoded string identifying the split points in each string. * An empty string indicates split on whitespace. + * @param maxSplit the maximum number of splits to perform, or -1 for all possible splits. * @return New table of strings columns. */ - public final Table stringSplit(Scalar delimiter) { + public final Table stringSplit(Scalar delimiter, int maxSplit) { assert type.equals(DType.STRING) : "column type must be a String"; assert delimiter != null : "delimiter may not be null"; assert delimiter.getType().equals(DType.STRING) : "delimiter must be a string scalar"; - return new Table(stringSplit(this.getNativeView(), delimiter.getScalarHandle())); + return new Table(stringSplit(this.getNativeView(), delimiter.getScalarHandle(), maxSplit)); + } + + /** + * Returns a list of columns by splitting each string using the specified delimiter. + * The number of rows in the output columns will be the same as the input column. + * Null entries are added for a row where split results have been exhausted. + * Null string entries return corresponding null output columns. + * @param delimiter UTF-8 encoded string identifying the split points in each string. + * An empty string indicates split on whitespace. + * @return New table of strings columns. + */ + public final Table stringSplit(Scalar delimiter) { + return stringSplit(delimiter, -1); } /** @@ -2349,7 +2363,7 @@ public final Table stringSplit(Scalar delimiter) { */ public final Table stringSplit() { try (Scalar emptyString = Scalar.fromString("")) { - return stringSplit(emptyString); + return stringSplit(emptyString, -1); } } @@ -2362,7 +2376,7 @@ public final ColumnVector stringSplitRecord() { /** * Returns a column of lists of strings by splitting each string using whitespace as the delimiter. - * @param maxSplit the maximum number of records to split, or -1 for all of them. + * @param maxSplit the maximum number of splits to perform, or -1 for all possible splits. */ public final ColumnVector stringSplitRecord(int maxSplit) { try (Scalar emptyString = Scalar.fromString("")) { @@ -2384,7 +2398,7 @@ public final ColumnVector stringSplitRecord(Scalar delimiter) { * string using the specified delimiter. * @param delimiter UTF-8 encoded string identifying the split points in each string. * An empty string indicates split on whitespace. - * @param maxSplit the maximum number of records to split, or -1 for all of them. + * @param maxSplit the maximum number of splits to perform, or -1 for all possible splits. * @return New table of strings columns. */ public final ColumnVector stringSplitRecord(Scalar delimiter, int maxSplit) { @@ -3490,8 +3504,9 @@ private static native long repeatStringsWithColumnRepeatTimes(long stringsHandle * delimiter. * @param columnView native handle of the cudf::column_view being operated on. * @param delimiter UTF-8 encoded string identifying the split points in each string. + * @param maxSplit the maximum number of splits to perform, or -1 for all possible splits. */ - private static native long[] stringSplit(long columnView, long delimiter); + private static native long[] stringSplit(long columnView, long delimiter, int maxSplit); private static native long stringSplitRecord(long nativeView, long scalarHandle, int maxSplit); diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index 38c6bb3740e..fe1173e417e 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -561,7 +561,8 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_listSortRows(JNIEnv *env, JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_stringSplit(JNIEnv *env, jclass, jlong column_view, - jlong delimiter) { + jlong delimiter, + jint max_split) { JNI_NULL_CHECK(env, column_view, "column is null", 0); JNI_NULL_CHECK(env, delimiter, "string scalar delimiter is null", 0); try { @@ -570,7 +571,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_stringSplit(JNIEnv * cudf::strings_column_view scv(*cv); cudf::string_scalar *ss_scalar = reinterpret_cast(delimiter); - std::unique_ptr table_result = cudf::strings::split(scv, *ss_scalar); + std::unique_ptr table_result = cudf::strings::split(scv, *ss_scalar, max_split); return cudf::jni::convert_table_for_return(env, table_result); } CATCH_STD(env, 0); diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 8d4bbff1542..2dbec454eb2 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -4703,13 +4703,21 @@ void testStringSplitRecord() { @Test void testStringSplit() { - try (ColumnVector v = ColumnVector.fromStrings("Héllo there", "thésé", null, "", "ARé some", "test strings"); - Table expected = new Table.TestBuilder().column("Héllo", "thésé", null, "", "ARé", "test") + try (ColumnVector v = ColumnVector.fromStrings("Héllo there all", "thésé", null, "", "ARé some things", "test strings here"); + Table expectedSplitOnce = new Table.TestBuilder() + .column("Héllo", "thésé", null, "", "ARé", "test") + .column("there all", null, null, null, "some things", "strings here") + .build(); + Table expectedSplitAll = new Table.TestBuilder() + .column("Héllo", "thésé", null, "", "ARé", "test") .column("there", null, null, null, "some", "strings") + .column("all", null, null, null, "things", "here") .build(); Scalar pattern = Scalar.fromString(" "); - Table result = v.stringSplit(pattern)) { - assertTablesAreEqual(expected, result); + Table resultSplitOnce = v.stringSplit(pattern, 1); + Table resultSplitAll = v.stringSplit(pattern)) { + assertTablesAreEqual(expectedSplitOnce, resultSplitOnce); + assertTablesAreEqual(expectedSplitAll, resultSplitAll); } }