Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JNI for strings::repeat_strings [skip ci] #8491

Merged
merged 6 commits into from
Jun 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -2194,6 +2194,25 @@ public final ColumnVector stringConcatenateListElements(Scalar separator,
emptyStringOutputIfEmptyList));
}

/**
* Given a strings column, each string in the given column is repeated a number of times
* specified by the <code>repeatTimes</code> parameter. If the parameter has a non-positive value,
* all the rows of the output strings column will be an empty string. Any null row will result
* in a null row regardless of the value of <code>repeatTimes</code>.
*
* Note that this function cannot handle the cases when the size of the output column exceeds
* the maximum value that can be indexed by int type (i.e., {@link Integer#MAX_VALUE}).
* In such situations, the output result is undefined.
*
* @param repeatTimes The number of times each input string is copied to the output.
* @return A new java column vector containing repeated strings.
*/
public final ColumnVector repeatStrings(int repeatTimes) {
assert type.equals(DType.STRING) : "column type must be a String";

return new ColumnVector(repeatStrings(getNativeView(), repeatTimes));
}

/**
* Apply a JSONPath string to all rows in an input strings column.
*
Expand Down Expand Up @@ -2870,6 +2889,23 @@ private static native long stringConcatenationListElements(long listColumnHandle
boolean separateNulls,
boolean emptyStringOutputIfEmptyList);

/**
* Native method to repeat each string in the given strings column a number of times
* specified by the <code>repeatTimes</code> parameter. If the parameter has a non-positive value,
* all the rows of the output strings column will be an empty string. Any null row will result
* in a null row regardless of the value of <code>repeatTimes</code>.
*
* Note that this function cannot handle the cases when the size of the output column exceeds
* the maximum value that can be indexed by int type (i.e., {@link Integer#MAX_VALUE}).
* In such situations, the output result is undefined.
*
* @param viewHandle long holding the native handle of the column containing strings to repeat.
* @param repeatTimes The number of times each input string is copied to the output.
* @return native handle of the resulting cudf column containing repeated strings.
*/
private static native long repeatStrings(long viewHandle, int repeatTimes);


private static native long getJSONObject(long viewHandle, long scalarHandle) throws CudfException;

/**
Expand Down
16 changes: 15 additions & 1 deletion java/src/main/java/ai/rapids/cudf/Scalar.java
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ private static ColumnVector buildNullColumnVector(HostColumnVector.DataType host
private static native long makeDecimal64Scalar(long value, int scale, boolean isValid);
private static native long makeListScalar(long viewHandle, boolean isValid);
private static native long makeStructScalar(long[] viewHandles, boolean isValid);

private static native long repeatString(long scalarHandle, int repeatTimes);

Scalar(DType type, long scalarHandle) {
this.type = type;
Expand Down Expand Up @@ -865,6 +865,20 @@ public String toString() {
return sb.toString();
}


/**
* Repeat the given string scalar a number of times specified by the <code>repeatTimes</code>
* parameter. If that parameter has a non-positive value, an empty (valid) string scalar will be
* returned. An invalid input scalar will always result in an invalid output scalar regardless
* of the value of <code>repeatTimes</code>.
*
* @param repeatTimes The number of times the input string is copied to the output.
* @return The resulting scalar containing repeated result of the current string.
*/
public Scalar repeatString(int repeatTimes) {
return new Scalar(DType.STRING, repeatString(getScalarHandle(), repeatTimes));
}

/**
* Holds the off-heap state of the scalar so it can be cleaned up, even if it is leaked.
*/
Expand Down
14 changes: 14 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
#include <cudf/strings/extract.hpp>
#include <cudf/strings/find.hpp>
#include <cudf/strings/padding.hpp>
#include <cudf/strings/repeat_strings.hpp>
#include <cudf/strings/replace.hpp>
#include <cudf/strings/replace_re.hpp>
#include <cudf/strings/split/split.hpp>
Expand Down Expand Up @@ -1962,4 +1963,17 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringConcatenationListEl
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_repeatStrings(JNIEnv *env, jclass,
jlong column_handle,
jint repeat_times) {
JNI_NULL_CHECK(env, column_handle, "column handle is null", 0);
try {
cudf::jni::auto_set_device(env);
auto const cv = *reinterpret_cast<cudf::column_view *>(column_handle);
auto const strs_col = cudf::strings_column_view(cv);
return reinterpret_cast<jlong>(cudf::strings::repeat_strings(strs_col, repeat_times).release());
}
CATCH_STD(env, 0);
}

} // extern "C"
12 changes: 12 additions & 0 deletions java/src/main/native/src/ScalarJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <cudf/binaryop.hpp>
#include <cudf/fixed_point/fixed_point.hpp>
#include <cudf/scalar/scalar_factories.hpp>
#include <cudf/strings/repeat_strings.hpp>

#include "cudf_jni_apis.hpp"
#include "dtype_utils.hpp"
Expand Down Expand Up @@ -512,4 +513,15 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_makeStructScalar(JNIEnv *env,
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_repeatString(JNIEnv *env, jclass, jlong handle,
jint repeat_times) {
JNI_NULL_CHECK(env, handle, "scalar handle is null", 0)
try {
cudf::jni::auto_set_device(env);
auto const str = *reinterpret_cast<cudf::string_scalar *>(handle);
return reinterpret_cast<jlong>(cudf::strings::repeat_strings(str, repeat_times).release());
}
CATCH_STD(env, 0);
}

} // extern "C"
36 changes: 36 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2621,6 +2621,42 @@ void testStringConcatWsSingleListColEmptyArrayReturnEmpty() {
}
}

@Test
void testRepeatStrings() {
// Empty strings column.
try (ColumnVector sv = ColumnVector.fromStrings("", "", "");
ColumnVector result = sv.repeatStrings(1)) {
assertColumnsAreEqual(sv, result);
}

// Zero repeatTimes.
try (ColumnVector sv = ColumnVector.fromStrings("abc", "xyz", "123");
ColumnVector result = sv.repeatStrings(0);
ColumnVector expected = ColumnVector.fromStrings("", "", "")) {
assertColumnsAreEqual(expected, result);
}

// Negative repeatTimes.
try (ColumnVector sv = ColumnVector.fromStrings("abc", "xyz", "123");
ColumnVector result = sv.repeatStrings(-1);
ColumnVector expected = ColumnVector.fromStrings("", "", "")) {
assertColumnsAreEqual(expected, result);
}

// Strings column containing both null and empty, output is copied exactly from input.
try (ColumnVector sv = ColumnVector.fromStrings("abc", "", null, "123", null);
ColumnVector result = sv.repeatStrings(1)) {
assertColumnsAreEqual(sv, result);
}

// Strings column containing both null and empty.
try (ColumnVector sv = ColumnVector.fromStrings("abc", "", null, "123", null);
ColumnVector result = sv.repeatStrings( 2);
ColumnVector expected = ColumnVector.fromStrings("abcabc", "", null, "123123", null)) {
assertColumnsAreEqual(expected, result);
}
}

@Test
void testListConcatByRow() {
try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true,
Expand Down
56 changes: 50 additions & 6 deletions java/src/test/java/ai/rapids/cudf/ScalarTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@

package ai.rapids.cudf;

import static ai.rapids.cudf.TableTest.assertColumnsAreEqual;
import static org.junit.jupiter.api.Assertions.*;

import ai.rapids.cudf.HostColumnVector.BasicType;
import ai.rapids.cudf.HostColumnVector.DataType;
import ai.rapids.cudf.HostColumnVector.ListType;
import ai.rapids.cudf.HostColumnVector.StructData;
import ai.rapids.cudf.HostColumnVector.StructType;

import org.junit.jupiter.api.Test;

import java.math.BigDecimal;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import org.junit.jupiter.api.Test;

import static ai.rapids.cudf.TableTest.assertColumnsAreEqual;
import static org.junit.jupiter.api.Assertions.*;

public class ScalarTest extends CudfTestBase {
@Test
Expand Down Expand Up @@ -405,4 +405,48 @@ public void testStruct() {
}
}
}

@Test
public void testRepeatString() {
// Invalid scalar.
try (Scalar nullString = Scalar.fromString(null)) {
Scalar result = nullString.repeatString(5);
assertFalse(result.isValid());
}

// Empty string.
try (Scalar emptyString = Scalar.fromString("")) {
Scalar result = emptyString.repeatString(5);
assertTrue(result.isValid());
assertEquals("", result.getJavaString());
}

// Negative repeatTimes.
try (Scalar s = Scalar.fromString("Hello World");
Scalar result = s.repeatString(-100)) {
assertTrue(result.isValid());
assertEquals("", result.getJavaString());
}

// Zero repeatTimes.
try (Scalar s = Scalar.fromString("Hello World");
Scalar result = s.repeatString(0)) {
assertTrue(result.isValid());
assertEquals("", result.getJavaString());
}

// Trivial input, output is copied exactly from input.
try (Scalar s = Scalar.fromString("Hello World");
Scalar result = s.repeatString(1)) {
assertTrue(result.isValid());
assertEquals(s.getJavaString(), result.getJavaString());
}

// Trivial input.
try (Scalar s = Scalar.fromString("abcxyz-");
Scalar result = s.repeatString(3)) {
assertTrue(result.isValid());
assertEquals("abcxyz-abcxyz-abcxyz-", result.getJavaString());
}
}
}