diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java
index db42a8c9ca2..f8b6aac863d 100644
--- a/java/src/main/java/ai/rapids/cudf/ColumnView.java
+++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java
@@ -2194,6 +2194,25 @@ public final ColumnVector stringConcatenateListElements(Scalar separator,
emptyStringOutputIfEmptyList));
}
+ /**
+ * Given a strings column, each string in the given column is repeated a number of times
+ * specified by the repeatTimes
parameter. If the parameter has a non-positive value,
+ * all the rows of the output strings column will be an empty string. Any null row will result
+ * in a null row regardless of the value of repeatTimes
.
+ *
+ * Note that this function cannot handle the cases when the size of the output column exceeds
+ * the maximum value that can be indexed by int type (i.e., {@link Integer#MAX_VALUE}).
+ * In such situations, the output result is undefined.
+ *
+ * @param repeatTimes The number of times each input string is copied to the output.
+ * @return A new java column vector containing repeated strings.
+ */
+ public final ColumnVector repeatStrings(int repeatTimes) {
+ assert type.equals(DType.STRING) : "column type must be a String";
+
+ return new ColumnVector(repeatStrings(getNativeView(), repeatTimes));
+ }
+
/**
* Apply a JSONPath string to all rows in an input strings column.
*
@@ -2870,6 +2889,23 @@ private static native long stringConcatenationListElements(long listColumnHandle
boolean separateNulls,
boolean emptyStringOutputIfEmptyList);
+ /**
+ * Native method to repeat each string in the given strings column a number of times
+ * specified by the repeatTimes
parameter. If the parameter has a non-positive value,
+ * all the rows of the output strings column will be an empty string. Any null row will result
+ * in a null row regardless of the value of repeatTimes
.
+ *
+ * Note that this function cannot handle the cases when the size of the output column exceeds
+ * the maximum value that can be indexed by int type (i.e., {@link Integer#MAX_VALUE}).
+ * In such situations, the output result is undefined.
+ *
+ * @param viewHandle long holding the native handle of the column containing strings to repeat.
+ * @param repeatTimes The number of times each input string is copied to the output.
+ * @return native handle of the resulting cudf column containing repeated strings.
+ */
+ private static native long repeatStrings(long viewHandle, int repeatTimes);
+
+
private static native long getJSONObject(long viewHandle, long scalarHandle) throws CudfException;
/**
diff --git a/java/src/main/java/ai/rapids/cudf/Scalar.java b/java/src/main/java/ai/rapids/cudf/Scalar.java
index 925cc89a51a..631f091005a 100644
--- a/java/src/main/java/ai/rapids/cudf/Scalar.java
+++ b/java/src/main/java/ai/rapids/cudf/Scalar.java
@@ -495,7 +495,7 @@ private static ColumnVector buildNullColumnVector(HostColumnVector.DataType host
private static native long makeDecimal64Scalar(long value, int scale, boolean isValid);
private static native long makeListScalar(long viewHandle, boolean isValid);
private static native long makeStructScalar(long[] viewHandles, boolean isValid);
-
+ private static native long repeatString(long scalarHandle, int repeatTimes);
Scalar(DType type, long scalarHandle) {
this.type = type;
@@ -865,6 +865,20 @@ public String toString() {
return sb.toString();
}
+
+ /**
+ * Repeat the given string scalar a number of times specified by the repeatTimes
+ * parameter. If that parameter has a non-positive value, an empty (valid) string scalar will be
+ * returned. An invalid input scalar will always result in an invalid output scalar regardless
+ * of the value of repeatTimes
.
+ *
+ * @param repeatTimes The number of times the input string is copied to the output.
+ * @return The resulting scalar containing repeated result of the current string.
+ */
+ public Scalar repeatString(int repeatTimes) {
+ return new Scalar(DType.STRING, repeatString(getScalarHandle(), repeatTimes));
+ }
+
/**
* Holds the off-heap state of the scalar so it can be cleaned up, even if it is leaked.
*/
diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp
index d41ed97b4cb..866e1e96188 100644
--- a/java/src/main/native/src/ColumnViewJni.cpp
+++ b/java/src/main/native/src/ColumnViewJni.cpp
@@ -50,6 +50,7 @@
#include
#include
#include
+#include
#include
#include
#include
@@ -1962,4 +1963,17 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringConcatenationListEl
CATCH_STD(env, 0);
}
+JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_repeatStrings(JNIEnv *env, jclass,
+ jlong column_handle,
+ jint repeat_times) {
+ JNI_NULL_CHECK(env, column_handle, "column handle is null", 0);
+ try {
+ cudf::jni::auto_set_device(env);
+ auto const cv = *reinterpret_cast(column_handle);
+ auto const strs_col = cudf::strings_column_view(cv);
+ return reinterpret_cast(cudf::strings::repeat_strings(strs_col, repeat_times).release());
+ }
+ CATCH_STD(env, 0);
+}
+
} // extern "C"
diff --git a/java/src/main/native/src/ScalarJni.cpp b/java/src/main/native/src/ScalarJni.cpp
index e0fad0a60c4..f58290395e3 100644
--- a/java/src/main/native/src/ScalarJni.cpp
+++ b/java/src/main/native/src/ScalarJni.cpp
@@ -17,6 +17,7 @@
#include
#include
#include
+#include
#include "cudf_jni_apis.hpp"
#include "dtype_utils.hpp"
@@ -512,4 +513,15 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_makeStructScalar(JNIEnv *env,
CATCH_STD(env, 0);
}
+JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_repeatString(JNIEnv *env, jclass, jlong handle,
+ jint repeat_times) {
+ JNI_NULL_CHECK(env, handle, "scalar handle is null", 0)
+ try {
+ cudf::jni::auto_set_device(env);
+ auto const str = *reinterpret_cast(handle);
+ return reinterpret_cast(cudf::strings::repeat_strings(str, repeat_times).release());
+ }
+ CATCH_STD(env, 0);
+}
+
} // extern "C"
diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
index 16570483f17..e3ca880d587 100644
--- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
+++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
@@ -2621,6 +2621,42 @@ void testStringConcatWsSingleListColEmptyArrayReturnEmpty() {
}
}
+ @Test
+ void testRepeatStrings() {
+ // Empty strings column.
+ try (ColumnVector sv = ColumnVector.fromStrings("", "", "");
+ ColumnVector result = sv.repeatStrings(1)) {
+ assertColumnsAreEqual(sv, result);
+ }
+
+ // Zero repeatTimes.
+ try (ColumnVector sv = ColumnVector.fromStrings("abc", "xyz", "123");
+ ColumnVector result = sv.repeatStrings(0);
+ ColumnVector expected = ColumnVector.fromStrings("", "", "")) {
+ assertColumnsAreEqual(expected, result);
+ }
+
+ // Negative repeatTimes.
+ try (ColumnVector sv = ColumnVector.fromStrings("abc", "xyz", "123");
+ ColumnVector result = sv.repeatStrings(-1);
+ ColumnVector expected = ColumnVector.fromStrings("", "", "")) {
+ assertColumnsAreEqual(expected, result);
+ }
+
+ // Strings column containing both null and empty, output is copied exactly from input.
+ try (ColumnVector sv = ColumnVector.fromStrings("abc", "", null, "123", null);
+ ColumnVector result = sv.repeatStrings(1)) {
+ assertColumnsAreEqual(sv, result);
+ }
+
+ // Strings column containing both null and empty.
+ try (ColumnVector sv = ColumnVector.fromStrings("abc", "", null, "123", null);
+ ColumnVector result = sv.repeatStrings( 2);
+ ColumnVector expected = ColumnVector.fromStrings("abcabc", "", null, "123123", null)) {
+ assertColumnsAreEqual(expected, result);
+ }
+ }
+
@Test
void testListConcatByRow() {
try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true,
diff --git a/java/src/test/java/ai/rapids/cudf/ScalarTest.java b/java/src/test/java/ai/rapids/cudf/ScalarTest.java
index e317392196e..37fd2ecb714 100644
--- a/java/src/test/java/ai/rapids/cudf/ScalarTest.java
+++ b/java/src/test/java/ai/rapids/cudf/ScalarTest.java
@@ -18,18 +18,18 @@
package ai.rapids.cudf;
-import static ai.rapids.cudf.TableTest.assertColumnsAreEqual;
-import static org.junit.jupiter.api.Assertions.*;
-
import ai.rapids.cudf.HostColumnVector.BasicType;
-import ai.rapids.cudf.HostColumnVector.DataType;
import ai.rapids.cudf.HostColumnVector.ListType;
-import ai.rapids.cudf.HostColumnVector.StructData;
import ai.rapids.cudf.HostColumnVector.StructType;
+
+import org.junit.jupiter.api.Test;
+
import java.math.BigDecimal;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
-import org.junit.jupiter.api.Test;
+
+import static ai.rapids.cudf.TableTest.assertColumnsAreEqual;
+import static org.junit.jupiter.api.Assertions.*;
public class ScalarTest extends CudfTestBase {
@Test
@@ -405,4 +405,48 @@ public void testStruct() {
}
}
}
+
+ @Test
+ public void testRepeatString() {
+ // Invalid scalar.
+ try (Scalar nullString = Scalar.fromString(null)) {
+ Scalar result = nullString.repeatString(5);
+ assertFalse(result.isValid());
+ }
+
+ // Empty string.
+ try (Scalar emptyString = Scalar.fromString("")) {
+ Scalar result = emptyString.repeatString(5);
+ assertTrue(result.isValid());
+ assertEquals("", result.getJavaString());
+ }
+
+ // Negative repeatTimes.
+ try (Scalar s = Scalar.fromString("Hello World");
+ Scalar result = s.repeatString(-100)) {
+ assertTrue(result.isValid());
+ assertEquals("", result.getJavaString());
+ }
+
+ // Zero repeatTimes.
+ try (Scalar s = Scalar.fromString("Hello World");
+ Scalar result = s.repeatString(0)) {
+ assertTrue(result.isValid());
+ assertEquals("", result.getJavaString());
+ }
+
+ // Trivial input, output is copied exactly from input.
+ try (Scalar s = Scalar.fromString("Hello World");
+ Scalar result = s.repeatString(1)) {
+ assertTrue(result.isValid());
+ assertEquals(s.getJavaString(), result.getJavaString());
+ }
+
+ // Trivial input.
+ try (Scalar s = Scalar.fromString("abcxyz-");
+ Scalar result = s.repeatString(3)) {
+ assertTrue(result.isValid());
+ assertEquals("abcxyz-abcxyz-abcxyz-", result.getJavaString());
+ }
+ }
}