From accd7f4c4f1a6de4a96bafb0684373c60eba656c Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Thu, 6 May 2021 12:19:44 +0800 Subject: [PATCH 1/4] provide Java Api for concat_list_by_rows Signed-off-by: sperlingxx --- .../java/ai/rapids/cudf/ColumnVector.java | 41 ++++++ .../main/java/ai/rapids/cudf/ColumnView.java | 12 +- java/src/main/native/src/ColumnViewJni.cpp | 23 +++ .../java/ai/rapids/cudf/ColumnVectorTest.java | 133 ++++++++++++++++++ 4 files changed, 208 insertions(+), 1 deletion(-) diff --git a/java/src/main/java/ai/rapids/cudf/ColumnVector.java b/java/src/main/java/ai/rapids/cudf/ColumnVector.java index fcdb5d44ad3..9d1bdc2e21c 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnVector.java @@ -527,6 +527,47 @@ public static ColumnVector stringConcatenate(Scalar separator, Scalar narep, Col return new ColumnVector(stringConcatenation(column_views, separator.getScalarHandle(), narep.getScalarHandle())); } + /** + * Concatenate columns of lists horizontally (row by row), combining a corresponding row + * from each column into a single list row of a new column. Null list element of input columns + * will be ignored (skipped) during the concatenation. + * + * @param columns array of columns containing lists, must be more than 2 columns + * @return A new java column vector containing the concatenated lists. + */ + public static ColumnVector listConcatenateByRow(ColumnView[] columns) { + return listConcatenateByRow(columns, false); + } + + /** + * Concatenate columns of lists horizontally (row by row), combining a corresponding row + * from each column into a single list row of a new column. + * + * @param columns array of columns containing lists, must be more than 2 columns + * @param ignoreNull whether to ignore null list element of input columns: If true, null list + * will be ignored from concatenation; Otherwise, any concatenation involving + * a null list element will result in a null list + * @return A new java column vector containing the concatenated lists. + */ + public static ColumnVector listConcatenateByRow(ColumnView[] columns, boolean ignoreNull) { + assert columns.length >= 2 : "listConcatenateByRow requires at least 2 columns"; + long size = columns[0].getRowCount(); + long[] columnViews = new long[columns.length]; + DType childType = columns[0].getChildColumnView(0).getType(); + assert !childType.isNestedType() : "listConcatenateByRow only supports lists with depth 1"; + + for(int i = 0; i < columns.length; i++) { + assert columns[i] != null : "Column vectors passed may not be null"; + assert columns[i].getType().equals(DType.LIST) : "All columns must be of type list"; + assert columns[i].getRowCount() == size : "Row count mismatch"; + assert childType.equals(columns[i].getChildColumnView(0).getType()) : "Element type mismatch"; + + columnViews[i] = columns[i].getNativeView(); + } + + return new ColumnVector(listConcatenationByRow(columnViews, ignoreNull)); + } + /** * Create a new vector containing the MD5 hash of each row in the table. * diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 51f89ea1114..73fbc73757e 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -2837,7 +2837,7 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat /** * Native method to concatenate columns of strings together, combining a row from - * each colunm into a single string. + * each column into a single string. * @param columnViews array of longs holding the native handles of the column_views to combine. * @param separator string scalar inserted between each string being merged, may not be null. * @param narep string scalar indicating null behavior. If set to null and any string in the row is null @@ -2849,6 +2849,16 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat */ protected static native long stringConcatenation(long[] columnViews, long separator, long narep); + + /** + * Native method to concatenate columns of lists together, combining a row from + * each column into a single list. + * @param columnViews array of longs holding the native handles of the column_views to combine. + * @return native handle of the resulting cudf column, used to construct the Java column + * by the listConcatenateByRow method. + */ + protected static native long listConcatenationByRow(long[] columnViews, boolean ignoreNull); + /** * Native method for map lookup over a column of List> * @param columnView the column view handle of the map diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index c9bafa5abee..b1139318c4e 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -1049,6 +1050,28 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringConcatenation( CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_listConcatenationByRow(JNIEnv *env, + jobject j_object, + jlongArray column_handles, + jboolean ignore_null) { + JNI_NULL_CHECK(env, column_handles, "array of column handles is null", 0); + try { + cudf::jni::auto_set_device(env); + auto null_policy = ignore_null ? cudf::lists::concatenate_null_policy::IGNORE + : cudf::lists::concatenate_null_policy::NULLIFY_OUTPUT_ROW; + cudf::jni::native_jpointerArray n_cudf_columns(env, column_handles); + std::vector column_views; + std::transform(n_cudf_columns.data(), n_cudf_columns.data() + n_cudf_columns.size(), + std::back_inserter(column_views), + [](auto const &p_column) { return *p_column; }); + + std::unique_ptr result = + cudf::lists::concatenate_rows(cudf::table_view(column_views), null_policy); + return reinterpret_cast(result.release()); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_binaryOpVV(JNIEnv *env, jclass, jlong lhs_view, jlong rhs_view, jint int_op, jint out_dtype, diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index a30d276d954..eba39e5ceec 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -2058,6 +2058,139 @@ void testStringConcatSeparators() { } } + @Test + void testListConcatByRow() { + try (ColumnVector cv1 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), + Arrays.asList(0), Arrays.asList(1, 2, 3), null, Arrays.asList(), Arrays.asList()); + ColumnVector cv2 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), + Arrays.asList(1, 2, 3), Arrays.asList((Integer) null), Arrays.asList(10, 12), Arrays.asList(100, 200, 300), + Arrays.asList()); + ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv1, cv2}); + ColumnVector expect = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), + Arrays.asList(0, 1, 2, 3), Arrays.asList(1, 2, 3, null), null, Arrays.asList(100, 200, 300), + Arrays.asList())) { + assertColumnsAreEqual(expect, result); + } + + try (ColumnVector cv1 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("AAA", "BBB"), Arrays.asList("aaa"), Arrays.asList("111"), Arrays.asList("X"), + Arrays.asList()); + ColumnVector cv2 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList(), Arrays.asList("bbb", "ccc"), null, Arrays.asList("Y", null), + Arrays.asList()); + ColumnVector cv3 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("CCC"), Arrays.asList(), Arrays.asList("222", "333"), Arrays.asList("Z"), + Arrays.asList()); + ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv1, cv2, cv3}); + ColumnVector expect = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("AAA", "BBB", "CCC"), Arrays.asList("aaa", "bbb", "ccc"), null, + Arrays.asList("X", "Y", null, "Z"), Arrays.asList())) { + assertColumnsAreEqual(expect, result); + } + + try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.FLOAT64)), + Arrays.asList(1.23, 0.0, Double.NaN), Arrays.asList(), null, Arrays.asList(-1.23e10, null)); + ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv, cv, cv}); + ColumnVector expect = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.FLOAT64)), + Arrays.asList(1.23, 0.0, Double.NaN, 1.23, 0.0, Double.NaN, 1.23, 0.0, Double.NaN), + Arrays.asList(), null, Arrays.asList(-1.23e10, null, -1.23e10, null, -1.23e10, null))) { + assertColumnsAreEqual(expect, result); + } + + assertThrows(AssertionError.class, () -> { + try (ColumnVector cv = ColumnVector.fromInts(1, 2, 3); + ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv, cv})) { + } + }, "All columns must be of type list"); + + assertThrows(AssertionError.class, () -> { + try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), Arrays.asList(1, 2, 3)); + ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv})) { + } + }, "listConcatenateByRow requires at least 2 columns"); + + assertThrows(AssertionError.class, () -> { + try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32))), Arrays.asList(Arrays.asList(1))); + ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv, cv})) { + } + }, "listConcatenateByRow only supports lists with depth 1"); + + assertThrows(AssertionError.class, () -> { + try (ColumnVector cv1 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), Arrays.asList(1, 2, 3)); + ColumnVector cv2 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), Arrays.asList(1, 2), Arrays.asList(3)); + ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv1, cv2})) { + } + }, "Row count mismatch"); + + assertThrows(AssertionError.class, () -> { + try (ColumnVector cv1 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), Arrays.asList(1, 2, 3)); + ColumnVector cv2 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT64)), Arrays.asList(1L)); + ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv1, cv2})) { + } + }, "Element type mismatch"); + } + + @Test + void testListConcatByRowIgnoreNull() { + try (ColumnVector cv1 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), + Arrays.asList((Integer) null), Arrays.asList(1, 2, 3), null, Arrays.asList(), null); + ColumnVector cv2 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), + Arrays.asList(1, 2, 3), null, Arrays.asList(10, 12), Arrays.asList(100, 200, 300), null); + ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv1, cv2}, true); + ColumnVector expect = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), + Arrays.asList(null, 1, 2, 3), Arrays.asList(1, 2, 3), Arrays.asList(10, 12), + Arrays.asList(100, 200, 300), null)) { + assertColumnsAreEqual(expect, result); + } + + try (ColumnVector cv1 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("AAA", "BBB"), Arrays.asList("aaa"), Arrays.asList("111"), null, null); + ColumnVector cv2 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + null, Arrays.asList("bbb", "ccc"), null, Arrays.asList("Y", null), null); + ColumnVector cv3 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("CCC"), Arrays.asList(), Arrays.asList("222", "333"), null, null); + ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv1, cv2, cv3}, true); + ColumnVector expect = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("AAA", "BBB", "CCC"), Arrays.asList("aaa", "bbb", "ccc"), + Arrays.asList("111", "222", "333"), Arrays.asList("Y", null), null)) { + assertColumnsAreEqual(expect, result); + } + + try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.FLOAT64)), + Arrays.asList(1.23, 0.0, Double.NaN), Arrays.asList(), null, Arrays.asList(-1.23e10, null)); + ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv, cv, cv}, true); + ColumnVector expect = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.FLOAT64)), + Arrays.asList(1.23, 0.0, Double.NaN, 1.23, 0.0, Double.NaN, 1.23, 0.0, Double.NaN), + Arrays.asList(), null, Arrays.asList(-1.23e10, null, -1.23e10, null, -1.23e10, null))) { + assertColumnsAreEqual(expect, result); + } + } + @Test void testPrefixSum() { try (ColumnVector v1 = ColumnVector.fromLongs(1, 2, 3, 5, 8, 10); From 9c53378135fb696a95b4f643136d5a0ae6069773 Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Thu, 6 May 2021 17:49:40 +0800 Subject: [PATCH 2/4] some refinement Signed-off-by: sperlingxx --- .../java/ai/rapids/cudf/ColumnVector.java | 9 ++++--- .../java/ai/rapids/cudf/ColumnVectorTest.java | 26 +++++++++---------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/ColumnVector.java b/java/src/main/java/ai/rapids/cudf/ColumnVector.java index 9d1bdc2e21c..1e988166377 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnVector.java @@ -535,21 +535,22 @@ public static ColumnVector stringConcatenate(Scalar separator, Scalar narep, Col * @param columns array of columns containing lists, must be more than 2 columns * @return A new java column vector containing the concatenated lists. */ - public static ColumnVector listConcatenateByRow(ColumnView[] columns) { - return listConcatenateByRow(columns, false); + public static ColumnVector listConcatenateByRow(ColumnView... columns) { + return listConcatenateByRow(false, columns); } /** * Concatenate columns of lists horizontally (row by row), combining a corresponding row * from each column into a single list row of a new column. * - * @param columns array of columns containing lists, must be more than 2 columns * @param ignoreNull whether to ignore null list element of input columns: If true, null list * will be ignored from concatenation; Otherwise, any concatenation involving * a null list element will result in a null list + * @param columns array of columns containing lists, must be more than 2 columns * @return A new java column vector containing the concatenated lists. */ - public static ColumnVector listConcatenateByRow(ColumnView[] columns, boolean ignoreNull) { + public static ColumnVector listConcatenateByRow(boolean ignoreNull, ColumnView... columns) { + assert columns != null : "input columns should not be null"; assert columns.length >= 2 : "listConcatenateByRow requires at least 2 columns"; long size = columns[0].getRowCount(); long[] columnViews = new long[columns.length]; diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index eba39e5ceec..8fcafe5b5de 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -2067,7 +2067,7 @@ void testListConcatByRow() { new HostColumnVector.BasicType(true, DType.INT32)), Arrays.asList(1, 2, 3), Arrays.asList((Integer) null), Arrays.asList(10, 12), Arrays.asList(100, 200, 300), Arrays.asList()); - ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv1, cv2}); + ColumnVector result = ColumnVector.listConcatenateByRow(cv1, cv2); ColumnVector expect = ColumnVector.fromLists(new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.INT32)), Arrays.asList(0, 1, 2, 3), Arrays.asList(1, 2, 3, null), null, Arrays.asList(100, 200, 300), @@ -2081,24 +2081,24 @@ void testListConcatByRow() { Arrays.asList()); ColumnVector cv2 = ColumnVector.fromLists(new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.STRING)), - Arrays.asList(), Arrays.asList("bbb", "ccc"), null, Arrays.asList("Y", null), + Arrays.asList(), Arrays.asList("bbb", "ccc"), null, Arrays.asList((String) null), Arrays.asList()); ColumnVector cv3 = ColumnVector.fromLists(new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.STRING)), Arrays.asList("CCC"), Arrays.asList(), Arrays.asList("222", "333"), Arrays.asList("Z"), Arrays.asList()); - ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv1, cv2, cv3}); + ColumnVector result = ColumnVector.listConcatenateByRow(cv1, cv2, cv3); ColumnVector expect = ColumnVector.fromLists(new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.STRING)), Arrays.asList("AAA", "BBB", "CCC"), Arrays.asList("aaa", "bbb", "ccc"), null, - Arrays.asList("X", "Y", null, "Z"), Arrays.asList())) { + Arrays.asList("X", null, "Z"), Arrays.asList())) { assertColumnsAreEqual(expect, result); } try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.FLOAT64)), Arrays.asList(1.23, 0.0, Double.NaN), Arrays.asList(), null, Arrays.asList(-1.23e10, null)); - ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv, cv, cv}); + ColumnVector result = ColumnVector.listConcatenateByRow(cv, cv, cv); ColumnVector expect = ColumnVector.fromLists(new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.FLOAT64)), Arrays.asList(1.23, 0.0, Double.NaN, 1.23, 0.0, Double.NaN, 1.23, 0.0, Double.NaN), @@ -2108,14 +2108,14 @@ void testListConcatByRow() { assertThrows(AssertionError.class, () -> { try (ColumnVector cv = ColumnVector.fromInts(1, 2, 3); - ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv, cv})) { + ColumnVector result = ColumnVector.listConcatenateByRow(cv, cv)) { } }, "All columns must be of type list"); assertThrows(AssertionError.class, () -> { try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.INT32)), Arrays.asList(1, 2, 3)); - ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv})) { + ColumnVector result = ColumnVector.listConcatenateByRow(cv)) { } }, "listConcatenateByRow requires at least 2 columns"); @@ -2123,7 +2123,7 @@ void testListConcatByRow() { try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.INT32))), Arrays.asList(Arrays.asList(1))); - ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv, cv})) { + ColumnVector result = ColumnVector.listConcatenateByRow(cv, cv)) { } }, "listConcatenateByRow only supports lists with depth 1"); @@ -2132,7 +2132,7 @@ void testListConcatByRow() { new HostColumnVector.BasicType(true, DType.INT32)), Arrays.asList(1, 2, 3)); ColumnVector cv2 = ColumnVector.fromLists(new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.INT32)), Arrays.asList(1, 2), Arrays.asList(3)); - ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv1, cv2})) { + ColumnVector result = ColumnVector.listConcatenateByRow(cv1, cv2)) { } }, "Row count mismatch"); @@ -2141,7 +2141,7 @@ void testListConcatByRow() { new HostColumnVector.BasicType(true, DType.INT32)), Arrays.asList(1, 2, 3)); ColumnVector cv2 = ColumnVector.fromLists(new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.INT64)), Arrays.asList(1L)); - ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv1, cv2})) { + ColumnVector result = ColumnVector.listConcatenateByRow(cv1, cv2)) { } }, "Element type mismatch"); } @@ -2154,7 +2154,7 @@ void testListConcatByRowIgnoreNull() { ColumnVector cv2 = ColumnVector.fromLists(new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.INT32)), Arrays.asList(1, 2, 3), null, Arrays.asList(10, 12), Arrays.asList(100, 200, 300), null); - ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv1, cv2}, true); + ColumnVector result = ColumnVector.listConcatenateByRow(true, cv1, cv2); ColumnVector expect = ColumnVector.fromLists(new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.INT32)), Arrays.asList(null, 1, 2, 3), Arrays.asList(1, 2, 3), Arrays.asList(10, 12), @@ -2171,7 +2171,7 @@ void testListConcatByRowIgnoreNull() { ColumnVector cv3 = ColumnVector.fromLists(new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.STRING)), Arrays.asList("CCC"), Arrays.asList(), Arrays.asList("222", "333"), null, null); - ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv1, cv2, cv3}, true); + ColumnVector result = ColumnVector.listConcatenateByRow(true, cv1, cv2, cv3); ColumnVector expect = ColumnVector.fromLists(new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.STRING)), Arrays.asList("AAA", "BBB", "CCC"), Arrays.asList("aaa", "bbb", "ccc"), @@ -2182,7 +2182,7 @@ void testListConcatByRowIgnoreNull() { try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.FLOAT64)), Arrays.asList(1.23, 0.0, Double.NaN), Arrays.asList(), null, Arrays.asList(-1.23e10, null)); - ColumnVector result = ColumnVector.listConcatenateByRow(new ColumnVector[]{cv, cv, cv}, true); + ColumnVector result = ColumnVector.listConcatenateByRow(true, cv, cv, cv); ColumnVector expect = ColumnVector.fromLists(new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.FLOAT64)), Arrays.asList(1.23, 0.0, Double.NaN, 1.23, 0.0, Double.NaN, 1.23, 0.0, Double.NaN), From 9af566a1c6257bdc6bfd4625f1b9067cda811162 Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Fri, 7 May 2021 09:36:32 +0800 Subject: [PATCH 3/4] address comments Signed-off-by: sperlingxx --- .../java/ai/rapids/cudf/ColumnVector.java | 50 +++++---- .../main/java/ai/rapids/cudf/ColumnView.java | 24 ----- java/src/main/native/src/ColumnVectorJni.cpp | 52 +++++++++ java/src/main/native/src/ColumnViewJni.cpp | 47 -------- .../java/ai/rapids/cudf/ColumnVectorTest.java | 100 +++++++++--------- 5 files changed, 133 insertions(+), 140 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/ColumnVector.java b/java/src/main/java/ai/rapids/cudf/ColumnVector.java index 1e988166377..7c358c90eb5 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnVector.java @@ -491,7 +491,7 @@ public static ColumnVector concatenate(ColumnView... columns) { * @param columns array of columns containing strings. * @return A new java column vector containing the concatenated strings. */ - public static ColumnVector stringConcatenate(ColumnView[] columns) { + public static ColumnVector stringConcatenate(ColumnView... columns) { try (Scalar emptyString = Scalar.fromString(""); Scalar nullString = Scalar.fromString(null)) { return stringConcatenate(emptyString, nullString, columns); @@ -508,19 +508,16 @@ public static ColumnVector stringConcatenate(ColumnView[] columns) { * @param columns array of columns containing strings, must be more than 2 columns * @return A new java column vector containing the concatenated strings. */ - public static ColumnVector stringConcatenate(Scalar separator, Scalar narep, ColumnView[] columns) { + public static ColumnVector stringConcatenate(Scalar separator, Scalar narep, ColumnView... columns) { assert columns.length >= 2 : ".stringConcatenate() operation requires at least 2 columns"; assert separator != null : "separator scalar provided may not be null"; assert separator.getType().equals(DType.STRING) : "separator scalar must be a string scalar"; assert narep != null : "narep scalar provided may not be null"; assert narep.getType().equals(DType.STRING) : "narep scalar must be a string scalar"; - long size = columns[0].getRowCount(); - long[] column_views = new long[columns.length]; + long[] column_views = new long[columns.length]; for(int i = 0; i < columns.length; i++) { assert columns[i] != null : "Column vectors passed may not be null"; - assert columns[i].getType().equals(DType.STRING) : "All columns must be of type string for .cat() operation"; - assert columns[i].getRowCount() == size : "Row count mismatch, all columns must have the same number of rows"; column_views[i] = columns[i].getNativeView(); } @@ -529,8 +526,8 @@ public static ColumnVector stringConcatenate(Scalar separator, Scalar narep, Col /** * Concatenate columns of lists horizontally (row by row), combining a corresponding row - * from each column into a single list row of a new column. Null list element of input columns - * will be ignored (skipped) during the concatenation. + * from each column into a single list row of a new column. + * NOTICE: Any concatenation involving a null list element will result in a null list. * * @param columns array of columns containing lists, must be more than 2 columns * @return A new java column vector containing the concatenated lists. @@ -552,21 +549,13 @@ public static ColumnVector listConcatenateByRow(ColumnView... columns) { public static ColumnVector listConcatenateByRow(boolean ignoreNull, ColumnView... columns) { assert columns != null : "input columns should not be null"; assert columns.length >= 2 : "listConcatenateByRow requires at least 2 columns"; - long size = columns[0].getRowCount(); - long[] columnViews = new long[columns.length]; - DType childType = columns[0].getChildColumnView(0).getType(); - assert !childType.isNestedType() : "listConcatenateByRow only supports lists with depth 1"; + long[] columnViews = new long[columns.length]; for(int i = 0; i < columns.length; i++) { - assert columns[i] != null : "Column vectors passed may not be null"; - assert columns[i].getType().equals(DType.LIST) : "All columns must be of type list"; - assert columns[i].getRowCount() == size : "Row count mismatch"; - assert childType.equals(columns[i].getChildColumnView(0).getType()) : "Element type mismatch"; - columnViews[i] = columns[i].getNativeView(); } - return new ColumnVector(listConcatenationByRow(columnViews, ignoreNull)); + return new ColumnVector(concatListByRow(columnViews, ignoreNull)); } /** @@ -711,6 +700,31 @@ private static native long makeList(long[] handles, long typeHandle, int scale, private static native long concatenate(long[] viewHandles) throws CudfException; + /** + * Native method to concatenate columns of lists horizontally (row by row), combining a row + * from each column into a single list. + * + * @param columnViews array of longs holding the native handles of the column_views to combine. + * @return native handle of the resulting cudf column, used to construct the Java column + * by the listConcatenateByRow method. + */ + private static native long concatListByRow(long[] columnViews, boolean ignoreNull); + + /** + * Native method to concatenate columns of strings together, combining a row from + * each column into a single string. + * + * @param columnViews array of longs holding the native handles of the column_views to combine. + * @param separator string scalar inserted between each string being merged, may not be null. + * @param narep string scalar indicating null behavior. If set to null and any string in the row is null + * the resulting string will be null. If not null, null values in any column will be + * replaced by the specified string. The underlying value in the string scalar may be null, + * but the object passed in may not. + * @return native handle of the resulting cudf column, used to construct the Java column + * by the stringConcatenate method. + */ + private static native long stringConcatenation(long[] columnViews, long separator, long narep); + /** * Native method to hash each row of the given table. Hashing function dispatched on the * native side using the hashId. diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 73fbc73757e..0c7073d95de 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -2835,30 +2835,6 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat private static native long urlEncode(long cudfViewHandle); - /** - * Native method to concatenate columns of strings together, combining a row from - * each column into a single string. - * @param columnViews array of longs holding the native handles of the column_views to combine. - * @param separator string scalar inserted between each string being merged, may not be null. - * @param narep string scalar indicating null behavior. If set to null and any string in the row is null - * the resulting string will be null. If not null, null values in any column will be - * replaced by the specified string. The underlying value in the string scalar may be null, - * but the object passed in may not. - * @return native handle of the resulting cudf column, used to construct the Java column - * by the stringConcatenate method. - */ - protected static native long stringConcatenation(long[] columnViews, long separator, long narep); - - - /** - * Native method to concatenate columns of lists together, combining a row from - * each column into a single list. - * @param columnViews array of longs holding the native handles of the column_views to combine. - * @return native handle of the resulting cudf column, used to construct the Java column - * by the listConcatenateByRow method. - */ - protected static native long listConcatenationByRow(long[] columnViews, boolean ignoreNull); - /** * Native method for map lookup over a column of List> * @param columnView the column view handle of the map diff --git a/java/src/main/native/src/ColumnVectorJni.cpp b/java/src/main/native/src/ColumnVectorJni.cpp index 858dcf6fd5d..f9efba673c6 100644 --- a/java/src/main/native/src/ColumnVectorJni.cpp +++ b/java/src/main/native/src/ColumnVectorJni.cpp @@ -23,9 +23,11 @@ #include #include #include +#include #include #include #include +#include #include #include "cudf_jni_apis.hpp" @@ -172,6 +174,56 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_fromArrow(JNIEnv *env, CATCH_STD(env, 0); } + +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_stringConcatenation(JNIEnv *env, jclass, + jlongArray column_handles, + jlong separator, + jlong narep) { + JNI_NULL_CHECK(env, column_handles, "array of column handles is null", 0); + JNI_NULL_CHECK(env, separator, "separator string scalar object is null", 0); + JNI_NULL_CHECK(env, narep, "narep string scalar object is null", 0); + try { + cudf::jni::auto_set_device(env); + const auto& separator_scalar = *reinterpret_cast(separator); + const auto& narep_scalar = *reinterpret_cast(narep); + + cudf::jni::native_jpointerArray n_cudf_columns(env, column_handles); + std::vector column_views; + std::transform(n_cudf_columns.data(), + n_cudf_columns.data() + n_cudf_columns.size(), + std::back_inserter(column_views), + [](auto const &p_column) { return *p_column; }); + + std::unique_ptr result = + cudf::strings::concatenate(cudf::table_view(column_views), separator_scalar, narep_scalar); + return reinterpret_cast(result.release()); + } + CATCH_STD(env, 0); +} + +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_concatListByRow(JNIEnv *env, jclass, + jlongArray column_handles, + jboolean ignore_null) { + JNI_NULL_CHECK(env, column_handles, "array of column handles is null", 0); + try { + cudf::jni::auto_set_device(env); + auto null_policy = ignore_null ? cudf::lists::concatenate_null_policy::IGNORE + : cudf::lists::concatenate_null_policy::NULLIFY_OUTPUT_ROW; + + cudf::jni::native_jpointerArray n_cudf_columns(env, column_handles); + std::vector column_views; + std::transform(n_cudf_columns.data(), + n_cudf_columns.data() + n_cudf_columns.size(), + std::back_inserter(column_views), + [](auto const &p_column) { return *p_column; }); + + std::unique_ptr result = + cudf::lists::concatenate_rows(cudf::table_view(column_views), null_policy); + return reinterpret_cast(result.release()); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeList(JNIEnv *env, jobject j_object, jlongArray handles, jlong j_type, diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index b1139318c4e..09212ada78b 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -23,7 +23,6 @@ #include #include #include -#include #include #include #include @@ -39,7 +38,6 @@ #include #include #include -#include #include #include #include @@ -1027,51 +1025,6 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_containsRe(JNIEnv *env, j CATCH_STD(env, 0); } -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringConcatenation( - JNIEnv *env, jobject j_object, jlongArray column_handles, jlong separator, jlong narep) { - JNI_NULL_CHECK(env, column_handles, "array of column handles is null", 0); - JNI_NULL_CHECK(env, separator, "separator string scalar object is null", 0); - JNI_NULL_CHECK(env, narep, "narep string scalar object is null", 0); - try { - cudf::jni::auto_set_device(env); - cudf::string_scalar *separator_scalar = reinterpret_cast(separator); - cudf::string_scalar *narep_scalar = reinterpret_cast(narep); - cudf::jni::native_jpointerArray n_cudf_columns(env, column_handles); - std::vector column_views; - std::transform(n_cudf_columns.data(), n_cudf_columns.data() + n_cudf_columns.size(), - std::back_inserter(column_views), - [](auto const &p_column) { return *p_column; }); - cudf::table_view *string_columns = new cudf::table_view(column_views); - - std::unique_ptr result = - cudf::strings::concatenate(*string_columns, *separator_scalar, *narep_scalar); - return reinterpret_cast(result.release()); - } - CATCH_STD(env, 0); -} - -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_listConcatenationByRow(JNIEnv *env, - jobject j_object, - jlongArray column_handles, - jboolean ignore_null) { - JNI_NULL_CHECK(env, column_handles, "array of column handles is null", 0); - try { - cudf::jni::auto_set_device(env); - auto null_policy = ignore_null ? cudf::lists::concatenate_null_policy::IGNORE - : cudf::lists::concatenate_null_policy::NULLIFY_OUTPUT_ROW; - cudf::jni::native_jpointerArray n_cudf_columns(env, column_handles); - std::vector column_views; - std::transform(n_cudf_columns.data(), n_cudf_columns.data() + n_cudf_columns.size(), - std::back_inserter(column_views), - [](auto const &p_column) { return *p_column; }); - - std::unique_ptr result = - cudf::lists::concatenate_rows(cudf::table_view(column_views), null_policy); - return reinterpret_cast(result.release()); - } - CATCH_STD(env, 0); -} - JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_binaryOpVV(JNIEnv *env, jclass, jlong lhs_view, jlong rhs_view, jint int_op, jint out_dtype, diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 8fcafe5b5de..af9e78f49fc 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -1965,81 +1965,80 @@ void testStringManipulationWithNulls() { @Test void testStringConcat() { try (ColumnVector v = ColumnVector.fromStrings("a", "B", "cd", "\u0480\u0481", "E\tf", - "g\nH", "IJ\"\u0100\u0101\u0500\u0501", - "kl m", "Nop1", "\\qRs2", "3tuV\'", - "wX4Yz", "\ud720\ud721"); + "g\nH", "IJ\"\u0100\u0101\u0500\u0501", + "kl m", "Nop1", "\\qRs2", "3tuV\'", + "wX4Yz", "\ud720\ud721"); ColumnVector e_concat = ColumnVector.fromStrings("aa", "BB", "cdcd", - "\u0480\u0481\u0480\u0481", "E\tfE\tf", "g\nHg\nH", - "IJ\"\u0100\u0101\u0500\u0501IJ\"\u0100\u0101\u0500\u0501", - "kl mkl m", "Nop1Nop1", "\\qRs2\\qRs2", "3tuV\'3tuV\'", - "wX4YzwX4Yz", "\ud720\ud721\ud720\ud721"); + "\u0480\u0481\u0480\u0481", "E\tfE\tf", "g\nHg\nH", + "IJ\"\u0100\u0101\u0500\u0501IJ\"\u0100\u0101\u0500\u0501", + "kl mkl m", "Nop1Nop1", "\\qRs2\\qRs2", "3tuV\'3tuV\'", + "wX4YzwX4Yz", "\ud720\ud721\ud720\ud721"); Scalar emptyString = Scalar.fromString(""); - ColumnVector concat = ColumnVector.stringConcatenate(emptyString, emptyString, - new ColumnVector[]{v, v})) { + ColumnVector concat = ColumnVector.stringConcatenate(emptyString, emptyString, v, v)) { assertColumnsAreEqual(concat, e_concat); } - assertThrows(AssertionError.class, () -> { + assertThrows(CudfException.class, () -> { try (ColumnVector sv = ColumnVector.fromStrings("B", "cd", "\u0480\u0481", "E\tf"); ColumnVector cv = ColumnVector.fromInts(1, 2, 3, 4); Scalar emptyString = Scalar.fromString(""); - ColumnVector concat = ColumnVector.stringConcatenate(emptyString, emptyString, - new ColumnVector[]{sv, cv})) {} + ColumnVector concat = ColumnVector.stringConcatenate(emptyString, emptyString, sv, cv)) { + } }); - assertThrows(AssertionError.class, () -> { + assertThrows(CudfException.class, () -> { try (ColumnVector sv1 = ColumnVector.fromStrings("a", "B", "cd"); ColumnVector sv2 = ColumnVector.fromStrings("a", "B"); Scalar emptyString = Scalar.fromString(""); ColumnVector concat = ColumnVector.stringConcatenate(emptyString, emptyString, - new ColumnVector[]{sv1, sv2})) {} + new ColumnVector[]{sv1, sv2})) { + } }); assertThrows(AssertionError.class, () -> { try (ColumnVector sv = ColumnVector.fromStrings("a", "B", "cd"); Scalar emptyString = Scalar.fromString(""); - ColumnVector concat = ColumnVector.stringConcatenate(emptyString, emptyString, - new ColumnVector[]{sv})) {} + ColumnVector concat = ColumnVector.stringConcatenate(emptyString, emptyString, sv)) { + } }); assertThrows(CudfException.class, () -> { try (ColumnVector sv = ColumnVector.fromStrings("a", "B", "cd"); Scalar emptyString = Scalar.fromString(""); Scalar nullString = Scalar.fromString(null); - ColumnVector concat = ColumnVector.stringConcatenate(nullString, emptyString, - new ColumnVector[]{sv, sv})) {} + ColumnVector concat = ColumnVector.stringConcatenate(nullString, emptyString, sv, sv)) { + } }); assertThrows(AssertionError.class, () -> { try (ColumnVector sv = ColumnVector.fromStrings("a", "B", "cd"); Scalar emptyString = Scalar.fromString(""); - ColumnVector concat = ColumnVector.stringConcatenate(null, emptyString, - new ColumnVector[]{sv, sv})) {} + ColumnVector concat = ColumnVector.stringConcatenate(null, emptyString, sv, sv)) { + } }); assertThrows(AssertionError.class, () -> { try (ColumnVector sv = ColumnVector.fromStrings("a", "B", "cd"); Scalar emptyString = Scalar.fromString(""); - ColumnVector concat = ColumnVector.stringConcatenate(emptyString, null, - new ColumnVector[]{sv, sv})) {} + ColumnVector concat = ColumnVector.stringConcatenate(emptyString, null, sv, sv)) { + } }); assertThrows(AssertionError.class, () -> { try (ColumnVector sv = ColumnVector.fromStrings("a", "B", "cd"); Scalar emptyString = Scalar.fromString(""); - ColumnVector concat = ColumnVector.stringConcatenate(emptyString, emptyString, - new ColumnVector[]{sv, null})) {} + ColumnVector concat = ColumnVector.stringConcatenate(emptyString, emptyString, sv, null)) { + } }); } @Test void testStringConcatWithNulls() { try (ColumnVector v = ColumnVector.fromStrings("a", "B", "cd", "\u0480\u0481", "E\tf", - "g\nH", "IJ\"\u0100\u0101\u0500\u0501", - "kl m", "Nop1", "\\qRs2", null, - "3tuV\'", "wX4Yz", "\ud720\ud721"); + "g\nH", "IJ\"\u0100\u0101\u0500\u0501", + "kl m", "Nop1", "\\qRs2", null, + "3tuV\'", "wX4Yz", "\ud720\ud721"); ColumnVector e_concat = ColumnVector.fromStrings("aa", "BB", "cdcd", - "\u0480\u0481\u0480\u0481", "E\tfE\tf", "g\nHg\nH", - "IJ\"\u0100\u0101\u0500\u0501IJ\"\u0100\u0101\u0500\u0501", - "kl mkl m", "Nop1Nop1", "\\qRs2\\qRs2", "NULLNULL", - "3tuV\'3tuV\'", "wX4YzwX4Yz", "\ud720\ud721\ud720\ud721"); - Scalar emptyString = Scalar.fromString(""); - Scalar nullSubstitute = Scalar.fromString("NULL"); - ColumnVector concat = ColumnVector.stringConcatenate(emptyString, nullSubstitute, - new ColumnVector[]{v, v})) { + "\u0480\u0481\u0480\u0481", "E\tfE\tf", "g\nHg\nH", + "IJ\"\u0100\u0101\u0500\u0501IJ\"\u0100\u0101\u0500\u0501", + "kl mkl m", "Nop1Nop1", "\\qRs2\\qRs2", "NULLNULL", + "3tuV\'3tuV\'", "wX4YzwX4Yz", "\ud720\ud721\ud720\ud721"); + Scalar emptyString = Scalar.fromString(""); + Scalar nullSubstitute = Scalar.fromString("NULL"); + ColumnVector concat = ColumnVector.stringConcatenate(emptyString, nullSubstitute, v, v)) { assertColumnsAreEqual(concat, e_concat); } } @@ -2049,11 +2048,10 @@ void testStringConcatSeparators() { try (ColumnVector sv1 = ColumnVector.fromStrings("a", "B", "cd", "\u0480\u0481", "E\tf", null, null, "\\G\u0100"); ColumnVector sv2 = ColumnVector.fromStrings("b", "C", "\u0500\u0501", "x\nYz", null, null, "", null); ColumnVector e_concat = ColumnVector.fromStrings("aA1\t\ud721b", "BA1\t\ud721C", "cdA1\t\ud721\u0500\u0501", - "\u0480\u0481A1\t\ud721x\nYz", null, null, null, null); + "\u0480\u0481A1\t\ud721x\nYz", null, null, null, null); Scalar separatorString = Scalar.fromString("A1\t\ud721"); Scalar nullString = Scalar.fromString(null); - ColumnVector concat = ColumnVector.stringConcatenate(separatorString, nullString, - new ColumnVector[]{sv1, sv2})) { + ColumnVector concat = ColumnVector.stringConcatenate(separatorString, nullString, sv1, sv2)) { assertColumnsAreEqual(concat, e_concat); } } @@ -2106,44 +2104,44 @@ void testListConcatByRow() { assertColumnsAreEqual(expect, result); } - assertThrows(AssertionError.class, () -> { - try (ColumnVector cv = ColumnVector.fromInts(1, 2, 3); - ColumnVector result = ColumnVector.listConcatenateByRow(cv, cv)) { - } - }, "All columns must be of type list"); - assertThrows(AssertionError.class, () -> { try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.INT32)), Arrays.asList(1, 2, 3)); ColumnVector result = ColumnVector.listConcatenateByRow(cv)) { } - }, "listConcatenateByRow requires at least 2 columns"); + }); - assertThrows(AssertionError.class, () -> { + assertThrows(CudfException.class, () -> { + try (ColumnVector cv = ColumnVector.fromInts(1, 2, 3); + ColumnVector result = ColumnVector.listConcatenateByRow(cv, cv)) { + } + }); + + assertThrows(CudfException.class, () -> { try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.INT32))), Arrays.asList(Arrays.asList(1))); ColumnVector result = ColumnVector.listConcatenateByRow(cv, cv)) { } - }, "listConcatenateByRow only supports lists with depth 1"); + }); - assertThrows(AssertionError.class, () -> { + assertThrows(CudfException.class, () -> { try (ColumnVector cv1 = ColumnVector.fromLists(new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.INT32)), Arrays.asList(1, 2, 3)); ColumnVector cv2 = ColumnVector.fromLists(new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.INT32)), Arrays.asList(1, 2), Arrays.asList(3)); ColumnVector result = ColumnVector.listConcatenateByRow(cv1, cv2)) { } - }, "Row count mismatch"); + }); - assertThrows(AssertionError.class, () -> { + assertThrows(CudfException.class, () -> { try (ColumnVector cv1 = ColumnVector.fromLists(new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.INT32)), Arrays.asList(1, 2, 3)); ColumnVector cv2 = ColumnVector.fromLists(new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.INT64)), Arrays.asList(1L)); ColumnVector result = ColumnVector.listConcatenateByRow(cv1, cv2)) { } - }, "Element type mismatch"); + }); } @Test From c0097279fcdefa459e1073a772950c4be2424f08 Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Fri, 7 May 2021 10:36:57 +0800 Subject: [PATCH 4/4] some fix --- .../main/java/ai/rapids/cudf/ColumnVector.java | 4 ++-- .../java/ai/rapids/cudf/ColumnVectorTest.java | 18 +++++++++--------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/ColumnVector.java b/java/src/main/java/ai/rapids/cudf/ColumnVector.java index 7c358c90eb5..7756d7d7ce4 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnVector.java @@ -491,7 +491,7 @@ public static ColumnVector concatenate(ColumnView... columns) { * @param columns array of columns containing strings. * @return A new java column vector containing the concatenated strings. */ - public static ColumnVector stringConcatenate(ColumnView... columns) { + public static ColumnVector stringConcatenate(ColumnView[] columns) { try (Scalar emptyString = Scalar.fromString(""); Scalar nullString = Scalar.fromString(null)) { return stringConcatenate(emptyString, nullString, columns); @@ -508,7 +508,7 @@ public static ColumnVector stringConcatenate(ColumnView... columns) { * @param columns array of columns containing strings, must be more than 2 columns * @return A new java column vector containing the concatenated strings. */ - public static ColumnVector stringConcatenate(Scalar separator, Scalar narep, ColumnView... columns) { + public static ColumnVector stringConcatenate(Scalar separator, Scalar narep, ColumnView[] columns) { assert columns.length >= 2 : ".stringConcatenate() operation requires at least 2 columns"; assert separator != null : "separator scalar provided may not be null"; assert separator.getType().equals(DType.STRING) : "separator scalar must be a string scalar"; diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index af9e78f49fc..d4f4e0aac6d 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -1974,14 +1974,14 @@ void testStringConcat() { "kl mkl m", "Nop1Nop1", "\\qRs2\\qRs2", "3tuV\'3tuV\'", "wX4YzwX4Yz", "\ud720\ud721\ud720\ud721"); Scalar emptyString = Scalar.fromString(""); - ColumnVector concat = ColumnVector.stringConcatenate(emptyString, emptyString, v, v)) { + ColumnVector concat = ColumnVector.stringConcatenate(emptyString, emptyString, new ColumnView[]{v, v})) { assertColumnsAreEqual(concat, e_concat); } assertThrows(CudfException.class, () -> { try (ColumnVector sv = ColumnVector.fromStrings("B", "cd", "\u0480\u0481", "E\tf"); ColumnVector cv = ColumnVector.fromInts(1, 2, 3, 4); Scalar emptyString = Scalar.fromString(""); - ColumnVector concat = ColumnVector.stringConcatenate(emptyString, emptyString, sv, cv)) { + ColumnVector concat = ColumnVector.stringConcatenate(emptyString, emptyString, new ColumnView[]{sv, cv})) { } }); assertThrows(CudfException.class, () -> { @@ -1995,32 +1995,32 @@ void testStringConcat() { assertThrows(AssertionError.class, () -> { try (ColumnVector sv = ColumnVector.fromStrings("a", "B", "cd"); Scalar emptyString = Scalar.fromString(""); - ColumnVector concat = ColumnVector.stringConcatenate(emptyString, emptyString, sv)) { + ColumnVector concat = ColumnVector.stringConcatenate(emptyString, emptyString, new ColumnView[]{sv})) { } }); assertThrows(CudfException.class, () -> { try (ColumnVector sv = ColumnVector.fromStrings("a", "B", "cd"); Scalar emptyString = Scalar.fromString(""); Scalar nullString = Scalar.fromString(null); - ColumnVector concat = ColumnVector.stringConcatenate(nullString, emptyString, sv, sv)) { + ColumnVector concat = ColumnVector.stringConcatenate(nullString, emptyString, new ColumnView[]{sv, sv})) { } }); assertThrows(AssertionError.class, () -> { try (ColumnVector sv = ColumnVector.fromStrings("a", "B", "cd"); Scalar emptyString = Scalar.fromString(""); - ColumnVector concat = ColumnVector.stringConcatenate(null, emptyString, sv, sv)) { + ColumnVector concat = ColumnVector.stringConcatenate(null, emptyString, new ColumnView[]{sv, sv})) { } }); assertThrows(AssertionError.class, () -> { try (ColumnVector sv = ColumnVector.fromStrings("a", "B", "cd"); Scalar emptyString = Scalar.fromString(""); - ColumnVector concat = ColumnVector.stringConcatenate(emptyString, null, sv, sv)) { + ColumnVector concat = ColumnVector.stringConcatenate(emptyString, null, new ColumnView[]{sv, sv})) { } }); assertThrows(AssertionError.class, () -> { try (ColumnVector sv = ColumnVector.fromStrings("a", "B", "cd"); Scalar emptyString = Scalar.fromString(""); - ColumnVector concat = ColumnVector.stringConcatenate(emptyString, emptyString, sv, null)) { + ColumnVector concat = ColumnVector.stringConcatenate(emptyString, emptyString, new ColumnView[]{sv, null})) { } }); } @@ -2038,7 +2038,7 @@ void testStringConcatWithNulls() { "3tuV\'3tuV\'", "wX4YzwX4Yz", "\ud720\ud721\ud720\ud721"); Scalar emptyString = Scalar.fromString(""); Scalar nullSubstitute = Scalar.fromString("NULL"); - ColumnVector concat = ColumnVector.stringConcatenate(emptyString, nullSubstitute, v, v)) { + ColumnVector concat = ColumnVector.stringConcatenate(emptyString, nullSubstitute, new ColumnView[]{v, v})) { assertColumnsAreEqual(concat, e_concat); } } @@ -2051,7 +2051,7 @@ void testStringConcatSeparators() { "\u0480\u0481A1\t\ud721x\nYz", null, null, null, null); Scalar separatorString = Scalar.fromString("A1\t\ud721"); Scalar nullString = Scalar.fromString(null); - ColumnVector concat = ColumnVector.stringConcatenate(separatorString, nullString, sv1, sv2)) { + ColumnVector concat = ColumnVector.stringConcatenate(separatorString, nullString, new ColumnView[]{sv1, sv2})) { assertColumnsAreEqual(concat, e_concat); } }