From 7af8966e705b1da5d5c694e063d89a47b88e3b9c Mon Sep 17 00:00:00 2001 From: Alfred Xu Date: Wed, 19 May 2021 18:50:31 +0800 Subject: [PATCH] Fix incorrect assertion in Java concat (#8258) Fix #8246. Authors: - Alfred Xu (https://github.com/sperlingxx) Approvers: - Liangcai Li (https://github.com/firestarman) - Jason Lowe (https://github.com/jlowe) - Robert (Bobby) Evans (https://github.com/revans2) URL: https://github.com/rapidsai/cudf/pull/8258 --- .../java/ai/rapids/cudf/ColumnVector.java | 13 +++---- .../java/ai/rapids/cudf/ColumnVectorTest.java | 36 +++++++++++++------ 2 files changed, 33 insertions(+), 16 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/ColumnVector.java b/java/src/main/java/ai/rapids/cudf/ColumnVector.java index adf0f317340..ea93a2daf36 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnVector.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnVector.java @@ -488,7 +488,7 @@ public static ColumnVector concatenate(ColumnView... columns) { * Concatenate columns of strings together, combining a corresponding row from each column * into a single string row of a new column with no separator string inserted between each * combined string and maintaining null values in combined rows. - * @param columns array of columns containing strings. + * @param columns array of columns containing strings, must be non-empty * @return A new java column vector containing the concatenated strings. */ public static ColumnVector stringConcatenate(ColumnView[] columns) { @@ -505,11 +505,12 @@ public static ColumnVector stringConcatenate(ColumnView[] columns) { * @param narep string scalar indicating null behavior. If set to null and any string in the row * is null the resulting string will be null. If not null, null values in any column * will be replaced by the specified string. - * @param columns array of columns containing strings, must be more than 2 columns + * @param columns array of columns containing strings, must be non-empty * @return A new java column vector containing the concatenated strings. */ public static ColumnVector stringConcatenate(Scalar separator, Scalar narep, ColumnView[] columns) { - assert columns.length >= 2 : ".stringConcatenate() operation requires at least 2 columns"; + assert columns != null : "input columns should not be null"; + assert columns.length > 0 : "input columns should not be empty"; assert separator != null : "separator scalar provided may not be null"; assert separator.getType().equals(DType.STRING) : "separator scalar must be a string scalar"; assert narep != null : "narep scalar provided may not be null"; @@ -529,7 +530,7 @@ public static ColumnVector stringConcatenate(Scalar separator, Scalar narep, Col * from each column into a single list row of a new column. * NOTICE: Any concatenation involving a null list element will result in a null list. * - * @param columns array of columns containing lists, must be more than 2 columns + * @param columns array of columns containing lists, must be non-empty * @return A new java column vector containing the concatenated lists. */ public static ColumnVector listConcatenateByRow(ColumnView... columns) { @@ -543,12 +544,12 @@ public static ColumnVector listConcatenateByRow(ColumnView... columns) { * @param ignoreNull whether to ignore null list element of input columns: If true, null list * will be ignored from concatenation; Otherwise, any concatenation involving * a null list element will result in a null list - * @param columns array of columns containing lists, must be more than 2 columns + * @param columns array of columns containing lists, must be non-empty * @return A new java column vector containing the concatenated lists. */ public static ColumnVector listConcatenateByRow(boolean ignoreNull, ColumnView... columns) { assert columns != null : "input columns should not be null"; - assert columns.length >= 2 : "listConcatenateByRow requires at least 2 columns"; + assert columns.length > 0 : "input columns should not be empty"; long[] columnViews = new long[columns.length]; for(int i = 0; i < columns.length; i++) { diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index a7733897d10..09ddef633e3 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -2037,9 +2037,8 @@ void testStringConcat() { } }); assertThrows(AssertionError.class, () -> { - try (ColumnVector sv = ColumnVector.fromStrings("a", "B", "cd"); - Scalar emptyString = Scalar.fromString(""); - ColumnVector concat = ColumnVector.stringConcatenate(emptyString, emptyString, new ColumnView[]{sv})) { + try (Scalar emptyString = Scalar.fromString(""); + ColumnVector concat = ColumnVector.stringConcatenate(emptyString, emptyString, new ColumnView[]{})) { } }); assertThrows(CudfException.class, () -> { @@ -2085,6 +2084,16 @@ void testStringConcatWithNulls() { ColumnVector concat = ColumnVector.stringConcatenate(emptyString, nullSubstitute, new ColumnView[]{v, v})) { assertColumnsAreEqual(concat, e_concat); } + + try (ColumnVector v = ColumnVector.fromStrings("a", "B", "cd", "\u0480\u0481", "E\tf", + "g\nH", "IJ\"\u0100\u0101\u0500\u0501", + "kl m", "Nop1", "\\qRs2", null, + "3tuV\'", "wX4Yz", "\ud720\ud721"); + Scalar emptyString = Scalar.fromString(""); + Scalar nullSubstitute = Scalar.fromString("NULL"); + ColumnVector concat = ColumnVector.stringConcatenate(emptyString, nullSubstitute, new ColumnView[]{v})) { + assertColumnsAreEqual(v, concat); + } } @Test @@ -2102,6 +2111,13 @@ void testStringConcatSeparators() { @Test void testListConcatByRow() { + try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), + Arrays.asList(0), Arrays.asList(1, 2, 3), null, Arrays.asList(), Arrays.asList()); + ColumnVector result = ColumnVector.listConcatenateByRow(cv)) { + assertColumnsAreEqual(cv, result); + } + try (ColumnVector cv1 = ColumnVector.fromLists(new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.INT32)), Arrays.asList(0), Arrays.asList(1, 2, 3), null, Arrays.asList(), Arrays.asList()); @@ -2148,13 +2164,6 @@ void testListConcatByRow() { assertColumnsAreEqual(expect, result); } - assertThrows(AssertionError.class, () -> { - try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, - new HostColumnVector.BasicType(true, DType.INT32)), Arrays.asList(1, 2, 3)); - ColumnVector result = ColumnVector.listConcatenateByRow(cv)) { - } - }); - assertThrows(CudfException.class, () -> { try (ColumnVector cv = ColumnVector.fromInts(1, 2, 3); ColumnVector result = ColumnVector.listConcatenateByRow(cv, cv)) { @@ -2190,6 +2199,13 @@ void testListConcatByRow() { @Test void testListConcatByRowIgnoreNull() { + try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.BasicType(true, DType.INT32)), + Arrays.asList(0), Arrays.asList(1, 2, 3), null, Arrays.asList(), Arrays.asList()); + ColumnVector result = ColumnVector.listConcatenateByRow(true, cv)) { + assertColumnsAreEqual(cv, result); + } + try (ColumnVector cv1 = ColumnVector.fromLists(new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.INT32)), Arrays.asList((Integer) null), Arrays.asList(1, 2, 3), null, Arrays.asList(), null);