Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JNI for strings::code_points #14533

Merged
merged 5 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,16 @@ public final ColumnVector getByteCount() {
return new ColumnVector(byteCount(getNativeView()));
}

/**
* Get the code point values (integers) for each character of each string.
Copy link
Contributor

@jlowe jlowe Nov 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API seems very problematic in light of the effort to move to large strings. A strings column will soon support more than 2^31 characters. Calling this API on such a column will crash since it cannot manifest an INT32 column with more than 2^31 entries.

It also seems problematic from a usability point of view. Since it returns only a column of INT32 instead of LIST(INT32), it's not straightforward to figure out where the code points of one string stops and another starts. We can't use the offset column of the original string, since that's byte offsets instead of character offsets. I guess one would need to get the character lengths of the original string (converting nulls to zereoes) and then do a prefix scan to compute the code point offsets to know where one string's codepoints are in the result.

It also seems very wasteful for what NVIDIA/spark-rapids#9585 needs if called directly, since it will explode the memory of many string columns by 4X. We should first slice the original string column to only select the first character of each string. That would work around the large strings issue, the "where does a string start" issue, as well as the waste, since we only need the codepoint of the first character for that Spark feature.

Copy link
Contributor Author

@thirtiseven thirtiseven Dec 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review and analysis! I'm trying the 'only select the first character' way in the plugin.

Another problem with the spark issue is that the results of Latin-1 Supplement chars are mismatched between spark and code_points. For example é is 50089 for code_points and utf-8, and 233 for spark and Unicode (and Latin-1 and utf-16?), I'm trying to work around it but it is possible that we need a custom kernel for ascii.

*
* @return ColumnVector, with code point integer values for each character as INT32
*/
public final ColumnVector codePoints() {
assert type.equals(DType.STRING) : "type has to be a String";
return new ColumnVector(codePoints(getNativeView()));
}

/**
* Get the number of elements for each list. Null lists will have a value of null.
* @return the number of elements in each list as an INT32 value.
Expand Down Expand Up @@ -4510,6 +4520,8 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat

private static native long byteCount(long viewHandle) throws CudfException;

private static native long codePoints(long viewHandle);

private static native long extractListElement(long nativeView, int index);

private static native long extractListElementV(long nativeView, long indicesView);
Expand Down
11 changes: 11 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,17 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_byteCount(JNIEnv *env, jc
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_codePoints(JNIEnv *env, jclass clazz,
jlong view_handle) {
JNI_NULL_CHECK(env, view_handle, "input column is null", 0);
try {
cudf::jni::auto_set_device(env);
auto const input = reinterpret_cast<cudf::column_view const *>(view_handle);
return release_as_jlong(cudf::strings::code_points(cudf::strings_column_view{*input}));
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_findAndReplaceAll(JNIEnv *env, jclass clazz,
jlong old_values_handle,
jlong new_values_handle,
Expand Down
9 changes: 9 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2169,6 +2169,15 @@ void testGetByteCount() {
}
}

@Test
void testCodePoints() {
try (ColumnVector cv = ColumnVector.fromStrings("eee", "bb", null, "", "aa", "bbb", "ééé");
ColumnVector codePoints = cv.codePoints();
ColumnVector expected = ColumnVector.fromBoxedInts(101, 101, 101, 98, 98, 97, 97, 98, 98, 98, 50089, 50089, 50089)) {
assertColumnsAreEqual(expected, codePoints);
}
}

@Test
void testEmptyStringColumnOpts() {
try (ColumnVector cv = ColumnVector.fromStrings()) {
Expand Down