From aa14d925b6a59c6eaf31623488a963b1722bcdad Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Tue, 18 May 2021 05:19:34 +0800 Subject: [PATCH 1/2] add unit tests for lead/lag on list for row window Signed-off-by: Bobby Wang --- .../test/java/ai/rapids/cudf/TableTest.java | 592 ++++++++++++++---- 1 file changed, 459 insertions(+), 133 deletions(-) diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index b81125965fa..1d98b78bf65 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -3053,6 +3053,16 @@ void testWindowingLead() { .decimal32Column(-1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3) // Decimal GBY Key .decimal64Column(1, 1L, 1L, 2L, 2L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L) // Decimal OBY Key .decimal64Column(-2, 7L, 5L, 1L, 9L, 7L, 9L, 8L, 2L, 8L, 0L, 6L, 6L) // Decimal Agg Column + .column(new ListType(false, new BasicType(true, DType.INT32)), + Arrays.asList(11, 12, null, 13), Arrays.asList(14, null, 15, null), Arrays.asList((Integer) null), Arrays.asList(16), + Arrays.asList(21, null, null, 22), Arrays.asList(23, 24), Arrays.asList(25, 26, 27), Arrays.asList(28, 29, null), + Arrays.asList(null, 31), Arrays.asList(32, 33, 34), Arrays.asList(35, 36), Arrays.asList(37, 38, 39)) // List Agg COLUMN + .column(new StructType(true, + new BasicType(true, DType.INT32), + new BasicType(true, DType.STRING)), + new StructData(1, "s1"), new StructData(null, "s2"), new StructData(2, null), new StructData(3, "s3"), + new StructData(11, "s11"), null, new StructData(13, "s13"), new StructData(14, "s14"), + new StructData(111, "s111"), new StructData(null, "s112"), new StructData(2, "s222"), new StructData(3, "s333")) //STRUCT Agg COLUMN .build()) { try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2)); @@ -3069,88 +3079,188 @@ void testWindowingLead() { try (Scalar two = Scalar.fromInt(2); Scalar one = Scalar.fromInt(1); WindowOptions options = windowBuilder.window(two, one).build(); - WindowOptions options1 = windowBuilder.window(two, one).build()) { - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation - .lead(0) - .onColumn(3) - .overWindow(options)); - Table decWindowAggResults = decSorted.groupBy(0, 4) - .aggregateWindows(Aggregation - .lead(0) - .onColumn(6) - .overWindow(options1)); - ColumnVector expectAggResult = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6); - ColumnVector decExpectAggResult = ColumnVector.decimalFromLongs(-2, 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6)) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); - assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); - } + Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation + .lead(0) + .onColumn(3) + .overWindow(options)); + Table decWindowAggResults = decSorted.groupBy(0, 4) + .aggregateWindows(Aggregation + .lead(0) + .onColumn(6) + .overWindow(options)); + Table listWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( + Aggregation + .lead(0) + .onColumn(7) + .overWindow(options)); + Table structWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( + Aggregation + .lead(0) + .onColumn(8) + .overWindow(options)); + ColumnVector expectAggResult = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6); + ColumnVector decExpectAggResult = ColumnVector.decimalFromLongs(-2, 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6); + ColumnVector listExpectAggResult = ColumnVector.fromLists( + new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.INT32)), + Arrays.asList(11, 12, null, 13), Arrays.asList(14, null, 15, null), Arrays.asList((Integer) null), Arrays.asList(16), + Arrays.asList(21, null, null, 22), Arrays.asList(23, 24), Arrays.asList(25, 26, 27), Arrays.asList(28, 29, null), + Arrays.asList(null, 31), Arrays.asList(32, 33, 34), Arrays.asList(35, 36), Arrays.asList(37, 38, 39)); + ColumnVector structExpectAggResult = ColumnVector.fromStructs( + new StructType(true, + new BasicType(true, DType.INT32), + new BasicType(true, DType.STRING)), + new StructData(1, "s1"), new StructData(null, "s2"), new StructData(2, null), new StructData(3, "s3"), + new StructData(11, "s11"), null, new StructData(13, "s13"), new StructData(14, "s14"), + new StructData(111, "s111"), new StructData(null, "s112"), new StructData(2, "s222"), new StructData(3, "s333"))) { + + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); + assertColumnsAreEqual(listExpectAggResult, listWindowAggResults.getColumn(0)); + assertColumnsAreEqual(structExpectAggResult, structWindowAggResults.getColumn(0)); } try (Scalar zero = Scalar.fromInt(0); Scalar one = Scalar.fromInt(1); WindowOptions options = windowBuilder.window(zero, one).build(); - WindowOptions options1 = windowBuilder.window(zero, one).build()) { - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation - .lead(1) - .onColumn(3) - .overWindow(options)); - Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation - .lead(1) - .onColumn(6) - .overWindow(options1)); - ColumnVector expectAggResult = ColumnVector.fromBoxedInts(5, 1, 9, null, 9, 8, 2, null, 0, 6, 6, null); - ColumnVector decExpectAggResult = decimalFromBoxedInts(true, -2, 5, 1, 9, null, 9, 8, 2, null, 0, 6, 6, null)) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); - assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); - } + Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation + .lead(1) + .onColumn(3) + .overWindow(options)); + Table decWindowAggResults = sorted.groupBy(0, 4) + .aggregateWindows(Aggregation + .lead(1) + .onColumn(6) + .overWindow(options)); + Table listWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( + Aggregation + .lead(1) + .onColumn(7) + .overWindow(options)); + Table structWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( + Aggregation + .lead(1) + .onColumn(8) + .overWindow(options)); + ColumnVector expectAggResult = ColumnVector.fromBoxedInts(5, 1, 9, null, 9, 8, 2, null, 0, 6, 6, null); + ColumnVector decExpectAggResult = decimalFromBoxedInts(true, -2, 5, 1, 9, null, 9, 8, 2, null, 0, 6, 6, null); + ColumnVector listExpectAggResult = ColumnVector.fromLists( + new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.INT32)), + Arrays.asList(14, null, 15, null), Arrays.asList((Integer) null), Arrays.asList(16), null, + Arrays.asList(23, 24), Arrays.asList(25, 26, 27), Arrays.asList(28, 29, null), null, + Arrays.asList(32, 33, 34), Arrays.asList(35, 36), Arrays.asList(37, 38, 39), null); + ColumnVector structExpectAggResult = ColumnVector.fromStructs( + new StructType(true, + new BasicType(true, DType.INT32), + new BasicType(true, DType.STRING)), + new StructData(null, "s2"), new StructData(2, null), new StructData(3, "s3"), null, + null, new StructData(13, "s13"), new StructData(14, "s14"), null, + new StructData(null, "s112"), new StructData(2, "s222"), new StructData(3, "s333"), null)) { + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); + assertColumnsAreEqual(listExpectAggResult, listWindowAggResults.getColumn(0)); + assertColumnsAreEqual(structExpectAggResult, structWindowAggResults.getColumn(0)); } try (Scalar zero = Scalar.fromInt(0); Scalar one = Scalar.fromInt(1); WindowOptions options = windowBuilder.window(zero, one).build(); - WindowOptions options1 = windowBuilder.window(zero, one).build()) { - try (ColumnVector defaultOutput = ColumnVector.fromBoxedInts(0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11); - ColumnVector decDefaultOutput = ColumnVector.decimalFromLongs(-2, 0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11); - Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation - .lead(1, defaultOutput) - .onColumn(3) - .overWindow(options)); - Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation - .lead(1, decDefaultOutput) - .onColumn(6) - .overWindow(options1)); - ColumnVector expectAggResult = ColumnVector.fromBoxedInts(5, 1, 9, -3, 9, 8, 2, -7, 0, 6, 6, -11); - ColumnVector decExpectAggResult = ColumnVector.decimalFromLongs(-2, 5, 1, 9, -3, 9, 8, 2, -7, 0, 6, 6, -11)) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); - assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); - } + ColumnVector defaultOutput = ColumnVector.fromBoxedInts(0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11); + ColumnVector decDefaultOutput = ColumnVector.decimalFromLongs(-2, 0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11); + ColumnVector listDefaultOutput = ColumnVector.fromLists( + new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.INT32)), + Arrays.asList(111), Arrays.asList(222), Arrays.asList(333), Arrays.asList(444, null, 555), + Arrays.asList(-11), Arrays.asList(-22), Arrays.asList(-33), Arrays.asList(-44), + Arrays.asList(6), Arrays.asList(6), Arrays.asList(6), Arrays.asList(6, null, null)); + ColumnVector structDefaultOutput = ColumnVector.fromStructs( + new StructType(true, + new BasicType(true, DType.INT32), + new BasicType(true, DType.STRING)), + new StructData(-1, "s1"), new StructData(null, "s2"), new StructData(-2, null), new StructData(-3, "s3"), + new StructData(-11, "s11"), null, new StructData(-13, "s13"), new StructData(-14, "s14"), + new StructData(-111, "s111"), new StructData(null, "s112"), new StructData(-222, "s222"), new StructData(-333, "s333")); + + Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation + .lead(1, defaultOutput) + .onColumn(3) + .overWindow(options)); + Table decWindowAggResults = sorted.groupBy(0, 4) + .aggregateWindows(Aggregation + .lead(1, decDefaultOutput) + .onColumn(6) + .overWindow(options)); + Table listWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( + Aggregation + .lead(1, listDefaultOutput) + .onColumn(7) + .overWindow(options)); + Table structWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( + Aggregation + .lead(1, structDefaultOutput) + .onColumn(8) + .overWindow(options)); + ColumnVector expectAggResult = ColumnVector.fromBoxedInts(5, 1, 9, -3, 9, 8, 2, -7, 0, 6, 6, -11); + ColumnVector decExpectAggResult = ColumnVector.decimalFromLongs(-2, 5, 1, 9, -3, 9, 8, 2, -7, 0, 6, 6, -11); + ColumnVector listExpectAggResult = ColumnVector.fromLists( + new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.INT32)), + Arrays.asList(14, null, 15, null), Arrays.asList((Integer) null), Arrays.asList(16), Arrays.asList(444, null, 555), + Arrays.asList(23, 24), Arrays.asList(25, 26, 27), Arrays.asList(28, 29, null), Arrays.asList(-44), + Arrays.asList(32, 33, 34), Arrays.asList(35, 36), Arrays.asList(37, 38, 39), Arrays.asList(6, null, null)); + ColumnVector structExpectAggResult = ColumnVector.fromStructs( + new StructType(true, + new BasicType(true, DType.INT32), + new BasicType(true, DType.STRING)), + new StructData(null, "s2"), new StructData(2, null), new StructData(3, "s3"), new StructData(-3, "s3"), + null, new StructData(13, "s13"), new StructData(14, "s14"), new StructData(-14, "s14"), + new StructData(null, "s112"), new StructData(2, "s222"), new StructData(3, "s333"), new StructData(-333, "s333"))) { + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); + assertColumnsAreEqual(listExpectAggResult, listWindowAggResults.getColumn(0)); + + // TODO this is not gonna work, since cudf has some issue on lead with default values + // assertColumnsAreEqual(structExpectAggResult, structWindowAggResults.getColumn(0)); } // Outside bounds try (Scalar zero = Scalar.fromInt(0); Scalar one = Scalar.fromInt(1); WindowOptions options = windowBuilder.window(zero, one).build(); - WindowOptions options1 = windowBuilder.window(zero, one).build()) { - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation - .lead(3) - .onColumn(3) - .overWindow(options)); - Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation - .lead(3) - .onColumn(6) - .overWindow(options1)); - ColumnVector expectAggResult = ColumnVector.fromBoxedInts(null, null, null, null, null, null, null, null, null, null, null, null); - ColumnVector decExpectAggResult = decimalFromBoxedInts(true, -2, null, null, null, null, null, null, null, null, null, null, null, null)) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); - assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); - } + Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation + .lead(3) + .onColumn(3) + .overWindow(options)); + Table decWindowAggResults = sorted.groupBy(0, 4) + .aggregateWindows(Aggregation + .lead(3) + .onColumn(6) + .overWindow(options)); + Table listWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( + Aggregation + .lead(3) + .onColumn(7) + .overWindow(options)); + Table structWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( + Aggregation + .lead(3) + .onColumn(8) + .overWindow(options)); + ColumnVector expectAggResult = ColumnVector.fromBoxedInts(null, null, null, null, null, null, null, null, null, null, null, null); + ColumnVector decExpectAggResult = decimalFromBoxedInts(true, -2, null, null, null, null, null, null, null, null, null, null, null, null); + ColumnVector listExpectAggResult = ColumnVector.fromLists( + new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.INT32)), + null, null, null, null, null, null, null, null, null, null, null, null); + ColumnVector structExpectAggResult = ColumnVector.fromStructs( + new StructType(true, + new BasicType(true, DType.INT32), + new BasicType(true, DType.STRING)), + null, null, null, null, null, null, null, null, null, null, null, null)){ + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); + assertColumnsAreEqual(listExpectAggResult, listWindowAggResults.getColumn(0)); + assertColumnsAreEqual(structExpectAggResult, structWindowAggResults.getColumn(0)); } } } @@ -3166,6 +3276,16 @@ void testWindowingLag() { .decimal32Column(-1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3) // Decimal GBY Key .decimal64Column(1, 1L, 1L, 2L, 2L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L) // Decimal OBY Key .decimal64Column(-2, 7L, 5L, 1L, 9L, 7L, 9L, 8L, 2L, 8L, 0L, 6L, 6L) // Decimal Agg Column + .column(new ListType(false, new BasicType(true, DType.INT32)), + Arrays.asList(11, 12, null, 13), Arrays.asList(14, null, 15, null), Arrays.asList((Integer) null), Arrays.asList(16), + Arrays.asList(21, null, null, 22), Arrays.asList(23, 24), Arrays.asList(25, 26, 27), Arrays.asList(28, 29, null), + Arrays.asList(null, 31), Arrays.asList(32, 33, 34), Arrays.asList(35, 36), Arrays.asList(37, 38, 39)) // List Agg COLUMN + .column(new StructType(true, + new BasicType(true, DType.INT32), + new BasicType(true, DType.STRING)), + new StructData(1, "s1"), new StructData(null, "s2"), new StructData(2, null), new StructData(3, "s3"), + new StructData(11, "s11"), null, new StructData(13, "s13"), new StructData(14, "s14"), + new StructData(111, "s111"), new StructData(null, "s112"), new StructData(2, "s222"), new StructData(3, "s333")) //STRUCT Agg COLUMN .build()) { try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2)); @@ -3182,88 +3302,185 @@ void testWindowingLag() { try (Scalar two = Scalar.fromInt(2); Scalar one = Scalar.fromInt(1); WindowOptions options = windowBuilder.window(two, one).build(); - WindowOptions options1 = windowBuilder.window(two, one).build()) { - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation - .lag(0) - .onColumn(3) - .overWindow(options)); - Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation - .lag(0) - .onColumn(6) - .overWindow(options1)); - ColumnVector expectAggResult = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6); - ColumnVector decExpectAggResult = ColumnVector.decimalFromLongs(-2, 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6)) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); - assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); - } + Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation + .lag(0) + .onColumn(3) + .overWindow(options)); + Table decWindowAggResults = sorted.groupBy(0, 4) + .aggregateWindows(Aggregation + .lag(0) + .onColumn(6) + .overWindow(options)); + Table listWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( + Aggregation + .lag(0) + .onColumn(7) + .overWindow(options)); + Table structWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( + Aggregation + .lag(0) + .onColumn(8) + .overWindow(options)); + ColumnVector expectAggResult = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6); + ColumnVector decExpectAggResult = ColumnVector.decimalFromLongs(-2, 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6); + ColumnVector listExpectAggResult = ColumnVector.fromLists( + new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.INT32)), + Arrays.asList(11, 12, null, 13), Arrays.asList(14, null, 15, null), Arrays.asList((Integer) null), Arrays.asList(16), + Arrays.asList(21, null, null, 22), Arrays.asList(23, 24), Arrays.asList(25, 26, 27), Arrays.asList(28, 29, null), + Arrays.asList(null, 31), Arrays.asList(32, 33, 34), Arrays.asList(35, 36), Arrays.asList(37, 38, 39)); + ColumnVector structExpectAggResult = ColumnVector.fromStructs( + new StructType(true, + new BasicType(true, DType.INT32), + new BasicType(true, DType.STRING)), + new StructData(1, "s1"), new StructData(null, "s2"), new StructData(2, null), new StructData(3, "s3"), + new StructData(11, "s11"), null, new StructData(13, "s13"), new StructData(14, "s14"), + new StructData(111, "s111"), new StructData(null, "s112"), new StructData(2, "s222"), new StructData(3, "s333"))) { + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); + assertColumnsAreEqual(listExpectAggResult, listWindowAggResults.getColumn(0)); + assertColumnsAreEqual(structExpectAggResult, structWindowAggResults.getColumn(0)); } try (Scalar zero = Scalar.fromInt(0); Scalar two = Scalar.fromInt(2); WindowOptions options = windowBuilder.window(two, zero).build(); - WindowOptions options1 = windowBuilder.window(two, zero).build()) { - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation - .lag(1) - .onColumn(3) - .overWindow(options)); - Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation - .lag(1) - .onColumn(6) - .overWindow(options1)); - ColumnVector expectAggResult = ColumnVector.fromBoxedInts(null, 7, 5, 1, null, 7, 9, 8, null, 8, 0, 6); - ColumnVector decExpectAggResult = decimalFromBoxedInts(true, -2, null, 7, 5, 1, null, 7, 9, 8, null, 8, 0, 6)) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); - assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); - } + Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation + .lag(1) + .onColumn(3) + .overWindow(options)); + Table decWindowAggResults = sorted.groupBy(0, 4) + .aggregateWindows(Aggregation + .lag(1) + .onColumn(6) + .overWindow(options)); + Table listWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( + Aggregation + .lag(1) + .onColumn(7) + .overWindow(options)); + Table structWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( + Aggregation + .lag(1) + .onColumn(8) + .overWindow(options)); + ColumnVector expectAggResult = ColumnVector.fromBoxedInts(null, 7, 5, 1, null, 7, 9, 8, null, 8, 0, 6); + ColumnVector decExpectAggResult = decimalFromBoxedInts(true, -2, null, 7, 5, 1, null, 7, 9, 8, null, 8, 0, 6); + ColumnVector listExpectAggResult = ColumnVector.fromLists( + new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.INT32)), + null, Arrays.asList(11, 12, null, 13), Arrays.asList(14, null, 15, null), Arrays.asList((Integer) null), + null, Arrays.asList(21, null, null, 22), Arrays.asList(23, 24), Arrays.asList(25, 26, 27), + null, Arrays.asList(null, 31), Arrays.asList(32, 33, 34), Arrays.asList(35, 36)); + ColumnVector structExpectAggResult = ColumnVector.fromStructs( + new StructType(true, + new BasicType(true, DType.INT32), + new BasicType(true, DType.STRING)), + null, new StructData(1, "s1"), new StructData(null, "s2"), new StructData(2, null), + null, new StructData(11, "s11"), null, new StructData(13, "s13"), + null, new StructData(111, "s111"), new StructData(null, "s112"), new StructData(2, "s222"))) { + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); + assertColumnsAreEqual(listExpectAggResult, listWindowAggResults.getColumn(0)); + assertColumnsAreEqual(structExpectAggResult, structWindowAggResults.getColumn(0)); } try (Scalar zero = Scalar.fromInt(0); Scalar two = Scalar.fromInt(2); WindowOptions options = windowBuilder.window(two, zero).build(); - WindowOptions options1 = windowBuilder.window(two, zero).build()) { - try (ColumnVector defaultOutput = ColumnVector.fromBoxedInts(0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11); - ColumnVector decDefaultOutput = ColumnVector.decimalFromLongs(-2, 0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11); - Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation - .lag(1, defaultOutput) - .onColumn(3) - .overWindow(options)); - Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation - .lag(1, decDefaultOutput) - .onColumn(6) - .overWindow(options1)); - ColumnVector expectAggResult = ColumnVector.fromBoxedInts(0, 7, 5, 1, -4, 7, 9, 8, -8, 8, 0, 6); - ColumnVector decExpectAggResult = ColumnVector.decimalFromLongs(-2, 0, 7, 5, 1, -4, 7, 9, 8, -8, 8, 0, 6)) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); - assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); - } + ColumnVector defaultOutput = ColumnVector.fromBoxedInts(0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11); + ColumnVector decDefaultOutput = ColumnVector.decimalFromLongs(-2, 0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11); + ColumnVector listDefaultOutput = ColumnVector.fromLists( + new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.INT32)), + Arrays.asList(111), Arrays.asList(222), Arrays.asList(333), Arrays.asList(444, null, 555), + Arrays.asList(-11), Arrays.asList(-22), Arrays.asList(-33), Arrays.asList(-44), + Arrays.asList(6), Arrays.asList(6), Arrays.asList(6), Arrays.asList(6, null, null)); + ColumnVector structDefaultOutput = ColumnVector.fromStructs( + new StructType(true, + new BasicType(true, DType.INT32), + new BasicType(true, DType.STRING)), + new StructData(-1, "s1"), new StructData(null, "s2"), new StructData(-2, null), new StructData(-3, "s3"), + new StructData(-11, "s11"), null, new StructData(-13, "s13"), new StructData(-14, "s14"), + new StructData(-111, "s111"), new StructData(null, "s112"), new StructData(-222, "s222"), new StructData(-333, "s333")); + Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation + .lag(1, defaultOutput) + .onColumn(3) + .overWindow(options)); + Table decWindowAggResults = sorted.groupBy(0, 4) + .aggregateWindows(Aggregation + .lag(1, decDefaultOutput) + .onColumn(6) + .overWindow(options)); + Table listWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( + Aggregation + .lag(1, listDefaultOutput) + .onColumn(7) + .overWindow(options)); + Table structWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( + Aggregation + .lag(1, structDefaultOutput) + .onColumn(8) + .overWindow(options)); + ColumnVector expectAggResult = ColumnVector.fromBoxedInts(0, 7, 5, 1, -4, 7, 9, 8, -8, 8, 0, 6); + ColumnVector decExpectAggResult = ColumnVector.decimalFromLongs(-2, 0, 7, 5, 1, -4, 7, 9, 8, -8, 8, 0, 6); + ColumnVector listExpectAggResult = ColumnVector.fromLists( + new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.INT32)), + Arrays.asList(111), Arrays.asList(11, 12, null, 13), Arrays.asList(14, null, 15, null), Arrays.asList((Integer) null), + Arrays.asList(-11), Arrays.asList(21, null, null, 22), Arrays.asList(23, 24), Arrays.asList(25, 26, 27), + Arrays.asList(6), Arrays.asList(null, 31), Arrays.asList(32, 33, 34), Arrays.asList(35, 36)); + ColumnVector structExpectAggResult = ColumnVector.fromStructs( + new StructType(true, + new BasicType(true, DType.INT32), + new BasicType(true, DType.STRING)), + new StructData(-1, "s1"), new StructData(1, "s1"), new StructData(null, "s2"), new StructData(2, null), + new StructData(-11, "s11"), new StructData(11, "s11"), null, new StructData(13, "s13"), + new StructData(-111, "s111"), new StructData(111, "s111"), new StructData(null, "s112"), new StructData(2, "s222"))) { + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); + assertColumnsAreEqual(listExpectAggResult, listWindowAggResults.getColumn(0)); + // TODO this is not gonna work, since cudf has some issue on lag with default values + // assertColumnsAreEqual(structExpectAggResult, structWindowAggResults.getColumn(0)); } // Outside bounds try (Scalar zero = Scalar.fromInt(0); Scalar one = Scalar.fromInt(1); WindowOptions options = windowBuilder.window(one, zero).build(); - WindowOptions options1 = windowBuilder.window(one, zero).build()) { - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation - .lag(3) - .onColumn(3) - .overWindow(options)); - Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation - .lag(3) - .onColumn(6) - .overWindow(options1)); - ColumnVector expectAggResult = ColumnVector.fromBoxedInts(null, null, null, null, null, null, null, null, null, null, null, null); - ColumnVector decExpectAggResult = decimalFromBoxedInts(true, -2, null, null, null, null, null, null, null, null, null, null, null, null);) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); - assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); - } + Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation + .lag(3) + .onColumn(3) + .overWindow(options)); + Table decWindowAggResults = sorted.groupBy(0, 4) + .aggregateWindows(Aggregation + .lag(3) + .onColumn(6) + .overWindow(options)); + Table listWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( + Aggregation + .lag(3) + .onColumn(7) + .overWindow(options)); + Table structWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( + Aggregation + .lag(3) + .onColumn(8) + .overWindow(options)); + ColumnVector expectAggResult = ColumnVector.fromBoxedInts(null, null, null, null, null, null, null, null, null, null, null, null); + ColumnVector decExpectAggResult = decimalFromBoxedInts(true, -2, null, null, null, null, null, null, null, null, null, null, null, null); + ColumnVector listExpectAggResult = ColumnVector.fromLists( + new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.INT32)), + null, null, null, null, null, null, null, null, null, null, null, null); + ColumnVector structExpectAggResult = ColumnVector.fromStructs( + new StructType(true, + new BasicType(true, DType.INT32), + new BasicType(true, DType.STRING)), + null, null, null, null, null, null, null, null, null, null, null, null);) { + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); + assertColumnsAreEqual(listExpectAggResult, listWindowAggResults.getColumn(0)); + assertColumnsAreEqual(structExpectAggResult, structWindowAggResults.getColumn(0)); } } } @@ -3464,6 +3681,16 @@ void testRangeWindowingLead() { .column(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key .column(0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key .column(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column + .column(new ListType(false, new BasicType(true, DType.INT32)), + Arrays.asList(11, 12, null, 13), Arrays.asList(14, null, 15, null), Arrays.asList((Integer) null), Arrays.asList(16), + Arrays.asList(21, null, null, 22), Arrays.asList(23, 24), Arrays.asList(25, 26, 27), Arrays.asList(28, 29, null), + Arrays.asList(null, 31), Arrays.asList(32, 33, 34), Arrays.asList(35, 36), Arrays.asList(37, 38, 39), Arrays.asList(40, 41)) // List Agg COLUMN + .column(new StructType(true, + new BasicType(true, DType.INT32), + new BasicType(true, DType.STRING)), + new StructData(1, "s1"), new StructData(null, "s2"), new StructData(2, null), new StructData(3, "s3"), + new StructData(11, "s11"), null, new StructData(13, "s13"), new StructData(14, "s14"), + new StructData(111, "s111"), new StructData(null, "s112"), new StructData(2, "s222"), new StructData(3, "s333"), new StructData(4, "s444")) //STRUCT Agg COLUMN .column(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) // orderBy Key .column((short) 1, (short)1, (short)2, (short)3, (short)3, (short)3, (short)4, (short)4, (short)5, (short)5, (short)6, (short)6, (short)7) // orderBy Key .column(1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // orderBy Key @@ -3475,7 +3702,7 @@ void testRangeWindowingLead() { .timestampNanosecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) .build()) { - for (int orderIndex = 3; orderIndex < unsorted.getNumberOfColumns(); orderIndex++) { + for (int orderIndex = 5; orderIndex < unsorted.getNumberOfColumns(); orderIndex++) { try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(orderIndex)); ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { ColumnVector sortedAggColumn = sorted.getColumn(2); @@ -3497,6 +3724,105 @@ void testRangeWindowingLead() { ColumnVector expect = ColumnVector.fromBoxedInts(5, 1, 9, null, 9, 8, 2, null, 0, 6, 6, 8, null)) { assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); } + try (Table listAggResults = sorted.groupBy(0, 1) + .aggregateWindowsOverRanges(Aggregation.lead(1) + .onColumn(3) // List Agg Column + .overWindow(window)); + Table structAggResults = sorted.groupBy(0, 1) + .aggregateWindowsOverRanges(Aggregation.lead(1) + .onColumn(4) // Struct Agg Column + .overWindow(window)); + ColumnVector listExpectAggResult = ColumnVector.fromLists( + new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.INT32)), + Arrays.asList(14, null, 15, null), Arrays.asList((Integer) null), Arrays.asList(16), null, + Arrays.asList(23, 24), Arrays.asList(25, 26, 27), Arrays.asList(28, 29, null), null, + Arrays.asList(32, 33, 34), Arrays.asList(35, 36), Arrays.asList(37, 38, 39), Arrays.asList(40, 41), null); + ColumnVector structExpectAggResult = ColumnVector.fromStructs( + new StructType(true, + new BasicType(true, DType.INT32), + new BasicType(true, DType.STRING)), + new StructData(null, "s2"), new StructData(2, null), new StructData(3, "s3"), null, + null, new StructData(13, "s13"), new StructData(14, "s14"), null, + new StructData(null, "s112"), new StructData(2, "s222"), new StructData(3, "s333"), new StructData(4, "s444"), null)) { + assertColumnsAreEqual(listExpectAggResult, listAggResults.getColumn(0)); + assertColumnsAreEqual(structExpectAggResult, structAggResults.getColumn(0)); + } + } + } + } + } + } + } + + @Test + void testRangeWindowingLag() { + try (Table unsorted = new Table.TestBuilder() + .column(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column(0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key + .column(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column + .column(new ListType(false, new BasicType(true, DType.INT32)), + Arrays.asList(11, 12, null, 13), Arrays.asList(14, null, 15, null), Arrays.asList((Integer) null), Arrays.asList(16), + Arrays.asList(21, null, null, 22), Arrays.asList(23, 24), Arrays.asList(25, 26, 27), Arrays.asList(28, 29, null), + Arrays.asList(null, 31), Arrays.asList(32, 33, 34), Arrays.asList(35, 36), Arrays.asList(37, 38, 39), Arrays.asList(40, 41)) // List Agg COLUMN + .column(new StructType(true, + new BasicType(true, DType.INT32), + new BasicType(true, DType.STRING)), + new StructData(1, "s1"), new StructData(null, "s2"), new StructData(2, null), new StructData(3, "s3"), + new StructData(11, "s11"), null, new StructData(13, "s13"), new StructData(14, "s14"), + new StructData(111, "s111"), new StructData(null, "s112"), new StructData(2, "s222"), new StructData(3, "s333"), new StructData(4, "s444")) //STRUCT Agg COLUMN + .column(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) // orderBy Key + .column((short) 1, (short)1, (short)2, (short)3, (short)3, (short)3, (short)4, (short)4, (short)5, (short)5, (short)6, (short)6, (short)7) // orderBy Key + .column(1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // orderBy Key + .column((byte) 1, (byte)1, (byte)2, (byte)3, (byte)3, (byte)3, (byte)4, (byte)4, (byte)5, (byte)5, (byte)6, (byte)6, (byte)7) // orderBy Key + .timestampDayColumn(1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // Timestamp orderBy Key + .timestampSecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) + .timestampMicrosecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) + .timestampMillisecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) + .timestampNanosecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) + .build()) { + + for (int orderIndex = 5; orderIndex < unsorted.getNumberOfColumns(); orderIndex++) { + try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(orderIndex)); + ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { + ColumnVector sortedAggColumn = sorted.getColumn(2); + assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); + + DType type = unsorted.getColumn(orderIndex).getType(); + try (Scalar one = getScalar(type, 1L); + WindowOptions window = WindowOptions.builder() + .minPeriods(1) + .window(one, one) + .orderByColumnIndex(orderIndex) + .build()) { + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindowsOverRanges(Aggregation.lag(1) + .onColumn(2) + .overWindow(window)); + ColumnVector expect = ColumnVector.fromBoxedInts(null, 7, 5, 1, null, 7, 9, 8, null, 8, 0, 6, 6)) { + assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); + } + try (Table listAggResults = sorted.groupBy(0, 1) + .aggregateWindowsOverRanges(Aggregation.lag(1) + .onColumn(3) // List Agg Column + .overWindow(window)); + Table structAggResults = sorted.groupBy(0, 1) + .aggregateWindowsOverRanges(Aggregation.lag(1) + .onColumn(4) // Struct Agg Column + .overWindow(window)); + ColumnVector listExpectAggResult = ColumnVector.fromLists( + new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.INT32)), + null, Arrays.asList(11, 12, null, 13), Arrays.asList(14, null, 15, null), Arrays.asList((Integer) null), + null, Arrays.asList(21, null, null, 22), Arrays.asList(23, 24), Arrays.asList(25, 26, 27), + null, Arrays.asList(null, 31), Arrays.asList(32, 33, 34), Arrays.asList(35, 36), Arrays.asList(37, 38, 39)); + ColumnVector structExpectAggResult = ColumnVector.fromStructs( + new StructType(true, + new BasicType(true, DType.INT32), + new BasicType(true, DType.STRING)), + null, new StructData(1, "s1"), new StructData(null, "s2"), new StructData(2, null), + null, new StructData(11, "s11"), null, new StructData(13, "s13"), + null, new StructData(111, "s111"), new StructData(null, "s112"), new StructData(2, "s222"), new StructData(3, "s333"))) { + assertColumnsAreEqual(listExpectAggResult, listAggResults.getColumn(0)); + assertColumnsAreEqual(structExpectAggResult, structAggResults.getColumn(0)); } } } From 6c5afa95506a4dd52ea1b5a73957c7e42f53db23 Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Wed, 19 May 2021 11:58:46 +0800 Subject: [PATCH 2/2] resolve comments --- .../test/java/ai/rapids/cudf/TableTest.java | 181 ++++-------------- 1 file changed, 36 insertions(+), 145 deletions(-) diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 1d98b78bf65..0a3156a6862 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -3049,7 +3049,7 @@ void testWindowingLead() { .column(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key .column(1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3) // GBY Key .column(1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6) // OBY Key - .column(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6) // Agg Column + .column(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6) // Int Agg Column .decimal32Column(-1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3) // Decimal GBY Key .decimal64Column(1, 1L, 1L, 2L, 2L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L) // Decimal OBY Key .decimal64Column(-2, 7L, 5L, 1L, 9L, 7L, 9L, 8L, 2L, 8L, 0L, 6L, 6L) // Decimal Agg Column @@ -3082,22 +3082,22 @@ void testWindowingLead() { Table windowAggResults = sorted.groupBy(0, 1) .aggregateWindows(Aggregation .lead(0) - .onColumn(3) + .onColumn(3) // Int Agg Column .overWindow(options)); Table decWindowAggResults = decSorted.groupBy(0, 4) .aggregateWindows(Aggregation .lead(0) - .onColumn(6) + .onColumn(6) // Decimal Agg Column .overWindow(options)); Table listWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( Aggregation .lead(0) - .onColumn(7) + .onColumn(7) // List Agg COLUMN .overWindow(options)); Table structWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( Aggregation .lead(0) - .onColumn(8) + .onColumn(8) //STRUCT Agg COLUMN .overWindow(options)); ColumnVector expectAggResult = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6); ColumnVector decExpectAggResult = ColumnVector.decimalFromLongs(-2, 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6); @@ -3126,22 +3126,22 @@ void testWindowingLead() { Table windowAggResults = sorted.groupBy(0, 1) .aggregateWindows(Aggregation .lead(1) - .onColumn(3) + .onColumn(3) //Int Agg COLUMN .overWindow(options)); Table decWindowAggResults = sorted.groupBy(0, 4) .aggregateWindows(Aggregation .lead(1) - .onColumn(6) + .onColumn(6) //Decimal Agg COLUMN .overWindow(options)); Table listWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( Aggregation .lead(1) - .onColumn(7) + .onColumn(7) //LIST Agg COLUMN .overWindow(options)); Table structWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( Aggregation .lead(1) - .onColumn(8) + .onColumn(8) //STRUCT Agg COLUMN .overWindow(options)); ColumnVector expectAggResult = ColumnVector.fromBoxedInts(5, 1, 9, null, 9, 8, 2, null, 0, 6, 6, null); ColumnVector decExpectAggResult = decimalFromBoxedInts(true, -2, 5, 1, 9, null, 9, 8, 2, null, 0, 6, 6, null); @@ -3184,22 +3184,22 @@ null, new StructData(13, "s13"), new StructData(14, "s14"), null, Table windowAggResults = sorted.groupBy(0, 1) .aggregateWindows(Aggregation .lead(1, defaultOutput) - .onColumn(3) + .onColumn(3) //Int Agg COLUMN .overWindow(options)); Table decWindowAggResults = sorted.groupBy(0, 4) .aggregateWindows(Aggregation .lead(1, decDefaultOutput) - .onColumn(6) + .onColumn(6) //Decimal Agg COLUMN .overWindow(options)); Table listWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( Aggregation .lead(1, listDefaultOutput) - .onColumn(7) + .onColumn(7) //LIST Agg COLUMN .overWindow(options)); Table structWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( Aggregation .lead(1, structDefaultOutput) - .onColumn(8) + .onColumn(8) //STRUCT Agg COLUMN .overWindow(options)); ColumnVector expectAggResult = ColumnVector.fromBoxedInts(5, 1, 9, -3, 9, 8, 2, -7, 0, 6, 6, -11); ColumnVector decExpectAggResult = ColumnVector.decimalFromLongs(-2, 5, 1, 9, -3, 9, 8, 2, -7, 0, 6, 6, -11); @@ -3219,7 +3219,7 @@ null, new StructData(13, "s13"), new StructData(14, "s14"), new StructData(-14, assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); assertColumnsAreEqual(listExpectAggResult, listWindowAggResults.getColumn(0)); - // TODO this is not gonna work, since cudf has some issue on lead with default values + // TODO this is not gonna work, since libcudf has some issue for lead on struct with default values // assertColumnsAreEqual(structExpectAggResult, structWindowAggResults.getColumn(0)); } @@ -3230,22 +3230,22 @@ null, new StructData(13, "s13"), new StructData(14, "s14"), new StructData(-14, Table windowAggResults = sorted.groupBy(0, 1) .aggregateWindows(Aggregation .lead(3) - .onColumn(3) + .onColumn(3) //Int Agg COLUMN .overWindow(options)); Table decWindowAggResults = sorted.groupBy(0, 4) .aggregateWindows(Aggregation .lead(3) - .onColumn(6) + .onColumn(6) //Decimal Agg COLUMN .overWindow(options)); Table listWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( Aggregation .lead(3) - .onColumn(7) + .onColumn(7) //LIST Agg COLUMN .overWindow(options)); Table structWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( Aggregation .lead(3) - .onColumn(8) + .onColumn(8) //STRUCT Agg COLUMN .overWindow(options)); ColumnVector expectAggResult = ColumnVector.fromBoxedInts(null, null, null, null, null, null, null, null, null, null, null, null); ColumnVector decExpectAggResult = decimalFromBoxedInts(true, -2, null, null, null, null, null, null, null, null, null, null, null, null); @@ -3305,22 +3305,22 @@ void testWindowingLag() { Table windowAggResults = sorted.groupBy(0, 1) .aggregateWindows(Aggregation .lag(0) - .onColumn(3) + .onColumn(3) //Int Agg COLUMN .overWindow(options)); Table decWindowAggResults = sorted.groupBy(0, 4) .aggregateWindows(Aggregation .lag(0) - .onColumn(6) + .onColumn(6) //Decimal Agg COLUMN .overWindow(options)); Table listWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( Aggregation .lag(0) - .onColumn(7) + .onColumn(7) //LIST Agg COLUMN .overWindow(options)); Table structWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( Aggregation .lag(0) - .onColumn(8) + .onColumn(8) //STRUCT Agg COLUMN .overWindow(options)); ColumnVector expectAggResult = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6); ColumnVector decExpectAggResult = ColumnVector.decimalFromLongs(-2, 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6); @@ -3348,22 +3348,22 @@ void testWindowingLag() { Table windowAggResults = sorted.groupBy(0, 1) .aggregateWindows(Aggregation .lag(1) - .onColumn(3) + .onColumn(3) //Int Agg COLUMN .overWindow(options)); Table decWindowAggResults = sorted.groupBy(0, 4) .aggregateWindows(Aggregation .lag(1) - .onColumn(6) + .onColumn(6) //Decimal Agg COLUMN .overWindow(options)); Table listWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( Aggregation .lag(1) - .onColumn(7) + .onColumn(7) //LIST Agg COLUMN .overWindow(options)); Table structWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( Aggregation .lag(1) - .onColumn(8) + .onColumn(8) //STRUCT Agg COLUMN .overWindow(options)); ColumnVector expectAggResult = ColumnVector.fromBoxedInts(null, 7, 5, 1, null, 7, 9, 8, null, 8, 0, 6); ColumnVector decExpectAggResult = decimalFromBoxedInts(true, -2, null, 7, 5, 1, null, 7, 9, 8, null, 8, 0, 6); @@ -3405,22 +3405,22 @@ null, new StructData(111, "s111"), new StructData(null, "s112"), new StructData( Table windowAggResults = sorted.groupBy(0, 1) .aggregateWindows(Aggregation .lag(1, defaultOutput) - .onColumn(3) + .onColumn(3) //Int Agg COLUMN .overWindow(options)); Table decWindowAggResults = sorted.groupBy(0, 4) .aggregateWindows(Aggregation .lag(1, decDefaultOutput) - .onColumn(6) + .onColumn(6) //Decimal Agg COLUMN .overWindow(options)); Table listWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( Aggregation .lag(1, listDefaultOutput) - .onColumn(7) + .onColumn(7) //LIST Agg COLUMN .overWindow(options)); Table structWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( Aggregation .lag(1, structDefaultOutput) - .onColumn(8) + .onColumn(8) //STRUCT Agg COLUMN .overWindow(options)); ColumnVector expectAggResult = ColumnVector.fromBoxedInts(0, 7, 5, 1, -4, 7, 9, 8, -8, 8, 0, 6); ColumnVector decExpectAggResult = ColumnVector.decimalFromLongs(-2, 0, 7, 5, 1, -4, 7, 9, 8, -8, 8, 0, 6); @@ -3439,7 +3439,7 @@ null, new StructData(111, "s111"), new StructData(null, "s112"), new StructData( assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); assertColumnsAreEqual(listExpectAggResult, listWindowAggResults.getColumn(0)); - // TODO this is not gonna work, since cudf has some issue on lag with default values + // TODO this is not gonna work, since libcudf has some issue for lag on struct with default values // assertColumnsAreEqual(structExpectAggResult, structWindowAggResults.getColumn(0)); } @@ -3450,22 +3450,22 @@ null, new StructData(111, "s111"), new StructData(null, "s112"), new StructData( Table windowAggResults = sorted.groupBy(0, 1) .aggregateWindows(Aggregation .lag(3) - .onColumn(3) + .onColumn(3) //Int Agg COLUMN .overWindow(options)); Table decWindowAggResults = sorted.groupBy(0, 4) .aggregateWindows(Aggregation .lag(3) - .onColumn(6) + .onColumn(6) //Decimal Agg COLUMN .overWindow(options)); Table listWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( Aggregation .lag(3) - .onColumn(7) + .onColumn(7) //LIST Agg COLUMN .overWindow(options)); Table structWindowAggResults = sorted.groupBy(0, 1).aggregateWindows( Aggregation .lag(3) - .onColumn(8) + .onColumn(8) //STRUCT Agg COLUMN .overWindow(options)); ColumnVector expectAggResult = ColumnVector.fromBoxedInts(null, null, null, null, null, null, null, null, null, null, null, null); ColumnVector decExpectAggResult = decimalFromBoxedInts(true, -2, null, null, null, null, null, null, null, null, null, null, null, null); @@ -3681,16 +3681,6 @@ void testRangeWindowingLead() { .column(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key .column(0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key .column(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column - .column(new ListType(false, new BasicType(true, DType.INT32)), - Arrays.asList(11, 12, null, 13), Arrays.asList(14, null, 15, null), Arrays.asList((Integer) null), Arrays.asList(16), - Arrays.asList(21, null, null, 22), Arrays.asList(23, 24), Arrays.asList(25, 26, 27), Arrays.asList(28, 29, null), - Arrays.asList(null, 31), Arrays.asList(32, 33, 34), Arrays.asList(35, 36), Arrays.asList(37, 38, 39), Arrays.asList(40, 41)) // List Agg COLUMN - .column(new StructType(true, - new BasicType(true, DType.INT32), - new BasicType(true, DType.STRING)), - new StructData(1, "s1"), new StructData(null, "s2"), new StructData(2, null), new StructData(3, "s3"), - new StructData(11, "s11"), null, new StructData(13, "s13"), new StructData(14, "s14"), - new StructData(111, "s111"), new StructData(null, "s112"), new StructData(2, "s222"), new StructData(3, "s333"), new StructData(4, "s444")) //STRUCT Agg COLUMN .column(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) // orderBy Key .column((short) 1, (short)1, (short)2, (short)3, (short)3, (short)3, (short)4, (short)4, (short)5, (short)5, (short)6, (short)6, (short)7) // orderBy Key .column(1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // orderBy Key @@ -3702,7 +3692,7 @@ void testRangeWindowingLead() { .timestampNanosecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) .build()) { - for (int orderIndex = 5; orderIndex < unsorted.getNumberOfColumns(); orderIndex++) { + for (int orderIndex = 3; orderIndex < unsorted.getNumberOfColumns(); orderIndex++) { try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(orderIndex)); ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { ColumnVector sortedAggColumn = sorted.getColumn(2); @@ -3724,105 +3714,6 @@ void testRangeWindowingLead() { ColumnVector expect = ColumnVector.fromBoxedInts(5, 1, 9, null, 9, 8, 2, null, 0, 6, 6, 8, null)) { assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); } - try (Table listAggResults = sorted.groupBy(0, 1) - .aggregateWindowsOverRanges(Aggregation.lead(1) - .onColumn(3) // List Agg Column - .overWindow(window)); - Table structAggResults = sorted.groupBy(0, 1) - .aggregateWindowsOverRanges(Aggregation.lead(1) - .onColumn(4) // Struct Agg Column - .overWindow(window)); - ColumnVector listExpectAggResult = ColumnVector.fromLists( - new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.INT32)), - Arrays.asList(14, null, 15, null), Arrays.asList((Integer) null), Arrays.asList(16), null, - Arrays.asList(23, 24), Arrays.asList(25, 26, 27), Arrays.asList(28, 29, null), null, - Arrays.asList(32, 33, 34), Arrays.asList(35, 36), Arrays.asList(37, 38, 39), Arrays.asList(40, 41), null); - ColumnVector structExpectAggResult = ColumnVector.fromStructs( - new StructType(true, - new BasicType(true, DType.INT32), - new BasicType(true, DType.STRING)), - new StructData(null, "s2"), new StructData(2, null), new StructData(3, "s3"), null, - null, new StructData(13, "s13"), new StructData(14, "s14"), null, - new StructData(null, "s112"), new StructData(2, "s222"), new StructData(3, "s333"), new StructData(4, "s444"), null)) { - assertColumnsAreEqual(listExpectAggResult, listAggResults.getColumn(0)); - assertColumnsAreEqual(structExpectAggResult, structAggResults.getColumn(0)); - } - } - } - } - } - } - } - - @Test - void testRangeWindowingLag() { - try (Table unsorted = new Table.TestBuilder() - .column(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column(0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key - .column(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column - .column(new ListType(false, new BasicType(true, DType.INT32)), - Arrays.asList(11, 12, null, 13), Arrays.asList(14, null, 15, null), Arrays.asList((Integer) null), Arrays.asList(16), - Arrays.asList(21, null, null, 22), Arrays.asList(23, 24), Arrays.asList(25, 26, 27), Arrays.asList(28, 29, null), - Arrays.asList(null, 31), Arrays.asList(32, 33, 34), Arrays.asList(35, 36), Arrays.asList(37, 38, 39), Arrays.asList(40, 41)) // List Agg COLUMN - .column(new StructType(true, - new BasicType(true, DType.INT32), - new BasicType(true, DType.STRING)), - new StructData(1, "s1"), new StructData(null, "s2"), new StructData(2, null), new StructData(3, "s3"), - new StructData(11, "s11"), null, new StructData(13, "s13"), new StructData(14, "s14"), - new StructData(111, "s111"), new StructData(null, "s112"), new StructData(2, "s222"), new StructData(3, "s333"), new StructData(4, "s444")) //STRUCT Agg COLUMN - .column(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) // orderBy Key - .column((short) 1, (short)1, (short)2, (short)3, (short)3, (short)3, (short)4, (short)4, (short)5, (short)5, (short)6, (short)6, (short)7) // orderBy Key - .column(1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // orderBy Key - .column((byte) 1, (byte)1, (byte)2, (byte)3, (byte)3, (byte)3, (byte)4, (byte)4, (byte)5, (byte)5, (byte)6, (byte)6, (byte)7) // orderBy Key - .timestampDayColumn(1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // Timestamp orderBy Key - .timestampSecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) - .timestampMicrosecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) - .timestampMillisecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) - .timestampNanosecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) - .build()) { - - for (int orderIndex = 5; orderIndex < unsorted.getNumberOfColumns(); orderIndex++) { - try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(orderIndex)); - ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { - ColumnVector sortedAggColumn = sorted.getColumn(2); - assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); - - DType type = unsorted.getColumn(orderIndex).getType(); - try (Scalar one = getScalar(type, 1L); - WindowOptions window = WindowOptions.builder() - .minPeriods(1) - .window(one, one) - .orderByColumnIndex(orderIndex) - .build()) { - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindowsOverRanges(Aggregation.lag(1) - .onColumn(2) - .overWindow(window)); - ColumnVector expect = ColumnVector.fromBoxedInts(null, 7, 5, 1, null, 7, 9, 8, null, 8, 0, 6, 6)) { - assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); - } - try (Table listAggResults = sorted.groupBy(0, 1) - .aggregateWindowsOverRanges(Aggregation.lag(1) - .onColumn(3) // List Agg Column - .overWindow(window)); - Table structAggResults = sorted.groupBy(0, 1) - .aggregateWindowsOverRanges(Aggregation.lag(1) - .onColumn(4) // Struct Agg Column - .overWindow(window)); - ColumnVector listExpectAggResult = ColumnVector.fromLists( - new HostColumnVector.ListType(true, new HostColumnVector.BasicType(true, DType.INT32)), - null, Arrays.asList(11, 12, null, 13), Arrays.asList(14, null, 15, null), Arrays.asList((Integer) null), - null, Arrays.asList(21, null, null, 22), Arrays.asList(23, 24), Arrays.asList(25, 26, 27), - null, Arrays.asList(null, 31), Arrays.asList(32, 33, 34), Arrays.asList(35, 36), Arrays.asList(37, 38, 39)); - ColumnVector structExpectAggResult = ColumnVector.fromStructs( - new StructType(true, - new BasicType(true, DType.INT32), - new BasicType(true, DType.STRING)), - null, new StructData(1, "s1"), new StructData(null, "s2"), new StructData(2, null), - null, new StructData(11, "s11"), null, new StructData(13, "s13"), - null, new StructData(111, "s111"), new StructData(null, "s112"), new StructData(2, "s222"), new StructData(3, "s333"))) { - assertColumnsAreEqual(listExpectAggResult, listAggResults.getColumn(0)); - assertColumnsAreEqual(structExpectAggResult, structAggResults.getColumn(0)); } } }