diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index db42a8c9ca2..f8b6aac863d 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -2194,6 +2194,25 @@ public final ColumnVector stringConcatenateListElements(Scalar separator, emptyStringOutputIfEmptyList)); } + /** + * Given a strings column, each string in the given column is repeated a number of times + * specified by the repeatTimes parameter. If the parameter has a non-positive value, + * all the rows of the output strings column will be an empty string. Any null row will result + * in a null row regardless of the value of repeatTimes. + * + * Note that this function cannot handle the cases when the size of the output column exceeds + * the maximum value that can be indexed by int type (i.e., {@link Integer#MAX_VALUE}). + * In such situations, the output result is undefined. + * + * @param repeatTimes The number of times each input string is copied to the output. + * @return A new java column vector containing repeated strings. + */ + public final ColumnVector repeatStrings(int repeatTimes) { + assert type.equals(DType.STRING) : "column type must be a String"; + + return new ColumnVector(repeatStrings(getNativeView(), repeatTimes)); + } + /** * Apply a JSONPath string to all rows in an input strings column. * @@ -2870,6 +2889,23 @@ private static native long stringConcatenationListElements(long listColumnHandle boolean separateNulls, boolean emptyStringOutputIfEmptyList); + /** + * Native method to repeat each string in the given strings column a number of times + * specified by the repeatTimes parameter. If the parameter has a non-positive value, + * all the rows of the output strings column will be an empty string. Any null row will result + * in a null row regardless of the value of repeatTimes. + * + * Note that this function cannot handle the cases when the size of the output column exceeds + * the maximum value that can be indexed by int type (i.e., {@link Integer#MAX_VALUE}). + * In such situations, the output result is undefined. + * + * @param viewHandle long holding the native handle of the column containing strings to repeat. + * @param repeatTimes The number of times each input string is copied to the output. + * @return native handle of the resulting cudf column containing repeated strings. + */ + private static native long repeatStrings(long viewHandle, int repeatTimes); + + private static native long getJSONObject(long viewHandle, long scalarHandle) throws CudfException; /** diff --git a/java/src/main/java/ai/rapids/cudf/Scalar.java b/java/src/main/java/ai/rapids/cudf/Scalar.java index 925cc89a51a..631f091005a 100644 --- a/java/src/main/java/ai/rapids/cudf/Scalar.java +++ b/java/src/main/java/ai/rapids/cudf/Scalar.java @@ -495,7 +495,7 @@ private static ColumnVector buildNullColumnVector(HostColumnVector.DataType host private static native long makeDecimal64Scalar(long value, int scale, boolean isValid); private static native long makeListScalar(long viewHandle, boolean isValid); private static native long makeStructScalar(long[] viewHandles, boolean isValid); - + private static native long repeatString(long scalarHandle, int repeatTimes); Scalar(DType type, long scalarHandle) { this.type = type; @@ -865,6 +865,20 @@ public String toString() { return sb.toString(); } + + /** + * Repeat the given string scalar a number of times specified by the repeatTimes + * parameter. If that parameter has a non-positive value, an empty (valid) string scalar will be + * returned. An invalid input scalar will always result in an invalid output scalar regardless + * of the value of repeatTimes. + * + * @param repeatTimes The number of times the input string is copied to the output. + * @return The resulting scalar containing repeated result of the current string. + */ + public Scalar repeatString(int repeatTimes) { + return new Scalar(DType.STRING, repeatString(getScalarHandle(), repeatTimes)); + } + /** * Holds the off-heap state of the scalar so it can be cleaned up, even if it is leaked. */ diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index d41ed97b4cb..866e1e96188 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -50,6 +50,7 @@ #include #include #include +#include #include #include #include @@ -1962,4 +1963,17 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringConcatenationListEl CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_repeatStrings(JNIEnv *env, jclass, + jlong column_handle, + jint repeat_times) { + JNI_NULL_CHECK(env, column_handle, "column handle is null", 0); + try { + cudf::jni::auto_set_device(env); + auto const cv = *reinterpret_cast(column_handle); + auto const strs_col = cudf::strings_column_view(cv); + return reinterpret_cast(cudf::strings::repeat_strings(strs_col, repeat_times).release()); + } + CATCH_STD(env, 0); +} + } // extern "C" diff --git a/java/src/main/native/src/ScalarJni.cpp b/java/src/main/native/src/ScalarJni.cpp index e0fad0a60c4..f58290395e3 100644 --- a/java/src/main/native/src/ScalarJni.cpp +++ b/java/src/main/native/src/ScalarJni.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include "cudf_jni_apis.hpp" #include "dtype_utils.hpp" @@ -512,4 +513,15 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_makeStructScalar(JNIEnv *env, CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_repeatString(JNIEnv *env, jclass, jlong handle, + jint repeat_times) { + JNI_NULL_CHECK(env, handle, "scalar handle is null", 0) + try { + cudf::jni::auto_set_device(env); + auto const str = *reinterpret_cast(handle); + return reinterpret_cast(cudf::strings::repeat_strings(str, repeat_times).release()); + } + CATCH_STD(env, 0); +} + } // extern "C" diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 16570483f17..e3ca880d587 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -2621,6 +2621,42 @@ void testStringConcatWsSingleListColEmptyArrayReturnEmpty() { } } + @Test + void testRepeatStrings() { + // Empty strings column. + try (ColumnVector sv = ColumnVector.fromStrings("", "", ""); + ColumnVector result = sv.repeatStrings(1)) { + assertColumnsAreEqual(sv, result); + } + + // Zero repeatTimes. + try (ColumnVector sv = ColumnVector.fromStrings("abc", "xyz", "123"); + ColumnVector result = sv.repeatStrings(0); + ColumnVector expected = ColumnVector.fromStrings("", "", "")) { + assertColumnsAreEqual(expected, result); + } + + // Negative repeatTimes. + try (ColumnVector sv = ColumnVector.fromStrings("abc", "xyz", "123"); + ColumnVector result = sv.repeatStrings(-1); + ColumnVector expected = ColumnVector.fromStrings("", "", "")) { + assertColumnsAreEqual(expected, result); + } + + // Strings column containing both null and empty, output is copied exactly from input. + try (ColumnVector sv = ColumnVector.fromStrings("abc", "", null, "123", null); + ColumnVector result = sv.repeatStrings(1)) { + assertColumnsAreEqual(sv, result); + } + + // Strings column containing both null and empty. + try (ColumnVector sv = ColumnVector.fromStrings("abc", "", null, "123", null); + ColumnVector result = sv.repeatStrings( 2); + ColumnVector expected = ColumnVector.fromStrings("abcabc", "", null, "123123", null)) { + assertColumnsAreEqual(expected, result); + } + } + @Test void testListConcatByRow() { try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, diff --git a/java/src/test/java/ai/rapids/cudf/ScalarTest.java b/java/src/test/java/ai/rapids/cudf/ScalarTest.java index e317392196e..37fd2ecb714 100644 --- a/java/src/test/java/ai/rapids/cudf/ScalarTest.java +++ b/java/src/test/java/ai/rapids/cudf/ScalarTest.java @@ -18,18 +18,18 @@ package ai.rapids.cudf; -import static ai.rapids.cudf.TableTest.assertColumnsAreEqual; -import static org.junit.jupiter.api.Assertions.*; - import ai.rapids.cudf.HostColumnVector.BasicType; -import ai.rapids.cudf.HostColumnVector.DataType; import ai.rapids.cudf.HostColumnVector.ListType; -import ai.rapids.cudf.HostColumnVector.StructData; import ai.rapids.cudf.HostColumnVector.StructType; + +import org.junit.jupiter.api.Test; + import java.math.BigDecimal; import java.nio.charset.StandardCharsets; import java.util.Arrays; -import org.junit.jupiter.api.Test; + +import static ai.rapids.cudf.TableTest.assertColumnsAreEqual; +import static org.junit.jupiter.api.Assertions.*; public class ScalarTest extends CudfTestBase { @Test @@ -405,4 +405,48 @@ public void testStruct() { } } } + + @Test + public void testRepeatString() { + // Invalid scalar. + try (Scalar nullString = Scalar.fromString(null)) { + Scalar result = nullString.repeatString(5); + assertFalse(result.isValid()); + } + + // Empty string. + try (Scalar emptyString = Scalar.fromString("")) { + Scalar result = emptyString.repeatString(5); + assertTrue(result.isValid()); + assertEquals("", result.getJavaString()); + } + + // Negative repeatTimes. + try (Scalar s = Scalar.fromString("Hello World"); + Scalar result = s.repeatString(-100)) { + assertTrue(result.isValid()); + assertEquals("", result.getJavaString()); + } + + // Zero repeatTimes. + try (Scalar s = Scalar.fromString("Hello World"); + Scalar result = s.repeatString(0)) { + assertTrue(result.isValid()); + assertEquals("", result.getJavaString()); + } + + // Trivial input, output is copied exactly from input. + try (Scalar s = Scalar.fromString("Hello World"); + Scalar result = s.repeatString(1)) { + assertTrue(result.isValid()); + assertEquals(s.getJavaString(), result.getJavaString()); + } + + // Trivial input. + try (Scalar s = Scalar.fromString("abcxyz-"); + Scalar result = s.repeatString(3)) { + assertTrue(result.isValid()); + assertEquals("abcxyz-abcxyz-abcxyz-", result.getJavaString()); + } + } }