From 2a169c801adc7f5d45144f8291a81fc21b9bf759 Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Fri, 14 May 2021 06:21:24 +0800 Subject: [PATCH 01/12] add java bindings for non-timestamps range window queries (#7909) This PR is to add java bindings for integral (boolean-exclusive) and timestamps supporting for range window. Authors: - Bobby Wang (https://github.com/wbo4958) Approvers: - MithunR (https://github.com/mythrocks) - Robert (Bobby) Evans (https://github.com/revans2) URL: https://github.com/rapidsai/cudf/pull/7909 --- java/src/main/java/ai/rapids/cudf/Scalar.java | 5 + java/src/main/java/ai/rapids/cudf/Table.java | 87 +- .../java/ai/rapids/cudf/WindowOptions.java | 182 +- java/src/main/native/src/TableJni.cpp | 77 +- .../java/ai/rapids/cudf/ColumnVectorTest.java | 153 +- .../test/java/ai/rapids/cudf/TableTest.java | 1618 ++++++++++------- 6 files changed, 1316 insertions(+), 806 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/Scalar.java b/java/src/main/java/ai/rapids/cudf/Scalar.java index ec20f39af27..62dd9bda13b 100644 --- a/java/src/main/java/ai/rapids/cudf/Scalar.java +++ b/java/src/main/java/ai/rapids/cudf/Scalar.java @@ -620,6 +620,7 @@ public int hashCode() { case UINT32: case TIMESTAMP_DAYS: case DECIMAL32: + case DURATION_DAYS: valueHash = getInt(); break; case INT64: @@ -629,6 +630,10 @@ public int hashCode() { case TIMESTAMP_MICROSECONDS: case TIMESTAMP_NANOSECONDS: case DECIMAL64: + case DURATION_MICROSECONDS: + case DURATION_SECONDS: + case DURATION_MILLISECONDS: + case DURATION_NANOSECONDS: valueHash = Long.hashCode(getLong()); break; case FLOAT32: diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index b2f2ad5bad1..e939411eece 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -475,10 +475,10 @@ private static native long[] rollingWindowAggregate( int[] following, boolean ignoreNullKeys) throws CudfException; - private static native long[] timeRangeRollingWindowAggregate(long inputTable, int[] keyIndices, int[] timestampIndices, boolean[] isTimesampAscending, - int[] aggColumnsIndices, long[] aggInstances, int[] minPeriods, - int[] preceding, int[] following, boolean[] unboundedPreceding, boolean[] unboundedFollowing, - boolean ignoreNullKeys) throws CudfException; + private static native long[] rangeRollingWindowAggregate(long inputTable, int[] keyIndices, int[] orderByIndices, boolean[] isOrderByAscending, + int[] aggColumnsIndices, long[] aggInstances, int[] minPeriods, + long[] preceding, long[] following, boolean[] unboundedPreceding, boolean[] unboundedFollowing, + boolean ignoreNullKeys) throws CudfException; private static native long sortOrder(long inputTable, long[] sortKeys, boolean[] isDescending, boolean[] areNullsSmallest) throws CudfException; @@ -2457,7 +2457,7 @@ public Table aggregateWindows(AggregationOverWindow... windowAggregates) { } /** - * Computes time-range-based window aggregation functions on the Table/projection, + * Computes range-based window aggregation functions on the Table/projection, * based on windows specified in the argument. * * This method enables queries such as the following SQL: @@ -2506,10 +2506,10 @@ public Table aggregateWindows(AggregationOverWindow... windowAggregates) { * @param windowAggregates the window-aggregations to be performed * @return Table instance, with each column containing the result of each aggregation. * @throws IllegalArgumentException if the window arguments are not of type - * {@link WindowOptions.FrameType#RANGE}, + * {@link WindowOptions.FrameType#RANGE} or the orderBys are not of (Boolean-exclusive) integral type * i.e. the timestamp-column was not specified for the aggregation. */ - public Table aggregateWindowsOverTimeRanges(AggregationOverWindow... windowAggregates) { + public Table aggregateWindowsOverRanges(AggregationOverWindow... windowAggregates) { // To improve performance and memory we want to remove duplicate operations // and also group the operations by column so hopefully cudf can do multiple aggregations // in a single pass. @@ -2521,51 +2521,76 @@ public Table aggregateWindowsOverTimeRanges(AggregationOverWindow... windowAggre for (int outputIndex = 0; outputIndex < windowAggregates.length; outputIndex++) { AggregationOverWindow agg = windowAggregates[outputIndex]; if (agg.getWindowOptions().getFrameType() != WindowOptions.FrameType.RANGE) { - throw new IllegalArgumentException("Expected time-range-based window specification. Unexpected window type: " - + agg.getWindowOptions().getFrameType()); + throw new IllegalArgumentException("Expected range-based window specification. Unexpected window type: " + + agg.getWindowOptions().getFrameType()); } + + DType orderByType = operation.table.getColumn(agg.getWindowOptions().getOrderByColumnIndex()).getType(); + switch (orderByType.getTypeId()) { + case INT8: + case INT16: + case INT32: + case INT64: + case UINT8: + case UINT16: + case UINT32: + case UINT64: + case TIMESTAMP_MILLISECONDS: + case TIMESTAMP_SECONDS: + case TIMESTAMP_DAYS: + case TIMESTAMP_NANOSECONDS: + case TIMESTAMP_MICROSECONDS: + break; + default: + throw new IllegalArgumentException("Expected range-based window orderBy's " + + "type: integral (Boolean-exclusive) and timestamp"); + } + ColumnWindowOps ops = groupedOps.computeIfAbsent(agg.getColumnIndex(), (idx) -> new ColumnWindowOps()); totalOps += ops.add(agg, outputIndex); } int[] aggColumnIndexes = new int[totalOps]; - int[] timestampColumnIndexes = new int[totalOps]; - boolean[] isTimestampOrderAscending = new boolean[totalOps]; + int[] orderByColumnIndexes = new int[totalOps]; + boolean[] isOrderByOrderAscending = new boolean[totalOps]; long[] aggInstances = new long[totalOps]; + long[] aggPrecedingWindows = new long[totalOps]; + long[] aggFollowingWindows = new long[totalOps]; try { - int[] aggPrecedingWindows = new int[totalOps]; - int[] aggFollowingWindows = new int[totalOps]; boolean[] aggPrecedingWindowsUnbounded = new boolean[totalOps]; boolean[] aggFollowingWindowsUnbounded = new boolean[totalOps]; int[] aggMinPeriods = new int[totalOps]; int opIndex = 0; for (Map.Entry entry: groupedOps.entrySet()) { int columnIndex = entry.getKey(); - for (AggregationOverWindow operation: entry.getValue().operations()) { + for (AggregationOverWindow op: entry.getValue().operations()) { aggColumnIndexes[opIndex] = columnIndex; - aggInstances[opIndex] = operation.createNativeInstance(); - aggPrecedingWindows[opIndex] = operation.getWindowOptions().getPreceding(); - aggFollowingWindows[opIndex] = operation.getWindowOptions().getFollowing(); - aggPrecedingWindowsUnbounded[opIndex] = operation.getWindowOptions().isUnboundedPreceding(); - aggFollowingWindowsUnbounded[opIndex] = operation.getWindowOptions().isUnboundedFollowing(); - aggMinPeriods[opIndex] = operation.getWindowOptions().getMinPeriods(); - assert (operation.getWindowOptions().getFrameType() == WindowOptions.FrameType.RANGE); - timestampColumnIndexes[opIndex] = operation.getWindowOptions().getTimestampColumnIndex(); - isTimestampOrderAscending[opIndex] = operation.getWindowOptions().isTimestampOrderAscending(); - if (operation.getDefaultOutput() != 0) { + aggInstances[opIndex] = op.createNativeInstance(); + aggPrecedingWindows[opIndex] = op.getWindowOptions().getPrecedingScalar() == + null ? 0 : op.getWindowOptions().getPrecedingScalar().getScalarHandle(); + aggFollowingWindows[opIndex] = op.getWindowOptions().getFollowingScalar() == + null ? 0 : op.getWindowOptions().getFollowingScalar().getScalarHandle(); + aggPrecedingWindowsUnbounded[opIndex] = op.getWindowOptions().isUnboundedPreceding(); + aggFollowingWindowsUnbounded[opIndex] = op.getWindowOptions().isUnboundedFollowing(); + aggMinPeriods[opIndex] = op.getWindowOptions().getMinPeriods(); + assert (op.getWindowOptions().getFrameType() == WindowOptions.FrameType.RANGE); + orderByColumnIndexes[opIndex] = op.getWindowOptions().getOrderByColumnIndex(); + isOrderByOrderAscending[opIndex] = op.getWindowOptions().isOrderByOrderAscending(); + if (op.getDefaultOutput() != 0) { throw new IllegalArgumentException("Operations with a default output are not " + "supported on time based rolling windows"); } + opIndex++; } } assert opIndex == totalOps : opIndex + " == " + totalOps; - try (Table aggregate = new Table(timeRangeRollingWindowAggregate( + try (Table aggregate = new Table(rangeRollingWindowAggregate( operation.table.nativeHandle, operation.indices, - timestampColumnIndexes, - isTimestampOrderAscending, + orderByColumnIndexes, + isOrderByOrderAscending, aggColumnIndexes, aggInstances, aggMinPeriods, aggPrecedingWindows, aggFollowingWindows, aggPrecedingWindowsUnbounded, aggFollowingWindowsUnbounded, @@ -2630,6 +2655,14 @@ public ContiguousTable[] contiguousSplitGroups() { groupByOptions.getKeysDescending(), groupByOptions.getKeysNullSmallest()); } + + /** + * @deprecated use aggregateWindowsOverRanges + */ + @Deprecated + public Table aggregateWindowsOverTimeRanges(AggregationOverWindow... windowAggregates) { + return aggregateWindowsOverRanges(windowAggregates); + } } public static final class TableOperation { diff --git a/java/src/main/java/ai/rapids/cudf/WindowOptions.java b/java/src/main/java/ai/rapids/cudf/WindowOptions.java index 429d4e1d978..826784a33f1 100644 --- a/java/src/main/java/ai/rapids/cudf/WindowOptions.java +++ b/java/src/main/java/ai/rapids/cudf/WindowOptions.java @@ -1,6 +1,6 @@ /* * - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,33 +21,50 @@ /** * Options for rolling windows. */ -public class WindowOptions { +public class WindowOptions implements AutoCloseable { enum FrameType {ROWS, RANGE} - private final int preceding; private final int minPeriods; + private final int preceding; private final int following; + private final Scalar precedingScalar; + private final Scalar followingScalar; private final ColumnVector precedingCol; private final ColumnVector followingCol; - private final int timestampColumnIndex; - private final boolean timestampOrderAscending; + private final int orderByColumnIndex; + private final boolean orderByOrderAscending; private final FrameType frameType; private final boolean isUnboundedPreceding; private final boolean isUnboundedFollowing; private WindowOptions(Builder builder) { - this.preceding = builder.preceding; this.minPeriods = builder.minPeriods; + this.preceding = builder.preceding; this.following = builder.following; + this.precedingScalar = builder.precedingScalar; + if (precedingScalar != null) { + precedingScalar.incRefCount(); + } + this.followingScalar = builder.followingScalar; + if (followingScalar != null) { + followingScalar.incRefCount(); + } this.precedingCol = builder.precedingCol; + if (precedingCol != null) { + precedingCol.incRefCount(); + } this.followingCol = builder.followingCol; - this.timestampColumnIndex = builder.timestampColumnIndex; - this.timestampOrderAscending = builder.timestampOrderAscending; - this.frameType = timestampColumnIndex == -1? FrameType.ROWS : FrameType.RANGE; + if (followingCol != null) { + followingCol.incRefCount(); + } + this.orderByColumnIndex = builder.orderByColumnIndex; + this.orderByOrderAscending = builder.orderByOrderAscending; + this.frameType = orderByColumnIndex == -1? FrameType.ROWS : FrameType.RANGE; this.isUnboundedPreceding = builder.isUnboundedPreceding; this.isUnboundedFollowing = builder.isUnboundedFollowing; + } @Override @@ -59,8 +76,8 @@ public boolean equals(Object other) { boolean ret = this.preceding == o.preceding && this.following == o.following && this.minPeriods == o.minPeriods && - this.timestampColumnIndex == o.timestampColumnIndex && - this.timestampOrderAscending == o.timestampOrderAscending && + this.orderByColumnIndex == o.orderByColumnIndex && + this.orderByOrderAscending == o.orderByOrderAscending && this.frameType == o.frameType && this.isUnboundedPreceding == o.isUnboundedPreceding && this.isUnboundedFollowing == o.isUnboundedFollowing; @@ -70,6 +87,12 @@ public boolean equals(Object other) { if (followingCol != null) { ret = ret && followingCol.equals(o.followingCol); } + if (precedingScalar != null) { + ret = ret && precedingScalar.equals(o.precedingScalar); + } + if (followingScalar != null) { + ret = ret && followingScalar.equals(o.followingScalar); + } return ret; } return false; @@ -81,8 +104,8 @@ public int hashCode() { ret = 31 * ret + preceding; ret = 31 * ret + following; ret = 31 * ret + minPeriods; - ret = 31 * ret + timestampColumnIndex; - ret = 31 * ret + Boolean.hashCode(timestampOrderAscending); + ret = 31 * ret + orderByColumnIndex; + ret = 31 * ret + Boolean.hashCode(orderByOrderAscending); ret = 31 * ret + frameType.hashCode(); if (precedingCol != null) { ret = 31 * ret + precedingCol.hashCode(); @@ -90,6 +113,12 @@ public int hashCode() { if (followingCol != null) { ret = 31 * ret + followingCol.hashCode(); } + if (precedingScalar != null) { + ret = 31 * ret + precedingScalar.hashCode(); + } + if (followingScalar != null) { + ret = 31 * ret + followingScalar.hashCode(); + } ret = 31 * ret + Boolean.hashCode(isUnboundedPreceding); ret = 31 * ret + Boolean.hashCode(isUnboundedFollowing); return ret; @@ -105,13 +134,23 @@ public static Builder builder(){ int getFollowing() { return this.following; } + Scalar getPrecedingScalar() { return this.precedingScalar; } + + Scalar getFollowingScalar() { return this.followingScalar; } + ColumnVector getPrecedingCol() { return precedingCol; } ColumnVector getFollowingCol() { return this.followingCol; } - int getTimestampColumnIndex() { return this.timestampColumnIndex; } + @Deprecated + int getTimestampColumnIndex() { return getOrderByColumnIndex(); } - boolean isTimestampOrderAscending() { return this.timestampOrderAscending; } + int getOrderByColumnIndex() { return this.orderByColumnIndex; } + + @Deprecated + boolean isTimestampOrderAscending() { return isOrderByOrderAscending(); } + + boolean isOrderByOrderAscending() { return this.orderByOrderAscending; } boolean isUnboundedPreceding() { return this.isUnboundedPreceding; } @@ -123,11 +162,14 @@ public static class Builder { private int minPeriods = 1; private int preceding = 0; private int following = 1; + // for range window + private Scalar precedingScalar = null; + private Scalar followingScalar = null; boolean staticSet = false; private ColumnVector precedingCol = null; private ColumnVector followingCol = null; - private int timestampColumnIndex = -1; - private boolean timestampOrderAscending = true; + private int orderByColumnIndex = -1; + private boolean orderByOrderAscending = true; private boolean isUnboundedPreceding = false; private boolean isUnboundedFollowing = false; @@ -147,8 +189,10 @@ public Builder minPeriods(int minPeriods) { * Set the size of the window, one entry per row. This does not take ownership of the * columns passed in so you have to be sure that the life time of the column outlives * this operation. - * @param precedingCol the number of rows preceding the current row. - * @param followingCol the number of rows following the current row. + * @param precedingCol the number of rows preceding the current row and + * precedingCol will be live outside of WindowOptions. + * @param followingCol the number of rows following the current row and + * following will be live outside of WindowOptions. */ public Builder window(ColumnVector precedingCol, ColumnVector followingCol) { assert (precedingCol != null && precedingCol.getNullCount() == 0); @@ -158,21 +202,60 @@ public Builder window(ColumnVector precedingCol, ColumnVector followingCol) { return this; } + /** + * Set the size of the range window. + * @param precedingScalar the relative number preceding the current row and + * the precedingScalar will be live outside of WindowOptions. + * @param followingScalar the relative number following the current row and + * the followingScalar will be live outside of WindowOptions + */ + public Builder window(Scalar precedingScalar, Scalar followingScalar) { + assert (precedingScalar != null && precedingScalar.isValid()); + assert (followingScalar != null && followingScalar.isValid()); + this.precedingScalar = precedingScalar; + this.followingScalar = followingScalar; + return this; + } + + /** + * @deprecated Use orderByColumnIndex(int index) + */ + @Deprecated public Builder timestampColumnIndex(int index) { - this.timestampColumnIndex = index; + return orderByColumnIndex(index); + } + + public Builder orderByColumnIndex(int index) { + this.orderByColumnIndex = index; return this; } + /** + * @deprecated Use orderByAscending() + */ + @Deprecated public Builder timestampAscending() { - this.timestampOrderAscending = true; + return orderByAscending(); + } + + public Builder orderByAscending() { + this.orderByOrderAscending = true; return this; } - public Builder timestampDescending() { - this.timestampOrderAscending = false; + public Builder orderByDescending() { + this.orderByOrderAscending = false; return this; } + /** + * @deprecated Use orderByDescending() + */ + @Deprecated + public Builder timestampDescending() { + return orderByDescending(); + } + public Builder unboundedPreceding() { this.isUnboundedPreceding = true; return this; @@ -193,6 +276,26 @@ public Builder following(int following) { return this; } + /** + * Set the relative number preceding the current row for range window + * @param preceding + * @return Builder + */ + public Builder preceding(Scalar preceding) { + this.precedingScalar = preceding; + return this; + } + + /** + * Set the relative number following the current row for range window + * @param following + * @return Builder + */ + public Builder following(Scalar following) { + this.followingScalar = following; + return this; + } + /** * Set the size of the window. * @param preceding the number of rows preceding the current row @@ -209,7 +312,40 @@ public WindowOptions build() { if (staticSet && precedingCol != null) { throw new IllegalArgumentException("Cannot set both a static window and a non-static window"); } + return new WindowOptions(this); } } + + public synchronized WindowOptions incRefCount() { + if (precedingScalar != null) { + precedingScalar.incRefCount(); + } + if (followingScalar != null) { + followingScalar.incRefCount(); + } + if (precedingCol != null) { + precedingCol.incRefCount(); + } + if (followingCol != null) { + followingCol.incRefCount(); + } + return this; + } + + @Override + public void close() { + if (precedingScalar != null) { + precedingScalar.close(); + } + if (followingScalar != null) { + followingScalar.close(); + } + if (precedingCol != null) { + precedingCol.close(); + } + if (followingCol != null) { + followingCol.close(); + } + } } \ No newline at end of file diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index 3799a5dbab3..4b01745382b 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -728,6 +728,16 @@ bool valid_window_parameters(native_jintArray const &values, values.size() == preceding.size() && values.size() == following.size(); } +// Check that window parameters are valid. +bool valid_window_parameters(native_jintArray const &values, + native_jpointerArray const &ops, + native_jintArray const &min_periods, + native_jpointerArray const &preceding, + native_jpointerArray const &following) { + return values.size() == ops.size() && values.size() == min_periods.size() && + values.size() == preceding.size() && values.size() == following.size(); +} + // Generate gather maps needed to manifest the result of a join between two tables. // The resulting Java long array contains the following at each index: // 0: Size of each gather map in bytes @@ -2315,20 +2325,22 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_rollingWindowAggregate( CATCH_STD(env, NULL); } -JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_timeRangeRollingWindowAggregate( +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_rangeRollingWindowAggregate( JNIEnv *env, jclass, jlong j_input_table, jintArray j_keys, - jintArray j_timestamp_column_indices, jbooleanArray j_is_timestamp_ascending, + jintArray j_orderby_column_indices, jbooleanArray j_is_orderby_ascending, jintArray j_aggregate_column_indices, jlongArray j_agg_instances, jintArray j_min_periods, - jintArray j_preceding, jintArray j_following, - jbooleanArray j_unbounded_preceding, jbooleanArray j_unbounded_following, + jlongArray j_preceding, jlongArray j_following, + jbooleanArray j_unbounded_preceding, jbooleanArray j_unbounded_following, jboolean ignore_null_keys) { JNI_NULL_CHECK(env, j_input_table, "input table is null", NULL); JNI_NULL_CHECK(env, j_keys, "input keys are null", NULL); - JNI_NULL_CHECK(env, j_timestamp_column_indices, "input timestamp_column_indices are null", NULL); - JNI_NULL_CHECK(env, j_is_timestamp_ascending, "input timestamp_ascending is null", NULL); + JNI_NULL_CHECK(env, j_orderby_column_indices, "input orderby_column_indices are null", NULL); + JNI_NULL_CHECK(env, j_is_orderby_ascending, "input orderby_ascending is null", NULL); JNI_NULL_CHECK(env, j_aggregate_column_indices, "input aggregate_column_indices are null", NULL); JNI_NULL_CHECK(env, j_agg_instances, "agg_instances are null", NULL); + JNI_NULL_CHECK(env, j_preceding, "preceding are null", NULL); + JNI_NULL_CHECK(env, j_following, "following are null", NULL); try { cudf::jni::auto_set_device(env); @@ -2338,15 +2350,15 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_timeRangeRollingWindowAgg // Convert from j-types to native. cudf::table_view *input_table{reinterpret_cast(j_input_table)}; cudf::jni::native_jintArray keys{env, j_keys}; - cudf::jni::native_jintArray timestamps{env, j_timestamp_column_indices}; - cudf::jni::native_jbooleanArray timestamp_ascending{env, j_is_timestamp_ascending}; + cudf::jni::native_jintArray orderbys{env, j_orderby_column_indices}; + cudf::jni::native_jbooleanArray orderbys_ascending{env, j_is_orderby_ascending}; cudf::jni::native_jintArray values{env, j_aggregate_column_indices}; cudf::jni::native_jpointerArray agg_instances(env, j_agg_instances); cudf::jni::native_jintArray min_periods{env, j_min_periods}; - cudf::jni::native_jintArray preceding{env, j_preceding}; - cudf::jni::native_jintArray following{env, j_following}; cudf::jni::native_jbooleanArray unbounded_preceding{env, j_unbounded_preceding}; cudf::jni::native_jbooleanArray unbounded_following{env, j_unbounded_following}; + cudf::jni::native_jpointerArray preceding(env, j_preceding); + cudf::jni::native_jpointerArray following(env, j_following); if (not valid_window_parameters(values, agg_instances, min_periods, preceding, following)) { JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", @@ -2361,21 +2373,48 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_timeRangeRollingWindowAgg std::vector> result_columns; for (int i(0); i < values.size(); ++i) { int agg_column_index = values[i]; + cudf::column_view const &order_by_column = input_table->column(orderbys[i]); + cudf::data_type order_by_type = order_by_column.type(); + cudf::data_type unbounded_type = order_by_type; + + if (unbounded_preceding[i] || unbounded_following[i]) { + switch (order_by_type.id()) { + case cudf::type_id::TIMESTAMP_DAYS: + unbounded_type = cudf::data_type{cudf::type_id::DURATION_DAYS}; + break; + case cudf::type_id::TIMESTAMP_SECONDS: + unbounded_type =cudf::data_type{cudf::type_id::DURATION_SECONDS}; + break; + case cudf::type_id::TIMESTAMP_MILLISECONDS: + unbounded_type = cudf::data_type{cudf::type_id::DURATION_MILLISECONDS}; + break; + case cudf::type_id::TIMESTAMP_MICROSECONDS: + unbounded_type = cudf::data_type{cudf::type_id::DURATION_MICROSECONDS}; + break; + case cudf::type_id::TIMESTAMP_NANOSECONDS: + unbounded_type = cudf::data_type{cudf::type_id::DURATION_NANOSECONDS}; + break; + default: + break; + } + } cudf::rolling_aggregation * agg = dynamic_cast(agg_instances[i]); JNI_ARG_CHECK(env, agg != nullptr, "aggregation is not an instance of rolling_aggregation", nullptr); result_columns.emplace_back( std::move( - cudf::grouped_time_range_rolling_window( - groupby_keys, - input_table->column(timestamps[i]), - timestamp_ascending[i] ? cudf::order::ASCENDING : cudf::order::DESCENDING, - input_table->column(agg_column_index), - unbounded_preceding[i] ? cudf::window_bounds::unbounded() : cudf::window_bounds::get(preceding[i]), - unbounded_following[i] ? cudf::window_bounds::unbounded() : cudf::window_bounds::get(following[i]), - min_periods[i], - *agg + cudf::grouped_range_rolling_window( + groupby_keys, + order_by_column, + orderbys_ascending[i] ? cudf::order::ASCENDING : cudf::order::DESCENDING, + input_table->column(agg_column_index), + unbounded_preceding[i] ? cudf::range_window_bounds::unbounded(unbounded_type) : + cudf::range_window_bounds::get(*preceding[i]), + unbounded_following[i] ? cudf::range_window_bounds::unbounded(unbounded_type) : + cudf::range_window_bounds::get(*following[i]), + min_periods[i], + *agg ) ) ); diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 4c5ee7295d9..20508bdb23b 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -2263,60 +2263,62 @@ void testPrefixSumErrors() { @Test void testWindowStatic() { - WindowOptions options = WindowOptions.builder().window(2, 1) - .minPeriods(2).build(); - try (ColumnVector v1 = ColumnVector.fromInts(5, 4, 7, 6, 8)) { - try (ColumnVector expected = ColumnVector.fromLongs(9, 16, 17, 21, 14); - ColumnVector result = v1.rollingWindow(Aggregation.sum(), options)) { - assertColumnsAreEqual(expected, result); - } + try (WindowOptions options = WindowOptions.builder().window(2, 1) + .minPeriods(2).build()) { + try (ColumnVector v1 = ColumnVector.fromInts(5, 4, 7, 6, 8)) { + try (ColumnVector expected = ColumnVector.fromLongs(9, 16, 17, 21, 14); + ColumnVector result = v1.rollingWindow(Aggregation.sum(), options)) { + assertColumnsAreEqual(expected, result); + } - try (ColumnVector expected = ColumnVector.fromInts(4, 4, 4, 6, 6); - ColumnVector result = v1.rollingWindow(Aggregation.min(), options)) { - assertColumnsAreEqual(expected, result); - } + try (ColumnVector expected = ColumnVector.fromInts(4, 4, 4, 6, 6); + ColumnVector result = v1.rollingWindow(Aggregation.min(), options)) { + assertColumnsAreEqual(expected, result); + } - try (ColumnVector expected = ColumnVector.fromInts(5, 7, 7, 8, 8); - ColumnVector result = v1.rollingWindow(Aggregation.max(), options)) { - assertColumnsAreEqual(expected, result); - } + try (ColumnVector expected = ColumnVector.fromInts(5, 7, 7, 8, 8); + ColumnVector result = v1.rollingWindow(Aggregation.max(), options)) { + assertColumnsAreEqual(expected, result); + } - // The rolling window produces the same result type as the input - try (ColumnVector expected = ColumnVector.fromDoubles(4.5, 16.0 / 3, 17.0 / 3, 7, 7); - ColumnVector result = v1.rollingWindow(Aggregation.mean(), options)) { - assertColumnsAreEqual(expected, result); - } + // The rolling window produces the same result type as the input + try (ColumnVector expected = ColumnVector.fromDoubles(4.5, 16.0 / 3, 17.0 / 3, 7, 7); + ColumnVector result = v1.rollingWindow(Aggregation.mean(), options)) { + assertColumnsAreEqual(expected, result); + } - try (ColumnVector expected = ColumnVector.fromBoxedInts(4, 7, 6, 8, null); - ColumnVector result = v1.rollingWindow(Aggregation.lead(1), options)) { - assertColumnsAreEqual(expected, result); - } + try (ColumnVector expected = ColumnVector.fromBoxedInts(4, 7, 6, 8, null); + ColumnVector result = v1.rollingWindow(Aggregation.lead(1), options)) { + assertColumnsAreEqual(expected, result); + } - try (ColumnVector expected = ColumnVector.fromBoxedInts(null, 5, 4, 7, 6); - ColumnVector result = v1.rollingWindow(Aggregation.lag(1), options)) { - assertColumnsAreEqual(expected, result); - } + try (ColumnVector expected = ColumnVector.fromBoxedInts(null, 5, 4, 7, 6); + ColumnVector result = v1.rollingWindow(Aggregation.lag(1), options)) { + assertColumnsAreEqual(expected, result); + } - try (ColumnVector defaultOutput = ColumnVector.fromInts(-1, -2, -3, -4, -5); - ColumnVector expected = ColumnVector.fromBoxedInts(-1, 5, 4, 7, 6); - ColumnVector result = v1.rollingWindow(Aggregation.lag(1, defaultOutput), options)) { - assertColumnsAreEqual(expected, result); + try (ColumnVector defaultOutput = ColumnVector.fromInts(-1, -2, -3, -4, -5); + ColumnVector expected = ColumnVector.fromBoxedInts(-1, 5, 4, 7, 6); + ColumnVector result = v1.rollingWindow(Aggregation.lag(1, defaultOutput), options)) { + assertColumnsAreEqual(expected, result); + } } } } @Test void testWindowStaticCounts() { - WindowOptions options = WindowOptions.builder().window(2, 1) - .minPeriods(2).build(); - try (ColumnVector v1 = ColumnVector.fromBoxedInts(5, 4, null, 6, 8)) { - try (ColumnVector expected = ColumnVector.fromInts(2, 2, 2, 2, 2); - ColumnVector result = v1.rollingWindow(Aggregation.count(NullPolicy.EXCLUDE), options)) { - assertColumnsAreEqual(expected, result); - } - try (ColumnVector expected = ColumnVector.fromInts(2, 3, 3, 3, 2); - ColumnVector result = v1.rollingWindow(Aggregation.count(NullPolicy.INCLUDE), options)) { - assertColumnsAreEqual(expected, result); + try (WindowOptions options = WindowOptions.builder().window(2, 1) + .minPeriods(2).build()) { + try (ColumnVector v1 = ColumnVector.fromBoxedInts(5, 4, null, 6, 8)) { + try (ColumnVector expected = ColumnVector.fromInts(2, 2, 2, 2, 2); + ColumnVector result = v1.rollingWindow(Aggregation.count(NullPolicy.EXCLUDE), options)) { + assertColumnsAreEqual(expected, result); + } + try (ColumnVector expected = ColumnVector.fromInts(2, 3, 3, 3, 2); + ColumnVector result = v1.rollingWindow(Aggregation.count(NullPolicy.INCLUDE), options)) { + assertColumnsAreEqual(expected, result); + } } } } @@ -2325,24 +2327,26 @@ void testWindowStaticCounts() { void testWindowDynamicNegative() { try (ColumnVector precedingCol = ColumnVector.fromInts(3, 3, 3, 4, 4); ColumnVector followingCol = ColumnVector.fromInts(-1, -1, -1, -1, 0)) { - WindowOptions window = WindowOptions.builder() - .minPeriods(2).window(precedingCol, followingCol).build(); - try (ColumnVector v1 = ColumnVector.fromInts(5, 4, 7, 6, 8); - ColumnVector expected = ColumnVector.fromBoxedLongs(null, null, 9L, 16L, 25L); - ColumnVector result = v1.rollingWindow(Aggregation.sum(), window)) { - assertColumnsAreEqual(expected, result); + try (WindowOptions window = WindowOptions.builder() + .minPeriods(2).window(precedingCol, followingCol).build()) { + try (ColumnVector v1 = ColumnVector.fromInts(5, 4, 7, 6, 8); + ColumnVector expected = ColumnVector.fromBoxedLongs(null, null, 9L, 16L, 25L); + ColumnVector result = v1.rollingWindow(Aggregation.sum(), window)) { + assertColumnsAreEqual(expected, result); + } } } } @Test void testWindowLag() { - WindowOptions window = WindowOptions.builder().minPeriods(1) - .window(2, -1).build(); - try (ColumnVector v1 = ColumnVector.fromInts(5, 4, 7, 6, 8); - ColumnVector expected = ColumnVector.fromBoxedInts(null, 5, 4, 7, 6); - ColumnVector result = v1.rollingWindow(Aggregation.max(), window)) { - assertColumnsAreEqual(expected, result); + try (WindowOptions window = WindowOptions.builder().minPeriods(1) + .window(2, -1).build()) { + try (ColumnVector v1 = ColumnVector.fromInts(5, 4, 7, 6, 8); + ColumnVector expected = ColumnVector.fromBoxedInts(null, 5, 4, 7, 6); + ColumnVector result = v1.rollingWindow(Aggregation.max(), window)) { + assertColumnsAreEqual(expected, result); + } } } @@ -2350,12 +2354,13 @@ void testWindowLag() { void testWindowDynamic() { try (ColumnVector precedingCol = ColumnVector.fromInts(1, 2, 3, 1, 2); ColumnVector followingCol = ColumnVector.fromInts(2, 2, 2, 2, 2)) { - WindowOptions window = WindowOptions.builder().minPeriods(2) - .window(precedingCol, followingCol).build(); - try (ColumnVector v1 = ColumnVector.fromInts(5, 4, 7, 6, 8); - ColumnVector expected = ColumnVector.fromLongs(16, 22, 30, 14, 14); - ColumnVector result = v1.rollingWindow(Aggregation.sum(), window)) { - assertColumnsAreEqual(expected, result); + try (WindowOptions window = WindowOptions.builder().minPeriods(2) + .window(precedingCol, followingCol).build()) { + try (ColumnVector v1 = ColumnVector.fromInts(5, 4, 7, 6, 8); + ColumnVector expected = ColumnVector.fromLongs(16, 22, 30, 14, 14); + ColumnVector result = v1.rollingWindow(Aggregation.sum(), window)) { + assertColumnsAreEqual(expected, result); + } } } } @@ -2363,17 +2368,23 @@ void testWindowDynamic() { @Test void testWindowThrowsException() { try (ColumnVector arraywindowCol = ColumnVector.fromBoxedInts(1, 2, 3 ,1, 1)) { - assertThrows(IllegalArgumentException.class, () -> WindowOptions.builder() - .window(3, 2).minPeriods(3) - .window(arraywindowCol, arraywindowCol).build()); - - assertThrows(IllegalArgumentException.class, - () -> arraywindowCol.rollingWindow(Aggregation.sum(), - WindowOptions.builder() - .window(2, 1) - .minPeriods(1) - .timestampColumnIndex(0) - .build())); + assertThrows(IllegalArgumentException.class, () -> { + try (WindowOptions options = WindowOptions.builder() + .window(3, 2).minPeriods(3) + .window(arraywindowCol, arraywindowCol).build()) { + + } + }); + + assertThrows(IllegalArgumentException.class, () -> { + try (WindowOptions options = WindowOptions.builder() + .window(2, 1) + .minPeriods(1) + .orderByColumnIndex(0) + .build()) { + arraywindowCol.rollingWindow(Aggregation.sum(), options); + } + }); } } diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 735dc86af17..041f386f3f9 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -2750,18 +2750,19 @@ void testWindowingCount() { ColumnVector decSortedAggColumn = decSorted.getColumn(3); assertColumnsAreEqual(expectSortedAggColumn, decSortedAggColumn); - WindowOptions window = WindowOptions.builder() + try (WindowOptions window = WindowOptions.builder() .minPeriods(1) .window(2, 1) - .build(); + .build()) { - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation.count().onColumn(3).overWindow(window)); - Table decWindowAggResults = decSorted.groupBy(0, 4) - .aggregateWindows(Aggregation.count().onColumn(3).overWindow(window)); - ColumnVector expect = ColumnVector.fromBoxedInts(2, 3, 3, 2, 2, 3, 3, 2, 2, 3, 3, 2)) { - assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); - assertColumnsAreEqual(expect, decWindowAggResults.getColumn(0)); + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation.count().onColumn(3).overWindow(window)); + Table decWindowAggResults = decSorted.groupBy(0, 4) + .aggregateWindows(Aggregation.count().onColumn(3).overWindow(window)); + ColumnVector expect = ColumnVector.fromBoxedInts(2, 3, 3, 2, 2, 3, 3, 2, 2, 3, 3, 2)) { + assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); + assertColumnsAreEqual(expect, decWindowAggResults.getColumn(0)); + } } } } @@ -2787,19 +2788,20 @@ void testWindowingMin() { ColumnVector decSortedAggColumn = decSorted.getColumn(6); assertColumnsAreEqual(expectDecSortedAggCol, decSortedAggColumn); - WindowOptions window = WindowOptions.builder() + try (WindowOptions window = WindowOptions.builder() .minPeriods(1) .window(2, 1) - .build(); + .build()) { - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation.min().onColumn(3).overWindow(window)); - Table decWindowAggResults = decSorted.groupBy(0, 4) - .aggregateWindows(Aggregation.min().onColumn(6).overWindow(window)); - ColumnVector expect = ColumnVector.fromBoxedInts(5, 1, 1, 1, 7, 7, 2, 2, 0, 0, 0, 6); - ColumnVector decExpect = ColumnVector.decimalFromLongs(2, 5, 1, 1, 1, 7, 7, 2, 2, 0, 0, 0, 6)) { - assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); - assertColumnsAreEqual(decExpect, decWindowAggResults.getColumn(0)); + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation.min().onColumn(3).overWindow(window)); + Table decWindowAggResults = decSorted.groupBy(0, 4) + .aggregateWindows(Aggregation.min().onColumn(6).overWindow(window)); + ColumnVector expect = ColumnVector.fromBoxedInts(5, 1, 1, 1, 7, 7, 2, 2, 0, 0, 0, 6); + ColumnVector decExpect = ColumnVector.decimalFromLongs(2, 5, 1, 1, 1, 7, 7, 2, 2, 0, 0, 0, 6)) { + assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); + assertColumnsAreEqual(decExpect, decWindowAggResults.getColumn(0)); + } } } } @@ -2825,19 +2827,20 @@ void testWindowingMax() { ColumnVector decSortedAggColumn = decSorted.getColumn(6); assertColumnsAreEqual(expectDecSortedAggCol, decSortedAggColumn); - WindowOptions window = WindowOptions.builder() + try (WindowOptions window = WindowOptions.builder() .minPeriods(1) .window(2, 1) - .build(); + .build()) { - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation.max().onColumn(3).overWindow(window)); - Table decWindowAggResults = decSorted.groupBy(0, 4) - .aggregateWindows(Aggregation.max().onColumn(6).overWindow(window)); - ColumnVector expect = ColumnVector.fromBoxedInts(7, 7, 9, 9, 9, 9, 9, 8, 8, 8, 6, 6); - ColumnVector decExpect = ColumnVector.decimalFromLongs(2, 7, 7, 9, 9, 9, 9, 9, 8, 8, 8, 6, 6)) { - assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); - assertColumnsAreEqual(decExpect, decWindowAggResults.getColumn(0)); + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation.max().onColumn(3).overWindow(window)); + Table decWindowAggResults = decSorted.groupBy(0, 4) + .aggregateWindows(Aggregation.max().onColumn(6).overWindow(window)); + ColumnVector expect = ColumnVector.fromBoxedInts(7, 7, 9, 9, 9, 9, 9, 8, 8, 8, 6, 6); + ColumnVector decExpect = ColumnVector.decimalFromLongs(2, 7, 7, 9, 9, 9, 9, 9, 8, 8, 8, 6, 6)) { + assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); + assertColumnsAreEqual(decExpect, decWindowAggResults.getColumn(0)); + } } } } @@ -2856,15 +2859,16 @@ void testWindowingSum() { ColumnVector sortedAggColumn = sorted.getColumn(3); assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); - WindowOptions window = WindowOptions.builder() + try (WindowOptions window = WindowOptions.builder() .minPeriods(1) .window(2, 1) - .build(); + .build()) { - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation.sum().onColumn(3).overWindow(window)); - ColumnVector expectAggResult = ColumnVector.fromBoxedLongs(12L, 13L, 15L, 10L, 16L, 24L, 19L, 10L, 8L, 14L, 12L, 12L)) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation.sum().onColumn(3).overWindow(window)); + ColumnVector expectAggResult = ColumnVector.fromBoxedLongs(12L, 13L, 15L, 10L, 16L, 24L, 19L, 10L, 8L, 14L, 12L, 12L)) { + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + } } } } @@ -2892,49 +2896,58 @@ void testWindowingRowNumber() { WindowOptions.Builder windowBuilder = WindowOptions.builder().minPeriods(1); - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation - .rowNumber() - .onColumn(3) - .overWindow(windowBuilder.window(2, 1).build())); - Table decWindowAggResults = decSorted.groupBy(0, 4) - .aggregateWindows(Aggregation - .rowNumber() - .onColumn(6) - .overWindow(windowBuilder.window(2, 1).build())); - ColumnVector expectAggResult = ColumnVector.fromBoxedInts(1, 2, 2, 2, 1, 2, 2, 2, 1, 2, 2, 2)) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); - assertColumnsAreEqual(expectAggResult, decWindowAggResults.getColumn(0)); + try (WindowOptions options = windowBuilder.window(2, 1).build(); + WindowOptions options1 = windowBuilder.window(2, 1).build()) { + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation + .rowNumber() + .onColumn(3) + .overWindow(options)); + Table decWindowAggResults = decSorted.groupBy(0, 4) + .aggregateWindows(Aggregation + .rowNumber() + .onColumn(6) + .overWindow(options1)); + ColumnVector expectAggResult = ColumnVector.fromBoxedInts(1, 2, 2, 2, 1, 2, 2, 2, 1, 2, 2, 2)) { + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + assertColumnsAreEqual(expectAggResult, decWindowAggResults.getColumn(0)); + } } - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation - .rowNumber() - .onColumn(3) - .overWindow(windowBuilder.window(3, 2).build())); - Table decWindowAggResults = decSorted.groupBy(0, 4) - .aggregateWindows(Aggregation - .rowNumber() - .onColumn(6) - .overWindow(windowBuilder.window(3, 2).build())); - ColumnVector expectAggResult = ColumnVector.fromBoxedInts(1, 2, 3, 3, 1, 2, 3, 3, 1, 2, 3, 3)) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); - assertColumnsAreEqual(expectAggResult, decWindowAggResults.getColumn(0)); + try (WindowOptions options = windowBuilder.window(3, 2).build(); + WindowOptions options1 = windowBuilder.window(3, 2).build()) { + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation + .rowNumber() + .onColumn(3) + .overWindow(options)); + Table decWindowAggResults = decSorted.groupBy(0, 4) + .aggregateWindows(Aggregation + .rowNumber() + .onColumn(6) + .overWindow(options1)); + ColumnVector expectAggResult = ColumnVector.fromBoxedInts(1, 2, 3, 3, 1, 2, 3, 3, 1, 2, 3, 3)) { + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + assertColumnsAreEqual(expectAggResult, decWindowAggResults.getColumn(0)); + } } - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation - .rowNumber() - .onColumn(3) - .overWindow(windowBuilder.window(4, 3).build())); - Table decWindowAggResults = decSorted.groupBy(0, 4) - .aggregateWindows(Aggregation - .rowNumber() - .onColumn(6) - .overWindow(windowBuilder.window(4, 3).build())); - ColumnVector expectAggResult = ColumnVector.fromBoxedInts(1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4)) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); - assertColumnsAreEqual(expectAggResult, decWindowAggResults.getColumn(0)); + try (WindowOptions options = windowBuilder.window(4, 3).build(); + WindowOptions options1 = windowBuilder.window(4, 3).build()) { + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation + .rowNumber() + .onColumn(3) + .overWindow(options)); + Table decWindowAggResults = decSorted.groupBy(0, 4) + .aggregateWindows(Aggregation + .rowNumber() + .onColumn(6) + .overWindow(options1)); + ColumnVector expectAggResult = ColumnVector.fromBoxedInts(1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4)) { + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + assertColumnsAreEqual(expectAggResult, decWindowAggResults.getColumn(0)); + } } } } @@ -2944,69 +2957,71 @@ void testWindowingRowNumber() { void testWindowingCollectList() { Aggregation aggCollectWithNulls = Aggregation.collectList(NullPolicy.INCLUDE); Aggregation aggCollect = Aggregation.collectList(); - WindowOptions winOpts = WindowOptions.builder() - .minPeriods(1) - .window(2, 1).build(); - StructType nestedType = new StructType(false, - new BasicType(false, DType.INT32), new BasicType(false, DType.STRING)); - try (Table raw = new Table.TestBuilder() - .column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column( 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3) // GBY Key - .column( 1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8) // OBY Key - .column( 7, 5, 1, 9, 7, 9, 8, 2, null, 0, 6, null) // Agg Column of INT32 - .column(nestedType, // Agg Column of Struct - new StructData(1, "s1"), new StructData(2, "s2"), new StructData(3, "s3"), - new StructData(4, "s4"), new StructData(11, "s11"), new StructData(22, "s22"), - new StructData(33, "s33"), new StructData(44, "s44"), new StructData(111, "s111"), - new StructData(222, "s222"), new StructData(333, "s333"), new StructData(444, "s444") - ).build(); - ColumnVector expectSortedAggColumn = ColumnVector - .fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, null, 0, 6, null)) { - try (Table sorted = raw.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2))) { - ColumnVector sortedAggColumn = sorted.getColumn(3); - assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); - - // Primitive type: INT32 - // a) including nulls - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(aggCollectWithNulls.onColumn(3).overWindow(winOpts)); - ColumnVector expected = ColumnVector.fromLists( - new ListType(false, new BasicType(false, DType.INT32)), - Arrays.asList(7,5), Arrays.asList(7,5,1), Arrays.asList(5,1,9), Arrays.asList(1,9), - Arrays.asList(7,9), Arrays.asList(7,9,8), Arrays.asList(9,8,2), Arrays.asList(8,2), - Arrays.asList(null,0), Arrays.asList(null,0,6), Arrays.asList(0,6,null), Arrays.asList(6,null))) { - assertColumnsAreEqual(expected, windowAggResults.getColumn(0)); - } - // b) excluding nulls - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(aggCollect.onColumn(3).overWindow(winOpts)); - ColumnVector expected = ColumnVector.fromLists( - new ListType(false, new BasicType(false, DType.INT32)), - Arrays.asList(7,5), Arrays.asList(7,5,1), Arrays.asList(5,1,9), Arrays.asList(1,9), - Arrays.asList(7,9), Arrays.asList(7,9,8), Arrays.asList(9,8,2), Arrays.asList(8,2), - Arrays.asList(0), Arrays.asList(0,6), Arrays.asList(0,6), Arrays.asList(6))) { - assertColumnsAreEqual(expected, windowAggResults.getColumn(0)); - } + try (WindowOptions winOpts = WindowOptions.builder() + .minPeriods(1) + .window(2, 1) + .build()) { + StructType nestedType = new StructType(false, + new BasicType(false, DType.INT32), new BasicType(false, DType.STRING)); + try (Table raw = new Table.TestBuilder() + .column(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column(1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3) // GBY Key + .column(1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8) // OBY Key + .column(7, 5, 1, 9, 7, 9, 8, 2, null, 0, 6, null) // Agg Column of INT32 + .column(nestedType, // Agg Column of Struct + new StructData(1, "s1"), new StructData(2, "s2"), new StructData(3, "s3"), + new StructData(4, "s4"), new StructData(11, "s11"), new StructData(22, "s22"), + new StructData(33, "s33"), new StructData(44, "s44"), new StructData(111, "s111"), + new StructData(222, "s222"), new StructData(333, "s333"), new StructData(444, "s444") + ).build(); + ColumnVector expectSortedAggColumn = ColumnVector + .fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, null, 0, 6, null)) { + try (Table sorted = raw.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2))) { + ColumnVector sortedAggColumn = sorted.getColumn(3); + assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); + + // Primitive type: INT32 + // a) including nulls + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(aggCollectWithNulls.onColumn(3).overWindow(winOpts)); + ColumnVector expected = ColumnVector.fromLists( + new ListType(false, new BasicType(false, DType.INT32)), + Arrays.asList(7, 5), Arrays.asList(7, 5, 1), Arrays.asList(5, 1, 9), Arrays.asList(1, 9), + Arrays.asList(7, 9), Arrays.asList(7, 9, 8), Arrays.asList(9, 8, 2), Arrays.asList(8, 2), + Arrays.asList(null, 0), Arrays.asList(null, 0, 6), Arrays.asList(0, 6, null), Arrays.asList(6, null))) { + assertColumnsAreEqual(expected, windowAggResults.getColumn(0)); + } + // b) excluding nulls + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(aggCollect.onColumn(3).overWindow(winOpts)); + ColumnVector expected = ColumnVector.fromLists( + new ListType(false, new BasicType(false, DType.INT32)), + Arrays.asList(7, 5), Arrays.asList(7, 5, 1), Arrays.asList(5, 1, 9), Arrays.asList(1, 9), + Arrays.asList(7, 9), Arrays.asList(7, 9, 8), Arrays.asList(9, 8, 2), Arrays.asList(8, 2), + Arrays.asList(0), Arrays.asList(0, 6), Arrays.asList(0, 6), Arrays.asList(6))) { + assertColumnsAreEqual(expected, windowAggResults.getColumn(0)); + } - // Nested type: Struct - List[] expectedNestedData = new List[12]; - expectedNestedData[0] = Arrays.asList(new StructData(1, "s1"),new StructData(2, "s2")); - expectedNestedData[1] = Arrays.asList(new StructData(1, "s1"),new StructData(2, "s2"),new StructData(3, "s3")); - expectedNestedData[2] = Arrays.asList(new StructData(2, "s2"),new StructData(3, "s3"),new StructData(4, "s4")); - expectedNestedData[3] = Arrays.asList(new StructData(3, "s3"),new StructData(4, "s4")); - expectedNestedData[4] = Arrays.asList(new StructData(11, "s11"),new StructData(22, "s22")); - expectedNestedData[5] = Arrays.asList(new StructData(11, "s11"),new StructData(22, "s22"),new StructData(33, "s33")); - expectedNestedData[6] = Arrays.asList(new StructData(22, "s22"),new StructData(33, "s33"), new StructData(44, "s44")); - expectedNestedData[7] = Arrays.asList(new StructData(33, "s33"), new StructData(44, "s44")); - expectedNestedData[8] = Arrays.asList(new StructData(111, "s111"),new StructData(222, "s222")); - expectedNestedData[9] = Arrays.asList(new StructData(111, "s111"),new StructData(222, "s222"),new StructData(333, "s333")); - expectedNestedData[10] = Arrays.asList(new StructData(222, "s222"),new StructData(333, "s333"),new StructData(444, "s444")); - expectedNestedData[11] = Arrays.asList(new StructData(333, "s333"),new StructData(444, "s444")); - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(aggCollect.onColumn(4).overWindow(winOpts)); - ColumnVector expected = ColumnVector.fromLists( - new ListType(false, nestedType), expectedNestedData)) { - assertColumnsAreEqual(expected, windowAggResults.getColumn(0)); + // Nested type: Struct + List[] expectedNestedData = new List[12]; + expectedNestedData[0] = Arrays.asList(new StructData(1, "s1"), new StructData(2, "s2")); + expectedNestedData[1] = Arrays.asList(new StructData(1, "s1"), new StructData(2, "s2"), new StructData(3, "s3")); + expectedNestedData[2] = Arrays.asList(new StructData(2, "s2"), new StructData(3, "s3"), new StructData(4, "s4")); + expectedNestedData[3] = Arrays.asList(new StructData(3, "s3"), new StructData(4, "s4")); + expectedNestedData[4] = Arrays.asList(new StructData(11, "s11"), new StructData(22, "s22")); + expectedNestedData[5] = Arrays.asList(new StructData(11, "s11"), new StructData(22, "s22"), new StructData(33, "s33")); + expectedNestedData[6] = Arrays.asList(new StructData(22, "s22"), new StructData(33, "s33"), new StructData(44, "s44")); + expectedNestedData[7] = Arrays.asList(new StructData(33, "s33"), new StructData(44, "s44")); + expectedNestedData[8] = Arrays.asList(new StructData(111, "s111"), new StructData(222, "s222")); + expectedNestedData[9] = Arrays.asList(new StructData(111, "s111"), new StructData(222, "s222"), new StructData(333, "s333")); + expectedNestedData[10] = Arrays.asList(new StructData(222, "s222"), new StructData(333, "s333"), new StructData(444, "s444")); + expectedNestedData[11] = Arrays.asList(new StructData(333, "s333"), new StructData(444, "s444")); + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(aggCollect.onColumn(4).overWindow(winOpts)); + ColumnVector expected = ColumnVector.fromLists( + new ListType(false, nestedType), expectedNestedData)) { + assertColumnsAreEqual(expected, windowAggResults.getColumn(0)); + } } } } @@ -3035,71 +3050,83 @@ void testWindowingLead() { WindowOptions.Builder windowBuilder = WindowOptions.builder().minPeriods(1); - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation - .lead(0) - .onColumn(3) - .overWindow(windowBuilder.window(2, 1).build())); - Table decWindowAggResults = decSorted.groupBy(0, 4) - .aggregateWindows(Aggregation - .lead(0) - .onColumn(6) - .overWindow(windowBuilder.window(2, 1).build())); - ColumnVector expectAggResult = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6); - ColumnVector decExpectAggResult = ColumnVector.decimalFromLongs(-2, 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6)) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); - assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); + try (WindowOptions options = windowBuilder.window(2, 1).build(); + WindowOptions options1 = windowBuilder.window(2, 1).build()) { + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation + .lead(0) + .onColumn(3) + .overWindow(options)); + Table decWindowAggResults = decSorted.groupBy(0, 4) + .aggregateWindows(Aggregation + .lead(0) + .onColumn(6) + .overWindow(options1)); + ColumnVector expectAggResult = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6); + ColumnVector decExpectAggResult = ColumnVector.decimalFromLongs(-2, 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6)) { + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); + } } - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation - .lead(1) - .onColumn(3) - .overWindow(windowBuilder.window(0,1).build())); - Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation - .lead(1) - .onColumn(6) - .overWindow(windowBuilder.window(0,1).build())); - ColumnVector expectAggResult = ColumnVector.fromBoxedInts(5, 1, 9, null, 9, 8, 2, null, 0, 6, 6, null); - ColumnVector decExpectAggResult = decimalFromBoxedInts(true, -2, 5, 1, 9, null, 9, 8, 2, null, 0, 6, 6, null)) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); - assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); + try (WindowOptions options = windowBuilder.window(0,1).build(); + WindowOptions options1 = windowBuilder.window(0,1).build()) { + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation + .lead(1) + .onColumn(3) + .overWindow(options)); + Table decWindowAggResults = sorted.groupBy(0, 4) + .aggregateWindows(Aggregation + .lead(1) + .onColumn(6) + .overWindow(options1)); + ColumnVector expectAggResult = ColumnVector.fromBoxedInts(5, 1, 9, null, 9, 8, 2, null, 0, 6, 6, null); + ColumnVector decExpectAggResult = decimalFromBoxedInts(true, -2, 5, 1, 9, null, 9, 8, 2, null, 0, 6, 6, null)) { + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); + } } - try (ColumnVector defaultOutput = ColumnVector.fromBoxedInts(0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11); - ColumnVector decDefaultOutput = ColumnVector.decimalFromLongs(-2, 0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11); - Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation - .lead(1, defaultOutput) - .onColumn(3) - .overWindow(windowBuilder.window(0,1).build())); - Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation - .lead(1, decDefaultOutput) - .onColumn(6) - .overWindow(windowBuilder.window(0,1).build())); - ColumnVector expectAggResult = ColumnVector.fromBoxedInts(5, 1, 9, -3, 9, 8, 2, -7, 0, 6, 6, -11); - ColumnVector decExpectAggResult = ColumnVector.decimalFromLongs(-2, 5, 1, 9, -3, 9, 8, 2, -7, 0, 6, 6, -11)) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); - assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); + try (WindowOptions options = windowBuilder.window(0,1).build(); + WindowOptions options1 = windowBuilder.window(0,1).build()) { + try (ColumnVector defaultOutput = ColumnVector.fromBoxedInts(0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11); + ColumnVector decDefaultOutput = ColumnVector.decimalFromLongs(-2, 0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11); + Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation + .lead(1, defaultOutput) + .onColumn(3) + .overWindow(options)); + Table decWindowAggResults = sorted.groupBy(0, 4) + .aggregateWindows(Aggregation + .lead(1, decDefaultOutput) + .onColumn(6) + .overWindow(options1)); + ColumnVector expectAggResult = ColumnVector.fromBoxedInts(5, 1, 9, -3, 9, 8, 2, -7, 0, 6, 6, -11); + ColumnVector decExpectAggResult = ColumnVector.decimalFromLongs(-2, 5, 1, 9, -3, 9, 8, 2, -7, 0, 6, 6, -11)) { + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); + } } // Outside bounds - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation - .lead(3) - .onColumn(3) - .overWindow(windowBuilder.window(0,1).build())); - Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation - .lead(3) - .onColumn(6) - .overWindow(windowBuilder.window(0,1).build())); - ColumnVector expectAggResult = ColumnVector.fromBoxedInts(null, null, null, null, null, null, null, null, null, null, null, null); - ColumnVector decExpectAggResult = decimalFromBoxedInts(true, -2, null, null, null, null, null, null, null, null, null, null, null, null)) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); - assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); + try (WindowOptions options = windowBuilder.window(0,1).build(); + WindowOptions options1 = windowBuilder.window(0,1).build()) { + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation + .lead(3) + .onColumn(3) + .overWindow(options)); + Table decWindowAggResults = sorted.groupBy(0, 4) + .aggregateWindows(Aggregation + .lead(3) + .onColumn(6) + .overWindow(options1)); + ColumnVector expectAggResult = ColumnVector.fromBoxedInts(null, null, null, null, null, null, null, null, null, null, null, null); + ColumnVector decExpectAggResult = decimalFromBoxedInts(true, -2, null, null, null, null, null, null, null, null, null, null, null, null)) { + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); + } } } } @@ -3128,71 +3155,83 @@ void testWindowingLag() { WindowOptions.Builder windowBuilder = WindowOptions.builder().minPeriods(1); - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation - .lag(0) - .onColumn(3) - .overWindow(windowBuilder.window(2,1).build())); - Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation - .lag(0) - .onColumn(6) - .overWindow(windowBuilder.window(2,1).build())); - ColumnVector expectAggResult = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6); - ColumnVector decExpectAggResult = ColumnVector.decimalFromLongs(-2, 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6)) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); - assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); + try (WindowOptions options = windowBuilder.window(2,1).build(); + WindowOptions options1 = windowBuilder.window(2,1).build()) { + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation + .lag(0) + .onColumn(3) + .overWindow(options)); + Table decWindowAggResults = sorted.groupBy(0, 4) + .aggregateWindows(Aggregation + .lag(0) + .onColumn(6) + .overWindow(options1)); + ColumnVector expectAggResult = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6); + ColumnVector decExpectAggResult = ColumnVector.decimalFromLongs(-2, 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6)) { + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); + } } - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation - .lag(1) - .onColumn(3) - .overWindow(windowBuilder.window(2,0).build())); - Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation - .lag(1) - .onColumn(6) - .overWindow(windowBuilder.window(2,0).build())); - ColumnVector expectAggResult = ColumnVector.fromBoxedInts(null, 7, 5, 1, null, 7, 9, 8, null, 8, 0, 6); - ColumnVector decExpectAggResult = decimalFromBoxedInts(true, -2, null, 7, 5, 1, null, 7, 9, 8, null, 8, 0, 6)) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); - assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); + try (WindowOptions options = windowBuilder.window(2,0).build(); + WindowOptions options1 = windowBuilder.window(2,0).build()) { + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation + .lag(1) + .onColumn(3) + .overWindow(options)); + Table decWindowAggResults = sorted.groupBy(0, 4) + .aggregateWindows(Aggregation + .lag(1) + .onColumn(6) + .overWindow(options1)); + ColumnVector expectAggResult = ColumnVector.fromBoxedInts(null, 7, 5, 1, null, 7, 9, 8, null, 8, 0, 6); + ColumnVector decExpectAggResult = decimalFromBoxedInts(true, -2, null, 7, 5, 1, null, 7, 9, 8, null, 8, 0, 6)) { + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); + } } - try (ColumnVector defaultOutput = ColumnVector.fromBoxedInts(0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11); - ColumnVector decDefaultOutput = ColumnVector.decimalFromLongs(-2, 0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11); - Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation - .lag(1, defaultOutput) - .onColumn(3) - .overWindow(windowBuilder.window(2, 0).build())); - Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation - .lag(1, decDefaultOutput) - .onColumn(6) - .overWindow(windowBuilder.window(2, 0).build())); - ColumnVector expectAggResult = ColumnVector.fromBoxedInts(0, 7, 5, 1, -4, 7, 9, 8, -8, 8, 0, 6); - ColumnVector decExpectAggResult = ColumnVector.decimalFromLongs(-2, 0, 7, 5, 1, -4, 7, 9, 8, -8, 8, 0, 6)) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); - assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); + try (WindowOptions options = windowBuilder.window(2, 0).build(); + WindowOptions options1 = windowBuilder.window(2, 0).build()) { + try (ColumnVector defaultOutput = ColumnVector.fromBoxedInts(0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11); + ColumnVector decDefaultOutput = ColumnVector.decimalFromLongs(-2, 0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11); + Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation + .lag(1, defaultOutput) + .onColumn(3) + .overWindow(options)); + Table decWindowAggResults = sorted.groupBy(0, 4) + .aggregateWindows(Aggregation + .lag(1, decDefaultOutput) + .onColumn(6) + .overWindow(options1)); + ColumnVector expectAggResult = ColumnVector.fromBoxedInts(0, 7, 5, 1, -4, 7, 9, 8, -8, 8, 0, 6); + ColumnVector decExpectAggResult = ColumnVector.decimalFromLongs(-2, 0, 7, 5, 1, -4, 7, 9, 8, -8, 8, 0, 6)) { + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); + } } // Outside bounds - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation - .lag(3) - .onColumn(3) - .overWindow(windowBuilder.window(1, 0).build())); - Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation - .lag(3) - .onColumn(6) - .overWindow(windowBuilder.window(1, 0).build())); - ColumnVector expectAggResult = ColumnVector.fromBoxedInts(null, null, null, null, null, null, null, null, null, null, null, null); - ColumnVector decExpectAggResult = decimalFromBoxedInts(true, -2, null, null, null, null, null, null, null, null, null, null, null, null);) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); - assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); + try (WindowOptions options = windowBuilder.window(1, 0).build(); + WindowOptions options1 = windowBuilder.window(1, 0).build()) { + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation + .lag(3) + .onColumn(3) + .overWindow(options)); + Table decWindowAggResults = sorted.groupBy(0, 4) + .aggregateWindows(Aggregation + .lag(3) + .onColumn(6) + .overWindow(options1)); + ColumnVector expectAggResult = ColumnVector.fromBoxedInts(null, null, null, null, null, null, null, null, null, null, null, null); + ColumnVector decExpectAggResult = decimalFromBoxedInts(true, -2, null, null, null, null, null, null, null, null, null, null, null, null);) { + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); + } } } } @@ -3201,24 +3240,25 @@ void testWindowingLag() { @Test void testWindowingMean() { try (Table unsorted = new Table.TestBuilder().column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column( 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3) // GBY Key - .column( 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6) // OBY Key - .column( 7, 5, 3, 7, 7, 9, 8, 4, 8, 0, 4, 8) // Agg Column - .build()) { + .column( 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3) // GBY Key + .column( 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6) // OBY Key + .column( 7, 5, 3, 7, 7, 9, 8, 4, 8, 0, 4, 8) // Agg Column + .build()) { try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2)); ColumnVector expectedSortedAggCol = ColumnVector.fromBoxedInts(7, 5, 3, 7, 7, 9, 8, 4, 8, 0, 4, 8)) { ColumnVector sortedAggColumn = sorted.getColumn(3); assertColumnsAreEqual(expectedSortedAggCol, sortedAggColumn); - WindowOptions window = WindowOptions.builder() - .minPeriods(1) - .window(2, 1) - .build(); + try (WindowOptions window = WindowOptions.builder() + .minPeriods(1) + .window(2, 1) + .build();) { - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation.mean().onColumn(3).overWindow(window)); - ColumnVector expect = ColumnVector.fromBoxedDoubles(6.0d, 5.0d, 5.0d, 5.0d, 8.0d, 8.0d, 7.0d, 6.0d, 4.0d, 4.0d, 4.0d, 6.0d)) { - assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation.mean().onColumn(3).overWindow(window)); + ColumnVector expect = ColumnVector.fromBoxedDoubles(6.0d, 5.0d, 5.0d, 5.0d, 8.0d, 8.0d, 7.0d, 6.0d, 4.0d, 4.0d, 4.0d, 6.0d)) { + assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); + } } } } @@ -3227,48 +3267,50 @@ void testWindowingMean() { @Test void testWindowingOnMultipleDifferentColumns() { try (Table unsorted = new Table.TestBuilder().column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column( 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3) // GBY Key - .column( 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6) // OBY Key - .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6) // Agg Column - .build()) { + .column( 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3) // GBY Key + .column( 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6) // OBY Key + .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6) // Agg Column + .build()) { try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2)); ColumnVector expectedSortedAggCol = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6)) { ColumnVector sortedAggColumn = sorted.getColumn(3); assertColumnsAreEqual(expectedSortedAggCol, sortedAggColumn); + try ( // Window (1,1), with a minimum of 1 reading. WindowOptions window_1 = WindowOptions.builder() - .minPeriods(1) - .window(2, 1) - .build(); + .minPeriods(1) + .window(2, 1) + .build(); // Window (2,2), with a minimum of 2 readings. WindowOptions window_2 = WindowOptions.builder() - .minPeriods(2) - .window(3, 2) - .build(); + .minPeriods(2) + .window(3, 2) + .build(); // Window (1,1), with a minimum of 3 readings. WindowOptions window_3 = WindowOptions.builder() - .minPeriods(3) - .window(2, 1) - .build(); - - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows( - Aggregation.sum().onColumn(3).overWindow(window_1), - Aggregation.max().onColumn(3).overWindow(window_1), - Aggregation.sum().onColumn(3).overWindow(window_2), - Aggregation.min().onColumn(2).overWindow(window_3) - ); - ColumnVector expect_0 = ColumnVector.fromBoxedLongs(12L, 13L, 15L, 10L, 16L, 24L, 19L, 10L, 8L, 14L, 12L, 12L); - ColumnVector expect_1 = ColumnVector.fromBoxedInts(7, 7, 9, 9, 9, 9, 9, 8, 8, 8, 6, 6); - ColumnVector expect_2 = ColumnVector.fromBoxedLongs(13L, 22L, 22L, 15L, 24L, 26L, 26L, 19L, 14L, 20L, 20L, 12L); - ColumnVector expect_3 = ColumnVector.fromBoxedInts(null, 1, 1, null, null, 3, 3, null, null, 5, 5, null)) { - assertColumnsAreEqual(expect_0, windowAggResults.getColumn(0)); - assertColumnsAreEqual(expect_1, windowAggResults.getColumn(1)); - assertColumnsAreEqual(expect_2, windowAggResults.getColumn(2)); - assertColumnsAreEqual(expect_3, windowAggResults.getColumn(3)); + .minPeriods(3) + .window(2, 1) + .build();) { + + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows( + Aggregation.sum().onColumn(3).overWindow(window_1), + Aggregation.max().onColumn(3).overWindow(window_1), + Aggregation.sum().onColumn(3).overWindow(window_2), + Aggregation.min().onColumn(2).overWindow(window_3) + ); + ColumnVector expect_0 = ColumnVector.fromBoxedLongs(12L, 13L, 15L, 10L, 16L, 24L, 19L, 10L, 8L, 14L, 12L, 12L); + ColumnVector expect_1 = ColumnVector.fromBoxedInts(7, 7, 9, 9, 9, 9, 9, 8, 8, 8, 6, 6); + ColumnVector expect_2 = ColumnVector.fromBoxedLongs(13L, 22L, 22L, 15L, 24L, 26L, 26L, 19L, 14L, 20L, 20L, 12L); + ColumnVector expect_3 = ColumnVector.fromBoxedInts(null, 1, 1, null, null, 3, 3, null, null, 5, 5, null)) { + assertColumnsAreEqual(expect_0, windowAggResults.getColumn(0)); + assertColumnsAreEqual(expect_1, windowAggResults.getColumn(1)); + assertColumnsAreEqual(expect_2, windowAggResults.getColumn(2)); + assertColumnsAreEqual(expect_3, windowAggResults.getColumn(3)); + } } } } @@ -3277,556 +3319,800 @@ void testWindowingOnMultipleDifferentColumns() { @Test void testWindowingWithoutGroupByColumns() { try (Table unsorted = new Table.TestBuilder().column( 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6) // OBY Key - .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6) // Agg Column - .build(); + .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6) // Agg Column + .build(); ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6)) { try (Table sorted = unsorted.orderBy(OrderByArg.asc(0))) { ColumnVector sortedAggColumn = sorted.getColumn(1); assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); - WindowOptions window = WindowOptions.builder() - .minPeriods(1) - .window(2, 1) - .build(); + try (WindowOptions window = WindowOptions.builder() + .minPeriods(1) + .window(2, 1) + .build();) { - try (Table windowAggResults = sorted.groupBy().aggregateWindows( - Aggregation.sum().onColumn(1).overWindow(window)); - ColumnVector expectAggResult = ColumnVector.fromBoxedLongs(12L, 13L, 15L, 17L, 25L, 24L, 19L, 18L, 10L, 14L, 12L, 12L); - ) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + try (Table windowAggResults = sorted.groupBy().aggregateWindows( + Aggregation.sum().onColumn(1).overWindow(window)); + ColumnVector expectAggResult = ColumnVector.fromBoxedLongs(12L, 13L, 15L, 17L, 25L, 24L, 19L, 18L, 10L, 14L, 12L, 12L); + ) { + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + } } } } } - @Test - void testTimeRangeWindowingCount() { - try (Table unsorted = new Table.TestBuilder().column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column( 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key - .timestampDayColumn( 1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // Timestamp Key - .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column - .build()) { - try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2)); - ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { - ColumnVector sortedAggColumn = sorted.getColumn(3); - assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); + private Scalar getScalar(DType type, long value) { + if (type.equals(DType.INT32)) { + return Scalar.fromInt((int) value); + } else if (type.equals(DType.INT64)) { + return Scalar.fromLong(value); + } else if (type.equals(DType.INT16)) { + return Scalar.fromShort((short) value); + } else if (type.equals(DType.INT8)) { + return Scalar.fromByte((byte) value); + } else if (type.equals(DType.UINT8)) { + return Scalar.fromUnsignedByte((byte) value); + } else if (type.equals(DType.UINT16)) { + return Scalar.fromUnsignedShort((short) value); + } else if (type.equals(DType.UINT32)) { + return Scalar.fromUnsignedInt((int) value); + } else if (type.equals(DType.UINT64)) { + return Scalar.fromUnsignedLong(value); + } else if (type.equals(DType.TIMESTAMP_DAYS)) { + return Scalar.durationFromLong(DType.DURATION_DAYS, value); + } else if (type.equals(DType.TIMESTAMP_SECONDS)) { + return Scalar.durationFromLong(DType.DURATION_SECONDS, value); + } else if (type.equals(DType.TIMESTAMP_MILLISECONDS)) { + return Scalar.durationFromLong(DType.DURATION_MILLISECONDS, value); + } else if (type.equals(DType.TIMESTAMP_MICROSECONDS)) { + return Scalar.durationFromLong(DType.DURATION_MICROSECONDS, value); + } else if (type.equals(DType.TIMESTAMP_NANOSECONDS)) { + return Scalar.durationFromLong(DType.DURATION_NANOSECONDS, value); + } else { + return Scalar.fromNull(type); + } + } + + @Test + void testRangeWindowingCount() { + try ( + Table unsorted = new Table.TestBuilder() + .column(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column(0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key + .column(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column + .column(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) // orderBy Key + .column((short) 1, (short)1, (short)2, (short)3, (short)3, (short)3, (short)4, (short)4, (short)5, (short)5, (short)6, (short)6, (short)7) // orderBy Key + .column(1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // orderBy Key + .column((byte) 1, (byte)1, (byte)2, (byte)3, (byte)3, (byte)3, (byte)4, (byte)4, (byte)5, (byte)5, (byte)6, (byte)6, (byte)7) // orderBy Key + .timestampDayColumn(1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // timestamp orderBy Key + .timestampSecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) + .timestampMicrosecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) + .timestampMillisecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) + .timestampNanosecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) + .build()) { - WindowOptions window = WindowOptions.builder() - .minPeriods(1) - .window(1, 1) - .timestampColumnIndex(2) - .build(); + for (int orderIndex = 3; orderIndex < unsorted.getNumberOfColumns(); orderIndex++) { + try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(orderIndex)); + ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { + ColumnVector sortedAggColumn = sorted.getColumn(2); + assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); - try (Table windowAggResults = sorted.groupBy(0, 1).aggregateWindowsOverTimeRanges( - Aggregation.count().onColumn(3).overWindow(window)); - ColumnVector expect = ColumnVector.fromBoxedInts(3, 3, 4, 2, 4, 4, 4, 4, 4, 4, 5, 5, 3)) { - assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); + DType type = unsorted.getColumn(orderIndex).getType(); + try (Scalar preceding = getScalar(type, 1L); + Scalar following = getScalar(type, 1L)) { + try (WindowOptions window = WindowOptions.builder() + .minPeriods(1) + .window(preceding, following) + .orderByColumnIndex(orderIndex) + .build()) { + try (Table windowAggResults = sorted.groupBy(0, 1).aggregateWindowsOverRanges( + Aggregation.count().onColumn(2).overWindow(window)); + ColumnVector expect = ColumnVector.fromBoxedInts(3, 3, 4, 2, 4, 4, 4, 4, 4, 4, 5, 5, 3)) { + assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); + } + } + } } } } } @Test - void testTimeRangeWindowingLead() { + void testRangeWindowingLead() { try (Table unsorted = new Table.TestBuilder() - .column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column( 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key - .timestampDayColumn( 1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // Timestamp Key - .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column + .column(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column(0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key + .column(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column + .column(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) // orderBy Key + .column((short) 1, (short)1, (short)2, (short)3, (short)3, (short)3, (short)4, (short)4, (short)5, (short)5, (short)6, (short)6, (short)7) // orderBy Key + .column(1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // orderBy Key + .column((byte) 1, (byte)1, (byte)2, (byte)3, (byte)3, (byte)3, (byte)4, (byte)4, (byte)5, (byte)5, (byte)6, (byte)6, (byte)7) // orderBy Key + .timestampDayColumn(1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // Timestamp orderBy Key + .timestampSecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) + .timestampMicrosecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) + .timestampMillisecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) + .timestampNanosecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) .build()) { - try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2)); - ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { - ColumnVector sortedAggColumn = sorted.getColumn(3); - assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); - WindowOptions window = WindowOptions.builder() - .minPeriods(1) - .window(1, 1) - .timestampColumnIndex(2) - .build(); + for (int orderIndex = 3; orderIndex < unsorted.getNumberOfColumns(); orderIndex++) { + try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(orderIndex)); + ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { + ColumnVector sortedAggColumn = sorted.getColumn(2); + assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindowsOverTimeRanges(Aggregation.lead(1) - .onColumn(3) - .overWindow(window)); - ColumnVector expect = ColumnVector.fromBoxedInts(5, 1, 9, null, 9, 8, 2, null, 0, 6, 6, 8, null)) { - assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); + DType type = unsorted.getColumn(orderIndex).getType(); + try (Scalar preceding = getScalar(type, 1L); + Scalar following = getScalar(type, 1L)) { + try (WindowOptions window = WindowOptions.builder() + .minPeriods(1) + .window(preceding, following) + .orderByColumnIndex(orderIndex) + .build()) { + + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindowsOverRanges(Aggregation.lead(1) + .onColumn(2) + .overWindow(window)); + ColumnVector expect = ColumnVector.fromBoxedInts(5, 1, 9, null, 9, 8, 2, null, 0, 6, 6, 8, null)) { + assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); + } + } + } } } } } @Test - void testTimeRangeWindowingMax() { - try (Table unsorted = new Table.TestBuilder().column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column( 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key - .timestampDayColumn( 1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // Timestamp Key - .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column - .build()) { - try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2)); - ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { - ColumnVector sortedAggColumn = sorted.getColumn(3); - assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); - - WindowOptions window = WindowOptions.builder() - .minPeriods(1) - .window(1, 1) - .timestampColumnIndex(2) - .build(); + void testRangeWindowingMax() { + try (Table unsorted = new Table.TestBuilder() + .column(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column(0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key + .column(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column + .column(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) // orderBy Key + .column((short) 1, (short)1, (short)2, (short)3, (short)3, (short)3, (short)4, (short)4, (short)5, (short)5, (short)6, (short)6, (short)7) // orderBy Key + .column(1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // orderBy Key + .column((byte) 1, (byte)1, (byte)2, (byte)3, (byte)3, (byte)3, (byte)4, (byte)4, (byte)5, (byte)5, (byte)6, (byte)6, (byte)7) // orderBy Key + .timestampDayColumn(1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // Timestamp orderBy Key + .timestampSecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) + .timestampMicrosecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) + .timestampMillisecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) + .timestampNanosecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) + .build()) { - try (Table windowAggResults = sorted.groupBy(0, 1).aggregateWindowsOverTimeRanges( - Aggregation.max().onColumn(3).overWindow(window)); - ColumnVector expect = ColumnVector.fromBoxedInts( 7, 7, 9, 9, 9, 9, 9, 9, 8, 8, 8, 8, 8)) { - assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); - } + for (int orderIndex = 3; orderIndex < unsorted.getNumberOfColumns(); orderIndex++) { + try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(orderIndex)); + ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { + ColumnVector sortedAggColumn = sorted.getColumn(2); + assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); - window = WindowOptions.builder() - .minPeriods(1) - .window(2, 1) - .build(); + DType type = unsorted.getColumn(orderIndex).getType(); + try (Scalar preceding = getScalar(type, 1L); + Scalar following = getScalar(type, 1L)) { + try (WindowOptions window = WindowOptions.builder() + .minPeriods(1) + .window(preceding, following) + .orderByColumnIndex(orderIndex) + .build()) { + + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindowsOverRanges(Aggregation.max().onColumn(2).overWindow(window)); + ColumnVector expect = ColumnVector.fromBoxedInts(7, 7, 9, 9, 9, 9, 9, 9, 8, 8, 8, 8, 8)) { + assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); + } + } - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation.max().onColumn(3).overWindow(window)); - ColumnVector expect = ColumnVector.fromBoxedInts( 7, 7, 9, 9, 9, 9, 9, 8, 8, 8, 6, 8, 8)) { - assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); + try (WindowOptions window = WindowOptions.builder() + .minPeriods(1) + .window(2, 1) + .build();) { + + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation.max().onColumn(2).overWindow(window)); + ColumnVector expect = ColumnVector.fromBoxedInts(7, 7, 9, 9, 9, 9, 9, 8, 8, 8, 6, 8, 8)) { + assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); + } + } + } } } } } @Test - void testTimeRangeWindowingRowNumber() { - try (Table unsorted = new Table.TestBuilder().column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column( 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key - .timestampDayColumn( 1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // Timestamp Key - .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column - .build()) { - try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2)); - ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { - ColumnVector sortedAggColumn = sorted.getColumn(3); - assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); + void testRangeWindowingRowNumber() { + try (Table unsorted = new Table.TestBuilder() + .column(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column(0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key + .column(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column + .column(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) // orderBy Key + .column((short) 1, (short)1, (short)2, (short)3, (short)3, (short)3, (short)4, (short)4, (short)5, (short)5, (short)6, (short)6, (short)7) // orderBy Key + .column(1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // orderBy Key + .column((byte) 1, (byte)1, (byte)2, (byte)3, (byte)3, (byte)3, (byte)4, (byte)4, (byte)5, (byte)5, (byte)6, (byte)6, (byte)7) // orderBy Key + .timestampDayColumn(1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // Timestamp orderBy Key + .timestampSecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) + .timestampMicrosecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) + .timestampMillisecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) + .timestampNanosecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) + .build()) { - WindowOptions window = WindowOptions.builder() - .minPeriods(1) - .window(3, 0) - .timestampColumnIndex(2) - .build(); + for (int orderIndex = 3; orderIndex < unsorted.getNumberOfColumns(); orderIndex++) { + try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(orderIndex)); + ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { + ColumnVector sortedAggColumn = sorted.getColumn(2); + assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindowsOverTimeRanges(Aggregation.rowNumber().onColumn(3).overWindow(window)); - ColumnVector expect = ColumnVector.fromBoxedInts(1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 5)) { - assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); + DType type = unsorted.getColumn(orderIndex).getType(); + try (Scalar preceding = getScalar(type, 2L); + Scalar following = getScalar(type, 0L)) { + try (WindowOptions window = WindowOptions.builder() + .minPeriods(1) + .window(preceding, following) + .orderByColumnIndex(orderIndex) + .build()) { + + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindowsOverRanges(Aggregation.rowNumber().onColumn(2).overWindow(window)); + ColumnVector expect = ColumnVector.fromBoxedInts(1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 5)) { + assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); + } + } + } } } } } @Test - void testTimeRangeWindowingCountDescendingTimestamps() { - try (Table unsorted = new Table.TestBuilder().column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column( 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key - .timestampDayColumn( 7, 6, 6, 5, 5, 4, 4, 3, 3, 3, 2, 1, 1) // Timestamp Key - .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column - .build()) { - try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.desc(2)); - ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { - ColumnVector sortedAggColumn = sorted.getColumn(3); - assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); + void testRangeWindowingCountDescendingTimestamps() { + try (Table unsorted = new Table.TestBuilder() + .column(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column(0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key + .column(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column + .column((short)7, (short)6, (short)6, (short)5, (short)5, (short)4, (short)4, (short)3, (short)3, (short)3, (short)2, (short)1, (short)1) + .column(7L, 6L, 6L, 5L, 5L, 4L, 4L, 3L, 3L, 3L, 2L, 1L, 1L) + .column(7, 6, 6, 5, 5, 4, 4, 3, 3, 3, 2, 1, 1) + .column((byte)7, (byte)6, (byte)6, (byte)5, (byte)5, (byte)4, (byte)4, (byte)3, (byte)3, (byte)3, (byte)2, (byte)1, (byte)1) + .timestampDayColumn(7, 6, 6, 5, 5, 4, 4, 3, 3, 3, 2, 1, 1) // Timestamp Key + .timestampSecondsColumn(7L, 6L, 6L, 5L, 5L, 4L, 4L, 3L, 3L, 3L, 2L, 1L, 1L) + .timestampMicrosecondsColumn(7L, 6L, 6L, 5L, 5L, 4L, 4L, 3L, 3L, 3L, 2L, 1L, 1L) + .timestampMillisecondsColumn(7L, 6L, 6L, 5L, 5L, 4L, 4L, 3L, 3L, 3L, 2L, 1L, 1L) + .timestampNanosecondsColumn(7L, 6L, 6L, 5L, 5L, 4L, 4L, 3L, 3L, 3L, 2L, 1L, 1L) + .build()) { - WindowOptions window_0 = WindowOptions.builder() - .minPeriods(1) - .window(2, 1) - .timestampColumnIndex(2) - .timestampDescending() - .build(); + for (int orderIndex = 3; orderIndex < unsorted.getNumberOfColumns(); orderIndex++) { + try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.desc(orderIndex)); + ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { + ColumnVector sortedAggColumn = sorted.getColumn(2); + assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); + + DType type = unsorted.getColumn(orderIndex).getType(); + try (Scalar preceding_0 = getScalar(type, 2L); + Scalar following_0 = getScalar(type, 1L); + Scalar preceding_1 = getScalar(type, 3L); + Scalar following_1 = getScalar(type, 0L)) { + + try (WindowOptions window_0 = WindowOptions.builder() + .minPeriods(1) + .window(preceding_0, following_0) + .orderByColumnIndex(orderIndex) + .orderByDescending() + .build(); + + WindowOptions window_1 = WindowOptions.builder() + .minPeriods(1) + .window(preceding_1, following_1) + .orderByColumnIndex(orderIndex) + .orderByDescending() + .build();) { + + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindowsOverRanges( + Aggregation.count().onColumn(2).overWindow(window_0), + Aggregation.sum().onColumn(2).overWindow(window_1)); + ColumnVector expect_0 = ColumnVector.fromBoxedInts(3, 4, 4, 4, 3, 4, 4, 4, 3, 3, 5, 5, 5); + ColumnVector expect_1 = ColumnVector.fromBoxedLongs(7L, 13L, 13L, 22L, 7L, 24L, 24L, 26L, 8L, 8L, 14L, 28L, 28L)) { + assertColumnsAreEqual(expect_0, windowAggResults.getColumn(0)); + assertColumnsAreEqual(expect_1, windowAggResults.getColumn(1)); + } + } + } + } + } + } + } - WindowOptions window_1 = WindowOptions.builder() - .minPeriods(1) - .window(3, 0) - .timestampColumnIndex(2) - .timestampDescending() - .build(); + @Test + void testRangeWindowingWithoutGroupByColumns() { + try (Table unsorted = new Table.TestBuilder() + .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column + .column(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) // orderBy Key + .column((short) 1, (short)1, (short)2, (short)3, (short)3, (short)3, (short)4, (short)4, (short)5, (short)5, (short)6, (short)6, (short)7) // orderBy Key + .column(1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // orderBy Key + .column((byte) 1, (byte)1, (byte)2, (byte)3, (byte)3, (byte)3, (byte)4, (byte)4, (byte)5, (byte)5, (byte)6, (byte)6, (byte)7) // orderBy Key + .timestampDayColumn(1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // Timestamp orderBy Key + .timestampSecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) + .timestampMicrosecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) + .timestampMillisecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) + .timestampNanosecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) + .build()) { + + for (int orderIndex = 3; orderIndex < unsorted.getNumberOfColumns(); orderIndex++) { + try (Table sorted = unsorted.orderBy(OrderByArg.asc(orderIndex)); + ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { + ColumnVector sortedAggColumn = sorted.getColumn(0); + assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindowsOverTimeRanges( - Aggregation.count().onColumn(3).overWindow(window_0), - Aggregation.sum().onColumn(3).overWindow(window_1)); - ColumnVector expect_0 = ColumnVector.fromBoxedInts(3, 4, 4, 4, 3, 4, 4, 4, 3, 3, 5, 5, 5); - ColumnVector expect_1 = ColumnVector.fromBoxedLongs(7L, 13L, 13L, 22L, 7L, 24L, 24L, 26L, 8L, 8L, 14L, 28L, 28L)) { - assertColumnsAreEqual(expect_0, windowAggResults.getColumn(0)); - assertColumnsAreEqual(expect_1, windowAggResults.getColumn(1)); + DType type = unsorted.getColumn(orderIndex).getType(); + try (Scalar preceding = getScalar(type, 1L); + Scalar following = getScalar(type, 1L)) { + try (WindowOptions window = WindowOptions.builder() + .minPeriods(1) + .window(preceding, following) + .orderByColumnIndex(orderIndex) + .build();) { + + try (Table windowAggResults = sorted.groupBy() + .aggregateWindowsOverRanges(Aggregation.count().onColumn(1).overWindow(window)); + ColumnVector expect = ColumnVector.fromBoxedInts(3, 3, 6, 6, 6, 6, 7, 7, 6, 6, 5, 5, 3)) { + assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); + } + } + } } } } } @Test - void testTimeRangeWindowingWithoutGroupByColumns() { - try (Table unsorted = new Table.TestBuilder().timestampDayColumn( 1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // Timestamp Key - .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column - .build()) { - try (Table sorted = unsorted.orderBy(OrderByArg.asc(0)); - ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { - ColumnVector sortedAggColumn = sorted.getColumn(1); - assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); + void testRangeWindowingOrderByUnsupportedDataTypeExceptions() { + try (Table table = new Table.TestBuilder() + .column(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column(0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key + .column(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column + .column(true, false, true, false, true, false, true, false, false, false, false, false, false) // orderBy Key + .build()) { - WindowOptions window = WindowOptions.builder() - .minPeriods(1) - .window(1, 1) - .timestampColumnIndex(0) - .build(); + try (WindowOptions rangeBasedWindow = WindowOptions.builder() + .minPeriods(1) + .window(1, 1) + .orderByColumnIndex(3) + .build();) { - try (Table windowAggResults = sorted.groupBy() - .aggregateWindowsOverTimeRanges(Aggregation.count().onColumn(1).overWindow(window)); - ColumnVector expect = ColumnVector.fromBoxedInts(3, 3, 6, 6, 6, 6, 7, 7, 6, 6, 5, 5, 3)) { - assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); - } + assertThrows(IllegalArgumentException.class, + () -> table + .groupBy(0, 1) + .aggregateWindowsOverRanges(Aggregation.max().onColumn(2).overWindow(rangeBasedWindow))); } } } @Test void testInvalidWindowTypeExceptions() { - try (Table table = new Table.TestBuilder().column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column( 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key - .timestampDayColumn( 1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // Timestamp Key - .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column - .build()) { + try (Table table = new Table.TestBuilder().column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column( 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key + .timestampDayColumn( 1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // Timestamp Key + .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column + .build()) { - WindowOptions rowBasedWindow = WindowOptions.builder().minPeriods(1).window(1,1).build(); - assertThrows(IllegalArgumentException.class, () -> - table.groupBy(0, 1) - .aggregateWindowsOverTimeRanges(Aggregation.max().onColumn(3).overWindow(rowBasedWindow))); - WindowOptions rangeBasedWindow = WindowOptions.builder().minPeriods(1).window(1,1).timestampColumnIndex(2).build(); - assertThrows(IllegalArgumentException.class, () -> - table.groupBy(0, 1) - .aggregateWindows(Aggregation.max().onColumn(3).overWindow(rangeBasedWindow))); + try (WindowOptions rowBasedWindow = WindowOptions.builder().minPeriods(1).window(1,1).build();) { + assertThrows(IllegalArgumentException.class, () -> table.groupBy(0, 1).aggregateWindowsOverRanges(Aggregation.max().onColumn(3).overWindow(rowBasedWindow))); + } + try (WindowOptions rangeBasedWindow = WindowOptions.builder().minPeriods(1).window(1,1).orderByColumnIndex(2).build();) { + assertThrows(IllegalArgumentException.class, () -> table.groupBy(0, 1).aggregateWindows(Aggregation.max().onColumn(3).overWindow(rangeBasedWindow))); } + } } @Test - void testTimeRangeWindowingCountUnboundedPreceding() { - try (Table unsorted = new Table.TestBuilder().column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column( 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key - .timestampDayColumn( 1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // Timestamp Key - .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column - .build()) { - try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2)); - ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { - ColumnVector sortedAggColumn = sorted.getColumn(3); - assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); + void testRangeWindowingCountUnboundedPreceding() { + try (Table unsorted = new Table.TestBuilder() + .column(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column(0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key + .column(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column + .column(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) // orderBy Key + .column((short) 1, (short)1, (short)2, (short)3, (short)3, (short)3, (short)4, (short)4, (short)5, (short)5, (short)6, (short)6, (short)7) // orderBy Key + .column(1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // orderBy Key + .column((byte) 1, (byte)1, (byte)2, (byte)3, (byte)3, (byte)3, (byte)4, (byte)4, (byte)5, (byte)5, (byte)6, (byte)6, (byte)7) // orderBy Key + .timestampDayColumn(1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // Timestamp orderBy Key + .timestampSecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) + .timestampMicrosecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) + .timestampMillisecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) + .timestampNanosecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) + .build()) { - WindowOptions window = WindowOptions.builder() - .minPeriods(1) - .unboundedPreceding() - .following(1) - .timestampColumnIndex(2) - .build(); + for (int orderIndex = 3; orderIndex < unsorted.getNumberOfColumns(); orderIndex++) { + try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(orderIndex)); + ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { + ColumnVector sortedAggColumn = sorted.getColumn(2); + assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindowsOverTimeRanges(Aggregation.count().onColumn(3).overWindow(window)); - ColumnVector expect = ColumnVector.fromBoxedInts(3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5)) { - assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); + DType type = unsorted.getColumn(orderIndex).getType(); + try (Scalar following = getScalar(type, 1L)) { + try (WindowOptions window = WindowOptions.builder() + .minPeriods(1) + .unboundedPreceding() + .following(following) + .orderByColumnIndex(orderIndex) + .build();) { + + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindowsOverRanges(Aggregation.count().onColumn(2).overWindow(window)); + ColumnVector expect = ColumnVector.fromBoxedInts(3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5)) { + assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); + } + } + } } } } } @Test - void testTimeRangeWindowingCountUnboundedASCWithNullsFirst() { - Integer X = null; + void testRangeWindowingCountUnboundedASCWithNullsFirst() { try (Table unsorted = new Table.TestBuilder() - .column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column( 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .timestampDayColumn( X, X, X, 2, 3, 5, X, X, 1, 2, 4, 5, 7) // Timestamp Key - .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column - .build()) { - try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2, true)); - ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { - ColumnVector sortedAggColumn = sorted.getColumn(3); - assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); + .column(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column(0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column + .column( null, null, null, 2, 3, 5, null, null, 1, 2, 4, 5, 7) // Timestamp Key + .column( null, null, null, 2L, 3L, 5L, null, null, 1L, 2L, 4L, 5L, 7L) // orderBy Key + .column( null, null, null, (short)2, (short)3, (short)5, null, null, (short)1, (short)2, (short)4, (short)5, (short)7) // orderBy Key + .column( null, null, null, (byte)2, (byte)3, (byte)5, null, null, (byte)1, (byte)2, (byte)4, (byte)5, (byte)7) // orderBy Key + .timestampDayColumn( null, null, null, 2, 3, 5, null, null, 1, 2, 4, 5, 7) // Timestamp orderBy Key + .timestampSecondsColumn( null, null, null, 2L, 3L, 5L, null, null, 1L, 2L, 4L, 5L, 7L) + .timestampMicrosecondsColumn( null, null, null, 2L, 3L, 5L, null, null, 1L, 2L, 4L, 5L, 7L) + .timestampMillisecondsColumn( null, null, null, 2L, 3L, 5L, null, null, 1L, 2L, 4L, 5L, 7L) + .timestampNanosecondsColumn( null, null, null, 2L, 3L, 5L, null, null, 1L, 2L, 4L, 5L, 7L) + .build()) { - WindowOptions unboundedPrecedingOneFollowing = WindowOptions.builder() - .minPeriods(1) - .unboundedPreceding() - .following(1) - .timestampColumnIndex(2) - .build(); + for (int orderIndex = 3; orderIndex < unsorted.getNumberOfColumns(); orderIndex++) { + try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(orderIndex, true)); + ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { + ColumnVector sortedAggColumn = sorted.getColumn(2); + assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); + + DType type = unsorted.getColumn(orderIndex).getType(); + try (Scalar following1 = getScalar(type, 1L); + Scalar preceding1 = getScalar(type, 1L); + Scalar following0 = getScalar(type, 0L); + Scalar preceding0 = getScalar(type, 0L);) { + try (WindowOptions unboundedPrecedingOneFollowing = WindowOptions.builder() + .minPeriods(1) + .unboundedPreceding() + .following(following1) + .orderByColumnIndex(orderIndex) + .build(); - WindowOptions onePrecedingUnboundedFollowing = WindowOptions.builder() + WindowOptions onePrecedingUnboundedFollowing = WindowOptions.builder() .minPeriods(1) - .preceding(1) + .preceding(preceding1) .unboundedFollowing() - .timestampColumnIndex(2) + .orderByColumnIndex(orderIndex) .build(); - WindowOptions unboundedPrecedingAndFollowing = WindowOptions.builder() + WindowOptions unboundedPrecedingAndFollowing = WindowOptions.builder() .minPeriods(1) .unboundedPreceding() .unboundedFollowing() - .timestampColumnIndex(2) + .orderByColumnIndex(orderIndex) .build(); - WindowOptions unboundedPrecedingAndCurrentRow = WindowOptions.builder() + WindowOptions unboundedPrecedingAndCurrentRow = WindowOptions.builder() .minPeriods(1) .unboundedPreceding() - .following(0) - .timestampColumnIndex(2) + .following(following0) + .orderByColumnIndex(orderIndex) .build(); - WindowOptions currentRowAndUnboundedFollowing = WindowOptions.builder() + WindowOptions currentRowAndUnboundedFollowing = WindowOptions.builder() .minPeriods(1) - .preceding(0) + .preceding(preceding0) .unboundedFollowing() - .timestampColumnIndex(2) - .build(); - - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindowsOverTimeRanges( - Aggregation.count().onColumn(3).overWindow(unboundedPrecedingOneFollowing), - Aggregation.count().onColumn(3).overWindow(onePrecedingUnboundedFollowing), - Aggregation.count().onColumn(3).overWindow(unboundedPrecedingAndFollowing), - Aggregation.count().onColumn(3).overWindow(unboundedPrecedingAndCurrentRow), - Aggregation.count().onColumn(3).overWindow(currentRowAndUnboundedFollowing)); - ColumnVector expect_0 = ColumnVector.fromBoxedInts(3, 3, 3, 5, 5, 6, 2, 2, 4, 4, 6, 6, 7); - ColumnVector expect_1 = ColumnVector.fromBoxedInts(6, 6, 6, 3, 3, 1, 7, 7, 5, 5, 3, 3, 1); - ColumnVector expect_2 = ColumnVector.fromBoxedInts(6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7); - ColumnVector expect_3 = ColumnVector.fromBoxedInts(3, 3, 3, 4, 5, 6, 2, 2, 3, 4, 5, 6, 7); - ColumnVector expect_4 = ColumnVector.fromBoxedInts(6, 6, 6, 3, 2, 1, 7, 7, 5, 4, 3, 2, 1)) { - - assertColumnsAreEqual(expect_0, windowAggResults.getColumn(0)); - assertColumnsAreEqual(expect_1, windowAggResults.getColumn(1)); - assertColumnsAreEqual(expect_2, windowAggResults.getColumn(2)); - assertColumnsAreEqual(expect_3, windowAggResults.getColumn(3)); - assertColumnsAreEqual(expect_4, windowAggResults.getColumn(4)); + .orderByColumnIndex(orderIndex) + .build();) { + + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindowsOverRanges( + Aggregation.count().onColumn(2).overWindow(unboundedPrecedingOneFollowing), + Aggregation.count().onColumn(2).overWindow(onePrecedingUnboundedFollowing), + Aggregation.count().onColumn(2).overWindow(unboundedPrecedingAndFollowing), + Aggregation.count().onColumn(2).overWindow(unboundedPrecedingAndCurrentRow), + Aggregation.count().onColumn(2).overWindow(currentRowAndUnboundedFollowing)); + ColumnVector expect_0 = ColumnVector.fromBoxedInts(3, 3, 3, 5, 5, 6, 2, 2, 4, 4, 6, 6, 7); + ColumnVector expect_1 = ColumnVector.fromBoxedInts(6, 6, 6, 3, 3, 1, 7, 7, 5, 5, 3, 3, 1); + ColumnVector expect_2 = ColumnVector.fromBoxedInts(6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7); + ColumnVector expect_3 = ColumnVector.fromBoxedInts(3, 3, 3, 4, 5, 6, 2, 2, 3, 4, 5, 6, 7); + ColumnVector expect_4 = ColumnVector.fromBoxedInts(6, 6, 6, 3, 2, 1, 7, 7, 5, 4, 3, 2, 1)) { + + assertColumnsAreEqual(expect_0, windowAggResults.getColumn(0)); + assertColumnsAreEqual(expect_1, windowAggResults.getColumn(1)); + assertColumnsAreEqual(expect_2, windowAggResults.getColumn(2)); + assertColumnsAreEqual(expect_3, windowAggResults.getColumn(3)); + assertColumnsAreEqual(expect_4, windowAggResults.getColumn(4)); + } + } + } } } } } @Test - void testTimeRangeWindowingCountUnboundedDESCWithNullsFirst() { - Integer X = null; + void testRangeWindowingCountUnboundedDESCWithNullsFirst() { try (Table unsorted = new Table.TestBuilder() - .column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column( 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .timestampDayColumn( X, X, X, 5, 3, 2, X, X, 7, 5, 4, 2, 1) // Timestamp Key - .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column - .build()) { - try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.desc(2, false)); - ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { - ColumnVector sortedAggColumn = sorted.getColumn(3); - assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); + .column(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column(0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column + .column(null, null, null, 5, 3, 2, null, null, 7, 5, 4, 2, 1) // Timestamp Key + .column(null, null, null, 5L, 3L, 2L, null, null, 7L, 5L, 4L, 2L, 1L) // orderby Key + .column(null, null, null, (short)5, (short)3, (short)2, null, null, (short)7, (short)5, (short)4, (short)2, (short)1) // orderby Key + .column(null, null, null, (byte)5, (byte)3, (byte)2, null, null, (byte)7, (byte)5, (byte)4, (byte)2, (byte)1) // orderby Key + .timestampDayColumn(null, null, null, 5, 3, 2, null, null, 7, 5, 4, 2, 1) // Timestamp orderby Key + .timestampSecondsColumn( null, null, null, 5L, 3L, 2L, null, null, 7L, 5L, 4L, 2L, 1L) + .timestampMicrosecondsColumn( null, null, null, 5L, 3L, 2L, null, null, 7L, 5L, 4L, 2L, 1L) + .timestampMillisecondsColumn( null, null, null, 5L, 3L, 2L, null, null, 7L, 5L, 4L, 2L, 1L) + .timestampNanosecondsColumn( null, null, null, 5L, 3L, 2L, null, null, 7L, 5L, 4L, 2L, 1L) + .build()) { - WindowOptions unboundedPrecedingOneFollowing = WindowOptions.builder() + for (int orderIndex = 3; orderIndex < unsorted.getNumberOfColumns(); orderIndex++) { + try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.desc(orderIndex, false)); + ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { + ColumnVector sortedAggColumn = sorted.getColumn(2); + assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); + + DType type = unsorted.getColumn(orderIndex).getType(); + try (Scalar following1 = getScalar(type, 1L); + Scalar preceding1 = getScalar(type, 1L); + Scalar following0 = getScalar(type, 0L); + Scalar preceding0 = getScalar(type, 0L);) { + + try (WindowOptions unboundedPrecedingOneFollowing = WindowOptions.builder() .minPeriods(1) .unboundedPreceding() - .following(1) - .timestampColumnIndex(2) - .timestampDescending() + .following(following1) + .orderByColumnIndex(orderIndex) + .orderByDescending() .build(); - WindowOptions onePrecedingUnboundedFollowing = WindowOptions.builder() + WindowOptions onePrecedingUnboundedFollowing = WindowOptions.builder() .minPeriods(1) - .preceding(1) + .preceding(preceding1) .unboundedFollowing() - .timestampColumnIndex(2) - .timestampDescending() + .orderByColumnIndex(orderIndex) + .orderByDescending() .build(); - WindowOptions unboundedPrecedingAndFollowing = WindowOptions.builder() + WindowOptions unboundedPrecedingAndFollowing = WindowOptions.builder() .minPeriods(1) .unboundedPreceding() .unboundedFollowing() - .timestampColumnIndex(2) - .timestampDescending() + .orderByColumnIndex(orderIndex) + .orderByDescending() .build(); - WindowOptions unboundedPrecedingAndCurrentRow = WindowOptions.builder() + WindowOptions unboundedPrecedingAndCurrentRow = WindowOptions.builder() .minPeriods(1) .unboundedPreceding() - .following(0) - .timestampColumnIndex(2) - .timestampDescending() + .following(following0) + .orderByColumnIndex(orderIndex) + .orderByDescending() .build(); - WindowOptions currentRowAndUnboundedFollowing = WindowOptions.builder() + WindowOptions currentRowAndUnboundedFollowing = WindowOptions.builder() .minPeriods(1) - .preceding(0) + .preceding(preceding0) .unboundedFollowing() - .timestampColumnIndex(2) - .timestampDescending() - .build(); - - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindowsOverTimeRanges( - Aggregation.count().onColumn(3).overWindow(unboundedPrecedingOneFollowing), - Aggregation.count().onColumn(3).overWindow(onePrecedingUnboundedFollowing), - Aggregation.count().onColumn(3).overWindow(unboundedPrecedingAndFollowing), - Aggregation.count().onColumn(3).overWindow(unboundedPrecedingAndCurrentRow), - Aggregation.count().onColumn(3).overWindow(currentRowAndUnboundedFollowing)); - ColumnVector expect_0 = ColumnVector.fromBoxedInts(3, 3, 3, 4, 6, 6, 2, 2, 3, 5, 5, 7, 7); - ColumnVector expect_1 = ColumnVector.fromBoxedInts(6, 6, 6, 3, 2, 2, 7, 7, 5, 4, 4, 2, 2); - ColumnVector expect_2 = ColumnVector.fromBoxedInts(6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7); - ColumnVector expect_3 = ColumnVector.fromBoxedInts(3, 3, 3, 4, 5, 6, 2, 2, 3, 4, 5, 6, 7); - ColumnVector expect_4 = ColumnVector.fromBoxedInts(6, 6, 6, 3, 2, 1, 7, 7, 5, 4, 3, 2, 1)) { - - assertColumnsAreEqual(expect_0, windowAggResults.getColumn(0)); - assertColumnsAreEqual(expect_1, windowAggResults.getColumn(1)); - assertColumnsAreEqual(expect_2, windowAggResults.getColumn(2)); - assertColumnsAreEqual(expect_3, windowAggResults.getColumn(3)); - assertColumnsAreEqual(expect_4, windowAggResults.getColumn(4)); + .orderByColumnIndex(orderIndex) + .orderByDescending() + .build();) { + + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindowsOverRanges( + Aggregation.count().onColumn(2).overWindow(unboundedPrecedingOneFollowing), + Aggregation.count().onColumn(2).overWindow(onePrecedingUnboundedFollowing), + Aggregation.count().onColumn(2).overWindow(unboundedPrecedingAndFollowing), + Aggregation.count().onColumn(2).overWindow(unboundedPrecedingAndCurrentRow), + Aggregation.count().onColumn(2).overWindow(currentRowAndUnboundedFollowing)); + ColumnVector expect_0 = ColumnVector.fromBoxedInts(3, 3, 3, 4, 6, 6, 2, 2, 3, 5, 5, 7, 7); + ColumnVector expect_1 = ColumnVector.fromBoxedInts(6, 6, 6, 3, 2, 2, 7, 7, 5, 4, 4, 2, 2); + ColumnVector expect_2 = ColumnVector.fromBoxedInts(6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7); + ColumnVector expect_3 = ColumnVector.fromBoxedInts(3, 3, 3, 4, 5, 6, 2, 2, 3, 4, 5, 6, 7); + ColumnVector expect_4 = ColumnVector.fromBoxedInts(6, 6, 6, 3, 2, 1, 7, 7, 5, 4, 3, 2, 1)) { + + assertColumnsAreEqual(expect_0, windowAggResults.getColumn(0)); + assertColumnsAreEqual(expect_1, windowAggResults.getColumn(1)); + assertColumnsAreEqual(expect_2, windowAggResults.getColumn(2)); + assertColumnsAreEqual(expect_3, windowAggResults.getColumn(3)); + assertColumnsAreEqual(expect_4, windowAggResults.getColumn(4)); + } + } + } } } } } @Test - void testTimeRangeWindowingCountUnboundedASCWithNullsLast() { - Integer X = null; + void testRangeWindowingCountUnboundedASCWithNullsLast() { try (Table unsorted = new Table.TestBuilder() - .column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column( 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .timestampDayColumn( 2, 3, 5, X, X, X, 1, 2, 4, 5, 7, X, X) // Timestamp Key - .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column - .build()) { - try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2, false)); - ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { - ColumnVector sortedAggColumn = sorted.getColumn(3); - assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); - WindowOptions unboundedPrecedingOneFollowing = WindowOptions.builder() + .column(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column(0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column + .column(2, 3, 5, null, null, null, 1, 2, 4, 5, 7, null, null) // Timestamp Key + .column(2L, 3L, 5L, null, null, null, 1L, 2L, 4L, 5L, 7L, null, null) // order by Key + .column((short)2, (short)3, (short)5, null, null, null, (short)1, (short)2, (short)4, (short)5, (short)7, null, null) // order by Key + .column((byte)2, (byte)3, (byte)5, null, null, null, (byte)1, (byte)2, (byte)4, (byte)5, (byte)7, null, null) // order by Key + .timestampDayColumn( 2, 3, 5, null, null, null, 1, 2, 4, 5, 7, null, null) // Timestamp order by Key + .timestampSecondsColumn( 2L, 3L, 5L, null, null, null, 1L, 2L, 4L, 5L, 7L, null, null) + .timestampMicrosecondsColumn( 2L, 3L, 5L, null, null, null, 1L, 2L, 4L, 5L, 7L, null, null) + .timestampMillisecondsColumn( 2L, 3L, 5L, null, null, null, 1L, 2L, 4L, 5L, 7L, null, null) + .timestampNanosecondsColumn( 2L, 3L, 5L, null, null, null, 1L, 2L, 4L, 5L, 7L, null, null) + .build()) { + for (int orderIndex = 3; orderIndex < unsorted.getNumberOfColumns(); orderIndex++) { + try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(orderIndex, false)); + ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { + ColumnVector sortedAggColumn = sorted.getColumn(2); + assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); + + DType type = unsorted.getColumn(orderIndex).getType(); + try (Scalar following1 = getScalar(type, 1L); + Scalar preceding1 = getScalar(type, 1L); + Scalar following0 = getScalar(type, 0L); + Scalar preceding0 = getScalar(type, 0L);) { + try (WindowOptions unboundedPrecedingOneFollowing = WindowOptions.builder() .minPeriods(1) .unboundedPreceding() - .following(1) - .timestampColumnIndex(2) + .following(following1) + .orderByColumnIndex(orderIndex) .build(); - WindowOptions onePrecedingUnboundedFollowing = WindowOptions.builder() + WindowOptions onePrecedingUnboundedFollowing = WindowOptions.builder() .minPeriods(1) - .preceding(1) + .preceding(preceding1) .unboundedFollowing() - .timestampColumnIndex(2) + .orderByColumnIndex(orderIndex) .build(); - WindowOptions unboundedPrecedingAndFollowing = WindowOptions.builder() + WindowOptions unboundedPrecedingAndFollowing = WindowOptions.builder() .minPeriods(1) .unboundedPreceding() .unboundedFollowing() - .timestampColumnIndex(2) + .orderByColumnIndex(orderIndex) .build(); - WindowOptions unboundedPrecedingAndCurrentRow = WindowOptions.builder() + WindowOptions unboundedPrecedingAndCurrentRow = WindowOptions.builder() .minPeriods(1) .unboundedPreceding() - .following(0) - .timestampColumnIndex(2) + .following(following0) + .orderByColumnIndex(orderIndex) .build(); - WindowOptions currentRowAndUnboundedFollowing = WindowOptions.builder() + WindowOptions currentRowAndUnboundedFollowing = WindowOptions.builder() .minPeriods(1) - .preceding(0) + .preceding(preceding0) .unboundedFollowing() - .timestampColumnIndex(2) - .build(); - - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindowsOverTimeRanges( - Aggregation.count().onColumn(3).overWindow(unboundedPrecedingOneFollowing), - Aggregation.count().onColumn(3).overWindow(onePrecedingUnboundedFollowing), - Aggregation.count().onColumn(3).overWindow(unboundedPrecedingAndFollowing), - Aggregation.count().onColumn(3).overWindow(unboundedPrecedingAndCurrentRow), - Aggregation.count().onColumn(3).overWindow(currentRowAndUnboundedFollowing)); - ColumnVector expect_0 = ColumnVector.fromBoxedInts(2, 2, 3, 6, 6, 6, 2, 2, 4, 4, 5, 7, 7); - ColumnVector expect_1 = ColumnVector.fromBoxedInts(6, 6, 4, 3, 3, 3, 7, 7, 5, 5, 3, 2, 2); - ColumnVector expect_2 = ColumnVector.fromBoxedInts(6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7); - ColumnVector expect_3 = ColumnVector.fromBoxedInts(1, 2, 3, 6, 6, 6, 1, 2, 3, 4, 5, 7, 7); - ColumnVector expect_4 = ColumnVector.fromBoxedInts(6, 5, 4, 3, 3, 3, 7, 6, 5, 4, 3, 2, 2)) { - - assertColumnsAreEqual(expect_0, windowAggResults.getColumn(0)); - assertColumnsAreEqual(expect_1, windowAggResults.getColumn(1)); - assertColumnsAreEqual(expect_2, windowAggResults.getColumn(2)); - assertColumnsAreEqual(expect_3, windowAggResults.getColumn(3)); - assertColumnsAreEqual(expect_4, windowAggResults.getColumn(4)); + .orderByColumnIndex(orderIndex) + .build();) { + + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindowsOverRanges( + Aggregation.count().onColumn(2).overWindow(unboundedPrecedingOneFollowing), + Aggregation.count().onColumn(2).overWindow(onePrecedingUnboundedFollowing), + Aggregation.count().onColumn(2).overWindow(unboundedPrecedingAndFollowing), + Aggregation.count().onColumn(2).overWindow(unboundedPrecedingAndCurrentRow), + Aggregation.count().onColumn(2).overWindow(currentRowAndUnboundedFollowing)); + ColumnVector expect_0 = ColumnVector.fromBoxedInts(2, 2, 3, 6, 6, 6, 2, 2, 4, 4, 5, 7, 7); + ColumnVector expect_1 = ColumnVector.fromBoxedInts(6, 6, 4, 3, 3, 3, 7, 7, 5, 5, 3, 2, 2); + ColumnVector expect_2 = ColumnVector.fromBoxedInts(6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7); + ColumnVector expect_3 = ColumnVector.fromBoxedInts(1, 2, 3, 6, 6, 6, 1, 2, 3, 4, 5, 7, 7); + ColumnVector expect_4 = ColumnVector.fromBoxedInts(6, 5, 4, 3, 3, 3, 7, 6, 5, 4, 3, 2, 2)) { + + assertColumnsAreEqual(expect_0, windowAggResults.getColumn(0)); + assertColumnsAreEqual(expect_1, windowAggResults.getColumn(1)); + assertColumnsAreEqual(expect_2, windowAggResults.getColumn(2)); + assertColumnsAreEqual(expect_3, windowAggResults.getColumn(3)); + assertColumnsAreEqual(expect_4, windowAggResults.getColumn(4)); + } + } + } } } } } @Test - void testTimeRangeWindowingCountUnboundedDESCWithNullsLast() { + void testRangeWindowingCountUnboundedDESCWithNullsLast() { Integer X = null; try (Table unsorted = new Table.TestBuilder() - .column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column( 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .timestampDayColumn( 5, 3, 2, X, X, X, 7, 5, 4, 2, 1, X, X) // Timestamp Key - .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column - .build()) { - try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.desc(2, true)); - ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { - ColumnVector sortedAggColumn = sorted.getColumn(3); - assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); - - WindowOptions unboundedPrecedingOneFollowing = WindowOptions.builder() + .column(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column(0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column + .column( 5, 3, 2, null, null, null, 7, 5, 4, 2, 1, null, null) // Timestamp Key + .column(5L, 3L, 2L, null, null, null, 7L, 5L, 4L, 2L, 1L, null, null) // Timestamp Key + .column((short)5, (short)3, (short)2, null, null, null, (short)7, (short)5, (short)4, (short)2, (short)1, null, null) // Timestamp Key + .column((byte)5, (byte)3, (byte)2, null, null, null, (byte)7, (byte)5, (byte)4, (byte)2, (byte)1, null, null) // Timestamp Key + .timestampDayColumn( 5, 3, 2, X, X, X, 7, 5, 4, 2, 1, X, X) // Timestamp Key + .timestampSecondsColumn( 5L, 3L, 2L, null, null, null, 7L, 5L, 4L, 2L, 1L, null, null) + .timestampMicrosecondsColumn( 5L, 3L, 2L, null, null, null, 7L, 5L, 4L, 2L, 1L, null, null) + .timestampMillisecondsColumn( 5L, 3L, 2L, null, null, null, 7L, 5L, 4L, 2L, 1L, null, null) + .timestampNanosecondsColumn( 5L, 3L, 2L, null, null, null, 7L, 5L, 4L, 2L, 1L, null, null) + .build()) { + for (int orderIndex = 3; orderIndex < unsorted.getNumberOfColumns(); orderIndex++) { + try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.desc(orderIndex, true)); + ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { + ColumnVector sortedAggColumn = sorted.getColumn(2); + assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); + + DType type = unsorted.getColumn(orderIndex).getType(); + try (Scalar following1 = getScalar(type, 1L); + Scalar preceding1 = getScalar(type, 1L); + Scalar following0 = getScalar(type, 0L); + Scalar preceding0 = getScalar(type, 0L);) { + try (WindowOptions unboundedPrecedingOneFollowing = WindowOptions.builder() .minPeriods(1) .unboundedPreceding() - .following(1) - .timestampColumnIndex(2) - .timestampDescending() + .following(following1) + .orderByColumnIndex(orderIndex) + .orderByDescending() .build(); - WindowOptions onePrecedingUnboundedFollowing = WindowOptions.builder() + WindowOptions onePrecedingUnboundedFollowing = WindowOptions.builder() .minPeriods(1) - .preceding(1) + .preceding(preceding1) .unboundedFollowing() - .timestampColumnIndex(2) - .timestampDescending() + .orderByColumnIndex(orderIndex) + .orderByDescending() .build(); - WindowOptions unboundedPrecedingAndFollowing = WindowOptions.builder() + WindowOptions unboundedPrecedingAndFollowing = WindowOptions.builder() .minPeriods(1) .unboundedPreceding() .unboundedFollowing() - .timestampColumnIndex(2) - .timestampDescending() + .orderByColumnIndex(orderIndex) + .orderByDescending() .build(); - WindowOptions unboundedPrecedingAndCurrentRow = WindowOptions.builder() + WindowOptions unboundedPrecedingAndCurrentRow = WindowOptions.builder() .minPeriods(1) .unboundedPreceding() - .following(0) - .timestampColumnIndex(2) - .timestampDescending() + .following(following0) + .orderByColumnIndex(orderIndex) + .orderByDescending() .build(); - WindowOptions currentRowAndUnboundedFollowing = WindowOptions.builder() + WindowOptions currentRowAndUnboundedFollowing = WindowOptions.builder() .minPeriods(1) - .preceding(0) + .preceding(preceding0) .unboundedFollowing() - .timestampColumnIndex(2) - .timestampDescending() - .build(); - - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindowsOverTimeRanges( - Aggregation.count().onColumn(3).overWindow(unboundedPrecedingOneFollowing), - Aggregation.count().onColumn(3).overWindow(onePrecedingUnboundedFollowing), - Aggregation.count().onColumn(3).overWindow(unboundedPrecedingAndFollowing), - Aggregation.count().onColumn(3).overWindow(unboundedPrecedingAndCurrentRow), - Aggregation.count().onColumn(3).overWindow(currentRowAndUnboundedFollowing)); - ColumnVector expect_0 = ColumnVector.fromBoxedInts(1, 3, 3, 6, 6, 6, 1, 3, 3, 5, 5, 7, 7); - ColumnVector expect_1 = ColumnVector.fromBoxedInts(6, 5, 5, 3, 3, 3, 7, 6, 6, 4, 4, 2, 2); - ColumnVector expect_2 = ColumnVector.fromBoxedInts(6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7); - ColumnVector expect_3 = ColumnVector.fromBoxedInts(1, 2, 3, 6, 6, 6, 1, 2, 3, 4, 5, 7, 7); - ColumnVector expect_4 = ColumnVector.fromBoxedInts(6, 5, 4, 3, 3, 3, 7, 6, 5, 4, 3, 2, 2)) { - - assertColumnsAreEqual(expect_0, windowAggResults.getColumn(0)); - assertColumnsAreEqual(expect_1, windowAggResults.getColumn(1)); - assertColumnsAreEqual(expect_2, windowAggResults.getColumn(2)); - assertColumnsAreEqual(expect_3, windowAggResults.getColumn(3)); - assertColumnsAreEqual(expect_4, windowAggResults.getColumn(4)); + .orderByColumnIndex(orderIndex) + .orderByDescending() + .build();) { + + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindowsOverRanges( + Aggregation.count().onColumn(2).overWindow(unboundedPrecedingOneFollowing), + Aggregation.count().onColumn(2).overWindow(onePrecedingUnboundedFollowing), + Aggregation.count().onColumn(2).overWindow(unboundedPrecedingAndFollowing), + Aggregation.count().onColumn(2).overWindow(unboundedPrecedingAndCurrentRow), + Aggregation.count().onColumn(2).overWindow(currentRowAndUnboundedFollowing)); + ColumnVector expect_0 = ColumnVector.fromBoxedInts(1, 3, 3, 6, 6, 6, 1, 3, 3, 5, 5, 7, 7); + ColumnVector expect_1 = ColumnVector.fromBoxedInts(6, 5, 5, 3, 3, 3, 7, 6, 6, 4, 4, 2, 2); + ColumnVector expect_2 = ColumnVector.fromBoxedInts(6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7); + ColumnVector expect_3 = ColumnVector.fromBoxedInts(1, 2, 3, 6, 6, 6, 1, 2, 3, 4, 5, 7, 7); + ColumnVector expect_4 = ColumnVector.fromBoxedInts(6, 5, 4, 3, 3, 3, 7, 6, 5, 4, 3, 2, 2)) { + + assertColumnsAreEqual(expect_0, windowAggResults.getColumn(0)); + assertColumnsAreEqual(expect_1, windowAggResults.getColumn(1)); + assertColumnsAreEqual(expect_2, windowAggResults.getColumn(2)); + assertColumnsAreEqual(expect_3, windowAggResults.getColumn(3)); + assertColumnsAreEqual(expect_4, windowAggResults.getColumn(4)); + } + } + } } } } From c681211df6253e1ceee9203658108980e7e93e3c Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Thu, 13 May 2021 17:43:47 -0600 Subject: [PATCH 02/12] Fix doxygens and comments for various APIs (#8201) This PR fixes doxygens and comments in several places. In particular: * Fixes the example codes in the doxygens for the `strings::concatenate_list_elements` APIs. In particular, it changes from calling `concatenate()` to `strings::concatenate_list_elements`. In addition, there are some other minor changes in indentation. * Fixes some comments in `lists/interleave_columns.cu` and `lists/concatenate_rows.cu`. No implementation code is changed in this PR. Authors: - Nghia Truong (https://github.com/ttnghia) Approvers: - Conor Hoekstra (https://github.com/codereport) - Karthikeyan (https://github.com/karthikeyann) URL: https://github.com/rapidsai/cudf/pull/8201 --- cpp/include/cudf/strings/combine.hpp | 50 +++++++++++++--------------- cpp/src/lists/concatenate_rows.cu | 19 ++++++----- cpp/src/lists/interleave_columns.cu | 11 +++--- 3 files changed, 41 insertions(+), 39 deletions(-) diff --git a/cpp/include/cudf/strings/combine.hpp b/cpp/include/cudf/strings/combine.hpp index 113b6d64f9a..6887ef0e670 100644 --- a/cpp/include/cudf/strings/combine.hpp +++ b/cpp/include/cudf/strings/combine.hpp @@ -184,28 +184,27 @@ std::unique_ptr concatenate( * s = [ {'aa', 'bb', 'cc'}, null, {'', 'dd'}, {'ee', null}, {'ff', 'gg'} ] * sep = ['::', '%%', '!', '*', null] * - * r1 = concatenate(s, sep) + * r1 = strings::concatenate_list_elements(s, sep) * r1 is ['aa::bb::cc', null, '!dd', null, null] * - * r2 = concatenate(s, sep, ':', '_') + * r2 = strings::concatenate_list_elements(s, sep, ':', '_') * r2 is ['aa::bb::cc', null, '!dd', 'ee*_', 'ff:gg'] * @endcode * * @throw cudf::logic_error if input column is not lists of strings column. * @throw cudf::logic_error if the number of rows from `separators` and `lists_strings_column` do - * not match - * - * @param lists_strings_column Column containing lists of strings to concatenate - * @param separators Strings column that provides separators for concatenation - * @param separator_narep String that should be used to replace null separator, default is an - * invalid-scalar denoting that rows containing null separator will result in null string in the - * corresponding output rows - * @param string_narep String that should be used to replace null strings in any - * non-null list row, default is an invalid-scalar denoting that list rows containing null strings - * will result in null string in the corresponding output rows - * @param mr Device memory resource used to allocate the returned column's - * device memory - * @return New strings column with concatenated results + * not match + * + * @param lists_strings_column Column containing lists of strings to concatenate. + * @param separators Strings column that provides separators for concatenation. + * @param separator_narep String that should be used to replace null separator, default is an + * invalid-scalar denoting that rows containing null separator will result in null string in + * the corresponding output rows. + * @param string_narep String that should be used to replace null strings in any non-null list row, + * default is an invalid-scalar denoting that list rows containing null strings will result + * in null string in the corresponding output rows. + * @param mr Device memory resource used to allocate the returned column's device memory. + * @return New strings column with concatenated results. */ std::unique_ptr concatenate_list_elements( const lists_column_view& lists_strings_column, @@ -229,25 +228,24 @@ std::unique_ptr concatenate_list_elements( * Example: * s = [ {'aa', 'bb', 'cc'}, null, {'', 'dd'}, {'ee', null}, {'ff'} ] * - * r1 = concatenate(s) + * r1 = strings::concatenate_list_elements(s) * r1 is ['aabbcc', null, 'dd', null, 'ff'] * - * r2 = concatenate(s, ':', '_') + * r2 = strings::concatenate_list_elements(s, ':', '_') * r2 is ['aa:bb:cc', null, ':dd', 'ee:_', 'ff'] * @endcode * * @throw cudf::logic_error if input column is not lists of strings column. * @throw cudf::logic_error if separator is not valid. * - * @param lists_strings_column Column containing lists of strings to concatenate - * @param separator String that should inserted between strings of each list row, - * default is an empty string - * @param narep String that should be used to replace null strings in any non-null - * list row, default is an invalid-scalar denoting that list rows containing null strings will - * result in null string in the corresponding output rows - * @param mr Device memory resource used to allocate the returned column's - * device memory - * @return New strings column with concatenated results + * @param lists_strings_column Column containing lists of strings to concatenate. + * @param separator String that should inserted between strings of each list row, default is an + * empty string. + * @param narep String that should be used to replace null strings in any non-null list row, default + * is an invalid-scalar denoting that list rows containing null strings will result in null + * string in the corresponding output rows. + * @param mr Device memory resource used to allocate the returned column's device memory. + * @return New strings column with concatenated results. */ std::unique_ptr concatenate_list_elements( const lists_column_view& lists_strings_column, diff --git a/cpp/src/lists/concatenate_rows.cu b/cpp/src/lists/concatenate_rows.cu index 8942bcc898d..8528a7680f7 100644 --- a/cpp/src/lists/concatenate_rows.cu +++ b/cpp/src/lists/concatenate_rows.cu @@ -37,6 +37,10 @@ namespace cudf { namespace lists { namespace detail { namespace { +/** + * @brief Concatenate lists within the same row into one list, ignoring any null list during + * concatenation. + */ std::unique_ptr concatenate_rows_ignore_null(table_view const& input, bool has_null_mask, rmm::cuda_stream_view stream, @@ -57,7 +61,7 @@ std::unique_ptr concatenate_rows_ignore_null(table_view const& input, auto const d_offsets = list_offsets->mutable_view().template begin(); // The array of int8_t to store validities for list elements. - // Since we combine multiple lists, we need to recompute list validities. + // Since we combine multiple lists, we may need to recompute list validities. auto validities = rmm::device_uvector(has_null_mask ? num_output_lists : 0, stream); // For an input table of `n` columns, if after interleaving we have the list offsets are @@ -169,8 +173,8 @@ generate_list_offsets_and_validities(table_view const& input, * * This functor is called only when (has_null_mask == true and null_policy == NULLIFY_OUTPUT_ROW). * It is executed twice. In the first pass, the sizes and validities of the output strings will be - * computed. In the second pass, this will concatenate the lists of strings of the given table of - * lists columns in a row-wise manner. + * computed. In the second pass, this will concatenate the lists of strings on the same row from the + * given input table. */ struct compute_string_sizes_and_concatenate_lists_fn { table_device_view const table_dv; @@ -182,7 +186,7 @@ struct compute_string_sizes_and_concatenate_lists_fn { offset_type* d_offsets{nullptr}; // If d_chars == nullptr: only compute sizes and validities of the output strings. - // If d_chars != nullptr: only concatenate strings. + // If d_chars != nullptr: only concatenate lists of strings. char* d_chars{nullptr}; // We need to set `1` or `0` for the validities of the strings in the child column. @@ -190,8 +194,7 @@ struct compute_string_sizes_and_concatenate_lists_fn { __device__ void operator()(size_type const idx) { - // The current row contain null, which has been identified during `dst_list_offsets` - // computation. + // The current row contain null, which has been identified during offsets computation. if (dst_list_offsets[idx + 1] == dst_list_offsets[idx]) { return; } // read_idx and write_idx are indices of string elements. @@ -205,7 +208,7 @@ struct compute_string_sizes_and_concatenate_lists_fn { auto const str_offsets = str_col.child(strings_column_view::offsets_column_index).template data(); - // The indices of the strings within the source list. + // The range of indices of the strings within the source list. auto const start_str_idx = list_offsets[idx]; auto const end_str_idx = list_offsets[idx + 1]; @@ -305,7 +308,7 @@ struct concatenate_lists_fn { lists_col.offset(); auto const& data_col = lists_col.child(lists_column_view::child_column_index); - // The indices of the entries within the source list. + // The range of indices of the entries within the source list. auto const start_idx = list_offsets[idx]; auto const end_idx = list_offsets[idx + 1]; diff --git a/cpp/src/lists/interleave_columns.cu b/cpp/src/lists/interleave_columns.cu index caadcf90f16..222c37507c4 100644 --- a/cpp/src/lists/interleave_columns.cu +++ b/cpp/src/lists/interleave_columns.cu @@ -106,7 +106,7 @@ struct compute_string_sizes_and_interleave_lists_fn { offset_type* d_offsets{nullptr}; // If d_chars == nullptr: only compute sizes and validities of the output strings. - // If d_chars != nullptr: only concatenate strings. + // If d_chars != nullptr: only interleave lists of strings. char* d_chars{nullptr}; // We need to set `1` or `0` for the validities of the strings in the child column. @@ -128,7 +128,7 @@ struct compute_string_sizes_and_interleave_lists_fn { auto const str_offsets = str_col.child(strings_column_view::offsets_column_index).template data(); - // The indices of the strings within the source list. + // The range of indices of the strings within the source list. auto const start_str_idx = list_offsets[list_id]; auto const end_str_idx = list_offsets[list_id + 1]; @@ -243,9 +243,10 @@ struct interleave_list_entries_fn { lists_col.offset(); auto const& data_col = lists_col.child(lists_column_view::child_column_index); - // The indices of the entries within the source list. - auto const start_idx = list_offsets[list_id]; - auto const end_idx = list_offsets[list_id + 1]; + // The range of indices of the entries within the source list. + auto const start_idx = list_offsets[list_id]; + auto const end_idx = list_offsets[list_id + 1]; + auto const write_start = d_offsets[idx]; // Fill the validities array if necessary. From d20675565f8a02568ea2d43f34871681248599d3 Mon Sep 17 00:00:00 2001 From: GALI PREM SAGAR Date: Thu, 13 May 2021 21:49:21 -0500 Subject: [PATCH 03/12] Fix cython flag to use c++17 (#8243) This PR changes cython to compile with c++17 Authors: - GALI PREM SAGAR (https://github.com/galipremsagar) Approvers: - Keith Kraus (https://github.com/kkraus14) URL: https://github.com/rapidsai/cudf/pull/8243 --- python/cudf/setup.py | 2 +- python/cudf_kafka/setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/cudf/setup.py b/python/cudf/setup.py index b8c7dc5868f..54921396b6f 100644 --- a/python/cudf/setup.py +++ b/python/cudf/setup.py @@ -192,7 +192,7 @@ def run(self): ), libraries=["cudart", "cudf"] + pa.get_libraries() + ["arrow_cuda"], language="c++", - extra_compile_args=["-std=c++14"], + extra_compile_args=["-std=c++17"], ) ] diff --git a/python/cudf_kafka/setup.py b/python/cudf_kafka/setup.py index f7523dda503..f16b7b42e4e 100644 --- a/python/cudf_kafka/setup.py +++ b/python/cudf_kafka/setup.py @@ -72,7 +72,7 @@ library_dirs=([get_python_lib(), os.path.join(os.sys.prefix, "lib")]), libraries=["cudf", "cudf_kafka"], language="c++", - extra_compile_args=["-std=c++14"], + extra_compile_args=["-std=c++17"], ) ] From 304f460cbf54a04f47d365516d26e357db08fa98 Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Fri, 14 May 2021 13:07:37 +0800 Subject: [PATCH 04/12] Revert #7909 add java bindings for non-timestamps range window queries (#8245) Since it's a breaking PR, and the corresponding spark-rapids PR has some issues. Just revert this cudf PR. Sorry This reverts commit 2a169c801adc7f5d45144f8291a81fc21b9bf759. Authors: - Bobby Wang (https://github.com/wbo4958) Approvers: - Liangcai Li (https://github.com/firestarman) URL: https://github.com/rapidsai/cudf/pull/8245 --- java/src/main/java/ai/rapids/cudf/Scalar.java | 5 - java/src/main/java/ai/rapids/cudf/Table.java | 87 +- .../java/ai/rapids/cudf/WindowOptions.java | 182 +- java/src/main/native/src/TableJni.cpp | 77 +- .../java/ai/rapids/cudf/ColumnVectorTest.java | 153 +- .../test/java/ai/rapids/cudf/TableTest.java | 1618 +++++++---------- 6 files changed, 806 insertions(+), 1316 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/Scalar.java b/java/src/main/java/ai/rapids/cudf/Scalar.java index 62dd9bda13b..ec20f39af27 100644 --- a/java/src/main/java/ai/rapids/cudf/Scalar.java +++ b/java/src/main/java/ai/rapids/cudf/Scalar.java @@ -620,7 +620,6 @@ public int hashCode() { case UINT32: case TIMESTAMP_DAYS: case DECIMAL32: - case DURATION_DAYS: valueHash = getInt(); break; case INT64: @@ -630,10 +629,6 @@ public int hashCode() { case TIMESTAMP_MICROSECONDS: case TIMESTAMP_NANOSECONDS: case DECIMAL64: - case DURATION_MICROSECONDS: - case DURATION_SECONDS: - case DURATION_MILLISECONDS: - case DURATION_NANOSECONDS: valueHash = Long.hashCode(getLong()); break; case FLOAT32: diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index e939411eece..b2f2ad5bad1 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -475,10 +475,10 @@ private static native long[] rollingWindowAggregate( int[] following, boolean ignoreNullKeys) throws CudfException; - private static native long[] rangeRollingWindowAggregate(long inputTable, int[] keyIndices, int[] orderByIndices, boolean[] isOrderByAscending, - int[] aggColumnsIndices, long[] aggInstances, int[] minPeriods, - long[] preceding, long[] following, boolean[] unboundedPreceding, boolean[] unboundedFollowing, - boolean ignoreNullKeys) throws CudfException; + private static native long[] timeRangeRollingWindowAggregate(long inputTable, int[] keyIndices, int[] timestampIndices, boolean[] isTimesampAscending, + int[] aggColumnsIndices, long[] aggInstances, int[] minPeriods, + int[] preceding, int[] following, boolean[] unboundedPreceding, boolean[] unboundedFollowing, + boolean ignoreNullKeys) throws CudfException; private static native long sortOrder(long inputTable, long[] sortKeys, boolean[] isDescending, boolean[] areNullsSmallest) throws CudfException; @@ -2457,7 +2457,7 @@ public Table aggregateWindows(AggregationOverWindow... windowAggregates) { } /** - * Computes range-based window aggregation functions on the Table/projection, + * Computes time-range-based window aggregation functions on the Table/projection, * based on windows specified in the argument. * * This method enables queries such as the following SQL: @@ -2506,10 +2506,10 @@ public Table aggregateWindows(AggregationOverWindow... windowAggregates) { * @param windowAggregates the window-aggregations to be performed * @return Table instance, with each column containing the result of each aggregation. * @throws IllegalArgumentException if the window arguments are not of type - * {@link WindowOptions.FrameType#RANGE} or the orderBys are not of (Boolean-exclusive) integral type + * {@link WindowOptions.FrameType#RANGE}, * i.e. the timestamp-column was not specified for the aggregation. */ - public Table aggregateWindowsOverRanges(AggregationOverWindow... windowAggregates) { + public Table aggregateWindowsOverTimeRanges(AggregationOverWindow... windowAggregates) { // To improve performance and memory we want to remove duplicate operations // and also group the operations by column so hopefully cudf can do multiple aggregations // in a single pass. @@ -2521,76 +2521,51 @@ public Table aggregateWindowsOverRanges(AggregationOverWindow... windowAggregate for (int outputIndex = 0; outputIndex < windowAggregates.length; outputIndex++) { AggregationOverWindow agg = windowAggregates[outputIndex]; if (agg.getWindowOptions().getFrameType() != WindowOptions.FrameType.RANGE) { - throw new IllegalArgumentException("Expected range-based window specification. Unexpected window type: " - + agg.getWindowOptions().getFrameType()); - } - - DType orderByType = operation.table.getColumn(agg.getWindowOptions().getOrderByColumnIndex()).getType(); - switch (orderByType.getTypeId()) { - case INT8: - case INT16: - case INT32: - case INT64: - case UINT8: - case UINT16: - case UINT32: - case UINT64: - case TIMESTAMP_MILLISECONDS: - case TIMESTAMP_SECONDS: - case TIMESTAMP_DAYS: - case TIMESTAMP_NANOSECONDS: - case TIMESTAMP_MICROSECONDS: - break; - default: - throw new IllegalArgumentException("Expected range-based window orderBy's " + - "type: integral (Boolean-exclusive) and timestamp"); + throw new IllegalArgumentException("Expected time-range-based window specification. Unexpected window type: " + + agg.getWindowOptions().getFrameType()); } - ColumnWindowOps ops = groupedOps.computeIfAbsent(agg.getColumnIndex(), (idx) -> new ColumnWindowOps()); totalOps += ops.add(agg, outputIndex); } int[] aggColumnIndexes = new int[totalOps]; - int[] orderByColumnIndexes = new int[totalOps]; - boolean[] isOrderByOrderAscending = new boolean[totalOps]; + int[] timestampColumnIndexes = new int[totalOps]; + boolean[] isTimestampOrderAscending = new boolean[totalOps]; long[] aggInstances = new long[totalOps]; - long[] aggPrecedingWindows = new long[totalOps]; - long[] aggFollowingWindows = new long[totalOps]; try { + int[] aggPrecedingWindows = new int[totalOps]; + int[] aggFollowingWindows = new int[totalOps]; boolean[] aggPrecedingWindowsUnbounded = new boolean[totalOps]; boolean[] aggFollowingWindowsUnbounded = new boolean[totalOps]; int[] aggMinPeriods = new int[totalOps]; int opIndex = 0; for (Map.Entry entry: groupedOps.entrySet()) { int columnIndex = entry.getKey(); - for (AggregationOverWindow op: entry.getValue().operations()) { + for (AggregationOverWindow operation: entry.getValue().operations()) { aggColumnIndexes[opIndex] = columnIndex; - aggInstances[opIndex] = op.createNativeInstance(); - aggPrecedingWindows[opIndex] = op.getWindowOptions().getPrecedingScalar() == - null ? 0 : op.getWindowOptions().getPrecedingScalar().getScalarHandle(); - aggFollowingWindows[opIndex] = op.getWindowOptions().getFollowingScalar() == - null ? 0 : op.getWindowOptions().getFollowingScalar().getScalarHandle(); - aggPrecedingWindowsUnbounded[opIndex] = op.getWindowOptions().isUnboundedPreceding(); - aggFollowingWindowsUnbounded[opIndex] = op.getWindowOptions().isUnboundedFollowing(); - aggMinPeriods[opIndex] = op.getWindowOptions().getMinPeriods(); - assert (op.getWindowOptions().getFrameType() == WindowOptions.FrameType.RANGE); - orderByColumnIndexes[opIndex] = op.getWindowOptions().getOrderByColumnIndex(); - isOrderByOrderAscending[opIndex] = op.getWindowOptions().isOrderByOrderAscending(); - if (op.getDefaultOutput() != 0) { + aggInstances[opIndex] = operation.createNativeInstance(); + aggPrecedingWindows[opIndex] = operation.getWindowOptions().getPreceding(); + aggFollowingWindows[opIndex] = operation.getWindowOptions().getFollowing(); + aggPrecedingWindowsUnbounded[opIndex] = operation.getWindowOptions().isUnboundedPreceding(); + aggFollowingWindowsUnbounded[opIndex] = operation.getWindowOptions().isUnboundedFollowing(); + aggMinPeriods[opIndex] = operation.getWindowOptions().getMinPeriods(); + assert (operation.getWindowOptions().getFrameType() == WindowOptions.FrameType.RANGE); + timestampColumnIndexes[opIndex] = operation.getWindowOptions().getTimestampColumnIndex(); + isTimestampOrderAscending[opIndex] = operation.getWindowOptions().isTimestampOrderAscending(); + if (operation.getDefaultOutput() != 0) { throw new IllegalArgumentException("Operations with a default output are not " + "supported on time based rolling windows"); } - opIndex++; } } assert opIndex == totalOps : opIndex + " == " + totalOps; - try (Table aggregate = new Table(rangeRollingWindowAggregate( + try (Table aggregate = new Table(timeRangeRollingWindowAggregate( operation.table.nativeHandle, operation.indices, - orderByColumnIndexes, - isOrderByOrderAscending, + timestampColumnIndexes, + isTimestampOrderAscending, aggColumnIndexes, aggInstances, aggMinPeriods, aggPrecedingWindows, aggFollowingWindows, aggPrecedingWindowsUnbounded, aggFollowingWindowsUnbounded, @@ -2655,14 +2630,6 @@ public ContiguousTable[] contiguousSplitGroups() { groupByOptions.getKeysDescending(), groupByOptions.getKeysNullSmallest()); } - - /** - * @deprecated use aggregateWindowsOverRanges - */ - @Deprecated - public Table aggregateWindowsOverTimeRanges(AggregationOverWindow... windowAggregates) { - return aggregateWindowsOverRanges(windowAggregates); - } } public static final class TableOperation { diff --git a/java/src/main/java/ai/rapids/cudf/WindowOptions.java b/java/src/main/java/ai/rapids/cudf/WindowOptions.java index 826784a33f1..429d4e1d978 100644 --- a/java/src/main/java/ai/rapids/cudf/WindowOptions.java +++ b/java/src/main/java/ai/rapids/cudf/WindowOptions.java @@ -1,6 +1,6 @@ /* * - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2020, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,50 +21,33 @@ /** * Options for rolling windows. */ -public class WindowOptions implements AutoCloseable { +public class WindowOptions { enum FrameType {ROWS, RANGE} - private final int minPeriods; private final int preceding; + private final int minPeriods; private final int following; - private final Scalar precedingScalar; - private final Scalar followingScalar; private final ColumnVector precedingCol; private final ColumnVector followingCol; - private final int orderByColumnIndex; - private final boolean orderByOrderAscending; + private final int timestampColumnIndex; + private final boolean timestampOrderAscending; private final FrameType frameType; private final boolean isUnboundedPreceding; private final boolean isUnboundedFollowing; private WindowOptions(Builder builder) { - this.minPeriods = builder.minPeriods; this.preceding = builder.preceding; + this.minPeriods = builder.minPeriods; this.following = builder.following; - this.precedingScalar = builder.precedingScalar; - if (precedingScalar != null) { - precedingScalar.incRefCount(); - } - this.followingScalar = builder.followingScalar; - if (followingScalar != null) { - followingScalar.incRefCount(); - } this.precedingCol = builder.precedingCol; - if (precedingCol != null) { - precedingCol.incRefCount(); - } this.followingCol = builder.followingCol; - if (followingCol != null) { - followingCol.incRefCount(); - } - this.orderByColumnIndex = builder.orderByColumnIndex; - this.orderByOrderAscending = builder.orderByOrderAscending; - this.frameType = orderByColumnIndex == -1? FrameType.ROWS : FrameType.RANGE; + this.timestampColumnIndex = builder.timestampColumnIndex; + this.timestampOrderAscending = builder.timestampOrderAscending; + this.frameType = timestampColumnIndex == -1? FrameType.ROWS : FrameType.RANGE; this.isUnboundedPreceding = builder.isUnboundedPreceding; this.isUnboundedFollowing = builder.isUnboundedFollowing; - } @Override @@ -76,8 +59,8 @@ public boolean equals(Object other) { boolean ret = this.preceding == o.preceding && this.following == o.following && this.minPeriods == o.minPeriods && - this.orderByColumnIndex == o.orderByColumnIndex && - this.orderByOrderAscending == o.orderByOrderAscending && + this.timestampColumnIndex == o.timestampColumnIndex && + this.timestampOrderAscending == o.timestampOrderAscending && this.frameType == o.frameType && this.isUnboundedPreceding == o.isUnboundedPreceding && this.isUnboundedFollowing == o.isUnboundedFollowing; @@ -87,12 +70,6 @@ public boolean equals(Object other) { if (followingCol != null) { ret = ret && followingCol.equals(o.followingCol); } - if (precedingScalar != null) { - ret = ret && precedingScalar.equals(o.precedingScalar); - } - if (followingScalar != null) { - ret = ret && followingScalar.equals(o.followingScalar); - } return ret; } return false; @@ -104,8 +81,8 @@ public int hashCode() { ret = 31 * ret + preceding; ret = 31 * ret + following; ret = 31 * ret + minPeriods; - ret = 31 * ret + orderByColumnIndex; - ret = 31 * ret + Boolean.hashCode(orderByOrderAscending); + ret = 31 * ret + timestampColumnIndex; + ret = 31 * ret + Boolean.hashCode(timestampOrderAscending); ret = 31 * ret + frameType.hashCode(); if (precedingCol != null) { ret = 31 * ret + precedingCol.hashCode(); @@ -113,12 +90,6 @@ public int hashCode() { if (followingCol != null) { ret = 31 * ret + followingCol.hashCode(); } - if (precedingScalar != null) { - ret = 31 * ret + precedingScalar.hashCode(); - } - if (followingScalar != null) { - ret = 31 * ret + followingScalar.hashCode(); - } ret = 31 * ret + Boolean.hashCode(isUnboundedPreceding); ret = 31 * ret + Boolean.hashCode(isUnboundedFollowing); return ret; @@ -134,23 +105,13 @@ public static Builder builder(){ int getFollowing() { return this.following; } - Scalar getPrecedingScalar() { return this.precedingScalar; } - - Scalar getFollowingScalar() { return this.followingScalar; } - ColumnVector getPrecedingCol() { return precedingCol; } ColumnVector getFollowingCol() { return this.followingCol; } - @Deprecated - int getTimestampColumnIndex() { return getOrderByColumnIndex(); } + int getTimestampColumnIndex() { return this.timestampColumnIndex; } - int getOrderByColumnIndex() { return this.orderByColumnIndex; } - - @Deprecated - boolean isTimestampOrderAscending() { return isOrderByOrderAscending(); } - - boolean isOrderByOrderAscending() { return this.orderByOrderAscending; } + boolean isTimestampOrderAscending() { return this.timestampOrderAscending; } boolean isUnboundedPreceding() { return this.isUnboundedPreceding; } @@ -162,14 +123,11 @@ public static class Builder { private int minPeriods = 1; private int preceding = 0; private int following = 1; - // for range window - private Scalar precedingScalar = null; - private Scalar followingScalar = null; boolean staticSet = false; private ColumnVector precedingCol = null; private ColumnVector followingCol = null; - private int orderByColumnIndex = -1; - private boolean orderByOrderAscending = true; + private int timestampColumnIndex = -1; + private boolean timestampOrderAscending = true; private boolean isUnboundedPreceding = false; private boolean isUnboundedFollowing = false; @@ -189,10 +147,8 @@ public Builder minPeriods(int minPeriods) { * Set the size of the window, one entry per row. This does not take ownership of the * columns passed in so you have to be sure that the life time of the column outlives * this operation. - * @param precedingCol the number of rows preceding the current row and - * precedingCol will be live outside of WindowOptions. - * @param followingCol the number of rows following the current row and - * following will be live outside of WindowOptions. + * @param precedingCol the number of rows preceding the current row. + * @param followingCol the number of rows following the current row. */ public Builder window(ColumnVector precedingCol, ColumnVector followingCol) { assert (precedingCol != null && precedingCol.getNullCount() == 0); @@ -202,58 +158,19 @@ public Builder window(ColumnVector precedingCol, ColumnVector followingCol) { return this; } - /** - * Set the size of the range window. - * @param precedingScalar the relative number preceding the current row and - * the precedingScalar will be live outside of WindowOptions. - * @param followingScalar the relative number following the current row and - * the followingScalar will be live outside of WindowOptions - */ - public Builder window(Scalar precedingScalar, Scalar followingScalar) { - assert (precedingScalar != null && precedingScalar.isValid()); - assert (followingScalar != null && followingScalar.isValid()); - this.precedingScalar = precedingScalar; - this.followingScalar = followingScalar; - return this; - } - - /** - * @deprecated Use orderByColumnIndex(int index) - */ - @Deprecated public Builder timestampColumnIndex(int index) { - return orderByColumnIndex(index); - } - - public Builder orderByColumnIndex(int index) { - this.orderByColumnIndex = index; + this.timestampColumnIndex = index; return this; } - /** - * @deprecated Use orderByAscending() - */ - @Deprecated public Builder timestampAscending() { - return orderByAscending(); - } - - public Builder orderByAscending() { - this.orderByOrderAscending = true; - return this; - } - - public Builder orderByDescending() { - this.orderByOrderAscending = false; + this.timestampOrderAscending = true; return this; } - /** - * @deprecated Use orderByDescending() - */ - @Deprecated public Builder timestampDescending() { - return orderByDescending(); + this.timestampOrderAscending = false; + return this; } public Builder unboundedPreceding() { @@ -276,26 +193,6 @@ public Builder following(int following) { return this; } - /** - * Set the relative number preceding the current row for range window - * @param preceding - * @return Builder - */ - public Builder preceding(Scalar preceding) { - this.precedingScalar = preceding; - return this; - } - - /** - * Set the relative number following the current row for range window - * @param following - * @return Builder - */ - public Builder following(Scalar following) { - this.followingScalar = following; - return this; - } - /** * Set the size of the window. * @param preceding the number of rows preceding the current row @@ -312,40 +209,7 @@ public WindowOptions build() { if (staticSet && precedingCol != null) { throw new IllegalArgumentException("Cannot set both a static window and a non-static window"); } - return new WindowOptions(this); } } - - public synchronized WindowOptions incRefCount() { - if (precedingScalar != null) { - precedingScalar.incRefCount(); - } - if (followingScalar != null) { - followingScalar.incRefCount(); - } - if (precedingCol != null) { - precedingCol.incRefCount(); - } - if (followingCol != null) { - followingCol.incRefCount(); - } - return this; - } - - @Override - public void close() { - if (precedingScalar != null) { - precedingScalar.close(); - } - if (followingScalar != null) { - followingScalar.close(); - } - if (precedingCol != null) { - precedingCol.close(); - } - if (followingCol != null) { - followingCol.close(); - } - } } \ No newline at end of file diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index 4b01745382b..3799a5dbab3 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -728,16 +728,6 @@ bool valid_window_parameters(native_jintArray const &values, values.size() == preceding.size() && values.size() == following.size(); } -// Check that window parameters are valid. -bool valid_window_parameters(native_jintArray const &values, - native_jpointerArray const &ops, - native_jintArray const &min_periods, - native_jpointerArray const &preceding, - native_jpointerArray const &following) { - return values.size() == ops.size() && values.size() == min_periods.size() && - values.size() == preceding.size() && values.size() == following.size(); -} - // Generate gather maps needed to manifest the result of a join between two tables. // The resulting Java long array contains the following at each index: // 0: Size of each gather map in bytes @@ -2325,22 +2315,20 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_rollingWindowAggregate( CATCH_STD(env, NULL); } -JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_rangeRollingWindowAggregate( +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_timeRangeRollingWindowAggregate( JNIEnv *env, jclass, jlong j_input_table, jintArray j_keys, - jintArray j_orderby_column_indices, jbooleanArray j_is_orderby_ascending, + jintArray j_timestamp_column_indices, jbooleanArray j_is_timestamp_ascending, jintArray j_aggregate_column_indices, jlongArray j_agg_instances, jintArray j_min_periods, - jlongArray j_preceding, jlongArray j_following, - jbooleanArray j_unbounded_preceding, jbooleanArray j_unbounded_following, + jintArray j_preceding, jintArray j_following, + jbooleanArray j_unbounded_preceding, jbooleanArray j_unbounded_following, jboolean ignore_null_keys) { JNI_NULL_CHECK(env, j_input_table, "input table is null", NULL); JNI_NULL_CHECK(env, j_keys, "input keys are null", NULL); - JNI_NULL_CHECK(env, j_orderby_column_indices, "input orderby_column_indices are null", NULL); - JNI_NULL_CHECK(env, j_is_orderby_ascending, "input orderby_ascending is null", NULL); + JNI_NULL_CHECK(env, j_timestamp_column_indices, "input timestamp_column_indices are null", NULL); + JNI_NULL_CHECK(env, j_is_timestamp_ascending, "input timestamp_ascending is null", NULL); JNI_NULL_CHECK(env, j_aggregate_column_indices, "input aggregate_column_indices are null", NULL); JNI_NULL_CHECK(env, j_agg_instances, "agg_instances are null", NULL); - JNI_NULL_CHECK(env, j_preceding, "preceding are null", NULL); - JNI_NULL_CHECK(env, j_following, "following are null", NULL); try { cudf::jni::auto_set_device(env); @@ -2350,15 +2338,15 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_rangeRollingWindowAggrega // Convert from j-types to native. cudf::table_view *input_table{reinterpret_cast(j_input_table)}; cudf::jni::native_jintArray keys{env, j_keys}; - cudf::jni::native_jintArray orderbys{env, j_orderby_column_indices}; - cudf::jni::native_jbooleanArray orderbys_ascending{env, j_is_orderby_ascending}; + cudf::jni::native_jintArray timestamps{env, j_timestamp_column_indices}; + cudf::jni::native_jbooleanArray timestamp_ascending{env, j_is_timestamp_ascending}; cudf::jni::native_jintArray values{env, j_aggregate_column_indices}; cudf::jni::native_jpointerArray agg_instances(env, j_agg_instances); cudf::jni::native_jintArray min_periods{env, j_min_periods}; + cudf::jni::native_jintArray preceding{env, j_preceding}; + cudf::jni::native_jintArray following{env, j_following}; cudf::jni::native_jbooleanArray unbounded_preceding{env, j_unbounded_preceding}; cudf::jni::native_jbooleanArray unbounded_following{env, j_unbounded_following}; - cudf::jni::native_jpointerArray preceding(env, j_preceding); - cudf::jni::native_jpointerArray following(env, j_following); if (not valid_window_parameters(values, agg_instances, min_periods, preceding, following)) { JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", @@ -2373,48 +2361,21 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_rangeRollingWindowAggrega std::vector> result_columns; for (int i(0); i < values.size(); ++i) { int agg_column_index = values[i]; - cudf::column_view const &order_by_column = input_table->column(orderbys[i]); - cudf::data_type order_by_type = order_by_column.type(); - cudf::data_type unbounded_type = order_by_type; - - if (unbounded_preceding[i] || unbounded_following[i]) { - switch (order_by_type.id()) { - case cudf::type_id::TIMESTAMP_DAYS: - unbounded_type = cudf::data_type{cudf::type_id::DURATION_DAYS}; - break; - case cudf::type_id::TIMESTAMP_SECONDS: - unbounded_type =cudf::data_type{cudf::type_id::DURATION_SECONDS}; - break; - case cudf::type_id::TIMESTAMP_MILLISECONDS: - unbounded_type = cudf::data_type{cudf::type_id::DURATION_MILLISECONDS}; - break; - case cudf::type_id::TIMESTAMP_MICROSECONDS: - unbounded_type = cudf::data_type{cudf::type_id::DURATION_MICROSECONDS}; - break; - case cudf::type_id::TIMESTAMP_NANOSECONDS: - unbounded_type = cudf::data_type{cudf::type_id::DURATION_NANOSECONDS}; - break; - default: - break; - } - } cudf::rolling_aggregation * agg = dynamic_cast(agg_instances[i]); JNI_ARG_CHECK(env, agg != nullptr, "aggregation is not an instance of rolling_aggregation", nullptr); result_columns.emplace_back( std::move( - cudf::grouped_range_rolling_window( - groupby_keys, - order_by_column, - orderbys_ascending[i] ? cudf::order::ASCENDING : cudf::order::DESCENDING, - input_table->column(agg_column_index), - unbounded_preceding[i] ? cudf::range_window_bounds::unbounded(unbounded_type) : - cudf::range_window_bounds::get(*preceding[i]), - unbounded_following[i] ? cudf::range_window_bounds::unbounded(unbounded_type) : - cudf::range_window_bounds::get(*following[i]), - min_periods[i], - *agg + cudf::grouped_time_range_rolling_window( + groupby_keys, + input_table->column(timestamps[i]), + timestamp_ascending[i] ? cudf::order::ASCENDING : cudf::order::DESCENDING, + input_table->column(agg_column_index), + unbounded_preceding[i] ? cudf::window_bounds::unbounded() : cudf::window_bounds::get(preceding[i]), + unbounded_following[i] ? cudf::window_bounds::unbounded() : cudf::window_bounds::get(following[i]), + min_periods[i], + *agg ) ) ); diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 20508bdb23b..4c5ee7295d9 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -2263,62 +2263,60 @@ void testPrefixSumErrors() { @Test void testWindowStatic() { - try (WindowOptions options = WindowOptions.builder().window(2, 1) - .minPeriods(2).build()) { - try (ColumnVector v1 = ColumnVector.fromInts(5, 4, 7, 6, 8)) { - try (ColumnVector expected = ColumnVector.fromLongs(9, 16, 17, 21, 14); - ColumnVector result = v1.rollingWindow(Aggregation.sum(), options)) { - assertColumnsAreEqual(expected, result); - } + WindowOptions options = WindowOptions.builder().window(2, 1) + .minPeriods(2).build(); + try (ColumnVector v1 = ColumnVector.fromInts(5, 4, 7, 6, 8)) { + try (ColumnVector expected = ColumnVector.fromLongs(9, 16, 17, 21, 14); + ColumnVector result = v1.rollingWindow(Aggregation.sum(), options)) { + assertColumnsAreEqual(expected, result); + } - try (ColumnVector expected = ColumnVector.fromInts(4, 4, 4, 6, 6); - ColumnVector result = v1.rollingWindow(Aggregation.min(), options)) { - assertColumnsAreEqual(expected, result); - } + try (ColumnVector expected = ColumnVector.fromInts(4, 4, 4, 6, 6); + ColumnVector result = v1.rollingWindow(Aggregation.min(), options)) { + assertColumnsAreEqual(expected, result); + } - try (ColumnVector expected = ColumnVector.fromInts(5, 7, 7, 8, 8); - ColumnVector result = v1.rollingWindow(Aggregation.max(), options)) { - assertColumnsAreEqual(expected, result); - } + try (ColumnVector expected = ColumnVector.fromInts(5, 7, 7, 8, 8); + ColumnVector result = v1.rollingWindow(Aggregation.max(), options)) { + assertColumnsAreEqual(expected, result); + } - // The rolling window produces the same result type as the input - try (ColumnVector expected = ColumnVector.fromDoubles(4.5, 16.0 / 3, 17.0 / 3, 7, 7); - ColumnVector result = v1.rollingWindow(Aggregation.mean(), options)) { - assertColumnsAreEqual(expected, result); - } + // The rolling window produces the same result type as the input + try (ColumnVector expected = ColumnVector.fromDoubles(4.5, 16.0 / 3, 17.0 / 3, 7, 7); + ColumnVector result = v1.rollingWindow(Aggregation.mean(), options)) { + assertColumnsAreEqual(expected, result); + } - try (ColumnVector expected = ColumnVector.fromBoxedInts(4, 7, 6, 8, null); - ColumnVector result = v1.rollingWindow(Aggregation.lead(1), options)) { - assertColumnsAreEqual(expected, result); - } + try (ColumnVector expected = ColumnVector.fromBoxedInts(4, 7, 6, 8, null); + ColumnVector result = v1.rollingWindow(Aggregation.lead(1), options)) { + assertColumnsAreEqual(expected, result); + } - try (ColumnVector expected = ColumnVector.fromBoxedInts(null, 5, 4, 7, 6); - ColumnVector result = v1.rollingWindow(Aggregation.lag(1), options)) { - assertColumnsAreEqual(expected, result); - } + try (ColumnVector expected = ColumnVector.fromBoxedInts(null, 5, 4, 7, 6); + ColumnVector result = v1.rollingWindow(Aggregation.lag(1), options)) { + assertColumnsAreEqual(expected, result); + } - try (ColumnVector defaultOutput = ColumnVector.fromInts(-1, -2, -3, -4, -5); - ColumnVector expected = ColumnVector.fromBoxedInts(-1, 5, 4, 7, 6); - ColumnVector result = v1.rollingWindow(Aggregation.lag(1, defaultOutput), options)) { - assertColumnsAreEqual(expected, result); - } + try (ColumnVector defaultOutput = ColumnVector.fromInts(-1, -2, -3, -4, -5); + ColumnVector expected = ColumnVector.fromBoxedInts(-1, 5, 4, 7, 6); + ColumnVector result = v1.rollingWindow(Aggregation.lag(1, defaultOutput), options)) { + assertColumnsAreEqual(expected, result); } } } @Test void testWindowStaticCounts() { - try (WindowOptions options = WindowOptions.builder().window(2, 1) - .minPeriods(2).build()) { - try (ColumnVector v1 = ColumnVector.fromBoxedInts(5, 4, null, 6, 8)) { - try (ColumnVector expected = ColumnVector.fromInts(2, 2, 2, 2, 2); - ColumnVector result = v1.rollingWindow(Aggregation.count(NullPolicy.EXCLUDE), options)) { - assertColumnsAreEqual(expected, result); - } - try (ColumnVector expected = ColumnVector.fromInts(2, 3, 3, 3, 2); - ColumnVector result = v1.rollingWindow(Aggregation.count(NullPolicy.INCLUDE), options)) { - assertColumnsAreEqual(expected, result); - } + WindowOptions options = WindowOptions.builder().window(2, 1) + .minPeriods(2).build(); + try (ColumnVector v1 = ColumnVector.fromBoxedInts(5, 4, null, 6, 8)) { + try (ColumnVector expected = ColumnVector.fromInts(2, 2, 2, 2, 2); + ColumnVector result = v1.rollingWindow(Aggregation.count(NullPolicy.EXCLUDE), options)) { + assertColumnsAreEqual(expected, result); + } + try (ColumnVector expected = ColumnVector.fromInts(2, 3, 3, 3, 2); + ColumnVector result = v1.rollingWindow(Aggregation.count(NullPolicy.INCLUDE), options)) { + assertColumnsAreEqual(expected, result); } } } @@ -2327,26 +2325,24 @@ void testWindowStaticCounts() { void testWindowDynamicNegative() { try (ColumnVector precedingCol = ColumnVector.fromInts(3, 3, 3, 4, 4); ColumnVector followingCol = ColumnVector.fromInts(-1, -1, -1, -1, 0)) { - try (WindowOptions window = WindowOptions.builder() - .minPeriods(2).window(precedingCol, followingCol).build()) { - try (ColumnVector v1 = ColumnVector.fromInts(5, 4, 7, 6, 8); - ColumnVector expected = ColumnVector.fromBoxedLongs(null, null, 9L, 16L, 25L); - ColumnVector result = v1.rollingWindow(Aggregation.sum(), window)) { - assertColumnsAreEqual(expected, result); - } + WindowOptions window = WindowOptions.builder() + .minPeriods(2).window(precedingCol, followingCol).build(); + try (ColumnVector v1 = ColumnVector.fromInts(5, 4, 7, 6, 8); + ColumnVector expected = ColumnVector.fromBoxedLongs(null, null, 9L, 16L, 25L); + ColumnVector result = v1.rollingWindow(Aggregation.sum(), window)) { + assertColumnsAreEqual(expected, result); } } } @Test void testWindowLag() { - try (WindowOptions window = WindowOptions.builder().minPeriods(1) - .window(2, -1).build()) { - try (ColumnVector v1 = ColumnVector.fromInts(5, 4, 7, 6, 8); - ColumnVector expected = ColumnVector.fromBoxedInts(null, 5, 4, 7, 6); - ColumnVector result = v1.rollingWindow(Aggregation.max(), window)) { - assertColumnsAreEqual(expected, result); - } + WindowOptions window = WindowOptions.builder().minPeriods(1) + .window(2, -1).build(); + try (ColumnVector v1 = ColumnVector.fromInts(5, 4, 7, 6, 8); + ColumnVector expected = ColumnVector.fromBoxedInts(null, 5, 4, 7, 6); + ColumnVector result = v1.rollingWindow(Aggregation.max(), window)) { + assertColumnsAreEqual(expected, result); } } @@ -2354,13 +2350,12 @@ void testWindowLag() { void testWindowDynamic() { try (ColumnVector precedingCol = ColumnVector.fromInts(1, 2, 3, 1, 2); ColumnVector followingCol = ColumnVector.fromInts(2, 2, 2, 2, 2)) { - try (WindowOptions window = WindowOptions.builder().minPeriods(2) - .window(precedingCol, followingCol).build()) { - try (ColumnVector v1 = ColumnVector.fromInts(5, 4, 7, 6, 8); - ColumnVector expected = ColumnVector.fromLongs(16, 22, 30, 14, 14); - ColumnVector result = v1.rollingWindow(Aggregation.sum(), window)) { - assertColumnsAreEqual(expected, result); - } + WindowOptions window = WindowOptions.builder().minPeriods(2) + .window(precedingCol, followingCol).build(); + try (ColumnVector v1 = ColumnVector.fromInts(5, 4, 7, 6, 8); + ColumnVector expected = ColumnVector.fromLongs(16, 22, 30, 14, 14); + ColumnVector result = v1.rollingWindow(Aggregation.sum(), window)) { + assertColumnsAreEqual(expected, result); } } } @@ -2368,23 +2363,17 @@ void testWindowDynamic() { @Test void testWindowThrowsException() { try (ColumnVector arraywindowCol = ColumnVector.fromBoxedInts(1, 2, 3 ,1, 1)) { - assertThrows(IllegalArgumentException.class, () -> { - try (WindowOptions options = WindowOptions.builder() - .window(3, 2).minPeriods(3) - .window(arraywindowCol, arraywindowCol).build()) { - - } - }); - - assertThrows(IllegalArgumentException.class, () -> { - try (WindowOptions options = WindowOptions.builder() - .window(2, 1) - .minPeriods(1) - .orderByColumnIndex(0) - .build()) { - arraywindowCol.rollingWindow(Aggregation.sum(), options); - } - }); + assertThrows(IllegalArgumentException.class, () -> WindowOptions.builder() + .window(3, 2).minPeriods(3) + .window(arraywindowCol, arraywindowCol).build()); + + assertThrows(IllegalArgumentException.class, + () -> arraywindowCol.rollingWindow(Aggregation.sum(), + WindowOptions.builder() + .window(2, 1) + .minPeriods(1) + .timestampColumnIndex(0) + .build())); } } diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 041f386f3f9..735dc86af17 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -2750,19 +2750,18 @@ void testWindowingCount() { ColumnVector decSortedAggColumn = decSorted.getColumn(3); assertColumnsAreEqual(expectSortedAggColumn, decSortedAggColumn); - try (WindowOptions window = WindowOptions.builder() + WindowOptions window = WindowOptions.builder() .minPeriods(1) .window(2, 1) - .build()) { + .build(); - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation.count().onColumn(3).overWindow(window)); - Table decWindowAggResults = decSorted.groupBy(0, 4) - .aggregateWindows(Aggregation.count().onColumn(3).overWindow(window)); - ColumnVector expect = ColumnVector.fromBoxedInts(2, 3, 3, 2, 2, 3, 3, 2, 2, 3, 3, 2)) { - assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); - assertColumnsAreEqual(expect, decWindowAggResults.getColumn(0)); - } + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation.count().onColumn(3).overWindow(window)); + Table decWindowAggResults = decSorted.groupBy(0, 4) + .aggregateWindows(Aggregation.count().onColumn(3).overWindow(window)); + ColumnVector expect = ColumnVector.fromBoxedInts(2, 3, 3, 2, 2, 3, 3, 2, 2, 3, 3, 2)) { + assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); + assertColumnsAreEqual(expect, decWindowAggResults.getColumn(0)); } } } @@ -2788,20 +2787,19 @@ void testWindowingMin() { ColumnVector decSortedAggColumn = decSorted.getColumn(6); assertColumnsAreEqual(expectDecSortedAggCol, decSortedAggColumn); - try (WindowOptions window = WindowOptions.builder() + WindowOptions window = WindowOptions.builder() .minPeriods(1) .window(2, 1) - .build()) { + .build(); - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation.min().onColumn(3).overWindow(window)); - Table decWindowAggResults = decSorted.groupBy(0, 4) - .aggregateWindows(Aggregation.min().onColumn(6).overWindow(window)); - ColumnVector expect = ColumnVector.fromBoxedInts(5, 1, 1, 1, 7, 7, 2, 2, 0, 0, 0, 6); - ColumnVector decExpect = ColumnVector.decimalFromLongs(2, 5, 1, 1, 1, 7, 7, 2, 2, 0, 0, 0, 6)) { - assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); - assertColumnsAreEqual(decExpect, decWindowAggResults.getColumn(0)); - } + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation.min().onColumn(3).overWindow(window)); + Table decWindowAggResults = decSorted.groupBy(0, 4) + .aggregateWindows(Aggregation.min().onColumn(6).overWindow(window)); + ColumnVector expect = ColumnVector.fromBoxedInts(5, 1, 1, 1, 7, 7, 2, 2, 0, 0, 0, 6); + ColumnVector decExpect = ColumnVector.decimalFromLongs(2, 5, 1, 1, 1, 7, 7, 2, 2, 0, 0, 0, 6)) { + assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); + assertColumnsAreEqual(decExpect, decWindowAggResults.getColumn(0)); } } } @@ -2827,20 +2825,19 @@ void testWindowingMax() { ColumnVector decSortedAggColumn = decSorted.getColumn(6); assertColumnsAreEqual(expectDecSortedAggCol, decSortedAggColumn); - try (WindowOptions window = WindowOptions.builder() + WindowOptions window = WindowOptions.builder() .minPeriods(1) .window(2, 1) - .build()) { + .build(); - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation.max().onColumn(3).overWindow(window)); - Table decWindowAggResults = decSorted.groupBy(0, 4) - .aggregateWindows(Aggregation.max().onColumn(6).overWindow(window)); - ColumnVector expect = ColumnVector.fromBoxedInts(7, 7, 9, 9, 9, 9, 9, 8, 8, 8, 6, 6); - ColumnVector decExpect = ColumnVector.decimalFromLongs(2, 7, 7, 9, 9, 9, 9, 9, 8, 8, 8, 6, 6)) { - assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); - assertColumnsAreEqual(decExpect, decWindowAggResults.getColumn(0)); - } + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation.max().onColumn(3).overWindow(window)); + Table decWindowAggResults = decSorted.groupBy(0, 4) + .aggregateWindows(Aggregation.max().onColumn(6).overWindow(window)); + ColumnVector expect = ColumnVector.fromBoxedInts(7, 7, 9, 9, 9, 9, 9, 8, 8, 8, 6, 6); + ColumnVector decExpect = ColumnVector.decimalFromLongs(2, 7, 7, 9, 9, 9, 9, 9, 8, 8, 8, 6, 6)) { + assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); + assertColumnsAreEqual(decExpect, decWindowAggResults.getColumn(0)); } } } @@ -2859,16 +2856,15 @@ void testWindowingSum() { ColumnVector sortedAggColumn = sorted.getColumn(3); assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); - try (WindowOptions window = WindowOptions.builder() + WindowOptions window = WindowOptions.builder() .minPeriods(1) .window(2, 1) - .build()) { + .build(); - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation.sum().onColumn(3).overWindow(window)); - ColumnVector expectAggResult = ColumnVector.fromBoxedLongs(12L, 13L, 15L, 10L, 16L, 24L, 19L, 10L, 8L, 14L, 12L, 12L)) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); - } + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation.sum().onColumn(3).overWindow(window)); + ColumnVector expectAggResult = ColumnVector.fromBoxedLongs(12L, 13L, 15L, 10L, 16L, 24L, 19L, 10L, 8L, 14L, 12L, 12L)) { + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); } } } @@ -2896,58 +2892,49 @@ void testWindowingRowNumber() { WindowOptions.Builder windowBuilder = WindowOptions.builder().minPeriods(1); - try (WindowOptions options = windowBuilder.window(2, 1).build(); - WindowOptions options1 = windowBuilder.window(2, 1).build()) { - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation - .rowNumber() - .onColumn(3) - .overWindow(options)); - Table decWindowAggResults = decSorted.groupBy(0, 4) - .aggregateWindows(Aggregation - .rowNumber() - .onColumn(6) - .overWindow(options1)); - ColumnVector expectAggResult = ColumnVector.fromBoxedInts(1, 2, 2, 2, 1, 2, 2, 2, 1, 2, 2, 2)) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); - assertColumnsAreEqual(expectAggResult, decWindowAggResults.getColumn(0)); - } + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation + .rowNumber() + .onColumn(3) + .overWindow(windowBuilder.window(2, 1).build())); + Table decWindowAggResults = decSorted.groupBy(0, 4) + .aggregateWindows(Aggregation + .rowNumber() + .onColumn(6) + .overWindow(windowBuilder.window(2, 1).build())); + ColumnVector expectAggResult = ColumnVector.fromBoxedInts(1, 2, 2, 2, 1, 2, 2, 2, 1, 2, 2, 2)) { + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + assertColumnsAreEqual(expectAggResult, decWindowAggResults.getColumn(0)); } - try (WindowOptions options = windowBuilder.window(3, 2).build(); - WindowOptions options1 = windowBuilder.window(3, 2).build()) { - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation - .rowNumber() - .onColumn(3) - .overWindow(options)); - Table decWindowAggResults = decSorted.groupBy(0, 4) - .aggregateWindows(Aggregation - .rowNumber() - .onColumn(6) - .overWindow(options1)); - ColumnVector expectAggResult = ColumnVector.fromBoxedInts(1, 2, 3, 3, 1, 2, 3, 3, 1, 2, 3, 3)) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); - assertColumnsAreEqual(expectAggResult, decWindowAggResults.getColumn(0)); - } + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation + .rowNumber() + .onColumn(3) + .overWindow(windowBuilder.window(3, 2).build())); + Table decWindowAggResults = decSorted.groupBy(0, 4) + .aggregateWindows(Aggregation + .rowNumber() + .onColumn(6) + .overWindow(windowBuilder.window(3, 2).build())); + ColumnVector expectAggResult = ColumnVector.fromBoxedInts(1, 2, 3, 3, 1, 2, 3, 3, 1, 2, 3, 3)) { + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + assertColumnsAreEqual(expectAggResult, decWindowAggResults.getColumn(0)); } - try (WindowOptions options = windowBuilder.window(4, 3).build(); - WindowOptions options1 = windowBuilder.window(4, 3).build()) { - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation - .rowNumber() - .onColumn(3) - .overWindow(options)); - Table decWindowAggResults = decSorted.groupBy(0, 4) - .aggregateWindows(Aggregation - .rowNumber() - .onColumn(6) - .overWindow(options1)); - ColumnVector expectAggResult = ColumnVector.fromBoxedInts(1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4)) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); - assertColumnsAreEqual(expectAggResult, decWindowAggResults.getColumn(0)); - } + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation + .rowNumber() + .onColumn(3) + .overWindow(windowBuilder.window(4, 3).build())); + Table decWindowAggResults = decSorted.groupBy(0, 4) + .aggregateWindows(Aggregation + .rowNumber() + .onColumn(6) + .overWindow(windowBuilder.window(4, 3).build())); + ColumnVector expectAggResult = ColumnVector.fromBoxedInts(1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4)) { + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + assertColumnsAreEqual(expectAggResult, decWindowAggResults.getColumn(0)); } } } @@ -2957,71 +2944,69 @@ void testWindowingRowNumber() { void testWindowingCollectList() { Aggregation aggCollectWithNulls = Aggregation.collectList(NullPolicy.INCLUDE); Aggregation aggCollect = Aggregation.collectList(); - try (WindowOptions winOpts = WindowOptions.builder() - .minPeriods(1) - .window(2, 1) - .build()) { - StructType nestedType = new StructType(false, - new BasicType(false, DType.INT32), new BasicType(false, DType.STRING)); - try (Table raw = new Table.TestBuilder() - .column(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column(1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3) // GBY Key - .column(1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8) // OBY Key - .column(7, 5, 1, 9, 7, 9, 8, 2, null, 0, 6, null) // Agg Column of INT32 - .column(nestedType, // Agg Column of Struct - new StructData(1, "s1"), new StructData(2, "s2"), new StructData(3, "s3"), - new StructData(4, "s4"), new StructData(11, "s11"), new StructData(22, "s22"), - new StructData(33, "s33"), new StructData(44, "s44"), new StructData(111, "s111"), - new StructData(222, "s222"), new StructData(333, "s333"), new StructData(444, "s444") - ).build(); - ColumnVector expectSortedAggColumn = ColumnVector - .fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, null, 0, 6, null)) { - try (Table sorted = raw.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2))) { - ColumnVector sortedAggColumn = sorted.getColumn(3); - assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); - - // Primitive type: INT32 - // a) including nulls - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(aggCollectWithNulls.onColumn(3).overWindow(winOpts)); - ColumnVector expected = ColumnVector.fromLists( - new ListType(false, new BasicType(false, DType.INT32)), - Arrays.asList(7, 5), Arrays.asList(7, 5, 1), Arrays.asList(5, 1, 9), Arrays.asList(1, 9), - Arrays.asList(7, 9), Arrays.asList(7, 9, 8), Arrays.asList(9, 8, 2), Arrays.asList(8, 2), - Arrays.asList(null, 0), Arrays.asList(null, 0, 6), Arrays.asList(0, 6, null), Arrays.asList(6, null))) { - assertColumnsAreEqual(expected, windowAggResults.getColumn(0)); - } - // b) excluding nulls - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(aggCollect.onColumn(3).overWindow(winOpts)); - ColumnVector expected = ColumnVector.fromLists( - new ListType(false, new BasicType(false, DType.INT32)), - Arrays.asList(7, 5), Arrays.asList(7, 5, 1), Arrays.asList(5, 1, 9), Arrays.asList(1, 9), - Arrays.asList(7, 9), Arrays.asList(7, 9, 8), Arrays.asList(9, 8, 2), Arrays.asList(8, 2), - Arrays.asList(0), Arrays.asList(0, 6), Arrays.asList(0, 6), Arrays.asList(6))) { - assertColumnsAreEqual(expected, windowAggResults.getColumn(0)); - } + WindowOptions winOpts = WindowOptions.builder() + .minPeriods(1) + .window(2, 1).build(); + StructType nestedType = new StructType(false, + new BasicType(false, DType.INT32), new BasicType(false, DType.STRING)); + try (Table raw = new Table.TestBuilder() + .column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column( 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3) // GBY Key + .column( 1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8) // OBY Key + .column( 7, 5, 1, 9, 7, 9, 8, 2, null, 0, 6, null) // Agg Column of INT32 + .column(nestedType, // Agg Column of Struct + new StructData(1, "s1"), new StructData(2, "s2"), new StructData(3, "s3"), + new StructData(4, "s4"), new StructData(11, "s11"), new StructData(22, "s22"), + new StructData(33, "s33"), new StructData(44, "s44"), new StructData(111, "s111"), + new StructData(222, "s222"), new StructData(333, "s333"), new StructData(444, "s444") + ).build(); + ColumnVector expectSortedAggColumn = ColumnVector + .fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, null, 0, 6, null)) { + try (Table sorted = raw.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2))) { + ColumnVector sortedAggColumn = sorted.getColumn(3); + assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); - // Nested type: Struct - List[] expectedNestedData = new List[12]; - expectedNestedData[0] = Arrays.asList(new StructData(1, "s1"), new StructData(2, "s2")); - expectedNestedData[1] = Arrays.asList(new StructData(1, "s1"), new StructData(2, "s2"), new StructData(3, "s3")); - expectedNestedData[2] = Arrays.asList(new StructData(2, "s2"), new StructData(3, "s3"), new StructData(4, "s4")); - expectedNestedData[3] = Arrays.asList(new StructData(3, "s3"), new StructData(4, "s4")); - expectedNestedData[4] = Arrays.asList(new StructData(11, "s11"), new StructData(22, "s22")); - expectedNestedData[5] = Arrays.asList(new StructData(11, "s11"), new StructData(22, "s22"), new StructData(33, "s33")); - expectedNestedData[6] = Arrays.asList(new StructData(22, "s22"), new StructData(33, "s33"), new StructData(44, "s44")); - expectedNestedData[7] = Arrays.asList(new StructData(33, "s33"), new StructData(44, "s44")); - expectedNestedData[8] = Arrays.asList(new StructData(111, "s111"), new StructData(222, "s222")); - expectedNestedData[9] = Arrays.asList(new StructData(111, "s111"), new StructData(222, "s222"), new StructData(333, "s333")); - expectedNestedData[10] = Arrays.asList(new StructData(222, "s222"), new StructData(333, "s333"), new StructData(444, "s444")); - expectedNestedData[11] = Arrays.asList(new StructData(333, "s333"), new StructData(444, "s444")); - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(aggCollect.onColumn(4).overWindow(winOpts)); - ColumnVector expected = ColumnVector.fromLists( - new ListType(false, nestedType), expectedNestedData)) { - assertColumnsAreEqual(expected, windowAggResults.getColumn(0)); - } + // Primitive type: INT32 + // a) including nulls + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(aggCollectWithNulls.onColumn(3).overWindow(winOpts)); + ColumnVector expected = ColumnVector.fromLists( + new ListType(false, new BasicType(false, DType.INT32)), + Arrays.asList(7,5), Arrays.asList(7,5,1), Arrays.asList(5,1,9), Arrays.asList(1,9), + Arrays.asList(7,9), Arrays.asList(7,9,8), Arrays.asList(9,8,2), Arrays.asList(8,2), + Arrays.asList(null,0), Arrays.asList(null,0,6), Arrays.asList(0,6,null), Arrays.asList(6,null))) { + assertColumnsAreEqual(expected, windowAggResults.getColumn(0)); + } + // b) excluding nulls + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(aggCollect.onColumn(3).overWindow(winOpts)); + ColumnVector expected = ColumnVector.fromLists( + new ListType(false, new BasicType(false, DType.INT32)), + Arrays.asList(7,5), Arrays.asList(7,5,1), Arrays.asList(5,1,9), Arrays.asList(1,9), + Arrays.asList(7,9), Arrays.asList(7,9,8), Arrays.asList(9,8,2), Arrays.asList(8,2), + Arrays.asList(0), Arrays.asList(0,6), Arrays.asList(0,6), Arrays.asList(6))) { + assertColumnsAreEqual(expected, windowAggResults.getColumn(0)); + } + + // Nested type: Struct + List[] expectedNestedData = new List[12]; + expectedNestedData[0] = Arrays.asList(new StructData(1, "s1"),new StructData(2, "s2")); + expectedNestedData[1] = Arrays.asList(new StructData(1, "s1"),new StructData(2, "s2"),new StructData(3, "s3")); + expectedNestedData[2] = Arrays.asList(new StructData(2, "s2"),new StructData(3, "s3"),new StructData(4, "s4")); + expectedNestedData[3] = Arrays.asList(new StructData(3, "s3"),new StructData(4, "s4")); + expectedNestedData[4] = Arrays.asList(new StructData(11, "s11"),new StructData(22, "s22")); + expectedNestedData[5] = Arrays.asList(new StructData(11, "s11"),new StructData(22, "s22"),new StructData(33, "s33")); + expectedNestedData[6] = Arrays.asList(new StructData(22, "s22"),new StructData(33, "s33"), new StructData(44, "s44")); + expectedNestedData[7] = Arrays.asList(new StructData(33, "s33"), new StructData(44, "s44")); + expectedNestedData[8] = Arrays.asList(new StructData(111, "s111"),new StructData(222, "s222")); + expectedNestedData[9] = Arrays.asList(new StructData(111, "s111"),new StructData(222, "s222"),new StructData(333, "s333")); + expectedNestedData[10] = Arrays.asList(new StructData(222, "s222"),new StructData(333, "s333"),new StructData(444, "s444")); + expectedNestedData[11] = Arrays.asList(new StructData(333, "s333"),new StructData(444, "s444")); + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(aggCollect.onColumn(4).overWindow(winOpts)); + ColumnVector expected = ColumnVector.fromLists( + new ListType(false, nestedType), expectedNestedData)) { + assertColumnsAreEqual(expected, windowAggResults.getColumn(0)); } } } @@ -3050,83 +3035,71 @@ void testWindowingLead() { WindowOptions.Builder windowBuilder = WindowOptions.builder().minPeriods(1); - try (WindowOptions options = windowBuilder.window(2, 1).build(); - WindowOptions options1 = windowBuilder.window(2, 1).build()) { - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation - .lead(0) - .onColumn(3) - .overWindow(options)); - Table decWindowAggResults = decSorted.groupBy(0, 4) - .aggregateWindows(Aggregation - .lead(0) - .onColumn(6) - .overWindow(options1)); - ColumnVector expectAggResult = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6); - ColumnVector decExpectAggResult = ColumnVector.decimalFromLongs(-2, 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6)) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); - assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); - } + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation + .lead(0) + .onColumn(3) + .overWindow(windowBuilder.window(2, 1).build())); + Table decWindowAggResults = decSorted.groupBy(0, 4) + .aggregateWindows(Aggregation + .lead(0) + .onColumn(6) + .overWindow(windowBuilder.window(2, 1).build())); + ColumnVector expectAggResult = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6); + ColumnVector decExpectAggResult = ColumnVector.decimalFromLongs(-2, 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6)) { + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); } - try (WindowOptions options = windowBuilder.window(0,1).build(); - WindowOptions options1 = windowBuilder.window(0,1).build()) { - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation - .lead(1) - .onColumn(3) - .overWindow(options)); - Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation - .lead(1) - .onColumn(6) - .overWindow(options1)); - ColumnVector expectAggResult = ColumnVector.fromBoxedInts(5, 1, 9, null, 9, 8, 2, null, 0, 6, 6, null); - ColumnVector decExpectAggResult = decimalFromBoxedInts(true, -2, 5, 1, 9, null, 9, 8, 2, null, 0, 6, 6, null)) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); - assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); - } + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation + .lead(1) + .onColumn(3) + .overWindow(windowBuilder.window(0,1).build())); + Table decWindowAggResults = sorted.groupBy(0, 4) + .aggregateWindows(Aggregation + .lead(1) + .onColumn(6) + .overWindow(windowBuilder.window(0,1).build())); + ColumnVector expectAggResult = ColumnVector.fromBoxedInts(5, 1, 9, null, 9, 8, 2, null, 0, 6, 6, null); + ColumnVector decExpectAggResult = decimalFromBoxedInts(true, -2, 5, 1, 9, null, 9, 8, 2, null, 0, 6, 6, null)) { + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); } - try (WindowOptions options = windowBuilder.window(0,1).build(); - WindowOptions options1 = windowBuilder.window(0,1).build()) { - try (ColumnVector defaultOutput = ColumnVector.fromBoxedInts(0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11); - ColumnVector decDefaultOutput = ColumnVector.decimalFromLongs(-2, 0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11); - Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation - .lead(1, defaultOutput) - .onColumn(3) - .overWindow(options)); - Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation - .lead(1, decDefaultOutput) - .onColumn(6) - .overWindow(options1)); - ColumnVector expectAggResult = ColumnVector.fromBoxedInts(5, 1, 9, -3, 9, 8, 2, -7, 0, 6, 6, -11); - ColumnVector decExpectAggResult = ColumnVector.decimalFromLongs(-2, 5, 1, 9, -3, 9, 8, 2, -7, 0, 6, 6, -11)) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); - assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); - } + try (ColumnVector defaultOutput = ColumnVector.fromBoxedInts(0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11); + ColumnVector decDefaultOutput = ColumnVector.decimalFromLongs(-2, 0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11); + Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation + .lead(1, defaultOutput) + .onColumn(3) + .overWindow(windowBuilder.window(0,1).build())); + Table decWindowAggResults = sorted.groupBy(0, 4) + .aggregateWindows(Aggregation + .lead(1, decDefaultOutput) + .onColumn(6) + .overWindow(windowBuilder.window(0,1).build())); + ColumnVector expectAggResult = ColumnVector.fromBoxedInts(5, 1, 9, -3, 9, 8, 2, -7, 0, 6, 6, -11); + ColumnVector decExpectAggResult = ColumnVector.decimalFromLongs(-2, 5, 1, 9, -3, 9, 8, 2, -7, 0, 6, 6, -11)) { + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); } // Outside bounds - try (WindowOptions options = windowBuilder.window(0,1).build(); - WindowOptions options1 = windowBuilder.window(0,1).build()) { - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation - .lead(3) - .onColumn(3) - .overWindow(options)); - Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation - .lead(3) - .onColumn(6) - .overWindow(options1)); - ColumnVector expectAggResult = ColumnVector.fromBoxedInts(null, null, null, null, null, null, null, null, null, null, null, null); - ColumnVector decExpectAggResult = decimalFromBoxedInts(true, -2, null, null, null, null, null, null, null, null, null, null, null, null)) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); - assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); - } + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation + .lead(3) + .onColumn(3) + .overWindow(windowBuilder.window(0,1).build())); + Table decWindowAggResults = sorted.groupBy(0, 4) + .aggregateWindows(Aggregation + .lead(3) + .onColumn(6) + .overWindow(windowBuilder.window(0,1).build())); + ColumnVector expectAggResult = ColumnVector.fromBoxedInts(null, null, null, null, null, null, null, null, null, null, null, null); + ColumnVector decExpectAggResult = decimalFromBoxedInts(true, -2, null, null, null, null, null, null, null, null, null, null, null, null)) { + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); } } } @@ -3155,83 +3128,71 @@ void testWindowingLag() { WindowOptions.Builder windowBuilder = WindowOptions.builder().minPeriods(1); - try (WindowOptions options = windowBuilder.window(2,1).build(); - WindowOptions options1 = windowBuilder.window(2,1).build()) { - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation - .lag(0) - .onColumn(3) - .overWindow(options)); - Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation - .lag(0) - .onColumn(6) - .overWindow(options1)); - ColumnVector expectAggResult = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6); - ColumnVector decExpectAggResult = ColumnVector.decimalFromLongs(-2, 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6)) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); - assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); - } + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation + .lag(0) + .onColumn(3) + .overWindow(windowBuilder.window(2,1).build())); + Table decWindowAggResults = sorted.groupBy(0, 4) + .aggregateWindows(Aggregation + .lag(0) + .onColumn(6) + .overWindow(windowBuilder.window(2,1).build())); + ColumnVector expectAggResult = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6); + ColumnVector decExpectAggResult = ColumnVector.decimalFromLongs(-2, 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6)) { + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); } - try (WindowOptions options = windowBuilder.window(2,0).build(); - WindowOptions options1 = windowBuilder.window(2,0).build()) { - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation - .lag(1) - .onColumn(3) - .overWindow(options)); - Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation - .lag(1) - .onColumn(6) - .overWindow(options1)); - ColumnVector expectAggResult = ColumnVector.fromBoxedInts(null, 7, 5, 1, null, 7, 9, 8, null, 8, 0, 6); - ColumnVector decExpectAggResult = decimalFromBoxedInts(true, -2, null, 7, 5, 1, null, 7, 9, 8, null, 8, 0, 6)) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); - assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); - } + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation + .lag(1) + .onColumn(3) + .overWindow(windowBuilder.window(2,0).build())); + Table decWindowAggResults = sorted.groupBy(0, 4) + .aggregateWindows(Aggregation + .lag(1) + .onColumn(6) + .overWindow(windowBuilder.window(2,0).build())); + ColumnVector expectAggResult = ColumnVector.fromBoxedInts(null, 7, 5, 1, null, 7, 9, 8, null, 8, 0, 6); + ColumnVector decExpectAggResult = decimalFromBoxedInts(true, -2, null, 7, 5, 1, null, 7, 9, 8, null, 8, 0, 6)) { + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); } - try (WindowOptions options = windowBuilder.window(2, 0).build(); - WindowOptions options1 = windowBuilder.window(2, 0).build()) { - try (ColumnVector defaultOutput = ColumnVector.fromBoxedInts(0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11); - ColumnVector decDefaultOutput = ColumnVector.decimalFromLongs(-2, 0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11); - Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation - .lag(1, defaultOutput) - .onColumn(3) - .overWindow(options)); - Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation - .lag(1, decDefaultOutput) - .onColumn(6) - .overWindow(options1)); - ColumnVector expectAggResult = ColumnVector.fromBoxedInts(0, 7, 5, 1, -4, 7, 9, 8, -8, 8, 0, 6); - ColumnVector decExpectAggResult = ColumnVector.decimalFromLongs(-2, 0, 7, 5, 1, -4, 7, 9, 8, -8, 8, 0, 6)) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); - assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); - } + try (ColumnVector defaultOutput = ColumnVector.fromBoxedInts(0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11); + ColumnVector decDefaultOutput = ColumnVector.decimalFromLongs(-2, 0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11); + Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation + .lag(1, defaultOutput) + .onColumn(3) + .overWindow(windowBuilder.window(2, 0).build())); + Table decWindowAggResults = sorted.groupBy(0, 4) + .aggregateWindows(Aggregation + .lag(1, decDefaultOutput) + .onColumn(6) + .overWindow(windowBuilder.window(2, 0).build())); + ColumnVector expectAggResult = ColumnVector.fromBoxedInts(0, 7, 5, 1, -4, 7, 9, 8, -8, 8, 0, 6); + ColumnVector decExpectAggResult = ColumnVector.decimalFromLongs(-2, 0, 7, 5, 1, -4, 7, 9, 8, -8, 8, 0, 6)) { + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); } // Outside bounds - try (WindowOptions options = windowBuilder.window(1, 0).build(); - WindowOptions options1 = windowBuilder.window(1, 0).build()) { - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation - .lag(3) - .onColumn(3) - .overWindow(options)); - Table decWindowAggResults = sorted.groupBy(0, 4) - .aggregateWindows(Aggregation - .lag(3) - .onColumn(6) - .overWindow(options1)); - ColumnVector expectAggResult = ColumnVector.fromBoxedInts(null, null, null, null, null, null, null, null, null, null, null, null); - ColumnVector decExpectAggResult = decimalFromBoxedInts(true, -2, null, null, null, null, null, null, null, null, null, null, null, null);) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); - assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); - } + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation + .lag(3) + .onColumn(3) + .overWindow(windowBuilder.window(1, 0).build())); + Table decWindowAggResults = sorted.groupBy(0, 4) + .aggregateWindows(Aggregation + .lag(3) + .onColumn(6) + .overWindow(windowBuilder.window(1, 0).build())); + ColumnVector expectAggResult = ColumnVector.fromBoxedInts(null, null, null, null, null, null, null, null, null, null, null, null); + ColumnVector decExpectAggResult = decimalFromBoxedInts(true, -2, null, null, null, null, null, null, null, null, null, null, null, null);) { + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); + assertColumnsAreEqual(decExpectAggResult, decWindowAggResults.getColumn(0)); } } } @@ -3240,25 +3201,24 @@ void testWindowingLag() { @Test void testWindowingMean() { try (Table unsorted = new Table.TestBuilder().column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column( 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3) // GBY Key - .column( 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6) // OBY Key - .column( 7, 5, 3, 7, 7, 9, 8, 4, 8, 0, 4, 8) // Agg Column - .build()) { + .column( 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3) // GBY Key + .column( 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6) // OBY Key + .column( 7, 5, 3, 7, 7, 9, 8, 4, 8, 0, 4, 8) // Agg Column + .build()) { try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2)); ColumnVector expectedSortedAggCol = ColumnVector.fromBoxedInts(7, 5, 3, 7, 7, 9, 8, 4, 8, 0, 4, 8)) { ColumnVector sortedAggColumn = sorted.getColumn(3); assertColumnsAreEqual(expectedSortedAggCol, sortedAggColumn); - try (WindowOptions window = WindowOptions.builder() - .minPeriods(1) - .window(2, 1) - .build();) { + WindowOptions window = WindowOptions.builder() + .minPeriods(1) + .window(2, 1) + .build(); - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation.mean().onColumn(3).overWindow(window)); - ColumnVector expect = ColumnVector.fromBoxedDoubles(6.0d, 5.0d, 5.0d, 5.0d, 8.0d, 8.0d, 7.0d, 6.0d, 4.0d, 4.0d, 4.0d, 6.0d)) { - assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); - } + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation.mean().onColumn(3).overWindow(window)); + ColumnVector expect = ColumnVector.fromBoxedDoubles(6.0d, 5.0d, 5.0d, 5.0d, 8.0d, 8.0d, 7.0d, 6.0d, 4.0d, 4.0d, 4.0d, 6.0d)) { + assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); } } } @@ -3267,50 +3227,48 @@ void testWindowingMean() { @Test void testWindowingOnMultipleDifferentColumns() { try (Table unsorted = new Table.TestBuilder().column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column( 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3) // GBY Key - .column( 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6) // OBY Key - .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6) // Agg Column - .build()) { + .column( 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3) // GBY Key + .column( 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6) // OBY Key + .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6) // Agg Column + .build()) { try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2)); ColumnVector expectedSortedAggCol = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6)) { ColumnVector sortedAggColumn = sorted.getColumn(3); assertColumnsAreEqual(expectedSortedAggCol, sortedAggColumn); - try ( // Window (1,1), with a minimum of 1 reading. WindowOptions window_1 = WindowOptions.builder() - .minPeriods(1) - .window(2, 1) - .build(); + .minPeriods(1) + .window(2, 1) + .build(); // Window (2,2), with a minimum of 2 readings. WindowOptions window_2 = WindowOptions.builder() - .minPeriods(2) - .window(3, 2) - .build(); + .minPeriods(2) + .window(3, 2) + .build(); // Window (1,1), with a minimum of 3 readings. WindowOptions window_3 = WindowOptions.builder() - .minPeriods(3) - .window(2, 1) - .build();) { - - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows( - Aggregation.sum().onColumn(3).overWindow(window_1), - Aggregation.max().onColumn(3).overWindow(window_1), - Aggregation.sum().onColumn(3).overWindow(window_2), - Aggregation.min().onColumn(2).overWindow(window_3) - ); - ColumnVector expect_0 = ColumnVector.fromBoxedLongs(12L, 13L, 15L, 10L, 16L, 24L, 19L, 10L, 8L, 14L, 12L, 12L); - ColumnVector expect_1 = ColumnVector.fromBoxedInts(7, 7, 9, 9, 9, 9, 9, 8, 8, 8, 6, 6); - ColumnVector expect_2 = ColumnVector.fromBoxedLongs(13L, 22L, 22L, 15L, 24L, 26L, 26L, 19L, 14L, 20L, 20L, 12L); - ColumnVector expect_3 = ColumnVector.fromBoxedInts(null, 1, 1, null, null, 3, 3, null, null, 5, 5, null)) { - assertColumnsAreEqual(expect_0, windowAggResults.getColumn(0)); - assertColumnsAreEqual(expect_1, windowAggResults.getColumn(1)); - assertColumnsAreEqual(expect_2, windowAggResults.getColumn(2)); - assertColumnsAreEqual(expect_3, windowAggResults.getColumn(3)); - } + .minPeriods(3) + .window(2, 1) + .build(); + + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows( + Aggregation.sum().onColumn(3).overWindow(window_1), + Aggregation.max().onColumn(3).overWindow(window_1), + Aggregation.sum().onColumn(3).overWindow(window_2), + Aggregation.min().onColumn(2).overWindow(window_3) + ); + ColumnVector expect_0 = ColumnVector.fromBoxedLongs(12L, 13L, 15L, 10L, 16L, 24L, 19L, 10L, 8L, 14L, 12L, 12L); + ColumnVector expect_1 = ColumnVector.fromBoxedInts(7, 7, 9, 9, 9, 9, 9, 8, 8, 8, 6, 6); + ColumnVector expect_2 = ColumnVector.fromBoxedLongs(13L, 22L, 22L, 15L, 24L, 26L, 26L, 19L, 14L, 20L, 20L, 12L); + ColumnVector expect_3 = ColumnVector.fromBoxedInts(null, 1, 1, null, null, 3, 3, null, null, 5, 5, null)) { + assertColumnsAreEqual(expect_0, windowAggResults.getColumn(0)); + assertColumnsAreEqual(expect_1, windowAggResults.getColumn(1)); + assertColumnsAreEqual(expect_2, windowAggResults.getColumn(2)); + assertColumnsAreEqual(expect_3, windowAggResults.getColumn(3)); } } } @@ -3319,800 +3277,556 @@ void testWindowingOnMultipleDifferentColumns() { @Test void testWindowingWithoutGroupByColumns() { try (Table unsorted = new Table.TestBuilder().column( 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6) // OBY Key - .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6) // Agg Column - .build(); + .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6) // Agg Column + .build(); ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6)) { try (Table sorted = unsorted.orderBy(OrderByArg.asc(0))) { ColumnVector sortedAggColumn = sorted.getColumn(1); assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); - try (WindowOptions window = WindowOptions.builder() - .minPeriods(1) - .window(2, 1) - .build();) { + WindowOptions window = WindowOptions.builder() + .minPeriods(1) + .window(2, 1) + .build(); - try (Table windowAggResults = sorted.groupBy().aggregateWindows( - Aggregation.sum().onColumn(1).overWindow(window)); - ColumnVector expectAggResult = ColumnVector.fromBoxedLongs(12L, 13L, 15L, 17L, 25L, 24L, 19L, 18L, 10L, 14L, 12L, 12L); - ) { - assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); - } + try (Table windowAggResults = sorted.groupBy().aggregateWindows( + Aggregation.sum().onColumn(1).overWindow(window)); + ColumnVector expectAggResult = ColumnVector.fromBoxedLongs(12L, 13L, 15L, 17L, 25L, 24L, 19L, 18L, 10L, 14L, 12L, 12L); + ) { + assertColumnsAreEqual(expectAggResult, windowAggResults.getColumn(0)); } } } } - private Scalar getScalar(DType type, long value) { - if (type.equals(DType.INT32)) { - return Scalar.fromInt((int) value); - } else if (type.equals(DType.INT64)) { - return Scalar.fromLong(value); - } else if (type.equals(DType.INT16)) { - return Scalar.fromShort((short) value); - } else if (type.equals(DType.INT8)) { - return Scalar.fromByte((byte) value); - } else if (type.equals(DType.UINT8)) { - return Scalar.fromUnsignedByte((byte) value); - } else if (type.equals(DType.UINT16)) { - return Scalar.fromUnsignedShort((short) value); - } else if (type.equals(DType.UINT32)) { - return Scalar.fromUnsignedInt((int) value); - } else if (type.equals(DType.UINT64)) { - return Scalar.fromUnsignedLong(value); - } else if (type.equals(DType.TIMESTAMP_DAYS)) { - return Scalar.durationFromLong(DType.DURATION_DAYS, value); - } else if (type.equals(DType.TIMESTAMP_SECONDS)) { - return Scalar.durationFromLong(DType.DURATION_SECONDS, value); - } else if (type.equals(DType.TIMESTAMP_MILLISECONDS)) { - return Scalar.durationFromLong(DType.DURATION_MILLISECONDS, value); - } else if (type.equals(DType.TIMESTAMP_MICROSECONDS)) { - return Scalar.durationFromLong(DType.DURATION_MICROSECONDS, value); - } else if (type.equals(DType.TIMESTAMP_NANOSECONDS)) { - return Scalar.durationFromLong(DType.DURATION_NANOSECONDS, value); - } else { - return Scalar.fromNull(type); - } - } - - @Test - void testRangeWindowingCount() { - try ( - Table unsorted = new Table.TestBuilder() - .column(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column(0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key - .column(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column - .column(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) // orderBy Key - .column((short) 1, (short)1, (short)2, (short)3, (short)3, (short)3, (short)4, (short)4, (short)5, (short)5, (short)6, (short)6, (short)7) // orderBy Key - .column(1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // orderBy Key - .column((byte) 1, (byte)1, (byte)2, (byte)3, (byte)3, (byte)3, (byte)4, (byte)4, (byte)5, (byte)5, (byte)6, (byte)6, (byte)7) // orderBy Key - .timestampDayColumn(1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // timestamp orderBy Key - .timestampSecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) - .timestampMicrosecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) - .timestampMillisecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) - .timestampNanosecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) - .build()) { + @Test + void testTimeRangeWindowingCount() { + try (Table unsorted = new Table.TestBuilder().column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column( 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key + .timestampDayColumn( 1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // Timestamp Key + .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column + .build()) { + try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2)); + ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { + ColumnVector sortedAggColumn = sorted.getColumn(3); + assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); - for (int orderIndex = 3; orderIndex < unsorted.getNumberOfColumns(); orderIndex++) { - try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(orderIndex)); - ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { - ColumnVector sortedAggColumn = sorted.getColumn(2); - assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); + WindowOptions window = WindowOptions.builder() + .minPeriods(1) + .window(1, 1) + .timestampColumnIndex(2) + .build(); - DType type = unsorted.getColumn(orderIndex).getType(); - try (Scalar preceding = getScalar(type, 1L); - Scalar following = getScalar(type, 1L)) { - try (WindowOptions window = WindowOptions.builder() - .minPeriods(1) - .window(preceding, following) - .orderByColumnIndex(orderIndex) - .build()) { - try (Table windowAggResults = sorted.groupBy(0, 1).aggregateWindowsOverRanges( - Aggregation.count().onColumn(2).overWindow(window)); - ColumnVector expect = ColumnVector.fromBoxedInts(3, 3, 4, 2, 4, 4, 4, 4, 4, 4, 5, 5, 3)) { - assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); - } - } - } + try (Table windowAggResults = sorted.groupBy(0, 1).aggregateWindowsOverTimeRanges( + Aggregation.count().onColumn(3).overWindow(window)); + ColumnVector expect = ColumnVector.fromBoxedInts(3, 3, 4, 2, 4, 4, 4, 4, 4, 4, 5, 5, 3)) { + assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); } } } } @Test - void testRangeWindowingLead() { + void testTimeRangeWindowingLead() { try (Table unsorted = new Table.TestBuilder() - .column(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column(0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key - .column(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column - .column(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) // orderBy Key - .column((short) 1, (short)1, (short)2, (short)3, (short)3, (short)3, (short)4, (short)4, (short)5, (short)5, (short)6, (short)6, (short)7) // orderBy Key - .column(1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // orderBy Key - .column((byte) 1, (byte)1, (byte)2, (byte)3, (byte)3, (byte)3, (byte)4, (byte)4, (byte)5, (byte)5, (byte)6, (byte)6, (byte)7) // orderBy Key - .timestampDayColumn(1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // Timestamp orderBy Key - .timestampSecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) - .timestampMicrosecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) - .timestampMillisecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) - .timestampNanosecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) + .column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column( 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key + .timestampDayColumn( 1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // Timestamp Key + .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column .build()) { + try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2)); + ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { + ColumnVector sortedAggColumn = sorted.getColumn(3); + assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); - for (int orderIndex = 3; orderIndex < unsorted.getNumberOfColumns(); orderIndex++) { - try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(orderIndex)); - ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { - ColumnVector sortedAggColumn = sorted.getColumn(2); - assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); + WindowOptions window = WindowOptions.builder() + .minPeriods(1) + .window(1, 1) + .timestampColumnIndex(2) + .build(); - DType type = unsorted.getColumn(orderIndex).getType(); - try (Scalar preceding = getScalar(type, 1L); - Scalar following = getScalar(type, 1L)) { - try (WindowOptions window = WindowOptions.builder() - .minPeriods(1) - .window(preceding, following) - .orderByColumnIndex(orderIndex) - .build()) { - - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindowsOverRanges(Aggregation.lead(1) - .onColumn(2) - .overWindow(window)); - ColumnVector expect = ColumnVector.fromBoxedInts(5, 1, 9, null, 9, 8, 2, null, 0, 6, 6, 8, null)) { - assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); - } - } - } + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindowsOverTimeRanges(Aggregation.lead(1) + .onColumn(3) + .overWindow(window)); + ColumnVector expect = ColumnVector.fromBoxedInts(5, 1, 9, null, 9, 8, 2, null, 0, 6, 6, 8, null)) { + assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); } } } } @Test - void testRangeWindowingMax() { - try (Table unsorted = new Table.TestBuilder() - .column(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column(0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key - .column(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column - .column(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) // orderBy Key - .column((short) 1, (short)1, (short)2, (short)3, (short)3, (short)3, (short)4, (short)4, (short)5, (short)5, (short)6, (short)6, (short)7) // orderBy Key - .column(1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // orderBy Key - .column((byte) 1, (byte)1, (byte)2, (byte)3, (byte)3, (byte)3, (byte)4, (byte)4, (byte)5, (byte)5, (byte)6, (byte)6, (byte)7) // orderBy Key - .timestampDayColumn(1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // Timestamp orderBy Key - .timestampSecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) - .timestampMicrosecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) - .timestampMillisecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) - .timestampNanosecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) - .build()) { - - for (int orderIndex = 3; orderIndex < unsorted.getNumberOfColumns(); orderIndex++) { - try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(orderIndex)); - ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { - ColumnVector sortedAggColumn = sorted.getColumn(2); - assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); + void testTimeRangeWindowingMax() { + try (Table unsorted = new Table.TestBuilder().column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column( 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key + .timestampDayColumn( 1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // Timestamp Key + .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column + .build()) { + try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2)); + ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { + ColumnVector sortedAggColumn = sorted.getColumn(3); + assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); - DType type = unsorted.getColumn(orderIndex).getType(); - try (Scalar preceding = getScalar(type, 1L); - Scalar following = getScalar(type, 1L)) { - try (WindowOptions window = WindowOptions.builder() - .minPeriods(1) - .window(preceding, following) - .orderByColumnIndex(orderIndex) - .build()) { - - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindowsOverRanges(Aggregation.max().onColumn(2).overWindow(window)); - ColumnVector expect = ColumnVector.fromBoxedInts(7, 7, 9, 9, 9, 9, 9, 9, 8, 8, 8, 8, 8)) { - assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); - } - } + WindowOptions window = WindowOptions.builder() + .minPeriods(1) + .window(1, 1) + .timestampColumnIndex(2) + .build(); - try (WindowOptions window = WindowOptions.builder() - .minPeriods(1) - .window(2, 1) - .build();) { - - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindows(Aggregation.max().onColumn(2).overWindow(window)); - ColumnVector expect = ColumnVector.fromBoxedInts(7, 7, 9, 9, 9, 9, 9, 8, 8, 8, 6, 8, 8)) { - assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); - } - } - } + try (Table windowAggResults = sorted.groupBy(0, 1).aggregateWindowsOverTimeRanges( + Aggregation.max().onColumn(3).overWindow(window)); + ColumnVector expect = ColumnVector.fromBoxedInts( 7, 7, 9, 9, 9, 9, 9, 9, 8, 8, 8, 8, 8)) { + assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); } - } - } - } - - @Test - void testRangeWindowingRowNumber() { - try (Table unsorted = new Table.TestBuilder() - .column(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column(0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key - .column(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column - .column(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) // orderBy Key - .column((short) 1, (short)1, (short)2, (short)3, (short)3, (short)3, (short)4, (short)4, (short)5, (short)5, (short)6, (short)6, (short)7) // orderBy Key - .column(1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // orderBy Key - .column((byte) 1, (byte)1, (byte)2, (byte)3, (byte)3, (byte)3, (byte)4, (byte)4, (byte)5, (byte)5, (byte)6, (byte)6, (byte)7) // orderBy Key - .timestampDayColumn(1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // Timestamp orderBy Key - .timestampSecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) - .timestampMicrosecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) - .timestampMillisecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) - .timestampNanosecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) - .build()) { - for (int orderIndex = 3; orderIndex < unsorted.getNumberOfColumns(); orderIndex++) { - try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(orderIndex)); - ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { - ColumnVector sortedAggColumn = sorted.getColumn(2); - assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); + window = WindowOptions.builder() + .minPeriods(1) + .window(2, 1) + .build(); - DType type = unsorted.getColumn(orderIndex).getType(); - try (Scalar preceding = getScalar(type, 2L); - Scalar following = getScalar(type, 0L)) { - try (WindowOptions window = WindowOptions.builder() - .minPeriods(1) - .window(preceding, following) - .orderByColumnIndex(orderIndex) - .build()) { - - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindowsOverRanges(Aggregation.rowNumber().onColumn(2).overWindow(window)); - ColumnVector expect = ColumnVector.fromBoxedInts(1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 5)) { - assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); - } - } - } + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindows(Aggregation.max().onColumn(3).overWindow(window)); + ColumnVector expect = ColumnVector.fromBoxedInts( 7, 7, 9, 9, 9, 9, 9, 8, 8, 8, 6, 8, 8)) { + assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); } } } } @Test - void testRangeWindowingCountDescendingTimestamps() { - try (Table unsorted = new Table.TestBuilder() - .column(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column(0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key - .column(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column - .column((short)7, (short)6, (short)6, (short)5, (short)5, (short)4, (short)4, (short)3, (short)3, (short)3, (short)2, (short)1, (short)1) - .column(7L, 6L, 6L, 5L, 5L, 4L, 4L, 3L, 3L, 3L, 2L, 1L, 1L) - .column(7, 6, 6, 5, 5, 4, 4, 3, 3, 3, 2, 1, 1) - .column((byte)7, (byte)6, (byte)6, (byte)5, (byte)5, (byte)4, (byte)4, (byte)3, (byte)3, (byte)3, (byte)2, (byte)1, (byte)1) - .timestampDayColumn(7, 6, 6, 5, 5, 4, 4, 3, 3, 3, 2, 1, 1) // Timestamp Key - .timestampSecondsColumn(7L, 6L, 6L, 5L, 5L, 4L, 4L, 3L, 3L, 3L, 2L, 1L, 1L) - .timestampMicrosecondsColumn(7L, 6L, 6L, 5L, 5L, 4L, 4L, 3L, 3L, 3L, 2L, 1L, 1L) - .timestampMillisecondsColumn(7L, 6L, 6L, 5L, 5L, 4L, 4L, 3L, 3L, 3L, 2L, 1L, 1L) - .timestampNanosecondsColumn(7L, 6L, 6L, 5L, 5L, 4L, 4L, 3L, 3L, 3L, 2L, 1L, 1L) - .build()) { + void testTimeRangeWindowingRowNumber() { + try (Table unsorted = new Table.TestBuilder().column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column( 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key + .timestampDayColumn( 1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // Timestamp Key + .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column + .build()) { + try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2)); + ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { + ColumnVector sortedAggColumn = sorted.getColumn(3); + assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); - for (int orderIndex = 3; orderIndex < unsorted.getNumberOfColumns(); orderIndex++) { - try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.desc(orderIndex)); - ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { - ColumnVector sortedAggColumn = sorted.getColumn(2); - assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); - - DType type = unsorted.getColumn(orderIndex).getType(); - try (Scalar preceding_0 = getScalar(type, 2L); - Scalar following_0 = getScalar(type, 1L); - Scalar preceding_1 = getScalar(type, 3L); - Scalar following_1 = getScalar(type, 0L)) { - - try (WindowOptions window_0 = WindowOptions.builder() - .minPeriods(1) - .window(preceding_0, following_0) - .orderByColumnIndex(orderIndex) - .orderByDescending() - .build(); - - WindowOptions window_1 = WindowOptions.builder() - .minPeriods(1) - .window(preceding_1, following_1) - .orderByColumnIndex(orderIndex) - .orderByDescending() - .build();) { - - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindowsOverRanges( - Aggregation.count().onColumn(2).overWindow(window_0), - Aggregation.sum().onColumn(2).overWindow(window_1)); - ColumnVector expect_0 = ColumnVector.fromBoxedInts(3, 4, 4, 4, 3, 4, 4, 4, 3, 3, 5, 5, 5); - ColumnVector expect_1 = ColumnVector.fromBoxedLongs(7L, 13L, 13L, 22L, 7L, 24L, 24L, 26L, 8L, 8L, 14L, 28L, 28L)) { - assertColumnsAreEqual(expect_0, windowAggResults.getColumn(0)); - assertColumnsAreEqual(expect_1, windowAggResults.getColumn(1)); - } - } - } + WindowOptions window = WindowOptions.builder() + .minPeriods(1) + .window(3, 0) + .timestampColumnIndex(2) + .build(); + + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindowsOverTimeRanges(Aggregation.rowNumber().onColumn(3).overWindow(window)); + ColumnVector expect = ColumnVector.fromBoxedInts(1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 5)) { + assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); } } } } @Test - void testRangeWindowingWithoutGroupByColumns() { - try (Table unsorted = new Table.TestBuilder() - .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column - .column(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) // orderBy Key - .column((short) 1, (short)1, (short)2, (short)3, (short)3, (short)3, (short)4, (short)4, (short)5, (short)5, (short)6, (short)6, (short)7) // orderBy Key - .column(1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // orderBy Key - .column((byte) 1, (byte)1, (byte)2, (byte)3, (byte)3, (byte)3, (byte)4, (byte)4, (byte)5, (byte)5, (byte)6, (byte)6, (byte)7) // orderBy Key - .timestampDayColumn(1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // Timestamp orderBy Key - .timestampSecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) - .timestampMicrosecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) - .timestampMillisecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) - .timestampNanosecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) - .build()) { + void testTimeRangeWindowingCountDescendingTimestamps() { + try (Table unsorted = new Table.TestBuilder().column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column( 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key + .timestampDayColumn( 7, 6, 6, 5, 5, 4, 4, 3, 3, 3, 2, 1, 1) // Timestamp Key + .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column + .build()) { + try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.desc(2)); + ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { + ColumnVector sortedAggColumn = sorted.getColumn(3); + assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); - for (int orderIndex = 3; orderIndex < unsorted.getNumberOfColumns(); orderIndex++) { - try (Table sorted = unsorted.orderBy(OrderByArg.asc(orderIndex)); - ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { - ColumnVector sortedAggColumn = sorted.getColumn(0); - assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); + WindowOptions window_0 = WindowOptions.builder() + .minPeriods(1) + .window(2, 1) + .timestampColumnIndex(2) + .timestampDescending() + .build(); - DType type = unsorted.getColumn(orderIndex).getType(); - try (Scalar preceding = getScalar(type, 1L); - Scalar following = getScalar(type, 1L)) { - try (WindowOptions window = WindowOptions.builder() - .minPeriods(1) - .window(preceding, following) - .orderByColumnIndex(orderIndex) - .build();) { - - try (Table windowAggResults = sorted.groupBy() - .aggregateWindowsOverRanges(Aggregation.count().onColumn(1).overWindow(window)); - ColumnVector expect = ColumnVector.fromBoxedInts(3, 3, 6, 6, 6, 6, 7, 7, 6, 6, 5, 5, 3)) { - assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); - } - } - } + WindowOptions window_1 = WindowOptions.builder() + .minPeriods(1) + .window(3, 0) + .timestampColumnIndex(2) + .timestampDescending() + .build(); + + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindowsOverTimeRanges( + Aggregation.count().onColumn(3).overWindow(window_0), + Aggregation.sum().onColumn(3).overWindow(window_1)); + ColumnVector expect_0 = ColumnVector.fromBoxedInts(3, 4, 4, 4, 3, 4, 4, 4, 3, 3, 5, 5, 5); + ColumnVector expect_1 = ColumnVector.fromBoxedLongs(7L, 13L, 13L, 22L, 7L, 24L, 24L, 26L, 8L, 8L, 14L, 28L, 28L)) { + assertColumnsAreEqual(expect_0, windowAggResults.getColumn(0)); + assertColumnsAreEqual(expect_1, windowAggResults.getColumn(1)); } } } } @Test - void testRangeWindowingOrderByUnsupportedDataTypeExceptions() { - try (Table table = new Table.TestBuilder() - .column(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column(0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key - .column(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column - .column(true, false, true, false, true, false, true, false, false, false, false, false, false) // orderBy Key - .build()) { + void testTimeRangeWindowingWithoutGroupByColumns() { + try (Table unsorted = new Table.TestBuilder().timestampDayColumn( 1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // Timestamp Key + .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column + .build()) { + try (Table sorted = unsorted.orderBy(OrderByArg.asc(0)); + ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { + ColumnVector sortedAggColumn = sorted.getColumn(1); + assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); - try (WindowOptions rangeBasedWindow = WindowOptions.builder() - .minPeriods(1) - .window(1, 1) - .orderByColumnIndex(3) - .build();) { + WindowOptions window = WindowOptions.builder() + .minPeriods(1) + .window(1, 1) + .timestampColumnIndex(0) + .build(); - assertThrows(IllegalArgumentException.class, - () -> table - .groupBy(0, 1) - .aggregateWindowsOverRanges(Aggregation.max().onColumn(2).overWindow(rangeBasedWindow))); + try (Table windowAggResults = sorted.groupBy() + .aggregateWindowsOverTimeRanges(Aggregation.count().onColumn(1).overWindow(window)); + ColumnVector expect = ColumnVector.fromBoxedInts(3, 3, 6, 6, 6, 6, 7, 7, 6, 6, 5, 5, 3)) { + assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); + } } } } @Test void testInvalidWindowTypeExceptions() { - try (Table table = new Table.TestBuilder().column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column( 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key - .timestampDayColumn( 1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // Timestamp Key - .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column - .build()) { + try (Table table = new Table.TestBuilder().column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column( 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key + .timestampDayColumn( 1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // Timestamp Key + .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column + .build()) { + WindowOptions rowBasedWindow = WindowOptions.builder().minPeriods(1).window(1,1).build(); + assertThrows(IllegalArgumentException.class, () -> + table.groupBy(0, 1) + .aggregateWindowsOverTimeRanges(Aggregation.max().onColumn(3).overWindow(rowBasedWindow))); - try (WindowOptions rowBasedWindow = WindowOptions.builder().minPeriods(1).window(1,1).build();) { - assertThrows(IllegalArgumentException.class, () -> table.groupBy(0, 1).aggregateWindowsOverRanges(Aggregation.max().onColumn(3).overWindow(rowBasedWindow))); - } + WindowOptions rangeBasedWindow = WindowOptions.builder().minPeriods(1).window(1,1).timestampColumnIndex(2).build(); + assertThrows(IllegalArgumentException.class, () -> + table.groupBy(0, 1) + .aggregateWindows(Aggregation.max().onColumn(3).overWindow(rangeBasedWindow))); - try (WindowOptions rangeBasedWindow = WindowOptions.builder().minPeriods(1).window(1,1).orderByColumnIndex(2).build();) { - assertThrows(IllegalArgumentException.class, () -> table.groupBy(0, 1).aggregateWindows(Aggregation.max().onColumn(3).overWindow(rangeBasedWindow))); } - } } @Test - void testRangeWindowingCountUnboundedPreceding() { - try (Table unsorted = new Table.TestBuilder() - .column(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column(0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key - .column(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column - .column(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) // orderBy Key - .column((short) 1, (short)1, (short)2, (short)3, (short)3, (short)3, (short)4, (short)4, (short)5, (short)5, (short)6, (short)6, (short)7) // orderBy Key - .column(1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // orderBy Key - .column((byte) 1, (byte)1, (byte)2, (byte)3, (byte)3, (byte)3, (byte)4, (byte)4, (byte)5, (byte)5, (byte)6, (byte)6, (byte)7) // orderBy Key - .timestampDayColumn(1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // Timestamp orderBy Key - .timestampSecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) - .timestampMicrosecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) - .timestampMillisecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) - .timestampNanosecondsColumn(1L, 1L, 2L, 3L, 3L, 3L, 4L, 4L, 5L, 5L, 6L, 6L, 7L) - .build()) { + void testTimeRangeWindowingCountUnboundedPreceding() { + try (Table unsorted = new Table.TestBuilder().column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column( 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2) // GBY Key + .timestampDayColumn( 1, 1, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7) // Timestamp Key + .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column + .build()) { + try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2)); + ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { + ColumnVector sortedAggColumn = sorted.getColumn(3); + assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); - for (int orderIndex = 3; orderIndex < unsorted.getNumberOfColumns(); orderIndex++) { - try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(orderIndex)); - ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { - ColumnVector sortedAggColumn = sorted.getColumn(2); - assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); + WindowOptions window = WindowOptions.builder() + .minPeriods(1) + .unboundedPreceding() + .following(1) + .timestampColumnIndex(2) + .build(); - DType type = unsorted.getColumn(orderIndex).getType(); - try (Scalar following = getScalar(type, 1L)) { - try (WindowOptions window = WindowOptions.builder() - .minPeriods(1) - .unboundedPreceding() - .following(following) - .orderByColumnIndex(orderIndex) - .build();) { - - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindowsOverRanges(Aggregation.count().onColumn(2).overWindow(window)); - ColumnVector expect = ColumnVector.fromBoxedInts(3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5)) { - assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); - } - } - } + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindowsOverTimeRanges(Aggregation.count().onColumn(3).overWindow(window)); + ColumnVector expect = ColumnVector.fromBoxedInts(3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5)) { + assertColumnsAreEqual(expect, windowAggResults.getColumn(0)); } } } } @Test - void testRangeWindowingCountUnboundedASCWithNullsFirst() { + void testTimeRangeWindowingCountUnboundedASCWithNullsFirst() { + Integer X = null; try (Table unsorted = new Table.TestBuilder() - .column(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column(0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column - .column( null, null, null, 2, 3, 5, null, null, 1, 2, 4, 5, 7) // Timestamp Key - .column( null, null, null, 2L, 3L, 5L, null, null, 1L, 2L, 4L, 5L, 7L) // orderBy Key - .column( null, null, null, (short)2, (short)3, (short)5, null, null, (short)1, (short)2, (short)4, (short)5, (short)7) // orderBy Key - .column( null, null, null, (byte)2, (byte)3, (byte)5, null, null, (byte)1, (byte)2, (byte)4, (byte)5, (byte)7) // orderBy Key - .timestampDayColumn( null, null, null, 2, 3, 5, null, null, 1, 2, 4, 5, 7) // Timestamp orderBy Key - .timestampSecondsColumn( null, null, null, 2L, 3L, 5L, null, null, 1L, 2L, 4L, 5L, 7L) - .timestampMicrosecondsColumn( null, null, null, 2L, 3L, 5L, null, null, 1L, 2L, 4L, 5L, 7L) - .timestampMillisecondsColumn( null, null, null, 2L, 3L, 5L, null, null, 1L, 2L, 4L, 5L, 7L) - .timestampNanosecondsColumn( null, null, null, 2L, 3L, 5L, null, null, 1L, 2L, 4L, 5L, 7L) - .build()) { + .column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column( 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .timestampDayColumn( X, X, X, 2, 3, 5, X, X, 1, 2, 4, 5, 7) // Timestamp Key + .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column + .build()) { + try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2, true)); + ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { + ColumnVector sortedAggColumn = sorted.getColumn(3); + assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); - for (int orderIndex = 3; orderIndex < unsorted.getNumberOfColumns(); orderIndex++) { - try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(orderIndex, true)); - ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { - ColumnVector sortedAggColumn = sorted.getColumn(2); - assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); - - DType type = unsorted.getColumn(orderIndex).getType(); - try (Scalar following1 = getScalar(type, 1L); - Scalar preceding1 = getScalar(type, 1L); - Scalar following0 = getScalar(type, 0L); - Scalar preceding0 = getScalar(type, 0L);) { - try (WindowOptions unboundedPrecedingOneFollowing = WindowOptions.builder() - .minPeriods(1) - .unboundedPreceding() - .following(following1) - .orderByColumnIndex(orderIndex) - .build(); + WindowOptions unboundedPrecedingOneFollowing = WindowOptions.builder() + .minPeriods(1) + .unboundedPreceding() + .following(1) + .timestampColumnIndex(2) + .build(); - WindowOptions onePrecedingUnboundedFollowing = WindowOptions.builder() + WindowOptions onePrecedingUnboundedFollowing = WindowOptions.builder() .minPeriods(1) - .preceding(preceding1) + .preceding(1) .unboundedFollowing() - .orderByColumnIndex(orderIndex) + .timestampColumnIndex(2) .build(); - WindowOptions unboundedPrecedingAndFollowing = WindowOptions.builder() + WindowOptions unboundedPrecedingAndFollowing = WindowOptions.builder() .minPeriods(1) .unboundedPreceding() .unboundedFollowing() - .orderByColumnIndex(orderIndex) + .timestampColumnIndex(2) .build(); - WindowOptions unboundedPrecedingAndCurrentRow = WindowOptions.builder() + WindowOptions unboundedPrecedingAndCurrentRow = WindowOptions.builder() .minPeriods(1) .unboundedPreceding() - .following(following0) - .orderByColumnIndex(orderIndex) + .following(0) + .timestampColumnIndex(2) .build(); - WindowOptions currentRowAndUnboundedFollowing = WindowOptions.builder() + WindowOptions currentRowAndUnboundedFollowing = WindowOptions.builder() .minPeriods(1) - .preceding(preceding0) + .preceding(0) .unboundedFollowing() - .orderByColumnIndex(orderIndex) - .build();) { - - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindowsOverRanges( - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingOneFollowing), - Aggregation.count().onColumn(2).overWindow(onePrecedingUnboundedFollowing), - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingAndFollowing), - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingAndCurrentRow), - Aggregation.count().onColumn(2).overWindow(currentRowAndUnboundedFollowing)); - ColumnVector expect_0 = ColumnVector.fromBoxedInts(3, 3, 3, 5, 5, 6, 2, 2, 4, 4, 6, 6, 7); - ColumnVector expect_1 = ColumnVector.fromBoxedInts(6, 6, 6, 3, 3, 1, 7, 7, 5, 5, 3, 3, 1); - ColumnVector expect_2 = ColumnVector.fromBoxedInts(6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7); - ColumnVector expect_3 = ColumnVector.fromBoxedInts(3, 3, 3, 4, 5, 6, 2, 2, 3, 4, 5, 6, 7); - ColumnVector expect_4 = ColumnVector.fromBoxedInts(6, 6, 6, 3, 2, 1, 7, 7, 5, 4, 3, 2, 1)) { - - assertColumnsAreEqual(expect_0, windowAggResults.getColumn(0)); - assertColumnsAreEqual(expect_1, windowAggResults.getColumn(1)); - assertColumnsAreEqual(expect_2, windowAggResults.getColumn(2)); - assertColumnsAreEqual(expect_3, windowAggResults.getColumn(3)); - assertColumnsAreEqual(expect_4, windowAggResults.getColumn(4)); - } - } - } + .timestampColumnIndex(2) + .build(); + + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindowsOverTimeRanges( + Aggregation.count().onColumn(3).overWindow(unboundedPrecedingOneFollowing), + Aggregation.count().onColumn(3).overWindow(onePrecedingUnboundedFollowing), + Aggregation.count().onColumn(3).overWindow(unboundedPrecedingAndFollowing), + Aggregation.count().onColumn(3).overWindow(unboundedPrecedingAndCurrentRow), + Aggregation.count().onColumn(3).overWindow(currentRowAndUnboundedFollowing)); + ColumnVector expect_0 = ColumnVector.fromBoxedInts(3, 3, 3, 5, 5, 6, 2, 2, 4, 4, 6, 6, 7); + ColumnVector expect_1 = ColumnVector.fromBoxedInts(6, 6, 6, 3, 3, 1, 7, 7, 5, 5, 3, 3, 1); + ColumnVector expect_2 = ColumnVector.fromBoxedInts(6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7); + ColumnVector expect_3 = ColumnVector.fromBoxedInts(3, 3, 3, 4, 5, 6, 2, 2, 3, 4, 5, 6, 7); + ColumnVector expect_4 = ColumnVector.fromBoxedInts(6, 6, 6, 3, 2, 1, 7, 7, 5, 4, 3, 2, 1)) { + + assertColumnsAreEqual(expect_0, windowAggResults.getColumn(0)); + assertColumnsAreEqual(expect_1, windowAggResults.getColumn(1)); + assertColumnsAreEqual(expect_2, windowAggResults.getColumn(2)); + assertColumnsAreEqual(expect_3, windowAggResults.getColumn(3)); + assertColumnsAreEqual(expect_4, windowAggResults.getColumn(4)); } } } } @Test - void testRangeWindowingCountUnboundedDESCWithNullsFirst() { + void testTimeRangeWindowingCountUnboundedDESCWithNullsFirst() { + Integer X = null; try (Table unsorted = new Table.TestBuilder() - .column(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column(0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column - .column(null, null, null, 5, 3, 2, null, null, 7, 5, 4, 2, 1) // Timestamp Key - .column(null, null, null, 5L, 3L, 2L, null, null, 7L, 5L, 4L, 2L, 1L) // orderby Key - .column(null, null, null, (short)5, (short)3, (short)2, null, null, (short)7, (short)5, (short)4, (short)2, (short)1) // orderby Key - .column(null, null, null, (byte)5, (byte)3, (byte)2, null, null, (byte)7, (byte)5, (byte)4, (byte)2, (byte)1) // orderby Key - .timestampDayColumn(null, null, null, 5, 3, 2, null, null, 7, 5, 4, 2, 1) // Timestamp orderby Key - .timestampSecondsColumn( null, null, null, 5L, 3L, 2L, null, null, 7L, 5L, 4L, 2L, 1L) - .timestampMicrosecondsColumn( null, null, null, 5L, 3L, 2L, null, null, 7L, 5L, 4L, 2L, 1L) - .timestampMillisecondsColumn( null, null, null, 5L, 3L, 2L, null, null, 7L, 5L, 4L, 2L, 1L) - .timestampNanosecondsColumn( null, null, null, 5L, 3L, 2L, null, null, 7L, 5L, 4L, 2L, 1L) - .build()) { - - for (int orderIndex = 3; orderIndex < unsorted.getNumberOfColumns(); orderIndex++) { - try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.desc(orderIndex, false)); - ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { - ColumnVector sortedAggColumn = sorted.getColumn(2); - assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); - - DType type = unsorted.getColumn(orderIndex).getType(); - try (Scalar following1 = getScalar(type, 1L); - Scalar preceding1 = getScalar(type, 1L); - Scalar following0 = getScalar(type, 0L); - Scalar preceding0 = getScalar(type, 0L);) { + .column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column( 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .timestampDayColumn( X, X, X, 5, 3, 2, X, X, 7, 5, 4, 2, 1) // Timestamp Key + .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column + .build()) { + try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.desc(2, false)); + ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { + ColumnVector sortedAggColumn = sorted.getColumn(3); + assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); - try (WindowOptions unboundedPrecedingOneFollowing = WindowOptions.builder() + WindowOptions unboundedPrecedingOneFollowing = WindowOptions.builder() .minPeriods(1) .unboundedPreceding() - .following(following1) - .orderByColumnIndex(orderIndex) - .orderByDescending() + .following(1) + .timestampColumnIndex(2) + .timestampDescending() .build(); - WindowOptions onePrecedingUnboundedFollowing = WindowOptions.builder() + WindowOptions onePrecedingUnboundedFollowing = WindowOptions.builder() .minPeriods(1) - .preceding(preceding1) + .preceding(1) .unboundedFollowing() - .orderByColumnIndex(orderIndex) - .orderByDescending() + .timestampColumnIndex(2) + .timestampDescending() .build(); - WindowOptions unboundedPrecedingAndFollowing = WindowOptions.builder() + WindowOptions unboundedPrecedingAndFollowing = WindowOptions.builder() .minPeriods(1) .unboundedPreceding() .unboundedFollowing() - .orderByColumnIndex(orderIndex) - .orderByDescending() + .timestampColumnIndex(2) + .timestampDescending() .build(); - WindowOptions unboundedPrecedingAndCurrentRow = WindowOptions.builder() + WindowOptions unboundedPrecedingAndCurrentRow = WindowOptions.builder() .minPeriods(1) .unboundedPreceding() - .following(following0) - .orderByColumnIndex(orderIndex) - .orderByDescending() + .following(0) + .timestampColumnIndex(2) + .timestampDescending() .build(); - WindowOptions currentRowAndUnboundedFollowing = WindowOptions.builder() + WindowOptions currentRowAndUnboundedFollowing = WindowOptions.builder() .minPeriods(1) - .preceding(preceding0) + .preceding(0) .unboundedFollowing() - .orderByColumnIndex(orderIndex) - .orderByDescending() - .build();) { - - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindowsOverRanges( - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingOneFollowing), - Aggregation.count().onColumn(2).overWindow(onePrecedingUnboundedFollowing), - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingAndFollowing), - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingAndCurrentRow), - Aggregation.count().onColumn(2).overWindow(currentRowAndUnboundedFollowing)); - ColumnVector expect_0 = ColumnVector.fromBoxedInts(3, 3, 3, 4, 6, 6, 2, 2, 3, 5, 5, 7, 7); - ColumnVector expect_1 = ColumnVector.fromBoxedInts(6, 6, 6, 3, 2, 2, 7, 7, 5, 4, 4, 2, 2); - ColumnVector expect_2 = ColumnVector.fromBoxedInts(6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7); - ColumnVector expect_3 = ColumnVector.fromBoxedInts(3, 3, 3, 4, 5, 6, 2, 2, 3, 4, 5, 6, 7); - ColumnVector expect_4 = ColumnVector.fromBoxedInts(6, 6, 6, 3, 2, 1, 7, 7, 5, 4, 3, 2, 1)) { - - assertColumnsAreEqual(expect_0, windowAggResults.getColumn(0)); - assertColumnsAreEqual(expect_1, windowAggResults.getColumn(1)); - assertColumnsAreEqual(expect_2, windowAggResults.getColumn(2)); - assertColumnsAreEqual(expect_3, windowAggResults.getColumn(3)); - assertColumnsAreEqual(expect_4, windowAggResults.getColumn(4)); - } - } - } + .timestampColumnIndex(2) + .timestampDescending() + .build(); + + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindowsOverTimeRanges( + Aggregation.count().onColumn(3).overWindow(unboundedPrecedingOneFollowing), + Aggregation.count().onColumn(3).overWindow(onePrecedingUnboundedFollowing), + Aggregation.count().onColumn(3).overWindow(unboundedPrecedingAndFollowing), + Aggregation.count().onColumn(3).overWindow(unboundedPrecedingAndCurrentRow), + Aggregation.count().onColumn(3).overWindow(currentRowAndUnboundedFollowing)); + ColumnVector expect_0 = ColumnVector.fromBoxedInts(3, 3, 3, 4, 6, 6, 2, 2, 3, 5, 5, 7, 7); + ColumnVector expect_1 = ColumnVector.fromBoxedInts(6, 6, 6, 3, 2, 2, 7, 7, 5, 4, 4, 2, 2); + ColumnVector expect_2 = ColumnVector.fromBoxedInts(6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7); + ColumnVector expect_3 = ColumnVector.fromBoxedInts(3, 3, 3, 4, 5, 6, 2, 2, 3, 4, 5, 6, 7); + ColumnVector expect_4 = ColumnVector.fromBoxedInts(6, 6, 6, 3, 2, 1, 7, 7, 5, 4, 3, 2, 1)) { + + assertColumnsAreEqual(expect_0, windowAggResults.getColumn(0)); + assertColumnsAreEqual(expect_1, windowAggResults.getColumn(1)); + assertColumnsAreEqual(expect_2, windowAggResults.getColumn(2)); + assertColumnsAreEqual(expect_3, windowAggResults.getColumn(3)); + assertColumnsAreEqual(expect_4, windowAggResults.getColumn(4)); } } } } @Test - void testRangeWindowingCountUnboundedASCWithNullsLast() { + void testTimeRangeWindowingCountUnboundedASCWithNullsLast() { + Integer X = null; try (Table unsorted = new Table.TestBuilder() - .column(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column(0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column - .column(2, 3, 5, null, null, null, 1, 2, 4, 5, 7, null, null) // Timestamp Key - .column(2L, 3L, 5L, null, null, null, 1L, 2L, 4L, 5L, 7L, null, null) // order by Key - .column((short)2, (short)3, (short)5, null, null, null, (short)1, (short)2, (short)4, (short)5, (short)7, null, null) // order by Key - .column((byte)2, (byte)3, (byte)5, null, null, null, (byte)1, (byte)2, (byte)4, (byte)5, (byte)7, null, null) // order by Key - .timestampDayColumn( 2, 3, 5, null, null, null, 1, 2, 4, 5, 7, null, null) // Timestamp order by Key - .timestampSecondsColumn( 2L, 3L, 5L, null, null, null, 1L, 2L, 4L, 5L, 7L, null, null) - .timestampMicrosecondsColumn( 2L, 3L, 5L, null, null, null, 1L, 2L, 4L, 5L, 7L, null, null) - .timestampMillisecondsColumn( 2L, 3L, 5L, null, null, null, 1L, 2L, 4L, 5L, 7L, null, null) - .timestampNanosecondsColumn( 2L, 3L, 5L, null, null, null, 1L, 2L, 4L, 5L, 7L, null, null) - .build()) { - for (int orderIndex = 3; orderIndex < unsorted.getNumberOfColumns(); orderIndex++) { - try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(orderIndex, false)); - ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { - ColumnVector sortedAggColumn = sorted.getColumn(2); - assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); - - DType type = unsorted.getColumn(orderIndex).getType(); - try (Scalar following1 = getScalar(type, 1L); - Scalar preceding1 = getScalar(type, 1L); - Scalar following0 = getScalar(type, 0L); - Scalar preceding0 = getScalar(type, 0L);) { - try (WindowOptions unboundedPrecedingOneFollowing = WindowOptions.builder() + .column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column( 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .timestampDayColumn( 2, 3, 5, X, X, X, 1, 2, 4, 5, 7, X, X) // Timestamp Key + .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column + .build()) { + try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.asc(2, false)); + ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { + ColumnVector sortedAggColumn = sorted.getColumn(3); + assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); + WindowOptions unboundedPrecedingOneFollowing = WindowOptions.builder() .minPeriods(1) .unboundedPreceding() - .following(following1) - .orderByColumnIndex(orderIndex) + .following(1) + .timestampColumnIndex(2) .build(); - WindowOptions onePrecedingUnboundedFollowing = WindowOptions.builder() + WindowOptions onePrecedingUnboundedFollowing = WindowOptions.builder() .minPeriods(1) - .preceding(preceding1) + .preceding(1) .unboundedFollowing() - .orderByColumnIndex(orderIndex) + .timestampColumnIndex(2) .build(); - WindowOptions unboundedPrecedingAndFollowing = WindowOptions.builder() + WindowOptions unboundedPrecedingAndFollowing = WindowOptions.builder() .minPeriods(1) .unboundedPreceding() .unboundedFollowing() - .orderByColumnIndex(orderIndex) + .timestampColumnIndex(2) .build(); - WindowOptions unboundedPrecedingAndCurrentRow = WindowOptions.builder() + WindowOptions unboundedPrecedingAndCurrentRow = WindowOptions.builder() .minPeriods(1) .unboundedPreceding() - .following(following0) - .orderByColumnIndex(orderIndex) + .following(0) + .timestampColumnIndex(2) .build(); - WindowOptions currentRowAndUnboundedFollowing = WindowOptions.builder() + WindowOptions currentRowAndUnboundedFollowing = WindowOptions.builder() .minPeriods(1) - .preceding(preceding0) + .preceding(0) .unboundedFollowing() - .orderByColumnIndex(orderIndex) - .build();) { - - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindowsOverRanges( - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingOneFollowing), - Aggregation.count().onColumn(2).overWindow(onePrecedingUnboundedFollowing), - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingAndFollowing), - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingAndCurrentRow), - Aggregation.count().onColumn(2).overWindow(currentRowAndUnboundedFollowing)); - ColumnVector expect_0 = ColumnVector.fromBoxedInts(2, 2, 3, 6, 6, 6, 2, 2, 4, 4, 5, 7, 7); - ColumnVector expect_1 = ColumnVector.fromBoxedInts(6, 6, 4, 3, 3, 3, 7, 7, 5, 5, 3, 2, 2); - ColumnVector expect_2 = ColumnVector.fromBoxedInts(6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7); - ColumnVector expect_3 = ColumnVector.fromBoxedInts(1, 2, 3, 6, 6, 6, 1, 2, 3, 4, 5, 7, 7); - ColumnVector expect_4 = ColumnVector.fromBoxedInts(6, 5, 4, 3, 3, 3, 7, 6, 5, 4, 3, 2, 2)) { - - assertColumnsAreEqual(expect_0, windowAggResults.getColumn(0)); - assertColumnsAreEqual(expect_1, windowAggResults.getColumn(1)); - assertColumnsAreEqual(expect_2, windowAggResults.getColumn(2)); - assertColumnsAreEqual(expect_3, windowAggResults.getColumn(3)); - assertColumnsAreEqual(expect_4, windowAggResults.getColumn(4)); - } - } - } + .timestampColumnIndex(2) + .build(); + + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindowsOverTimeRanges( + Aggregation.count().onColumn(3).overWindow(unboundedPrecedingOneFollowing), + Aggregation.count().onColumn(3).overWindow(onePrecedingUnboundedFollowing), + Aggregation.count().onColumn(3).overWindow(unboundedPrecedingAndFollowing), + Aggregation.count().onColumn(3).overWindow(unboundedPrecedingAndCurrentRow), + Aggregation.count().onColumn(3).overWindow(currentRowAndUnboundedFollowing)); + ColumnVector expect_0 = ColumnVector.fromBoxedInts(2, 2, 3, 6, 6, 6, 2, 2, 4, 4, 5, 7, 7); + ColumnVector expect_1 = ColumnVector.fromBoxedInts(6, 6, 4, 3, 3, 3, 7, 7, 5, 5, 3, 2, 2); + ColumnVector expect_2 = ColumnVector.fromBoxedInts(6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7); + ColumnVector expect_3 = ColumnVector.fromBoxedInts(1, 2, 3, 6, 6, 6, 1, 2, 3, 4, 5, 7, 7); + ColumnVector expect_4 = ColumnVector.fromBoxedInts(6, 5, 4, 3, 3, 3, 7, 6, 5, 4, 3, 2, 2)) { + + assertColumnsAreEqual(expect_0, windowAggResults.getColumn(0)); + assertColumnsAreEqual(expect_1, windowAggResults.getColumn(1)); + assertColumnsAreEqual(expect_2, windowAggResults.getColumn(2)); + assertColumnsAreEqual(expect_3, windowAggResults.getColumn(3)); + assertColumnsAreEqual(expect_4, windowAggResults.getColumn(4)); } } } } @Test - void testRangeWindowingCountUnboundedDESCWithNullsLast() { + void testTimeRangeWindowingCountUnboundedDESCWithNullsLast() { Integer X = null; try (Table unsorted = new Table.TestBuilder() - .column(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column(0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1) // GBY Key - .column(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column - .column( 5, 3, 2, null, null, null, 7, 5, 4, 2, 1, null, null) // Timestamp Key - .column(5L, 3L, 2L, null, null, null, 7L, 5L, 4L, 2L, 1L, null, null) // Timestamp Key - .column((short)5, (short)3, (short)2, null, null, null, (short)7, (short)5, (short)4, (short)2, (short)1, null, null) // Timestamp Key - .column((byte)5, (byte)3, (byte)2, null, null, null, (byte)7, (byte)5, (byte)4, (byte)2, (byte)1, null, null) // Timestamp Key - .timestampDayColumn( 5, 3, 2, X, X, X, 7, 5, 4, 2, 1, X, X) // Timestamp Key - .timestampSecondsColumn( 5L, 3L, 2L, null, null, null, 7L, 5L, 4L, 2L, 1L, null, null) - .timestampMicrosecondsColumn( 5L, 3L, 2L, null, null, null, 7L, 5L, 4L, 2L, 1L, null, null) - .timestampMillisecondsColumn( 5L, 3L, 2L, null, null, null, 7L, 5L, 4L, 2L, 1L, null, null) - .timestampNanosecondsColumn( 5L, 3L, 2L, null, null, null, 7L, 5L, 4L, 2L, 1L, null, null) - .build()) { - for (int orderIndex = 3; orderIndex < unsorted.getNumberOfColumns(); orderIndex++) { - try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.desc(orderIndex, true)); - ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { - ColumnVector sortedAggColumn = sorted.getColumn(2); - assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); - - DType type = unsorted.getColumn(orderIndex).getType(); - try (Scalar following1 = getScalar(type, 1L); - Scalar preceding1 = getScalar(type, 1L); - Scalar following0 = getScalar(type, 0L); - Scalar preceding0 = getScalar(type, 0L);) { - try (WindowOptions unboundedPrecedingOneFollowing = WindowOptions.builder() + .column( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .column( 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1) // GBY Key + .timestampDayColumn( 5, 3, 2, X, X, X, 7, 5, 4, 2, 1, X, X) // Timestamp Key + .column( 7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8) // Agg Column + .build()) { + try (Table sorted = unsorted.orderBy(OrderByArg.asc(0), OrderByArg.asc(1), OrderByArg.desc(2, true)); + ColumnVector expectSortedAggColumn = ColumnVector.fromBoxedInts(7, 5, 1, 9, 7, 9, 8, 2, 8, 0, 6, 6, 8)) { + ColumnVector sortedAggColumn = sorted.getColumn(3); + assertColumnsAreEqual(expectSortedAggColumn, sortedAggColumn); + + WindowOptions unboundedPrecedingOneFollowing = WindowOptions.builder() .minPeriods(1) .unboundedPreceding() - .following(following1) - .orderByColumnIndex(orderIndex) - .orderByDescending() + .following(1) + .timestampColumnIndex(2) + .timestampDescending() .build(); - WindowOptions onePrecedingUnboundedFollowing = WindowOptions.builder() + WindowOptions onePrecedingUnboundedFollowing = WindowOptions.builder() .minPeriods(1) - .preceding(preceding1) + .preceding(1) .unboundedFollowing() - .orderByColumnIndex(orderIndex) - .orderByDescending() + .timestampColumnIndex(2) + .timestampDescending() .build(); - WindowOptions unboundedPrecedingAndFollowing = WindowOptions.builder() + WindowOptions unboundedPrecedingAndFollowing = WindowOptions.builder() .minPeriods(1) .unboundedPreceding() .unboundedFollowing() - .orderByColumnIndex(orderIndex) - .orderByDescending() + .timestampColumnIndex(2) + .timestampDescending() .build(); - WindowOptions unboundedPrecedingAndCurrentRow = WindowOptions.builder() + WindowOptions unboundedPrecedingAndCurrentRow = WindowOptions.builder() .minPeriods(1) .unboundedPreceding() - .following(following0) - .orderByColumnIndex(orderIndex) - .orderByDescending() + .following(0) + .timestampColumnIndex(2) + .timestampDescending() .build(); - WindowOptions currentRowAndUnboundedFollowing = WindowOptions.builder() + WindowOptions currentRowAndUnboundedFollowing = WindowOptions.builder() .minPeriods(1) - .preceding(preceding0) + .preceding(0) .unboundedFollowing() - .orderByColumnIndex(orderIndex) - .orderByDescending() - .build();) { - - try (Table windowAggResults = sorted.groupBy(0, 1) - .aggregateWindowsOverRanges( - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingOneFollowing), - Aggregation.count().onColumn(2).overWindow(onePrecedingUnboundedFollowing), - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingAndFollowing), - Aggregation.count().onColumn(2).overWindow(unboundedPrecedingAndCurrentRow), - Aggregation.count().onColumn(2).overWindow(currentRowAndUnboundedFollowing)); - ColumnVector expect_0 = ColumnVector.fromBoxedInts(1, 3, 3, 6, 6, 6, 1, 3, 3, 5, 5, 7, 7); - ColumnVector expect_1 = ColumnVector.fromBoxedInts(6, 5, 5, 3, 3, 3, 7, 6, 6, 4, 4, 2, 2); - ColumnVector expect_2 = ColumnVector.fromBoxedInts(6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7); - ColumnVector expect_3 = ColumnVector.fromBoxedInts(1, 2, 3, 6, 6, 6, 1, 2, 3, 4, 5, 7, 7); - ColumnVector expect_4 = ColumnVector.fromBoxedInts(6, 5, 4, 3, 3, 3, 7, 6, 5, 4, 3, 2, 2)) { - - assertColumnsAreEqual(expect_0, windowAggResults.getColumn(0)); - assertColumnsAreEqual(expect_1, windowAggResults.getColumn(1)); - assertColumnsAreEqual(expect_2, windowAggResults.getColumn(2)); - assertColumnsAreEqual(expect_3, windowAggResults.getColumn(3)); - assertColumnsAreEqual(expect_4, windowAggResults.getColumn(4)); - } - } - } + .timestampColumnIndex(2) + .timestampDescending() + .build(); + + try (Table windowAggResults = sorted.groupBy(0, 1) + .aggregateWindowsOverTimeRanges( + Aggregation.count().onColumn(3).overWindow(unboundedPrecedingOneFollowing), + Aggregation.count().onColumn(3).overWindow(onePrecedingUnboundedFollowing), + Aggregation.count().onColumn(3).overWindow(unboundedPrecedingAndFollowing), + Aggregation.count().onColumn(3).overWindow(unboundedPrecedingAndCurrentRow), + Aggregation.count().onColumn(3).overWindow(currentRowAndUnboundedFollowing)); + ColumnVector expect_0 = ColumnVector.fromBoxedInts(1, 3, 3, 6, 6, 6, 1, 3, 3, 5, 5, 7, 7); + ColumnVector expect_1 = ColumnVector.fromBoxedInts(6, 5, 5, 3, 3, 3, 7, 6, 6, 4, 4, 2, 2); + ColumnVector expect_2 = ColumnVector.fromBoxedInts(6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7); + ColumnVector expect_3 = ColumnVector.fromBoxedInts(1, 2, 3, 6, 6, 6, 1, 2, 3, 4, 5, 7, 7); + ColumnVector expect_4 = ColumnVector.fromBoxedInts(6, 5, 4, 3, 3, 3, 7, 6, 5, 4, 3, 2, 2)) { + + assertColumnsAreEqual(expect_0, windowAggResults.getColumn(0)); + assertColumnsAreEqual(expect_1, windowAggResults.getColumn(1)); + assertColumnsAreEqual(expect_2, windowAggResults.getColumn(2)); + assertColumnsAreEqual(expect_3, windowAggResults.getColumn(3)); + assertColumnsAreEqual(expect_4, windowAggResults.getColumn(4)); } } } From b7eeaf52e6e719a33946253d2d144982d3105864 Mon Sep 17 00:00:00 2001 From: David Wendt <45795991+davidwendt@users.noreply.github.com> Date: Fri, 14 May 2021 10:22:46 -0400 Subject: [PATCH 05/12] Split up scan.cu to improve compile time (#8183) Moving to thrust 1.12 from 1.10 increased compile time for `scan.cu` significantly. This is likely due to the improvements made to the scan algorithm to use CUB's DeviceScan: https://github.com/NVIDIA/thrust/pull/1304 This PR splits up scan.cu into `scan_exclusive.cu` and `scan_inclusive.cu` to help speed up build time when running parallel compiles. This PR also includes patches to libcudf's thrust's CUB source to disable compiling tuning artifacts for architectures below sm60. The result is about 2 minute (~11%) overall speedup on a parallel build and reduces the libcudf.so by about 25MB (17%). Authors: - David Wendt (https://github.com/davidwendt) Approvers: - AJ Schmidt (https://github.com/ajschmidt8) - Robert Maynard (https://github.com/robertmaynard) - Elias Stehle (https://github.com/elstehle) - https://github.com/nvdbaranec URL: https://github.com/rapidsai/cudf/pull/8183 --- conda/recipes/libcudf/meta.yaml | 1 + cpp/CMakeLists.txt | 4 +- cpp/cmake/thrust.patch | 39 +++ cpp/include/cudf/detail/scan.hpp | 79 ++++++ cpp/src/reductions/scan.cu | 286 ---------------------- cpp/src/reductions/scan/scan.cpp | 39 +++ cpp/src/reductions/scan/scan.cuh | 61 +++++ cpp/src/reductions/scan/scan_exclusive.cu | 103 ++++++++ cpp/src/reductions/scan/scan_inclusive.cu | 168 +++++++++++++ cpp/tests/reductions/scan_tests.cpp | 4 +- 10 files changed, 495 insertions(+), 289 deletions(-) create mode 100644 cpp/include/cudf/detail/scan.hpp delete mode 100644 cpp/src/reductions/scan.cu create mode 100644 cpp/src/reductions/scan/scan.cpp create mode 100644 cpp/src/reductions/scan/scan.cuh create mode 100644 cpp/src/reductions/scan/scan_exclusive.cu create mode 100644 cpp/src/reductions/scan/scan_inclusive.cu diff --git a/conda/recipes/libcudf/meta.yaml b/conda/recipes/libcudf/meta.yaml index ffd30758a50..0b3fb5aa549 100644 --- a/conda/recipes/libcudf/meta.yaml +++ b/conda/recipes/libcudf/meta.yaml @@ -87,6 +87,7 @@ test: - test -f $PREFIX/include/cudf/detail/reshape.hpp - test -f $PREFIX/include/cudf/detail/rolling.hpp - test -f $PREFIX/include/cudf/detail/round.hpp + - test -f $PREFIX/include/cudf/detail/scan.hpp - test -f $PREFIX/include/cudf/detail/scatter.hpp - test -f $PREFIX/include/cudf/detail/search.hpp - test -f $PREFIX/include/cudf/detail/sequence.hpp diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 52edcca82c6..fc209406e06 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -292,7 +292,9 @@ add_library(cudf src/reductions/nth_element.cu src/reductions/product.cu src/reductions/reductions.cpp - src/reductions/scan.cu + src/reductions/scan/scan.cpp + src/reductions/scan/scan_exclusive.cu + src/reductions/scan/scan_inclusive.cu src/reductions/std.cu src/reductions/sum.cu src/reductions/sum_of_squares.cu diff --git a/cpp/cmake/thrust.patch b/cpp/cmake/thrust.patch index 3cedff8b80d..c14b8cdafe5 100644 --- a/cpp/cmake/thrust.patch +++ b/cpp/cmake/thrust.patch @@ -42,6 +42,45 @@ index 1ffeef0..5e80800 100644 for (int ITEM = 1; ITEM < ITEMS_PER_THREAD; ++ITEM) { if (ITEMS_PER_THREAD * tid + ITEM < num_remaining) +diff a/cub/device/dispatch/dispatch_radix_sort.cuh b/cub/device/dispatch/dispatch_radix_sort.cuh +index 41eb1d2..f2893b4 100644 +--- a/cub/device/dispatch/dispatch_radix_sort.cuh ++++ b/cub/device/dispatch/dispatch_radix_sort.cuh +@@ -723,7 +723,7 @@ struct DeviceRadixSortPolicy + + + /// SM60 (GP100) +- struct Policy600 : ChainedPolicy<600, Policy600, Policy500> ++ struct Policy600 : ChainedPolicy<600, Policy600, Policy600> + { + enum { + PRIMARY_RADIX_BITS = (sizeof(KeyT) > 1) ? 7 : 5, // 6.9B 32b keys/s (Quadro P100) +diff a/cub/device/dispatch/dispatch_reduce.cuh b/cub/device/dispatch/dispatch_reduce.cuh +index f6aee45..dd64301 100644 +--- a/cub/device/dispatch/dispatch_reduce.cuh ++++ b/cub/device/dispatch/dispatch_reduce.cuh +@@ -284,7 +284,7 @@ struct DeviceReducePolicy + }; + + /// SM60 +- struct Policy600 : ChainedPolicy<600, Policy600, Policy350> ++ struct Policy600 : ChainedPolicy<600, Policy600, Policy600> + { + // ReducePolicy (P100: 591 GB/s @ 64M 4B items; 583 GB/s @ 256M 1B items) + typedef AgentReducePolicy< +diff a/cub/device/dispatch/dispatch_scan.cuh b/cub/device/dispatch/dispatch_scan.cuh +index c0c6d59..937ee31 100644 +--- a/cub/device/dispatch/dispatch_scan.cuh ++++ b/cub/device/dispatch/dispatch_scan.cuh +@@ -178,7 +178,7 @@ struct DeviceScanPolicy + }; + + /// SM600 +- struct Policy600 : ChainedPolicy<600, Policy600, Policy520> ++ struct Policy600 : ChainedPolicy<600, Policy600, Policy600> + { + typedef AgentScanPolicy< + 128, 15, ///< Threads per block, items per thread diff --git a/thrust/system/cuda/detail/scan_by_key.h b/thrust/system/cuda/detail/scan_by_key.h index fe4b321c..b3974c69 100644 --- a/thrust/system/cuda/detail/scan_by_key.h diff --git a/cpp/include/cudf/detail/scan.hpp b/cpp/include/cudf/detail/scan.hpp new file mode 100644 index 00000000000..5691adecb5e --- /dev/null +++ b/cpp/include/cudf/detail/scan.hpp @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +#include + +namespace cudf { +namespace detail { + +/** + * @brief Computes the exclusive scan of a column. + * + * The null values are skipped for the operation, and if an input element + * at `i` is null, then the output element at `i` will also be null. + * + * The identity value for the column type as per the aggregation type + * is used for the value of the first element in the output column. + * + * @throws cudf::logic_error if column data_type is not an arithmetic type. + * + * @param input The input column view for the scan + * @param agg unique_ptr to aggregation operator applied by the scan + * @param null_handling Exclude null values when computing the result if + * null_policy::EXCLUDE. Include nulls if null_policy::INCLUDE. + * Any operation with a null results in a null. + * @param stream CUDA stream used for device memory operations and kernel launches. + * @param mr Device memory resource used to allocate the returned scalar's device memory + * @returns Column with scan results + */ +std::unique_ptr scan_exclusive(column_view const& input, + std::unique_ptr const& agg, + null_policy null_handling, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr); + +/** + * @brief Computes the inclusive scan of a column. + * + * The null values are skipped for the operation, and if an input element + * at `i` is null, then the output element at `i` will also be null. + * + * String columns are allowed with aggregation types Min and Max. + * + * @throws cudf::logic_error if column data_type is not an arithmetic type + * or string type but the `agg` is not Min or Max + * + * @param input The input column view for the scan + * @param agg unique_ptr to aggregation operator applied by the scan + * @param null_handling Exclude null values when computing the result if + * null_policy::EXCLUDE. Include nulls if null_policy::INCLUDE. + * Any operation with a null results in a null. + * @param stream CUDA stream used for device memory operations and kernel launches. + * @param mr Device memory resource used to allocate the returned scalar's device memory + * @returns Column with scan results + */ +std::unique_ptr scan_inclusive(column_view const& input, + std::unique_ptr const& agg, + null_policy null_handling, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr); + +} // namespace detail +} // namespace cudf diff --git a/cpp/src/reductions/scan.cu b/cpp/src/reductions/scan.cu deleted file mode 100644 index c3aadf47794..00000000000 --- a/cpp/src/reductions/scan.cu +++ /dev/null @@ -1,286 +0,0 @@ -/* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -namespace cudf { -namespace detail { - -/** - * @brief Dispatcher for running Scan operation on input column - * Dispatches scan operation on `Op` and creates output column - * - * @tparam Op device binary operator - */ -template -struct scan_dispatcher { - private: - template - static constexpr bool is_string_supported() - { - return std::is_same::value && - (std::is_same::value || std::is_same::value); - } - // return true if T is arithmetic type (including bool) - template - static constexpr bool is_supported() - { - return std::is_arithmetic::value || is_string_supported() || is_fixed_point(); - } - - // for arithmetic types - template ::value, T>* = nullptr> - auto exclusive_scan(const column_view& input_view, - null_policy null_handling, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) - { - const size_type size = input_view.size(); - auto output_column = - detail::allocate_like(input_view, size, mask_allocation_policy::NEVER, stream, mr); - if (null_handling == null_policy::EXCLUDE) { - output_column->set_null_mask(detail::copy_bitmask(input_view, stream, mr), - input_view.null_count()); - } - mutable_column_view output = output_column->mutable_view(); - auto d_input = column_device_view::create(input_view, stream); - - auto input = - make_null_replacement_iterator(*d_input, Op::template identity(), input_view.has_nulls()); - thrust::exclusive_scan(rmm::exec_policy(stream), - input, - input + size, - output.data(), - Op::template identity(), - Op{}); - - CHECK_CUDA(stream.value()); - return output_column; - } - - // for string type - template (), T>* = nullptr> - std::unique_ptr exclusive_scan(const column_view& input_view, - null_policy null_handling, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) - { - CUDF_FAIL("String types supports only inclusive min/max for `cudf::scan`"); - } - - rmm::device_buffer mask_inclusive_scan(const column_view& input_view, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) - { - rmm::device_buffer mask = - detail::create_null_mask(input_view.size(), mask_state::UNINITIALIZED, stream, mr); - auto d_input = column_device_view::create(input_view, stream); - auto v = detail::make_validity_iterator(*d_input); - auto first_null_position = - thrust::find_if_not( - rmm::exec_policy(stream), v, v + input_view.size(), thrust::identity{}) - - v; - cudf::set_null_mask( - static_cast(mask.data()), 0, first_null_position, true); - cudf::set_null_mask( - static_cast(mask.data()), first_null_position, input_view.size(), false); - return mask; - } - - // for arithmetic types - template ::value, T>* = nullptr> - auto inclusive_scan(const column_view& input_view, - null_policy null_handling, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) - { - const size_type size = input_view.size(); - auto output_column = - detail::allocate_like(input_view, size, mask_allocation_policy::NEVER, stream, mr); - if (null_handling == null_policy::EXCLUDE) { - output_column->set_null_mask(detail::copy_bitmask(input_view, stream, mr), - input_view.null_count()); - } else { - if (input_view.nullable()) { - output_column->set_null_mask(mask_inclusive_scan(input_view, stream, mr), - cudf::UNKNOWN_NULL_COUNT); - } - } - - auto d_input = column_device_view::create(input_view, stream); - mutable_column_view output = output_column->mutable_view(); - - auto const input = - make_null_replacement_iterator(*d_input, Op::template identity(), input_view.has_nulls()); - thrust::inclusive_scan(rmm::exec_policy(stream), input, input + size, output.data(), Op{}); - - CHECK_CUDA(stream.value()); - return output_column; - } - - // for string type - template (), T>* = nullptr> - std::unique_ptr inclusive_scan(const column_view& input_view, - null_policy null_handling, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) - { - const size_type size = input_view.size(); - rmm::device_uvector result(size, stream); - - auto d_input = column_device_view::create(input_view, stream); - - auto input = - make_null_replacement_iterator(*d_input, Op::template identity(), input_view.has_nulls()); - thrust::inclusive_scan(rmm::exec_policy(stream), input, input + size, result.data(), Op{}); - - CHECK_CUDA(stream.value()); - - auto output_column = - cudf::make_strings_column(result, Op::template identity(), stream, mr); - if (null_handling == null_policy::EXCLUDE) { - output_column->set_null_mask(detail::copy_bitmask(input_view, stream, mr), - input_view.null_count()); - } else { - if (input_view.nullable()) { - output_column->set_null_mask(mask_inclusive_scan(input_view, stream, mr), - cudf::UNKNOWN_NULL_COUNT); - } - } - return output_column; - } - - public: - /** - * @brief creates new column from input column by applying scan operation - * - * @param input input column view - * @param inclusive inclusive or exclusive scan - * @param stream CUDA stream used for device memory operations and kernel launches. - * @param mr Device memory resource used to allocate the returned column's device memory - * @return - * - * @tparam T type of input column - */ - template (), T>* = nullptr> - std::unique_ptr operator()(const column_view& input, - scan_type inclusive, - null_policy null_handling, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) - { - auto output = inclusive == scan_type::INCLUSIVE - ? inclusive_scan(input, null_handling, stream, mr) - : exclusive_scan(input, null_handling, stream, mr); - - if (null_handling == null_policy::EXCLUDE) { - CUDF_EXPECTS(input.null_count() == output->null_count(), - "Input / output column null count mismatch"); - } - - return output; - } - - template (), T>* = nullptr> - std::unique_ptr operator()(const column_view& input, - scan_type inclusive, - null_policy null_handling, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) - { - CUDF_FAIL("Non-arithmetic types not supported for `cudf::scan`"); - } -}; - -std::unique_ptr scan( - const column_view& input, - std::unique_ptr const& agg, - scan_type inclusive, - null_policy null_handling, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) -{ - CUDF_EXPECTS( - is_numeric(input.type()) || is_compound(input.type()) || is_fixed_point(input.type()), - "Unexpected non-numeric or non-string type."); - - switch (agg->kind) { - case aggregation::SUM: - return cudf::type_dispatcher(input.type(), - scan_dispatcher(), - input, - inclusive, - null_handling, - stream, - mr); - case aggregation::MIN: - return cudf::type_dispatcher(input.type(), - scan_dispatcher(), - input, - inclusive, - null_handling, - stream, - mr); - case aggregation::MAX: - return cudf::type_dispatcher(input.type(), - scan_dispatcher(), - input, - inclusive, - null_handling, - stream, - mr); - case aggregation::PRODUCT: - // a product scan on a decimal type with non-zero scale would result in each element having - // a different scale, and because scale is stored once per column, this is not possible - if (is_fixed_point(input.type())) CUDF_FAIL("decimal32/64 cannot support product scan"); - return cudf::type_dispatcher(input.type(), - scan_dispatcher(), - input, - inclusive, - null_handling, - stream, - mr); - default: CUDF_FAIL("Unsupported aggregation operator for scan"); - } -} -} // namespace detail - -std::unique_ptr scan(const column_view& input, - std::unique_ptr const& agg, - scan_type inclusive, - null_policy null_handling, - rmm::mr::device_memory_resource* mr) -{ - CUDF_FUNC_RANGE(); - return detail::scan(input, agg, inclusive, null_handling, rmm::cuda_stream_default, mr); -} - -} // namespace cudf diff --git a/cpp/src/reductions/scan/scan.cpp b/cpp/src/reductions/scan/scan.cpp new file mode 100644 index 00000000000..f40a3fd5c75 --- /dev/null +++ b/cpp/src/reductions/scan/scan.cpp @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include + +namespace cudf { + +std::unique_ptr scan(column_view const& input, + std::unique_ptr const& agg, + scan_type inclusive, + null_policy null_handling, + rmm::mr::device_memory_resource* mr) +{ + CUDF_FUNC_RANGE(); + return inclusive == scan_type::EXCLUSIVE + ? detail::scan_exclusive(input, agg, null_handling, rmm::cuda_stream_default, mr) + : detail::scan_inclusive(input, agg, null_handling, rmm::cuda_stream_default, mr); +} + +} // namespace cudf diff --git a/cpp/src/reductions/scan/scan.cuh b/cpp/src/reductions/scan/scan.cuh new file mode 100644 index 00000000000..39fed60735f --- /dev/null +++ b/cpp/src/reductions/scan/scan.cuh @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace cudf { +namespace detail { + +template