Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JNI for strings::repeat_strings [skip ci] #8491

Merged
merged 6 commits into from
Jun 11, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -2194,6 +2194,25 @@ public final ColumnVector stringConcatenateListElements(Scalar separator,
emptyStringOutputIfEmptyList));
}

/**
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
* Given a strings column, each string in the given column is repeated by a number of times
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
* specified by the @p `repeat_times` parameter. If `repeat_times` is not a positive value,
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
* all the rows of the output strings column will be an empty string. Any null row will result
* in a null row regardless of the value of `repeat_times` parameter.
*
* Note that this function cannot handle the cases when the size of the output column exceeds
* the maximum value that can be indexed by int type. In such situations, the output result
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
* is undefined.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting, could not we just repeat as many times as fit ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bound check should be done on the plugin side to make sure the output is correct. I tried to implement bounds check in cudf C++ but was rejected.

*
* @param repeatTimes The number of times each input string is copied to the output.
* @return A new java column vector containing repeated strings.
*/
public final ColumnVector repeatStrings(int repeatTimes) {
assert type.equals(DType.STRING) : "column type must be a String";

return new ColumnVector(repeatStrings(getNativeView(), repeatTimes));
}

/**
* Apply a JSONPath string to all rows in an input strings column.
*
Expand Down Expand Up @@ -2870,6 +2889,23 @@ private static native long stringConcatenationListElements(long listColumnHandle
boolean separateNulls,
boolean emptyStringOutputIfEmptyList);

/**
* Native method to repeat each string in the given strings column by a given number of times
* given by the @p `repeat_times` parameter. If `repeat_times` is not a positive value, all the
* rows of the output strings column will be an empty string. Any null row will result in a null
* row regardless of the value of `repeat_times` parameter.
*
* Note that this function cannot handle the cases when the size of the output column exceeds
* the maximum value that can be indexed by int type. In such situations, the output result
* is undefined.
*
* @param viewHandle long holding the native handle of the column containing strings to repeat.
* @param repeatTimes The number of times each input string is copied to the output.
* @return native handle of the resulting cudf column containing repeated strings.
*/
private static native long repeatStrings(long viewHandle, int repeatTimes);


private static native long getJSONObject(long viewHandle, long scalarHandle) throws CudfException;

/**
Expand Down
14 changes: 14 additions & 0 deletions java/src/main/java/ai/rapids/cudf/Scalar.java
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,17 @@ private static ColumnVector buildNullColumnVector(HostColumnVector.DataType host
private static native long makeListScalar(long viewHandle, boolean isValid);
private static native long makeStructScalar(long[] viewHandles, boolean isValid);

/**
* Native method to repeat the given string scalar by a given number of times
* given by the @p `repeat_times` parameter. If `repeat_times` is not a positive value, an
* empty (valid) string scalar will be returned. An invalid input scalar will always
* result in an invalid output scalar regardless of the value of `repeat_times` parameter.
*
* @param scalarHandle long holding the native handle of the scalar containing strings to repeat.
* @param repeatTimes The number of times the input string is copied to the output.
* @return native handle of the resulting cudf string_scalar containing repeated input string.
*/
private static native long repeatString(long scalarHandle, int repeatTimes);

Scalar(DType type, long scalarHandle) {
this.type = type;
Expand Down Expand Up @@ -865,6 +876,9 @@ public String toString() {
return sb.toString();
}

public Scalar repeatString(int repeatTimes) {
return new Scalar(DType.STRING, repeatString(getScalarHandle(), repeatTimes));
}
/**
* Holds the off-heap state of the scalar so it can be cleaned up, even if it is leaked.
*/
Expand Down
14 changes: 14 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
#include <cudf/strings/extract.hpp>
#include <cudf/strings/find.hpp>
#include <cudf/strings/padding.hpp>
#include <cudf/strings/repeat_strings.hpp>
#include <cudf/strings/replace.hpp>
#include <cudf/strings/replace_re.hpp>
#include <cudf/strings/split/split.hpp>
Expand Down Expand Up @@ -1962,4 +1963,17 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringConcatenationListEl
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_repeatStrings(JNIEnv *env, jclass,
jlong column_handle,
jint repeat_times) {
JNI_NULL_CHECK(env, column_handle, "column handle is null", 0);
try {
cudf::jni::auto_set_device(env);
auto const cv = *reinterpret_cast<cudf::column_view *>(column_handle);
auto const strs_col = cudf::strings_column_view(cv);
return reinterpret_cast<jlong>(cudf::strings::repeat_strings(strs_col, repeat_times).release());
}
CATCH_STD(env, 0);
}

} // extern "C"
12 changes: 12 additions & 0 deletions java/src/main/native/src/ScalarJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <cudf/binaryop.hpp>
#include <cudf/fixed_point/fixed_point.hpp>
#include <cudf/scalar/scalar_factories.hpp>
#include <cudf/strings/repeat_strings.hpp>

#include "cudf_jni_apis.hpp"
#include "dtype_utils.hpp"
Expand Down Expand Up @@ -512,4 +513,15 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_makeStructScalar(JNIEnv *env,
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_repeatString(JNIEnv *env, jclass, jlong handle,
jint repeat_times) {
JNI_NULL_CHECK(env, handle, "scalar handle is null", 0)
try {
cudf::jni::auto_set_device(env);
auto const str = *reinterpret_cast<cudf::string_scalar *>(handle);
return reinterpret_cast<jlong>(cudf::strings::repeat_strings(str, repeat_times).release());
}
CATCH_STD(env, 0);
}

} // extern "C"
36 changes: 36 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2621,6 +2621,42 @@ void testStringConcatWsSingleListColEmptyArrayReturnEmpty() {
}
}

@Test
void testRepeatStrings() {
// Empty strings column.
try (ColumnVector sv = ColumnVector.fromStrings("", "", "");
ColumnVector result = sv.repeatStrings(1)) {
assertColumnsAreEqual(sv, result);
}

// Zero repeatTimes.
try (ColumnVector sv = ColumnVector.fromStrings("abc", "xyz", "123");
ColumnVector result = sv.repeatStrings(0);
ColumnVector expected = ColumnVector.fromStrings("", "", "")) {
assertColumnsAreEqual(expected, result);
}

// Negative repeatTimes.
try (ColumnVector sv = ColumnVector.fromStrings("abc", "xyz", "123");
ColumnVector result = sv.repeatStrings(-1);
ColumnVector expected = ColumnVector.fromStrings("", "", "")) {
assertColumnsAreEqual(expected, result);
}

// Strings column containing both null and empty, output is copied exactly from input.
try (ColumnVector sv = ColumnVector.fromStrings("abc", "", null, "123", null);
ColumnVector result = sv.repeatStrings(1)) {
assertColumnsAreEqual(sv, result);
}

// Strings column containing both null and empty.
try (ColumnVector sv = ColumnVector.fromStrings("abc", "", null, "123", null);
ColumnVector result = sv.repeatStrings( 2);
ColumnVector expected = ColumnVector.fromStrings("abcabc", "", null, "123123", null)) {
assertColumnsAreEqual(expected, result);
}
}

@Test
void testListConcatByRow() {
try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true,
Expand Down
56 changes: 50 additions & 6 deletions java/src/test/java/ai/rapids/cudf/ScalarTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@

package ai.rapids.cudf;

import static ai.rapids.cudf.TableTest.assertColumnsAreEqual;
import static org.junit.jupiter.api.Assertions.*;

import ai.rapids.cudf.HostColumnVector.BasicType;
import ai.rapids.cudf.HostColumnVector.DataType;
import ai.rapids.cudf.HostColumnVector.ListType;
import ai.rapids.cudf.HostColumnVector.StructData;
import ai.rapids.cudf.HostColumnVector.StructType;

import org.junit.jupiter.api.Test;

import java.math.BigDecimal;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import org.junit.jupiter.api.Test;

import static ai.rapids.cudf.TableTest.assertColumnsAreEqual;
import static org.junit.jupiter.api.Assertions.*;

public class ScalarTest extends CudfTestBase {
@Test
Expand Down Expand Up @@ -405,4 +405,48 @@ public void testStruct() {
}
}
}

@Test
public void testRepeatString() {
// Invalid scalar.
try (Scalar nullString = Scalar.fromString(null)) {
Scalar result = nullString.repeatString(5);
assertFalse(result.isValid());
}

// Empty string.
try (Scalar emptyString = Scalar.fromString("")) {
Scalar result = emptyString.repeatString(5);
assertTrue(result.isValid());
assertEquals("", result.getJavaString());
}

// Negative repeatTimes.
try (Scalar s = Scalar.fromString("Hello World");
Scalar result = s.repeatString(-100)) {
assertTrue(result.isValid());
assertEquals("", result.getJavaString());
}

// Zero repeatTimes.
try (Scalar s = Scalar.fromString("Hello World");
Scalar result = s.repeatString(0)) {
assertTrue(result.isValid());
assertEquals("", result.getJavaString());
}

// Trivial input, output is copied exactly from input.
try (Scalar s = Scalar.fromString("Hello World");
Scalar result = s.repeatString(1)) {
assertTrue(result.isValid());
assertEquals(s.getJavaString(), result.getJavaString());
}

// Trivial input.
try (Scalar s = Scalar.fromString("abcxyz-");
Scalar result = s.repeatString(3)) {
assertTrue(result.isValid());
assertEquals("abcxyz-abcxyz-abcxyz-", result.getJavaString());
}
}
}