diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index bb9e5e40cb9..86f6bec9eef 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -2483,6 +2483,65 @@ void testWindowingRowNumber() { } } + @Test + void testWindowingCollect() { + Aggregation aggCollect = Aggregation.collect(); + WindowOptions winOpts = WindowOptions.builder() + .minPeriods(1) + .window(2, 1).build(); + StructType nestedType = new StructType(false, + new BasicType(false, DType.INT32), new BasicType(false, DType.STRING)); + try (Table raw = new Table.TestBuilder() + .column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column( 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3) // GBY Key + .column( 1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8) // OBY Key + .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6) // Agg Column of INT32 + .column(nestedType, // Agg Column of Struct + new StructData(1, "s1"), new StructData(2, "s2"), new StructData(3, "s3"), + new StructData(4, "s4"), new StructData(11, "s11"), new StructData(22, "s22"), + new StructData(33, "s33"), new StructData(44, "s44"), new StructData(111, "s111"), + new StructData(222, "s222"), new StructData(333, "s333"), new StructData(444, "s444") + ).build(); + ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6)) { + try (Table sorted = raw.orderBy(Table.asc(0), Table.asc(1), Table.asc(2))) { + ColumnVector sortedAggColumn = sorted.getColumn(3); + assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); + + // Primitive type: INT32 + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(aggCollect.onColumn(3).overWindow(winOpts)); + ColumnVector expected = ColumnVector.fromLists( + new ListType(false, new BasicType(false, DType.INT32)), + Arrays.asList(7,5), Arrays.asList(7,5,1), Arrays.asList(5,1,9), Arrays.asList(1,9), + Arrays.asList(7,9), Arrays.asList(7,9,8), Arrays.asList(9,8,2), Arrays.asList(8,2), + Arrays.asList(8,0), Arrays.asList(8,0,6), Arrays.asList(0,6,6), Arrays.asList(6,6))) { + assertColumnsAreEqual(expected, windowAggResults.getColumn(0)); + } + + // Nested type: Struct + List[] expectedNestedData = new List[12]; + expectedNestedData[0] = Arrays.asList(new StructData(1, "s1"),new StructData(2, "s2")); + expectedNestedData[1] = Arrays.asList(new StructData(1, "s1"),new StructData(2, "s2"),new StructData(3, "s3")); + expectedNestedData[2] = Arrays.asList(new StructData(2, "s2"),new StructData(3, "s3"),new StructData(4, "s4")); + expectedNestedData[3] = Arrays.asList(new StructData(3, "s3"),new StructData(4, "s4")); + expectedNestedData[4] = Arrays.asList(new StructData(11, "s11"),new StructData(22, "s22")); + expectedNestedData[5] = Arrays.asList(new StructData(11, "s11"),new StructData(22, "s22"),new StructData(33, "s33")); + expectedNestedData[6] = Arrays.asList(new StructData(22, "s22"),new StructData(33, "s33"), new StructData(44, "s44")); + expectedNestedData[7] = Arrays.asList(new StructData(33, "s33"), new StructData(44, "s44")); + expectedNestedData[8] = Arrays.asList(new StructData(111, "s111"),new StructData(222, "s222")); + expectedNestedData[9] = Arrays.asList(new StructData(111, "s111"),new StructData(222, "s222"),new StructData(333, "s333")); + expectedNestedData[10] = Arrays.asList(new StructData(222, "s222"),new StructData(333, "s333"),new StructData(444, "s444")); + expectedNestedData[11] = Arrays.asList(new StructData(333, "s333"),new StructData(444, "s444")); + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(aggCollect.onColumn(4).overWindow(winOpts)); + ColumnVector expected = ColumnVector.fromLists( + new ListType(false, nestedType), expectedNestedData)) { + assertColumnsAreEqual(expected, windowAggResults.getColumn(0)); + } + } + } + } + @Test void testWindowingLead() { try (Table unsorted = new Table.TestBuilder()