diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index a30d276d954..76999f402c7 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -4342,37 +4342,64 @@ void testMakeStruct() { @Test void testMakeListEmpty() { - final int numRows = 10; - try (ColumnVector expected = + final int numRows = 4; + List> emptyListOfList = new ArrayList<>(); + emptyListOfList.add(Arrays.asList()); + try ( + ColumnVector expectedList = ColumnVector.fromLists( new ListType(false, new BasicType(false, DType.STRING)), Arrays.asList(), Arrays.asList(), Arrays.asList(), - Arrays.asList(), - Arrays.asList(), - Arrays.asList(), - Arrays.asList(), - Arrays.asList(), - Arrays.asList(), Arrays.asList()); - ColumnVector created = ColumnVector.makeList(numRows, DType.STRING)) { - assertColumnsAreEqual(expected, created); + ColumnVector expectedListOfList = ColumnVector.fromLists(new HostColumnVector.ListType(false, + new HostColumnVector.ListType(false, + new HostColumnVector.BasicType(false, DType.STRING))), + emptyListOfList, emptyListOfList, emptyListOfList, emptyListOfList); + + ColumnVector createdList = ColumnVector.makeList(numRows, DType.STRING); + ColumnVector createdListOfList = ColumnVector.makeList(createdList)) { + assertColumnsAreEqual(expectedList, createdList); + assertColumnsAreEqual(expectedListOfList, createdListOfList); } } @Test void testMakeList() { - try (ColumnVector expected = - ColumnVector.fromLists( - new ListType(false, new BasicType(false, DType.INT32)), - Arrays.asList(1, 3, 5), - Arrays.asList(2, 4, 6)); + List list1 = Arrays.asList(1, 3); + List list2 = Arrays.asList(2, 4); + List list3 = Arrays.asList(5, 7, 9); + List list4 = Arrays.asList(6, 8, 10); + List> mainList1 = new ArrayList<>(Arrays.asList(list1, list3)); + List> mainList2 = new ArrayList<>(Arrays.asList(list2, list4)); + try (ColumnVector expectedList1 = + ColumnVector.fromLists(new ListType(false, + new BasicType(false, DType.INT32)), list1, list2); + ColumnVector expectedList2 = + ColumnVector.fromLists(new ListType(false, + new BasicType(false, DType.INT32)), list3, list4); + ColumnVector expectedListOfList = ColumnVector.fromLists(new HostColumnVector.ListType(true, + new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.INT32))), + mainList1, mainList2); ColumnVector child1 = ColumnVector.fromInts(1, 2); ColumnVector child2 = ColumnVector.fromInts(3, 4); ColumnVector child3 = ColumnVector.fromInts(5, 6); - ColumnVector created = ColumnVector.makeList(child1, child2, child3)) { - assertColumnsAreEqual(expected, created); + ColumnVector child4 = ColumnVector.fromInts(7, 8); + ColumnVector child5 = ColumnVector.fromInts(9, 10); + ColumnVector createdList1 = ColumnVector.makeList(child1, child2); + ColumnVector createdList2 = ColumnVector.makeList(child3, child4, child5); + ColumnVector createdListOfList = ColumnVector.makeList(createdList1, createdList2); + HostColumnVector hcv = createdListOfList.copyToHost()) { + + assertColumnsAreEqual(expectedList1, createdList1); + assertColumnsAreEqual(expectedList2, createdList2); + assertColumnsAreEqual(expectedListOfList, createdListOfList); + + List> ret1 = hcv.getList(0); + List> ret2 = hcv.getList(1); + assertEquals(mainList1, ret1, "Lists don't match"); + assertEquals(mainList2, ret2, "Lists don't match"); } }