Skip to content

Commit

Permalink
Add JNI for strings::repeat_strings (#8491)
Browse files Browse the repository at this point in the history
This PR adds JNI for `strings::repeat_strings`.

Authors:
  - Nghia Truong (https://github.com/ttnghia)

Approvers:
  - Robert (Bobby) Evans (https://github.com/revans2)
  - Gera Shegalov (https://github.com/gerashegalov)

URL: #8491
  • Loading branch information
ttnghia authored Jun 11, 2021
1 parent 00a9398 commit 709adb1
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 7 deletions.
36 changes: 36 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -2194,6 +2194,25 @@ public final ColumnVector stringConcatenateListElements(Scalar separator,
emptyStringOutputIfEmptyList));
}

/**
* Given a strings column, each string in the given column is repeated a number of times
* specified by the <code>repeatTimes</code> parameter. If the parameter has a non-positive value,
* all the rows of the output strings column will be an empty string. Any null row will result
* in a null row regardless of the value of <code>repeatTimes</code>.
*
* Note that this function cannot handle the cases when the size of the output column exceeds
* the maximum value that can be indexed by int type (i.e., {@link Integer#MAX_VALUE}).
* In such situations, the output result is undefined.
*
* @param repeatTimes The number of times each input string is copied to the output.
* @return A new java column vector containing repeated strings.
*/
public final ColumnVector repeatStrings(int repeatTimes) {
assert type.equals(DType.STRING) : "column type must be a String";

return new ColumnVector(repeatStrings(getNativeView(), repeatTimes));
}

/**
* Apply a JSONPath string to all rows in an input strings column.
*
Expand Down Expand Up @@ -2870,6 +2889,23 @@ private static native long stringConcatenationListElements(long listColumnHandle
boolean separateNulls,
boolean emptyStringOutputIfEmptyList);

/**
* Native method to repeat each string in the given strings column a number of times
* specified by the <code>repeatTimes</code> parameter. If the parameter has a non-positive value,
* all the rows of the output strings column will be an empty string. Any null row will result
* in a null row regardless of the value of <code>repeatTimes</code>.
*
* Note that this function cannot handle the cases when the size of the output column exceeds
* the maximum value that can be indexed by int type (i.e., {@link Integer#MAX_VALUE}).
* In such situations, the output result is undefined.
*
* @param viewHandle long holding the native handle of the column containing strings to repeat.
* @param repeatTimes The number of times each input string is copied to the output.
* @return native handle of the resulting cudf column containing repeated strings.
*/
private static native long repeatStrings(long viewHandle, int repeatTimes);


private static native long getJSONObject(long viewHandle, long scalarHandle) throws CudfException;

/**
Expand Down
16 changes: 15 additions & 1 deletion java/src/main/java/ai/rapids/cudf/Scalar.java
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ private static ColumnVector buildNullColumnVector(HostColumnVector.DataType host
private static native long makeDecimal64Scalar(long value, int scale, boolean isValid);
private static native long makeListScalar(long viewHandle, boolean isValid);
private static native long makeStructScalar(long[] viewHandles, boolean isValid);

private static native long repeatString(long scalarHandle, int repeatTimes);

Scalar(DType type, long scalarHandle) {
this.type = type;
Expand Down Expand Up @@ -865,6 +865,20 @@ public String toString() {
return sb.toString();
}


/**
* Repeat the given string scalar a number of times specified by the <code>repeatTimes</code>
* parameter. If that parameter has a non-positive value, an empty (valid) string scalar will be
* returned. An invalid input scalar will always result in an invalid output scalar regardless
* of the value of <code>repeatTimes</code>.
*
* @param repeatTimes The number of times the input string is copied to the output.
* @return The resulting scalar containing repeated result of the current string.
*/
public Scalar repeatString(int repeatTimes) {
return new Scalar(DType.STRING, repeatString(getScalarHandle(), repeatTimes));
}

/**
* Holds the off-heap state of the scalar so it can be cleaned up, even if it is leaked.
*/
Expand Down
14 changes: 14 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
#include <cudf/strings/extract.hpp>
#include <cudf/strings/find.hpp>
#include <cudf/strings/padding.hpp>
#include <cudf/strings/repeat_strings.hpp>
#include <cudf/strings/replace.hpp>
#include <cudf/strings/replace_re.hpp>
#include <cudf/strings/split/split.hpp>
Expand Down Expand Up @@ -1962,4 +1963,17 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringConcatenationListEl
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_repeatStrings(JNIEnv *env, jclass,
jlong column_handle,
jint repeat_times) {
JNI_NULL_CHECK(env, column_handle, "column handle is null", 0);
try {
cudf::jni::auto_set_device(env);
auto const cv = *reinterpret_cast<cudf::column_view *>(column_handle);
auto const strs_col = cudf::strings_column_view(cv);
return reinterpret_cast<jlong>(cudf::strings::repeat_strings(strs_col, repeat_times).release());
}
CATCH_STD(env, 0);
}

} // extern "C"
12 changes: 12 additions & 0 deletions java/src/main/native/src/ScalarJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <cudf/binaryop.hpp>
#include <cudf/fixed_point/fixed_point.hpp>
#include <cudf/scalar/scalar_factories.hpp>
#include <cudf/strings/repeat_strings.hpp>

#include "cudf_jni_apis.hpp"
#include "dtype_utils.hpp"
Expand Down Expand Up @@ -512,4 +513,15 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_makeStructScalar(JNIEnv *env,
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_repeatString(JNIEnv *env, jclass, jlong handle,
jint repeat_times) {
JNI_NULL_CHECK(env, handle, "scalar handle is null", 0)
try {
cudf::jni::auto_set_device(env);
auto const str = *reinterpret_cast<cudf::string_scalar *>(handle);
return reinterpret_cast<jlong>(cudf::strings::repeat_strings(str, repeat_times).release());
}
CATCH_STD(env, 0);
}

} // extern "C"
36 changes: 36 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2621,6 +2621,42 @@ void testStringConcatWsSingleListColEmptyArrayReturnEmpty() {
}
}

@Test
void testRepeatStrings() {
// Empty strings column.
try (ColumnVector sv = ColumnVector.fromStrings("", "", "");
ColumnVector result = sv.repeatStrings(1)) {
assertColumnsAreEqual(sv, result);
}

// Zero repeatTimes.
try (ColumnVector sv = ColumnVector.fromStrings("abc", "xyz", "123");
ColumnVector result = sv.repeatStrings(0);
ColumnVector expected = ColumnVector.fromStrings("", "", "")) {
assertColumnsAreEqual(expected, result);
}

// Negative repeatTimes.
try (ColumnVector sv = ColumnVector.fromStrings("abc", "xyz", "123");
ColumnVector result = sv.repeatStrings(-1);
ColumnVector expected = ColumnVector.fromStrings("", "", "")) {
assertColumnsAreEqual(expected, result);
}

// Strings column containing both null and empty, output is copied exactly from input.
try (ColumnVector sv = ColumnVector.fromStrings("abc", "", null, "123", null);
ColumnVector result = sv.repeatStrings(1)) {
assertColumnsAreEqual(sv, result);
}

// Strings column containing both null and empty.
try (ColumnVector sv = ColumnVector.fromStrings("abc", "", null, "123", null);
ColumnVector result = sv.repeatStrings( 2);
ColumnVector expected = ColumnVector.fromStrings("abcabc", "", null, "123123", null)) {
assertColumnsAreEqual(expected, result);
}
}

@Test
void testListConcatByRow() {
try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true,
Expand Down
56 changes: 50 additions & 6 deletions java/src/test/java/ai/rapids/cudf/ScalarTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@

package ai.rapids.cudf;

import static ai.rapids.cudf.TableTest.assertColumnsAreEqual;
import static org.junit.jupiter.api.Assertions.*;

import ai.rapids.cudf.HostColumnVector.BasicType;
import ai.rapids.cudf.HostColumnVector.DataType;
import ai.rapids.cudf.HostColumnVector.ListType;
import ai.rapids.cudf.HostColumnVector.StructData;
import ai.rapids.cudf.HostColumnVector.StructType;

import org.junit.jupiter.api.Test;

import java.math.BigDecimal;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import org.junit.jupiter.api.Test;

import static ai.rapids.cudf.TableTest.assertColumnsAreEqual;
import static org.junit.jupiter.api.Assertions.*;

public class ScalarTest extends CudfTestBase {
@Test
Expand Down Expand Up @@ -405,4 +405,48 @@ public void testStruct() {
}
}
}

@Test
public void testRepeatString() {
// Invalid scalar.
try (Scalar nullString = Scalar.fromString(null)) {
Scalar result = nullString.repeatString(5);
assertFalse(result.isValid());
}

// Empty string.
try (Scalar emptyString = Scalar.fromString("")) {
Scalar result = emptyString.repeatString(5);
assertTrue(result.isValid());
assertEquals("", result.getJavaString());
}

// Negative repeatTimes.
try (Scalar s = Scalar.fromString("Hello World");
Scalar result = s.repeatString(-100)) {
assertTrue(result.isValid());
assertEquals("", result.getJavaString());
}

// Zero repeatTimes.
try (Scalar s = Scalar.fromString("Hello World");
Scalar result = s.repeatString(0)) {
assertTrue(result.isValid());
assertEquals("", result.getJavaString());
}

// Trivial input, output is copied exactly from input.
try (Scalar s = Scalar.fromString("Hello World");
Scalar result = s.repeatString(1)) {
assertTrue(result.isValid());
assertEquals(s.getJavaString(), result.getJavaString());
}

// Trivial input.
try (Scalar s = Scalar.fromString("abcxyz-");
Scalar result = s.repeatString(3)) {
assertTrue(result.isValid());
assertEquals("abcxyz-abcxyz-abcxyz-", result.getJavaString());
}
}
}

0 comments on commit 709adb1

Please sign in to comment.