Skip to content

Commit

Permalink
Fix incorrect assertion in Java concat (#8258)
Browse files Browse the repository at this point in the history
Fix #8246.

Authors:
  - Alfred Xu (https://github.com/sperlingxx)

Approvers:
  - Liangcai Li (https://github.com/firestarman)
  - Jason Lowe (https://github.com/jlowe)
  - Robert (Bobby) Evans (https://github.com/revans2)

URL: #8258
  • Loading branch information
sperlingxx authored May 19, 2021
1 parent 072cbee commit 7af8966
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 16 deletions.
13 changes: 7 additions & 6 deletions java/src/main/java/ai/rapids/cudf/ColumnVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ public static ColumnVector concatenate(ColumnView... columns) {
* Concatenate columns of strings together, combining a corresponding row from each column
* into a single string row of a new column with no separator string inserted between each
* combined string and maintaining null values in combined rows.
* @param columns array of columns containing strings.
* @param columns array of columns containing strings, must be non-empty
* @return A new java column vector containing the concatenated strings.
*/
public static ColumnVector stringConcatenate(ColumnView[] columns) {
Expand All @@ -505,11 +505,12 @@ public static ColumnVector stringConcatenate(ColumnView[] columns) {
* @param narep string scalar indicating null behavior. If set to null and any string in the row
* is null the resulting string will be null. If not null, null values in any column
* will be replaced by the specified string.
* @param columns array of columns containing strings, must be more than 2 columns
* @param columns array of columns containing strings, must be non-empty
* @return A new java column vector containing the concatenated strings.
*/
public static ColumnVector stringConcatenate(Scalar separator, Scalar narep, ColumnView[] columns) {
assert columns.length >= 2 : ".stringConcatenate() operation requires at least 2 columns";
assert columns != null : "input columns should not be null";
assert columns.length > 0 : "input columns should not be empty";
assert separator != null : "separator scalar provided may not be null";
assert separator.getType().equals(DType.STRING) : "separator scalar must be a string scalar";
assert narep != null : "narep scalar provided may not be null";
Expand All @@ -529,7 +530,7 @@ public static ColumnVector stringConcatenate(Scalar separator, Scalar narep, Col
* from each column into a single list row of a new column.
* NOTICE: Any concatenation involving a null list element will result in a null list.
*
* @param columns array of columns containing lists, must be more than 2 columns
* @param columns array of columns containing lists, must be non-empty
* @return A new java column vector containing the concatenated lists.
*/
public static ColumnVector listConcatenateByRow(ColumnView... columns) {
Expand All @@ -543,12 +544,12 @@ public static ColumnVector listConcatenateByRow(ColumnView... columns) {
* @param ignoreNull whether to ignore null list element of input columns: If true, null list
* will be ignored from concatenation; Otherwise, any concatenation involving
* a null list element will result in a null list
* @param columns array of columns containing lists, must be more than 2 columns
* @param columns array of columns containing lists, must be non-empty
* @return A new java column vector containing the concatenated lists.
*/
public static ColumnVector listConcatenateByRow(boolean ignoreNull, ColumnView... columns) {
assert columns != null : "input columns should not be null";
assert columns.length >= 2 : "listConcatenateByRow requires at least 2 columns";
assert columns.length > 0 : "input columns should not be empty";

long[] columnViews = new long[columns.length];
for(int i = 0; i < columns.length; i++) {
Expand Down
36 changes: 26 additions & 10 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2037,9 +2037,8 @@ void testStringConcat() {
}
});
assertThrows(AssertionError.class, () -> {
try (ColumnVector sv = ColumnVector.fromStrings("a", "B", "cd");
Scalar emptyString = Scalar.fromString("");
ColumnVector concat = ColumnVector.stringConcatenate(emptyString, emptyString, new ColumnView[]{sv})) {
try (Scalar emptyString = Scalar.fromString("");
ColumnVector concat = ColumnVector.stringConcatenate(emptyString, emptyString, new ColumnView[]{})) {
}
});
assertThrows(CudfException.class, () -> {
Expand Down Expand Up @@ -2085,6 +2084,16 @@ void testStringConcatWithNulls() {
ColumnVector concat = ColumnVector.stringConcatenate(emptyString, nullSubstitute, new ColumnView[]{v, v})) {
assertColumnsAreEqual(concat, e_concat);
}

try (ColumnVector v = ColumnVector.fromStrings("a", "B", "cd", "\u0480\u0481", "E\tf",
"g\nH", "IJ\"\u0100\u0101\u0500\u0501",
"kl m", "Nop1", "\\qRs2", null,
"3tuV\'", "wX4Yz", "\ud720\ud721");
Scalar emptyString = Scalar.fromString("");
Scalar nullSubstitute = Scalar.fromString("NULL");
ColumnVector concat = ColumnVector.stringConcatenate(emptyString, nullSubstitute, new ColumnView[]{v})) {
assertColumnsAreEqual(v, concat);
}
}

@Test
Expand All @@ -2102,6 +2111,13 @@ void testStringConcatSeparators() {

@Test
void testListConcatByRow() {
try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.INT32)),
Arrays.asList(0), Arrays.asList(1, 2, 3), null, Arrays.asList(), Arrays.asList());
ColumnVector result = ColumnVector.listConcatenateByRow(cv)) {
assertColumnsAreEqual(cv, result);
}

try (ColumnVector cv1 = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.INT32)),
Arrays.asList(0), Arrays.asList(1, 2, 3), null, Arrays.asList(), Arrays.asList());
Expand Down Expand Up @@ -2148,13 +2164,6 @@ void testListConcatByRow() {
assertColumnsAreEqual(expect, result);
}

assertThrows(AssertionError.class, () -> {
try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.INT32)), Arrays.asList(1, 2, 3));
ColumnVector result = ColumnVector.listConcatenateByRow(cv)) {
}
});

assertThrows(CudfException.class, () -> {
try (ColumnVector cv = ColumnVector.fromInts(1, 2, 3);
ColumnVector result = ColumnVector.listConcatenateByRow(cv, cv)) {
Expand Down Expand Up @@ -2190,6 +2199,13 @@ void testListConcatByRow() {

@Test
void testListConcatByRowIgnoreNull() {
try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.INT32)),
Arrays.asList(0), Arrays.asList(1, 2, 3), null, Arrays.asList(), Arrays.asList());
ColumnVector result = ColumnVector.listConcatenateByRow(true, cv)) {
assertColumnsAreEqual(cv, result);
}

try (ColumnVector cv1 = ColumnVector.fromLists(new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.INT32)),
Arrays.asList((Integer) null), Arrays.asList(1, 2, 3), null, Arrays.asList(), null);
Expand Down

0 comments on commit 7af8966

Please sign in to comment.