From 5f9dade58a1b438ec73d4bcd8cf48ca938e7cb65 Mon Sep 17 00:00:00 2001 From: Alfred Xu Date: Fri, 7 May 2021 22:34:10 +0800 Subject: [PATCH] Support listConcatenateByRows in Java package (#8171) Current PR is to provide Java API for `cudf::lists::concatenate_rows`, which is added in #8049. Authors: - Alfred Xu (https://github.com/sperlingxx) Approvers: - Jason Lowe (https://github.com/jlowe) - Robert (Bobby) Evans (https://github.com/revans2) URL: https://github.com/rapidsai/cudf/pull/8171 --- .../java/ai/rapids/cudf/ColumnVector.java | 64 +++++- .../main/java/ai/rapids/cudf/ColumnView.java | 14 -- java/src/main/native/src/ColumnVectorJni.cpp | 52 +++++ java/src/main/native/src/ColumnViewJni.cpp | 24 -- .../java/ai/rapids/cudf/ColumnVectorTest.java | 207 ++++++++++++++---- 5 files changed, 281 insertions(+), 80 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/ColumnVector.java b/java/src/main/java/ai/rapids/cudf/ColumnVector.java index fcdb5d44ad3..7756d7d7ce4 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnVector.java @@ -514,19 +514,50 @@ public static ColumnVector stringConcatenate(Scalar separator, Scalar narep, Col assert separator.getType().equals(DType.STRING) : "separator scalar must be a string scalar"; assert narep != null : "narep scalar provided may not be null"; assert narep.getType().equals(DType.STRING) : "narep scalar must be a string scalar"; - long size = columns[0].getRowCount(); - long[] column_views = new long[columns.length]; + long[] column_views = new long[columns.length]; for(int i = 0; i < columns.length; i++) { assert columns[i] != null : "Column vectors passed may not be null"; - assert columns[i].getType().equals(DType.STRING) : "All columns must be of type string for .cat() operation"; - assert columns[i].getRowCount() == size : "Row count mismatch, all columns must have the same number of rows"; column_views[i] = columns[i].getNativeView(); } return new ColumnVector(stringConcatenation(column_views, separator.getScalarHandle(), narep.getScalarHandle())); } + /** + * Concatenate columns of lists horizontally (row by row), combining a corresponding row + * from each column into a single list row of a new column. + * NOTICE: Any concatenation involving a null list element will result in a null list. + * + * @param columns array of columns containing lists, must be more than 2 columns + * @return A new java column vector containing the concatenated lists. + */ + public static ColumnVector listConcatenateByRow(ColumnView... columns) { + return listConcatenateByRow(false, columns); + } + + /** + * Concatenate columns of lists horizontally (row by row), combining a corresponding row + * from each column into a single list row of a new column. + * + * @param ignoreNull whether to ignore null list element of input columns: If true, null list + * will be ignored from concatenation; Otherwise, any concatenation involving + * a null list element will result in a null list + * @param columns array of columns containing lists, must be more than 2 columns + * @return A new java column vector containing the concatenated lists. + */ + public static ColumnVector listConcatenateByRow(boolean ignoreNull, ColumnView... columns) { + assert columns != null : "input columns should not be null"; + assert columns.length >= 2 : "listConcatenateByRow requires at least 2 columns"; + + long[] columnViews = new long[columns.length]; + for(int i = 0; i < columns.length; i++) { + columnViews[i] = columns[i].getNativeView(); + } + + return new ColumnVector(concatListByRow(columnViews, ignoreNull)); + } + /** * Create a new vector containing the MD5 hash of each row in the table. * @@ -669,6 +700,31 @@ private static native long makeList(long[] handles, long typeHandle, int scale, private static native long concatenate(long[] viewHandles) throws CudfException; + /** + * Native method to concatenate columns of lists horizontally (row by row), combining a row + * from each column into a single list. + * + * @param columnViews array of longs holding the native handles of the column_views to combine. + * @return native handle of the resulting cudf column, used to construct the Java column + * by the listConcatenateByRow method. + */ + private static native long concatListByRow(long[] columnViews, boolean ignoreNull); + + /** + * Native method to concatenate columns of strings together, combining a row from + * each column into a single string. + * + * @param columnViews array of longs holding the native handles of the column_views to combine. + * @param separator string scalar inserted between each string being merged, may not be null. + * @param narep string scalar indicating null behavior. If set to null and any string in the row is null + * the resulting string will be null. If not null, null values in any column will be + * replaced by the specified string. The underlying value in the string scalar may be null, + * but the object passed in may not. + * @return native handle of the resulting cudf column, used to construct the Java column + * by the stringConcatenate method. + */ + private static native long stringConcatenation(long[] columnViews, long separator, long narep); + /** * Native method to hash each row of the given table. Hashing function dispatched on the * native side using the hashId. diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 51f89ea1114..0c7073d95de 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -2835,20 +2835,6 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat private static native long urlEncode(long cudfViewHandle); - /** - * Native method to concatenate columns of strings together, combining a row from - * each colunm into a single string. - * @param columnViews array of longs holding the native handles of the column_views to combine. - * @param separator string scalar inserted between each string being merged, may not be null. - * @param narep string scalar indicating null behavior. If set to null and any string in the row is null - * the resulting string will be null. If not null, null values in any column will be - * replaced by the specified string. The underlying value in the string scalar may be null, - * but the object passed in may not. - * @return native handle of the resulting cudf column, used to construct the Java column - * by the stringConcatenate method. - */ - protected static native long stringConcatenation(long[] columnViews, long separator, long narep); - /** * Native method for map lookup over a column of List> * @param columnView the column view handle of the map diff --git a/java/src/main/native/src/ColumnVectorJni.cpp b/java/src/main/native/src/ColumnVectorJni.cpp index 858dcf6fd5d..f9efba673c6 100644 --- a/java/src/main/native/src/ColumnVectorJni.cpp +++ b/java/src/main/native/src/ColumnVectorJni.cpp @@ -23,9 +23,11 @@ #include #include #include +#include #include #include #include +#include #include #include "cudf_jni_apis.hpp" @@ -172,6 +174,56 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_fromArrow(JNIEnv *env, CATCH_STD(env, 0); } + +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_stringConcatenation(JNIEnv *env, jclass, + jlongArray column_handles, + jlong separator, + jlong narep) { + JNI_NULL_CHECK(env, column_handles, "array of column handles is null", 0); + JNI_NULL_CHECK(env, separator, "separator string scalar object is null", 0); + JNI_NULL_CHECK(env, narep, "narep string scalar object is null", 0); + try { + cudf::jni::auto_set_device(env); + const auto& separator_scalar = *reinterpret_cast(separator); + const auto& narep_scalar = *reinterpret_cast(narep); + + cudf::jni::native_jpointerArray n_cudf_columns(env, column_handles); + std::vector column_views; + std::transform(n_cudf_columns.data(), + n_cudf_columns.data() + n_cudf_columns.size(), + std::back_inserter(column_views), + [](auto const &p_column) { return *p_column; }); + + std::unique_ptr result = + cudf::strings::concatenate(cudf::table_view(column_views), separator_scalar, narep_scalar); + return reinterpret_cast(result.release()); + } + CATCH_STD(env, 0); +} + +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_concatListByRow(JNIEnv *env, jclass, + jlongArray column_handles, + jboolean ignore_null) { + JNI_NULL_CHECK(env, column_handles, "array of column handles is null", 0); + try { + cudf::jni::auto_set_device(env); + auto null_policy = ignore_null ? cudf::lists::concatenate_null_policy::IGNORE + : cudf::lists::concatenate_null_policy::NULLIFY_OUTPUT_ROW; + + cudf::jni::native_jpointerArray n_cudf_columns(env, column_handles); + std::vector column_views; + std::transform(n_cudf_columns.data(), + n_cudf_columns.data() + n_cudf_columns.size(), + std::back_inserter(column_views), + [](auto const &p_column) { return *p_column; }); + + std::unique_ptr result = + cudf::lists::concatenate_rows(cudf::table_view(column_views), null_policy); + return reinterpret_cast(result.release()); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeList(JNIEnv *env, jobject j_object, jlongArray handles, jlong j_type, diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index c9bafa5abee..09212ada78b 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -38,7 +38,6 @@ #include #include #include -#include #include #include #include @@ -1026,29 +1025,6 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_containsRe(JNIEnv *env, j CATCH_STD(env, 0); } -JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringConcatenation( - JNIEnv *env, jobject j_object, jlongArray column_handles, jlong separator, jlong narep) { - JNI_NULL_CHECK(env, column_handles, "array of column handles is null", 0); - JNI_NULL_CHECK(env, separator, "separator string scalar object is null", 0); - JNI_NULL_CHECK(env, narep, "narep string scalar object is null", 0); - try { - cudf::jni::auto_set_device(env); - cudf::string_scalar *separator_scalar = reinterpret_cast(separator); - cudf::string_scalar *narep_scalar = reinterpret_cast(narep); - cudf::jni::native_jpointerArray n_cudf_columns(env, column_handles); - std::vector column_views; - std::transform(n_cudf_columns.data(), n_cudf_columns.data() + n_cudf_columns.size(), - std::back_inserter(column_views), - [](auto const &p_column) { return *p_column; }); - cudf::table_view *string_columns = new cudf::table_view(column_views); - - std::unique_ptr result = - cudf::strings::concatenate(*string_columns, *separator_scalar, *narep_scalar); - return reinterpret_cast(result.release()); - } - CATCH_STD(env, 0); -} - JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_binaryOpVV(JNIEnv *env, jclass, jlong lhs_view, jlong rhs_view, jint int_op, jint out_dtype, diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 76999f402c7..44c6324ff8f 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -1965,81 +1965,80 @@ void testStringManipulationWithNulls() { @Test void testStringConcat() { try (ColumnVector v = ColumnVector.fromStrings("a", "B", "cd", "\u0480\u0481", "E\tf", - "g\nH", "IJ\"\u0100\u0101\u0500\u0501", - "kl m", "Nop1", "\\qRs2", "3tuV\'", - "wX4Yz", "\ud720\ud721"); + "g\nH", "IJ\"\u0100\u0101\u0500\u0501", + "kl m", "Nop1", "\\qRs2", "3tuV\'", + "wX4Yz", "\ud720\ud721"); ColumnVector e_concat = ColumnVector.fromStrings("aa", "BB", "cdcd", - "\u0480\u0481\u0480\u0481", "E\tfE\tf", "g\nHg\nH", - "IJ\"\u0100\u0101\u0500\u0501IJ\"\u0100\u0101\u0500\u0501", - "kl mkl m", "Nop1Nop1", "\\qRs2\\qRs2", "3tuV\'3tuV\'", - "wX4YzwX4Yz", "\ud720\ud721\ud720\ud721"); + "\u0480\u0481\u0480\u0481", "E\tfE\tf", "g\nHg\nH", + "IJ\"\u0100\u0101\u0500\u0501IJ\"\u0100\u0101\u0500\u0501", + "kl mkl m", "Nop1Nop1", "\\qRs2\\qRs2", "3tuV\'3tuV\'", + "wX4YzwX4Yz", "\ud720\ud721\ud720\ud721"); Scalar emptyString = Scalar.fromString(""); - ColumnVector concat = ColumnVector.stringConcatenate(emptyString, emptyString, - new ColumnVector[]{v, v})) { + ColumnVector concat = ColumnVector.stringConcatenate(emptyString, emptyString, new ColumnView[]{v, v})) { assertColumnsAreEqual(concat, e_concat); } - assertThrows(AssertionError.class, () -> { + assertThrows(CudfException.class, () -> { try (ColumnVector sv = ColumnVector.fromStrings("B", "cd", "\u0480\u0481", "E\tf"); ColumnVector cv = ColumnVector.fromInts(1, 2, 3, 4); Scalar emptyString = Scalar.fromString(""); - ColumnVector concat = ColumnVector.stringConcatenate(emptyString, emptyString, - new ColumnVector[]{sv, cv})) {} + ColumnVector concat = ColumnVector.stringConcatenate(emptyString, emptyString, new ColumnView[]{sv, cv})) { + } }); - assertThrows(AssertionError.class, () -> { + assertThrows(CudfException.class, () -> { try (ColumnVector sv1 = ColumnVector.fromStrings("a", "B", "cd"); ColumnVector sv2 = ColumnVector.fromStrings("a", "B"); Scalar emptyString = Scalar.fromString(""); ColumnVector concat = ColumnVector.stringConcatenate(emptyString, emptyString, - new ColumnVector[]{sv1, sv2})) {} + new ColumnVector[]{sv1, sv2})) { + } }); assertThrows(AssertionError.class, () -> { try (ColumnVector sv = ColumnVector.fromStrings("a", "B", "cd"); Scalar emptyString = Scalar.fromString(""); - ColumnVector concat = ColumnVector.stringConcatenate(emptyString, emptyString, - new ColumnVector[]{sv})) {} + ColumnVector concat = ColumnVector.stringConcatenate(emptyString, emptyString, new ColumnView[]{sv})) { + } }); assertThrows(CudfException.class, () -> { try (ColumnVector sv = ColumnVector.fromStrings("a", "B", "cd"); Scalar emptyString = Scalar.fromString(""); Scalar nullString = Scalar.fromString(null); - ColumnVector concat = ColumnVector.stringConcatenate(nullString, emptyString, - new ColumnVector[]{sv, sv})) {} + ColumnVector concat = ColumnVector.stringConcatenate(nullString, emptyString, new ColumnView[]{sv, sv})) { + } }); assertThrows(AssertionError.class, () -> { try (ColumnVector sv = ColumnVector.fromStrings("a", "B", "cd"); Scalar emptyString = Scalar.fromString(""); - ColumnVector concat = ColumnVector.stringConcatenate(null, emptyString, - new ColumnVector[]{sv, sv})) {} + ColumnVector concat = ColumnVector.stringConcatenate(null, emptyString, new ColumnView[]{sv, sv})) { + } }); assertThrows(AssertionError.class, () -> { try (ColumnVector sv = ColumnVector.fromStrings("a", "B", "cd"); Scalar emptyString = Scalar.fromString(""); - ColumnVector concat = ColumnVector.stringConcatenate(emptyString, null, - new ColumnVector[]{sv, sv})) {} + ColumnVector concat = ColumnVector.stringConcatenate(emptyString, null, new ColumnView[]{sv, sv})) { + } }); assertThrows(AssertionError.class, () -> { try (ColumnVector sv = ColumnVector.fromStrings("a", "B", "cd"); Scalar emptyString = Scalar.fromString(""); - ColumnVector concat = ColumnVector.stringConcatenate(emptyString, emptyString, - new ColumnVector[]{sv, null})) {} + ColumnVector concat = ColumnVector.stringConcatenate(emptyString, emptyString, new ColumnView[]{sv, null})) { + } }); } @Test void testStringConcatWithNulls() { try (ColumnVector v = ColumnVector.fromStrings("a", "B", "cd", "\u0480\u0481", "E\tf", - "g\nH", "IJ\"\u0100\u0101\u0500\u0501", - "kl m", "Nop1", "\\qRs2", null, - "3tuV\'", "wX4Yz", "\ud720\ud721"); + "g\nH", "IJ\"\u0100\u0101\u0500\u0501", + "kl m", "Nop1", "\\qRs2", null, + "3tuV\'", "wX4Yz", "\ud720\ud721"); ColumnVector e_concat = ColumnVector.fromStrings("aa", "BB", "cdcd", - "\u0480\u0481\u0480\u0481", "E\tfE\tf", "g\nHg\nH", - "IJ\"\u0100\u0101\u0500\u0501IJ\"\u0100\u0101\u0500\u0501", - "kl mkl m", "Nop1Nop1", "\\qRs2\\qRs2", "NULLNULL", - "3tuV\'3tuV\'", "wX4YzwX4Yz", "\ud720\ud721\ud720\ud721"); - Scalar emptyString = Scalar.fromString(""); - Scalar nullSubstitute = Scalar.fromString("NULL"); - ColumnVector concat = ColumnVector.stringConcatenate(emptyString, nullSubstitute, - new ColumnVector[]{v, v})) { + "\u0480\u0481\u0480\u0481", "E\tfE\tf", "g\nHg\nH", + "IJ\"\u0100\u0101\u0500\u0501IJ\"\u0100\u0101\u0500\u0501", + "kl mkl m", "Nop1Nop1", "\\qRs2\\qRs2", "NULLNULL", + "3tuV\'3tuV\'", "wX4YzwX4Yz", "\ud720\ud721\ud720\ud721"); + Scalar emptyString = Scalar.fromString(""); + Scalar nullSubstitute = Scalar.fromString("NULL"); + ColumnVector concat = ColumnVector.stringConcatenate(emptyString, nullSubstitute, new ColumnView[]{v, v})) { assertColumnsAreEqual(concat, e_concat); } } @@ -2049,15 +2048,147 @@ void testStringConcatSeparators() { try (ColumnVector sv1 = ColumnVector.fromStrings("a", "B", "cd", "\u0480\u0481", "E\tf", null, null, "\\G\u0100"); ColumnVector sv2 = ColumnVector.fromStrings("b", "C", "\u0500\u0501", "x\nYz", null, null, "", null); ColumnVector e_concat = ColumnVector.fromStrings("aA1\t\ud721b", "BA1\t\ud721C", "cdA1\t\ud721\u0500\u0501", - "\u0480\u0481A1\t\ud721x\nYz", null, null, null, null); + "\u0480\u0481A1\t\ud721x\nYz", null, null, null, null); Scalar separatorString = Scalar.fromString("A1\t\ud721"); Scalar nullString = Scalar.fromString(null); - ColumnVector concat = ColumnVector.stringConcatenate(separatorString, nullString, - new ColumnVector[]{sv1, sv2})) { + ColumnVector concat = ColumnVector.stringConcatenate(separatorString, nullString, new ColumnView[]{sv1, sv2})) { assertColumnsAreEqual(concat, e_concat); } } + @Test + void testListConcatByRow() { + try (ColumnVector cv1 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), + Arrays.asList(0), Arrays.asList(1, 2, 3), null, Arrays.asList(), Arrays.asList()); + ColumnVector cv2 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), + Arrays.asList(1, 2, 3), Arrays.asList((Integer) null), Arrays.asList(10, 12), Arrays.asList(100, 200, 300), + Arrays.asList()); + ColumnVector result = ColumnVector.listConcatenateByRow(cv1, cv2); + ColumnVector expect = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), + Arrays.asList(0, 1, 2, 3), Arrays.asList(1, 2, 3, null), null, Arrays.asList(100, 200, 300), + Arrays.asList())) { + assertColumnsAreEqual(expect, result); + } + + try (ColumnVector cv1 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("AAA", "BBB"), Arrays.asList("aaa"), Arrays.asList("111"), Arrays.asList("X"), + Arrays.asList()); + ColumnVector cv2 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList(), Arrays.asList("bbb", "ccc"), null, Arrays.asList((String) null), + Arrays.asList()); + ColumnVector cv3 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("CCC"), Arrays.asList(), Arrays.asList("222", "333"), Arrays.asList("Z"), + Arrays.asList()); + ColumnVector result = ColumnVector.listConcatenateByRow(cv1, cv2, cv3); + ColumnVector expect = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("AAA", "BBB", "CCC"), Arrays.asList("aaa", "bbb", "ccc"), null, + Arrays.asList("X", null, "Z"), Arrays.asList())) { + assertColumnsAreEqual(expect, result); + } + + try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.FLOAT64)), + Arrays.asList(1.23, 0.0, Double.NaN), Arrays.asList(), null, Arrays.asList(-1.23e10, null)); + ColumnVector result = ColumnVector.listConcatenateByRow(cv, cv, cv); + ColumnVector expect = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.FLOAT64)), + Arrays.asList(1.23, 0.0, Double.NaN, 1.23, 0.0, Double.NaN, 1.23, 0.0, Double.NaN), + Arrays.asList(), null, Arrays.asList(-1.23e10, null, -1.23e10, null, -1.23e10, null))) { + assertColumnsAreEqual(expect, result); + } + + assertThrows(AssertionError.class, () -> { + try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), Arrays.asList(1, 2, 3)); + ColumnVector result = ColumnVector.listConcatenateByRow(cv)) { + } + }); + + assertThrows(CudfException.class, () -> { + try (ColumnVector cv = ColumnVector.fromInts(1, 2, 3); + ColumnVector result = ColumnVector.listConcatenateByRow(cv, cv)) { + } + }); + + assertThrows(CudfException.class, () -> { + try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32))), Arrays.asList(Arrays.asList(1))); + ColumnVector result = ColumnVector.listConcatenateByRow(cv, cv)) { + } + }); + + assertThrows(CudfException.class, () -> { + try (ColumnVector cv1 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), Arrays.asList(1, 2, 3)); + ColumnVector cv2 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), Arrays.asList(1, 2), Arrays.asList(3)); + ColumnVector result = ColumnVector.listConcatenateByRow(cv1, cv2)) { + } + }); + + assertThrows(CudfException.class, () -> { + try (ColumnVector cv1 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), Arrays.asList(1, 2, 3)); + ColumnVector cv2 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT64)), Arrays.asList(1L)); + ColumnVector result = ColumnVector.listConcatenateByRow(cv1, cv2)) { + } + }); + } + + @Test + void testListConcatByRowIgnoreNull() { + try (ColumnVector cv1 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), + Arrays.asList((Integer) null), Arrays.asList(1, 2, 3), null, Arrays.asList(), null); + ColumnVector cv2 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), + Arrays.asList(1, 2, 3), null, Arrays.asList(10, 12), Arrays.asList(100, 200, 300), null); + ColumnVector result = ColumnVector.listConcatenateByRow(true, cv1, cv2); + ColumnVector expect = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), + Arrays.asList(null, 1, 2, 3), Arrays.asList(1, 2, 3), Arrays.asList(10, 12), + Arrays.asList(100, 200, 300), null)) { + assertColumnsAreEqual(expect, result); + } + + try (ColumnVector cv1 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("AAA", "BBB"), Arrays.asList("aaa"), Arrays.asList("111"), null, null); + ColumnVector cv2 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + null, Arrays.asList("bbb", "ccc"), null, Arrays.asList("Y", null), null); + ColumnVector cv3 = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("CCC"), Arrays.asList(), Arrays.asList("222", "333"), null, null); + ColumnVector result = ColumnVector.listConcatenateByRow(true, cv1, cv2, cv3); + ColumnVector expect = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.STRING)), + Arrays.asList("AAA", "BBB", "CCC"), Arrays.asList("aaa", "bbb", "ccc"), + Arrays.asList("111", "222", "333"), Arrays.asList("Y", null), null)) { + assertColumnsAreEqual(expect, result); + } + + try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.FLOAT64)), + Arrays.asList(1.23, 0.0, Double.NaN), Arrays.asList(), null, Arrays.asList(-1.23e10, null)); + ColumnVector result = ColumnVector.listConcatenateByRow(true, cv, cv, cv); + ColumnVector expect = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.FLOAT64)), + Arrays.asList(1.23, 0.0, Double.NaN, 1.23, 0.0, Double.NaN, 1.23, 0.0, Double.NaN), + Arrays.asList(), null, Arrays.asList(-1.23e10, null, -1.23e10, null, -1.23e10, null))) { + assertColumnsAreEqual(expect, result); + } + } + @Test void testPrefixSum() { try (ColumnVector v1 = ColumnVector.fromLongs(1, 2, 3, 5, 8, 10);