diff --git a/java/src/main/java/ai/rapids/cudf/ColumnVector.java b/java/src/main/java/ai/rapids/cudf/ColumnVector.java index fcdb5d44ad3..9d1bdc2e21c 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnVector.java @@ -527,6 +527,47 @@ public static ColumnVector stringConcatenate(Scalar separator, Scalar narep, Col return new ColumnVector(stringConcatenation(column_views, separator.getScalarHandle(), narep.getScalarHandle())); } + /** + * Concatenate columns of lists horizontally (row by row), combining a corresponding row + * from each column into a single list row of a new column. Null list element of input columns + * will be ignored (skipped) during the concatenation. + * + * @param columns array of columns containing lists, must be more than 2 columns + * @return A new java column vector containing the concatenated lists. + */ + public static ColumnVector listConcatenateByRow(ColumnView[] columns) { + return listConcatenateByRow(columns, false); + } + + /** + * Concatenate columns of lists horizontally (row by row), combining a corresponding row + * from each column into a single list row of a new column. + * + * @param columns array of columns containing lists, must be more than 2 columns + * @param ignoreNull whether to ignore null list element of input columns: If true, null list + * will be ignored from concatenation; Otherwise, any concatenation involving + * a null list element will result in a null list + * @return A new java column vector containing the concatenated lists. + */ + public static ColumnVector listConcatenateByRow(ColumnView[] columns, boolean ignoreNull) { + assert columns.length >= 2 : "listConcatenateByRow requires at least 2 columns"; + long size = columns[0].getRowCount(); + long[] columnViews = new long[columns.length]; + DType childType = columns[0].getChildColumnView(0).getType(); + assert !childType.isNestedType() : "listConcatenateByRow only supports lists with depth 1"; + + for(int i = 0; i < columns.length; i++) { + assert columns[i] != null : "Column vectors passed may not be null"; + assert columns[i].getType().equals(DType.LIST) : "All columns must be of type list"; + assert columns[i].getRowCount() == size : "Row count mismatch"; + assert childType.equals(columns[i].getChildColumnView(0).getType()) : "Element type mismatch"; + + columnViews[i] = columns[i].getNativeView(); + } + + return new ColumnVector(listConcatenationByRow(columnViews, ignoreNull)); + } + /** * Create a new vector containing the MD5 hash of each row in the table. * diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 51f89ea1114..73fbc73757e 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -2837,7 +2837,7 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat /** * Native method to concatenate columns of strings together, combining a row from - * each colunm into a single string. + * each column into a single string. * @param columnViews array of longs holding the native handles of the column_views to combine. * @param separator string scalar inserted between each string being merged, may not be null. * @param narep string scalar indicating null behavior. If set to null and any string in the row is null @@ -2849,6 +2849,16 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat */ protected static native long stringConcatenation(long[] columnViews, long separator, long narep); + + /** + * Native method to concatenate columns of lists together, combining a row from + * each column into a single list. + * @param columnViews array of longs holding the native handles of the column_views to combine. + * @return native handle of the resulting cudf column, used to construct the Java column + * by the listConcatenateByRow method. + */ + protected static native long listConcatenationByRow(long[] columnViews, boolean ignoreNull); + /** * Native method for map lookup over a column of List> * @param columnView the column view handle of the map diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index c9bafa5abee..b1139318c4e 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -1049,6 +1050,28 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringConcatenation( CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_listConcatenationByRow(JNIEnv *env, + jobject j_object, + jlongArray column_handles, + jboolean ignore_null) { + JNI_NULL_CHECK(env, column_handles, "array of column handles is null", 0); + try { + cudf::jni::auto_set_device(env); + auto null_policy = ignore_null ? cudf::lists::concatenate_null_policy::IGNORE + : cudf::lists::concatenate_null_policy::NULLIFY_OUTPUT_ROW; + cudf::jni::native_jpointerArray n_cudf_columns(env, column_handles); + std::vector column_views; + std::transform(n_cudf_columns.data(), n_cudf_columns.data() + n_cudf_columns.size(), + std::back_inserter(column_views), + [](auto const &p_column) { return *p_column; }); + + std::unique_ptr result = + cudf::lists::concatenate_rows(cudf::table_view(column_views), null_policy); + return reinterpret_cast(result.release()); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_binaryOpVV(JNIEnv *env, jclass, jlong lhs_view, jlong rhs_view, jint int_op, jint out_dtype, diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index a30d276d954..eba39e5ceec 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -2058,6 +2058,139 @@ void testStringConcatSeparators() { } } + @Test + void testListConcatByRow() { + try (ColumnVector cv1 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), + Arrays.asList(0), Arrays.asList(1, 2, 3), null, Arrays.asList(), Arrays.asList()); + ColumnVector cv2 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), + Arrays.asList(1, 2, 3), Arrays.asList((Integer) null), Arrays.asList(10, 12), Arrays.asList(100, 200, 300), + Arrays.asList()); + ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv1, cv2}); + ColumnVector expect = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), + Arrays.asList(0, 1, 2, 3), Arrays.asList(1, 2, 3, null), null, Arrays.asList(100, 200, 300), + Arrays.asList())) { + assertColumnsAreEqual(expect, result); + } + + try (ColumnVector cv1 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("AAA", "BBB"), Arrays.asList("aaa"), Arrays.asList("111"), Arrays.asList("X"), + Arrays.asList()); + ColumnVector cv2 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList(), Arrays.asList("bbb", "ccc"), null, Arrays.asList("Y", null), + Arrays.asList()); + ColumnVector cv3 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("CCC"), Arrays.asList(), Arrays.asList("222", "333"), Arrays.asList("Z"), + Arrays.asList()); + ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv1, cv2, cv3}); + ColumnVector expect = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("AAA", "BBB", "CCC"), Arrays.asList("aaa", "bbb", "ccc"), null, + Arrays.asList("X", "Y", null, "Z"), Arrays.asList())) { + assertColumnsAreEqual(expect, result); + } + + try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.FLOAT64)), + Arrays.asList(1.23, 0.0, Double.NaN), Arrays.asList(), null, Arrays.asList(-1.23e10, null)); + ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv, cv, cv}); + ColumnVector expect = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.FLOAT64)), + Arrays.asList(1.23, 0.0, Double.NaN, 1.23, 0.0, Double.NaN, 1.23, 0.0, Double.NaN), + Arrays.asList(), null, Arrays.asList(-1.23e10, null, -1.23e10, null, -1.23e10, null))) { + assertColumnsAreEqual(expect, result); + } + + assertThrows(AssertionError.class, () -> { + try (ColumnVector cv = ColumnVector.fromInts(1, 2, 3); + ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv, cv})) { + } + }, "All columns must be of type list"); + + assertThrows(AssertionError.class, () -> { + try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), Arrays.asList(1, 2, 3)); + ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv})) { + } + }, "listConcatenateByRow requires at least 2 columns"); + + assertThrows(AssertionError.class, () -> { + try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32))), Arrays.asList(Arrays.asList(1))); + ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv, cv})) { + } + }, "listConcatenateByRow only supports lists with depth 1"); + + assertThrows(AssertionError.class, () -> { + try (ColumnVector cv1 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), Arrays.asList(1, 2, 3)); + ColumnVector cv2 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), Arrays.asList(1, 2), Arrays.asList(3)); + ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv1, cv2})) { + } + }, "Row count mismatch"); + + assertThrows(AssertionError.class, () -> { + try (ColumnVector cv1 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), Arrays.asList(1, 2, 3)); + ColumnVector cv2 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT64)), Arrays.asList(1L)); + ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv1, cv2})) { + } + }, "Element type mismatch"); + } + + @Test + void testListConcatByRowIgnoreNull() { + try (ColumnVector cv1 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), + Arrays.asList((Integer) null), Arrays.asList(1, 2, 3), null, Arrays.asList(), null); + ColumnVector cv2 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), + Arrays.asList(1, 2, 3), null, Arrays.asList(10, 12), Arrays.asList(100, 200, 300), null); + ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv1, cv2}, true); + ColumnVector expect = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), + Arrays.asList(null, 1, 2, 3), Arrays.asList(1, 2, 3), Arrays.asList(10, 12), + Arrays.asList(100, 200, 300), null)) { + assertColumnsAreEqual(expect, result); + } + + try (ColumnVector cv1 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("AAA", "BBB"), Arrays.asList("aaa"), Arrays.asList("111"), null, null); + ColumnVector cv2 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + null, Arrays.asList("bbb", "ccc"), null, Arrays.asList("Y", null), null); + ColumnVector cv3 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("CCC"), Arrays.asList(), Arrays.asList("222", "333"), null, null); + ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv1, cv2, cv3}, true); + ColumnVector expect = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("AAA", "BBB", "CCC"), Arrays.asList("aaa", "bbb", "ccc"), + Arrays.asList("111", "222", "333"), Arrays.asList("Y", null), null)) { + assertColumnsAreEqual(expect, result); + } + + try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.FLOAT64)), + Arrays.asList(1.23, 0.0, Double.NaN), Arrays.asList(), null, Arrays.asList(-1.23e10, null)); + ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv, cv, cv}, true); + ColumnVector expect = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.FLOAT64)), + Arrays.asList(1.23, 0.0, Double.NaN, 1.23, 0.0, Double.NaN, 1.23, 0.0, Double.NaN), + Arrays.asList(), null, Arrays.asList(-1.23e10, null, -1.23e10, null, -1.23e10, null))) { + assertColumnsAreEqual(expect, result); + } + } + @Test void testPrefixSum() { try (ColumnVector v1 = ColumnVector.fromLongs(1, 2, 3, 5, 8, 10);