Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement JNI for strings:repeat_strings that repeats each string separately by different numbers of times #8572

Merged
merged 52 commits into from
Jul 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
f0b5ff7
Add doxygen
ttnghia Jun 17, 2021
a61c1e7
Finish implementation
ttnghia Jun 18, 2021
2639f9a
Finish unit tests
ttnghia Jun 18, 2021
dbbfbf9
Merge branch 'branch-21.08' into repeat_strings
ttnghia Jun 18, 2021
0dec873
Fix merge conflicts
ttnghia Jun 18, 2021
143f853
Rename parameter back to `input`
ttnghia Jun 21, 2021
b39bb06
Fix typo
ttnghia Jun 21, 2021
ae40591
Rewrite using type_dispatcher for different integer types
ttnghia Jun 21, 2021
8612f61
Fix comment typo
ttnghia Jun 21, 2021
3dfec42
Remove input check for int32_t data type
ttnghia Jun 21, 2021
d230498
Remove bool type from the expecting types for `repeat_times` data type
ttnghia Jun 21, 2021
7ec3c76
Merge branch 'repeat_strings' into jni_repeat_strings
ttnghia Jun 21, 2021
c56627d
Implement Java binding for repeatStringsWithColumnRepeatTimes
ttnghia Jun 21, 2021
834892d
Implement JNI for repeatStringsWithColumnRepeatTimes
ttnghia Jun 21, 2021
c4d3a36
Fix repeatString with scalar repeat times
ttnghia Jun 21, 2021
f534372
Implement overflow check for the new API, as it can't be done outside…
ttnghia Jun 21, 2021
554d20d
Update doxygen
ttnghia Jun 21, 2021
d00ba01
Add typed tests for various types of `repeat_times` column
ttnghia Jun 21, 2021
5b5c2a4
Fix doxygen
ttnghia Jun 21, 2021
d488eca
Simplify overflow checking
ttnghia Jun 21, 2021
8498fc6
Just re-order code
ttnghia Jun 21, 2021
855a774
Add a parameter to allow turning on/off overflow checking
ttnghia Jun 24, 2021
e5f5db8
Implement overflow checking
ttnghia Jun 25, 2021
50f05fd
Merge branch 'branch-21.08' into repeat_strings
ttnghia Jun 25, 2021
0bcf8d8
Redesign the API and update doxygen
ttnghia Jun 25, 2021
c7b7c3b
Add an optional column of pre-computed output strings offsets
ttnghia Jul 7, 2021
f6d7ee3
Merge branch 'branch-21.08' into repeat_strings
ttnghia Jul 8, 2021
d22c4e5
Finish implementation
ttnghia Jul 8, 2021
6124e83
Fix JNI
ttnghia Jul 8, 2021
90517aa
Merge branch 'branch-21.08' into repeat_strings
ttnghia Jul 8, 2021
7795f54
Cleanup
ttnghia Jul 8, 2021
9158e4f
Remove duplicate code
ttnghia Jul 9, 2021
5e37782
Add test for computing string output sizes that causes overflow
ttnghia Jul 9, 2021
4dfca75
Fix test build error
ttnghia Jul 9, 2021
95cb6c0
Merge branch 'branch-21.08' into repeat_strings
ttnghia Jul 9, 2021
91c414e
Simple fix comment typo
ttnghia Jul 9, 2021
20f79c6
Merge branch 'repeat_strings' into jni_repeat_strings
ttnghia Jul 12, 2021
5334708
Merge branch 'branch-21.08' into jni_repeat_strings
ttnghia Jul 12, 2021
0a173b8
WIP
ttnghia Jul 20, 2021
2c49666
Merge branch 'branch-21.08' into jni_repeat_strings
ttnghia Jul 21, 2021
31b3aa9
Rewrite javadoc
ttnghia Jul 21, 2021
62c24f5
Add Java API for repeatStringsOutputSizes
ttnghia Jul 22, 2021
d5c5964
Rewrite tests
ttnghia Jul 22, 2021
2c8ac37
Rewrite the output types for repeatStringsOutputSizes
ttnghia Jul 22, 2021
2df33d6
Rename functions
ttnghia Jul 22, 2021
22ab360
Add tests with pre-computed output string sizes
ttnghia Jul 22, 2021
272c24e
Merge branch 'branch-21.08' into jni_repeat_strings
ttnghia Jul 22, 2021
b537d81
Merge branch 'branch-21.10' into jni_repeat_strings
ttnghia Jul 22, 2021
8331836
Rename functions and variables
ttnghia Jul 23, 2021
2f32d40
Fix typo
ttnghia Jul 23, 2021
69000ae
Rewrite javadoc and add public qualifier for class
ttnghia Jul 23, 2021
7488a39
Merge branch 'branch-21.10' into jni_repeat_strings
ttnghia Jul 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 143 additions & 19 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -2327,24 +2327,103 @@ public final ColumnVector stringConcatenateListElements(Scalar separator,
}

/**
* Given a strings column, each string in the given column is repeated a number of times
* specified by the <code>repeatTimes</code> parameter. If the parameter has a non-positive value,
* all the rows of the output strings column will be an empty string. Any null row will result
* in a null row regardless of the value of <code>repeatTimes</code>.
* Given a strings column, each string in it is repeated a number of times specified by the
* <code>repeatTimes</code> parameter.
*
* Note that this function cannot handle the cases when the size of the output column exceeds
* the maximum value that can be indexed by int type (i.e., {@link Integer#MAX_VALUE}).
* In such situations, the output result is undefined.
* In special cases:
* - If <code>repeatTimes</code> is not a positive number, a non-null input string will always
* result in an empty output string.
* - A null input string will always result in a null output string regardless of the value of
* the <code>repeatTimes</code> parameter.
*
* @param repeatTimes The number of times each input string is copied to the output.
* The caller is responsible for checking the output column size will not exceed the maximum size
* of a strings column (number of total characters is less than the value {@link Integer#MAX_VALUE}).
*
* @param repeatTimes The number of times each input string is repeated.
* @return A new java column vector containing repeated strings.
*/
public final ColumnVector repeatStrings(int repeatTimes) {
assert type.equals(DType.STRING) : "column type must be a String";

assert type.equals(DType.STRING) : "column type must be String";
return new ColumnVector(repeatStrings(getNativeView(), repeatTimes));
}

/**
* Given a strings column, an output strings column is generated by repeating each of the input
* string by a number of times given by the corresponding row in a <code>repeatTimes</code>
* numeric column.
*
* In special cases:
* - Any null row (from either the input strings column or the <code>repeatTimes</code> column)
* will always result in a null output string.
* - If any value in the <code>repeatTimes</code> column is not a positive number and its
* corresponding input string is not null, the output string will be an empty string.
*
* The caller is responsible for checking the output column size will not exceed the maximum size
* of a strings column (number of total characters is less than the value {@link Integer#MAX_VALUE}).
*
* @param repeatTimes The column containing numbers of times each input string is repeated.
* @return A new java column vector containing repeated strings.
*/
public final ColumnVector repeatStrings(ColumnView repeatTimes) {
assert type.equals(DType.STRING) : "column type must be String";
return new ColumnVector(repeatStringsWithColumnRepeatTimes(getNativeView(),
repeatTimes.getNativeView(), 0));
}

/**
* This function is an overloaded version of {@link #repeatStrings(ColumnView) repeatStrings},
* with an additional parameter <code>outputStringSizes</code> that provides a column containing
* the pre-computed sizes of the output strings.
*
* @param repeatTimes The column containing numbers of times each input string is repeated.
* @param outputStringSizes A numeric column providing the pre-computed sizes of the output strings.
* @return A new java column vector containing repeated strings.
*/
public final ColumnVector repeatStrings(ColumnView repeatTimes, ColumnView outputStringSizes) {
assert type.equals(DType.STRING) : "column type must be String";
return new ColumnVector(repeatStringsWithColumnRepeatTimes(getNativeView(),
repeatTimes.getNativeView(), outputStringSizes.getNativeView()));
}

/** Struct to return the computed strings size column and total size */
public static final class StringSizes implements AutoCloseable {
private final ColumnVector stringSizes;
private final long totalSize;

StringSizes(ColumnVector stringSizes, long totalSize) {
this.stringSizes = stringSizes;
this.totalSize = totalSize;
}

public ColumnVector getStringSizes() { return stringSizes; }
public long getTotalSize() { return totalSize; }

/** Close the underlying resources */
@Override
public void close() {
if (stringSizes != null) {
stringSizes.close();
}
}
};

/**
* Compute sizes of the output strings if each string in an input strings column is repeated by
* a different number of times given by the corresponding row in a <code>repeatTimes</code>
* numeric column (i.e., compute sizes of the strings resulted from
* {@link #repeatStringsWithColumnRepeatTimes}).
*
* @param repeatTimes The column containing numbers of times each input string is repeated.
* @return An instance of StringSizes class which stores a Java column vector containing
* the computed sizes of the repeated strings, and a long value storing sum of all these
* computed sizes.
*/
public final StringSizes repeatStringsSizes(ColumnView repeatTimes) {
assert type.equals(DType.STRING) : "column type must be String";
final long[] sizes = repeatStringsSizes(getNativeView(), repeatTimes.getNativeView());
return new StringSizes(new ColumnVector(sizes[0]), sizes[1]);
}

/**
* Apply a JSONPath string to all rows in an input strings column.
*
Expand Down Expand Up @@ -3022,21 +3101,66 @@ private static native long stringConcatenationListElements(long listColumnHandle
boolean emptyStringOutputIfEmptyList);

/**
* Native method to repeat each string in the given strings column a number of times
* specified by the <code>repeatTimes</code> parameter. If the parameter has a non-positive value,
* all the rows of the output strings column will be an empty string. Any null row will result
* in a null row regardless of the value of <code>repeatTimes</code>.
* Native method to repeat each string in the given input strings column a number of times
* specified by the <code>repeatTimes</code> parameter.
*
* In special cases:
* - If <code>repeatTimes</code> is not a positive number, a non-null input string will always
* result in an empty output string.
* - A null input string will always result in a null output string regardless of the value of
* the <code>repeatTimes</code> parameter.
*
* Note that this function cannot handle the cases when the size of the output column exceeds
* the maximum value that can be indexed by int type (i.e., {@link Integer#MAX_VALUE}).
* In such situations, the output result is undefined.
* The caller is responsible for checking the output column size will not exceed the maximum size
* of a strings column (number of total characters is less than the value {@link Integer#MAX_VALUE}).
*
* @param viewHandle long holding the native handle of the column containing strings to repeat.
* @param repeatTimes The number of times each input string is copied to the output.
* @return native handle of the resulting cudf column containing repeated strings.
* @param repeatTimes The number of times each input string is repeated.
* @return native handle of the resulting cudf strings column containing repeated strings.
*/
private static native long repeatStrings(long viewHandle, int repeatTimes);

/**
* Native method to repeat strings in the given input strings column, each string is repeated
* by a different number of times given by the corresponding row in a <code>repeatTimes</code>
* numeric column.
*
* In special cases:
* - Any null row (from either the input strings column or the <code>repeatTimes</code> column)
* will always result in a null output string.
* - If any value in the <code>repeatTimes</code> column is not a positive number and its
* corresponding input string is not null, the output string will be an empty string.
*
* The caller is responsible for checking the output column size will not exceed the maximum size
* of a strings column (number of total characters is less than the value {@link Integer#MAX_VALUE}).
*
* If the input <code>repeatTimesHandle</code> column does not have a numeric type, or it has a
* size that is different from size of the input strings column, an exception will be thrown.
*
* @param stringsHandle long holding the native handle of the column containing strings to repeat.
* @param repeatTimesHandle long holding the native handle of the column containing the numbers
* of times each input string is repeated.
* @param outputStringSizesHandle long holding the native handle of the column containing the
* pre-computed sizes of the output strings, can be specified as
* <code>0</code> value if that column is not available.
* @return native handle of the resulting cudf strings column containing repeated strings.
*/
private static native long repeatStringsWithColumnRepeatTimes(long stringsHandle,
long repeatTimesHandle, long outputStringSizesHandle);

/**
* Native method to compute sizes of the output strings if each string in the input strings
* column is repeated by a different number of times given by the corresponding row in a
* <code>repeatTimes</code> numeric column (i.e., compute sizes of the strings resulted from
* {@link #repeatStringsWithColumnRepeatTimes}).
*
* @param stringsHandle long holding the native handle of the column containing strings to repeat.
* @param repeatTimesHandle long holding the native handle of the column containing the numbers
* of times each input string is repeated.
* @return An array of two long values, the first one holds native handle of a numeric column
* containing the computed sizes of the repeated strings, and the second value is the sum
* of all those string sizes.
*/
private static native long[] repeatStringsSizes(long stringsHandle, long repeatTimesHandle);

private static native long getJSONObject(long viewHandle, long scalarHandle) throws CudfException;

Expand Down
48 changes: 45 additions & 3 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1943,16 +1943,58 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringConcatenationListEl
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_repeatStrings(JNIEnv *env, jclass,
jlong column_handle,
jlong strings_handle,
jint repeat_times) {
JNI_NULL_CHECK(env, column_handle, "column handle is null", 0);
JNI_NULL_CHECK(env, strings_handle, "strings_handle is null", 0);
try {
cudf::jni::auto_set_device(env);
auto const cv = *reinterpret_cast<cudf::column_view *>(column_handle);
auto const cv = *reinterpret_cast<cudf::column_view *>(strings_handle);
auto const strs_col = cudf::strings_column_view(cv);
return reinterpret_cast<jlong>(cudf::strings::repeat_strings(strs_col, repeat_times).release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_repeatStringsWithColumnRepeatTimes(
JNIEnv *env, jclass, jlong strings_handle, jlong repeat_times_handle,
jlong output_sizes_handle) {
JNI_NULL_CHECK(env, strings_handle, "strings_handle is null", 0);
JNI_NULL_CHECK(env, repeat_times_handle, "repeat_times_handle is null", 0);
try {
cudf::jni::auto_set_device(env);
auto const strings_cv = *reinterpret_cast<cudf::column_view *>(strings_handle);
auto const strs_col = cudf::strings_column_view(strings_cv);
auto const repeat_times_cv = *reinterpret_cast<cudf::column_view *>(repeat_times_handle);
if (output_sizes_handle != 0) {
auto const output_sizes_cv = *reinterpret_cast<cudf::column_view *>(output_sizes_handle);
return reinterpret_cast<jlong>(
cudf::strings::repeat_strings(strs_col, repeat_times_cv, output_sizes_cv).release());
} else {
return reinterpret_cast<jlong>(
cudf::strings::repeat_strings(strs_col, repeat_times_cv).release());
}
}
CATCH_STD(env, 0);
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_repeatStringsSizes(
JNIEnv *env, jclass, jlong strings_handle, jlong repeat_times_handle) {
JNI_NULL_CHECK(env, strings_handle, "strings handle is null", 0);
JNI_NULL_CHECK(env, repeat_times_handle, "repeat_times handle is null", 0);
try {
cudf::jni::auto_set_device(env);
auto const strings_cv = *reinterpret_cast<cudf::column_view *>(strings_handle);
auto const strs_col = cudf::strings_column_view(strings_cv);
auto const repeat_times_cv = *reinterpret_cast<cudf::column_view *>(repeat_times_handle);

auto [output_sizes, total_bytes] =
cudf::strings::repeat_strings_output_sizes(strs_col, repeat_times_cv);
auto results = cudf::jni::native_jlongArray(env, 2);
results[0] = reinterpret_cast<jlong>(output_sizes.release());
results[1] = static_cast<jlong>(total_bytes);
return results.get_jArray();
}
CATCH_STD(env, 0);
}

} // extern "C"
112 changes: 95 additions & 17 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2634,38 +2634,116 @@ void testStringConcatWsSingleListColEmptyArrayReturnEmpty() {
}

@Test
void testRepeatStrings() {
// Empty strings column.
try (ColumnVector sv = ColumnVector.fromStrings("", "", "");
ColumnVector result = sv.repeatStrings(1)) {
assertColumnsAreEqual(sv, result);
void testRepeatStringsWithScalarRepeatTimes() {
// Empty strings column.
try (ColumnVector input = ColumnVector.fromStrings("", "", "");
ColumnVector results = input.repeatStrings(1)) {
assertColumnsAreEqual(input, results);
}

// Zero repeatTimes.
try (ColumnVector sv = ColumnVector.fromStrings("abc", "xyz", "123");
ColumnVector result = sv.repeatStrings(0);
try (ColumnVector input = ColumnVector.fromStrings("abc", "xyz", "123");
ColumnVector results = input.repeatStrings(0);
ColumnVector expected = ColumnVector.fromStrings("", "", "")) {
assertColumnsAreEqual(expected, result);
assertColumnsAreEqual(expected, results);
}

// Negative repeatTimes.
try (ColumnVector sv = ColumnVector.fromStrings("abc", "xyz", "123");
ColumnVector result = sv.repeatStrings(-1);
try (ColumnVector input = ColumnVector.fromStrings("abc", "xyz", "123");
ColumnVector results = input.repeatStrings(-1);
ColumnVector expected = ColumnVector.fromStrings("", "", "")) {
assertColumnsAreEqual(expected, result);
assertColumnsAreEqual(expected, results);
}

// Strings column containing both null and empty, output is copied exactly from input.
try (ColumnVector sv = ColumnVector.fromStrings("abc", "", null, "123", null);
ColumnVector result = sv.repeatStrings(1)) {
assertColumnsAreEqual(sv, result);
try (ColumnVector input = ColumnVector.fromStrings("abc", "", null, "123", null);
ColumnVector results = input.repeatStrings(1)) {
assertColumnsAreEqual(input, results);
}

// Strings column containing both null and empty.
try (ColumnVector sv = ColumnVector.fromStrings("abc", "", null, "123", null);
ColumnVector result = sv.repeatStrings( 2);
try (ColumnVector input = ColumnVector.fromStrings("abc", "", null, "123", null);
ColumnVector results = input.repeatStrings( 2);
ColumnVector expected = ColumnVector.fromStrings("abcabc", "", null, "123123", null)) {
assertColumnsAreEqual(expected, result);
assertColumnsAreEqual(expected, results);
}
}

@Test
void testRepeatStringsWithColumnRepeatTimes() {
// Empty strings column.
try (ColumnVector input = ColumnVector.fromStrings("", "", "");
ColumnVector repeatTimes = ColumnVector.fromInts(-1, 0, 1);
ColumnVector results = input.repeatStrings(repeatTimes)) {
assertColumnsAreEqual(input, results);
}

// Zero and negative repeatTimes.
try (ColumnVector input = ColumnVector.fromStrings("abc", "xyz", "123", "456", "789", "a1");
ColumnVector repeatTimes = ColumnVector.fromInts(-200, -100, 0, 0, 1, 2);
ColumnVector results = input.repeatStrings(repeatTimes);
ColumnVector expected = ColumnVector.fromStrings("", "", "", "", "789", "a1a1")) {
assertColumnsAreEqual(expected, results);
}

// Strings column contains both null and empty, output is copied exactly from input.
try (ColumnVector input = ColumnVector.fromStrings("abc", "", null, "123", null);
ColumnVector repeatTimes = ColumnVector.fromInts(1, 1, 1, 1, 1);
ColumnVector results = input.repeatStrings(repeatTimes)) {
assertColumnsAreEqual(input, results);
}

// Strings column contains both null and empty.
try (ColumnVector input = ColumnVector.fromStrings("abc", "", null, "123", null);
ColumnVector repeatTimes = ColumnVector.fromInts(2, 3, 1, 3, 2);
ColumnVector results = input.repeatStrings(repeatTimes);
ColumnVector expected = ColumnVector.fromStrings("abcabc", "", null, "123123123", null)) {
assertColumnsAreEqual(expected, results);
}
}

@Test
void testRepeatStringsWithColumnRepeatTimesAndPrecomputedOutputSizes() {
// Empty strings column.
try (ColumnVector input = ColumnVector.fromStrings("", "", "");
ColumnVector repeatTimes = ColumnVector.fromInts(-1, 0, 1);
ColumnView.StringSizes outputSizes = input.repeatStringsSizes(repeatTimes)) {
assertEquals(0, outputSizes.getTotalSize());
try (ColumnVector results = input.repeatStrings(repeatTimes, outputSizes.getStringSizes())) {
assertColumnsAreEqual(input, results);
}
}

// Zero and negative repeatTimes.
try (ColumnVector input = ColumnVector.fromStrings("abc", "xyz", "123", "456", "789", "a1");
ColumnVector repeatTimes = ColumnVector.fromInts(-200, -100, 0, 0, 1, 2);
ColumnVector expected = ColumnVector.fromStrings("", "", "", "", "789", "a1a1");
ColumnView.StringSizes outputSizes = input.repeatStringsSizes(repeatTimes)) {
assertEquals(7, outputSizes.getTotalSize());
try (ColumnVector results = input.repeatStrings(repeatTimes, outputSizes.getStringSizes())) {
assertColumnsAreEqual(expected, results);
}
}

// Strings column contains both null and empty, output is copied exactly from input.
try (ColumnVector input = ColumnVector.fromStrings("abc", "", null, "123", null);
ColumnVector repeatTimes = ColumnVector.fromInts(1, 1, 1, 1, 1);
ColumnView.StringSizes outputSizes = input.repeatStringsSizes(repeatTimes)) {
assertEquals(6, outputSizes.getTotalSize());
try (ColumnVector results = input.repeatStrings(repeatTimes, outputSizes.getStringSizes())) {
assertColumnsAreEqual(input, results);
}
}

// Strings column contains both null and empty.
try (ColumnVector input = ColumnVector.fromStrings("abc", "", null, "123", null);
ColumnVector repeatTimes = ColumnVector.fromInts(2, 3, 1, 3, 2);
ColumnVector expected = ColumnVector.fromStrings("abcabc", "", null, "123123123", null);
ColumnView.StringSizes outputSizes = input.repeatStringsSizes(repeatTimes)) {
assertEquals(15, outputSizes.getTotalSize());
try (ColumnVector results = input.repeatStrings(repeatTimes, outputSizes.getStringSizes())) {
assertColumnsAreEqual(expected, results);
}
}
}

Expand Down