Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JNI changes for range-extents in window functions. #13199

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ private static native long[] rollingWindowAggregate(

private static native long[] rangeRollingWindowAggregate(long inputTable, int[] keyIndices, int[] orderByIndices, boolean[] isOrderByAscending,
int[] aggColumnsIndices, long[] aggInstances, int[] minPeriods,
long[] preceding, long[] following, boolean[] unboundedPreceding, boolean[] unboundedFollowing,
long[] preceding, long[] following, int[] precedingRangeExtent, int[] followingRangeExtent,
boolean ignoreNullKeys) throws CudfException;

private static native long sortOrder(long inputTable, long[] sortKeys, boolean[] isDescending,
Expand Down Expand Up @@ -3981,10 +3981,11 @@ public Table aggregateWindowsOverRanges(AggregationOverWindow... windowAggregate
case DECIMAL32:
case DECIMAL64:
case DECIMAL128:
case STRING:
break;
default:
throw new IllegalArgumentException("Expected range-based window orderBy's " +
"type: integral (Boolean-exclusive), decimal, and timestamp");
"type: integral (Boolean-exclusive), decimal, timestamp, and string");
}

ColumnWindowOps ops = groupedOps.computeIfAbsent(agg.getColumnIndex(), (idx) -> new ColumnWindowOps());
Expand All @@ -3998,27 +3999,28 @@ public Table aggregateWindowsOverRanges(AggregationOverWindow... windowAggregate
long[] aggPrecedingWindows = new long[totalOps];
long[] aggFollowingWindows = new long[totalOps];
try {
boolean[] aggPrecedingWindowsUnbounded = new boolean[totalOps];
boolean[] aggFollowingWindowsUnbounded = new boolean[totalOps];
int[] aggPrecedingWindowsExtent = new int[totalOps];
int[] aggFollowingWindowsExtent = new int[totalOps];
int[] aggMinPeriods = new int[totalOps];
int opIndex = 0;
for (Map.Entry<Integer, ColumnWindowOps> entry: groupedOps.entrySet()) {
int columnIndex = entry.getKey();
for (AggregationOverWindow op: entry.getValue().operations()) {
aggColumnIndexes[opIndex] = columnIndex;
aggInstances[opIndex] = op.createNativeInstance();
Scalar p = op.getWindowOptions().getPrecedingScalar();
Scalar f = op.getWindowOptions().getFollowingScalar();
if ((p == null || !p.isValid()) && !op.getWindowOptions().isUnboundedPreceding()) {
WindowOptions windowOptions = op.getWindowOptions();
Scalar p = windowOptions.getPrecedingScalar();
Scalar f = windowOptions.getFollowingScalar();
if ((p == null || !p.isValid()) && !(windowOptions.isUnboundedPreceding() || windowOptions.isCurrentRowPreceding())) {
throw new IllegalArgumentException("Some kind of preceding must be set and a preceding column is not currently supported");
}
if ((f == null || !f.isValid()) && !op.getWindowOptions().isUnboundedFollowing()) {
if ((f == null || !f.isValid()) && !(windowOptions.isUnboundedFollowing() || windowOptions.isCurrentRowFollowing())) {
throw new IllegalArgumentException("some kind of following must be set and a follow column is not currently supported");
}
aggPrecedingWindows[opIndex] = p == null ? 0 : p.getScalarHandle();
aggFollowingWindows[opIndex] = f == null ? 0 : f.getScalarHandle();
aggPrecedingWindowsUnbounded[opIndex] = op.getWindowOptions().isUnboundedPreceding();
aggFollowingWindowsUnbounded[opIndex] = op.getWindowOptions().isUnboundedFollowing();
aggPrecedingWindowsExtent[opIndex] = windowOptions.getPrecedingBoundsExtent().nominalValue;
aggFollowingWindowsExtent[opIndex] = windowOptions.getFollowingBoundsExtent().nominalValue;
aggMinPeriods[opIndex] = op.getWindowOptions().getMinPeriods();
assert (op.getWindowOptions().getFrameType() == WindowOptions.FrameType.RANGE);
orderByColumnIndexes[opIndex] = op.getWindowOptions().getOrderByColumnIndex();
Expand All @@ -4040,7 +4042,7 @@ public Table aggregateWindowsOverRanges(AggregationOverWindow... windowAggregate
isOrderByOrderAscending,
aggColumnIndexes,
aggInstances, aggMinPeriods, aggPrecedingWindows, aggFollowingWindows,
aggPrecedingWindowsUnbounded, aggFollowingWindowsUnbounded,
aggPrecedingWindowsExtent, aggFollowingWindowsExtent,
groupByOptions.getIgnoreNullKeys()))) {
// prepare the final table
ColumnVector[] finalCols = new ColumnVector[windowAggregates.length];
Expand Down
81 changes: 60 additions & 21 deletions java/src/main/java/ai/rapids/cudf/WindowOptions.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -25,6 +25,23 @@ public class WindowOptions implements AutoCloseable {

enum FrameType {ROWS, RANGE}

/**
* Extent of (range) window bounds.
* Analogous to cudf::range_window_bounds::extent_type.
*/
enum RangeExtentType {
CURRENT_ROW(0), // Bounds defined as the first/last row that matches the current row.
BOUNDED(1), // Bounds defined as the first/last row that falls within
// a specified range from the current row.
UNBOUNDED(2); // Bounds stretching to the first/last row in the entire group.

public final int nominalValue;

RangeExtentType(int n) {
this.nominalValue = n;
}
}

private final int minPeriods;
private final Scalar precedingScalar;
private final Scalar followingScalar;
Expand All @@ -33,8 +50,8 @@ enum FrameType {ROWS, RANGE}
private final int orderByColumnIndex;
private final boolean orderByOrderAscending;
private final FrameType frameType;
private final boolean isUnboundedPreceding;
private final boolean isUnboundedFollowing;
private final RangeExtentType precedingBoundsExtent;
private final RangeExtentType followingBoundsExtent;

private WindowOptions(Builder builder) {
this.minPeriods = builder.minPeriods;
Expand All @@ -57,9 +74,8 @@ private WindowOptions(Builder builder) {
this.orderByColumnIndex = builder.orderByColumnIndex;
this.orderByOrderAscending = builder.orderByOrderAscending;
this.frameType = orderByColumnIndex == -1? FrameType.ROWS : FrameType.RANGE;
this.isUnboundedPreceding = builder.isUnboundedPreceding;
this.isUnboundedFollowing = builder.isUnboundedFollowing;

this.precedingBoundsExtent = builder.precedingBoundsExtent;
this.followingBoundsExtent = builder.followingBoundsExtent;
}

@Override
Expand All @@ -72,8 +88,8 @@ public boolean equals(Object other) {
this.orderByColumnIndex == o.orderByColumnIndex &&
this.orderByOrderAscending == o.orderByOrderAscending &&
this.frameType == o.frameType &&
this.isUnboundedPreceding == o.isUnboundedPreceding &&
this.isUnboundedFollowing == o.isUnboundedFollowing;
this.precedingBoundsExtent == o.precedingBoundsExtent &&
this.followingBoundsExtent == o.followingBoundsExtent;
if (precedingCol != null) {
ret = ret && precedingCol.equals(o.precedingCol);
}
Expand Down Expand Up @@ -110,8 +126,8 @@ public int hashCode() {
if (followingScalar != null) {
ret = 31 * ret + followingScalar.hashCode();
}
ret = 31 * ret + Boolean.hashCode(isUnboundedPreceding);
ret = 31 * ret + Boolean.hashCode(isUnboundedFollowing);
ret = 31 * ret + precedingBoundsExtent.hashCode();
ret = 31 * ret + followingBoundsExtent.hashCode();
return ret;
}

Expand Down Expand Up @@ -139,9 +155,16 @@ public static Builder builder(){

boolean isOrderByOrderAscending() { return this.orderByOrderAscending; }

boolean isUnboundedPreceding() { return this.isUnboundedPreceding; }
boolean isUnboundedPreceding() { return this.precedingBoundsExtent == RangeExtentType.UNBOUNDED; }

boolean isUnboundedFollowing() { return this.isUnboundedFollowing; }
boolean isUnboundedFollowing() { return this.followingBoundsExtent == RangeExtentType.UNBOUNDED; }

boolean isCurrentRowPreceding() { return this.precedingBoundsExtent == RangeExtentType.CURRENT_ROW; }

boolean isCurrentRowFollowing() { return this.followingBoundsExtent == RangeExtentType.CURRENT_ROW; }

RangeExtentType getPrecedingBoundsExtent() { return this.precedingBoundsExtent; }
RangeExtentType getFollowingBoundsExtent() { return this.followingBoundsExtent; }

FrameType getFrameType() { return frameType; }

Expand All @@ -154,8 +177,8 @@ public static class Builder {
private ColumnVector followingCol = null;
private int orderByColumnIndex = -1;
private boolean orderByOrderAscending = true;
private boolean isUnboundedPreceding = false;
private boolean isUnboundedFollowing = false;
private RangeExtentType precedingBoundsExtent = RangeExtentType.BOUNDED;
private RangeExtentType followingBoundsExtent = RangeExtentType.BOUNDED;

/**
* Set the minimum number of observation required to evaluate an element. If there are not
Expand All @@ -171,7 +194,7 @@ public Builder minPeriods(int minPeriods) {

/**
* Set the size of the window, one entry per row. This does not take ownership of the
* columns passed in so you have to be sure that the life time of the column outlives
* columns passed in so you have to be sure that the lifetime of the column outlives
* this operation.
* @param precedingCol the number of rows preceding the current row and
* precedingCol will be live outside of WindowOptions.
Expand All @@ -185,10 +208,10 @@ public Builder window(ColumnVector precedingCol, ColumnVector followingCol) {
if (followingCol == null || followingCol.hasNulls()) {
throw new IllegalArgumentException("following cannot be null or have nulls");
}
if (isUnboundedPreceding || precedingScalar != null) {
if (precedingBoundsExtent != RangeExtentType.BOUNDED || precedingScalar != null) {
throw new IllegalStateException("preceding has already been set a different way");
}
if (isUnboundedFollowing || followingScalar != null) {
if (followingBoundsExtent != RangeExtentType.BOUNDED || followingScalar != null) {
throw new IllegalStateException("following has already been set a different way");
}
this.precedingCol = precedingCol;
Expand Down Expand Up @@ -246,19 +269,35 @@ public Builder timestampDescending() {
return orderByDescending();
}

public Builder currentRowPreceding() {
if (precedingCol != null || precedingScalar != null) {
throw new IllegalStateException("preceding has already been set a different way");
}
this.precedingBoundsExtent = RangeExtentType.CURRENT_ROW;
return this;
}

public Builder currentRowFollowing() {
if (followingCol != null || followingScalar != null) {
throw new IllegalStateException("following has already been set a different way");
}
this.followingBoundsExtent = RangeExtentType.CURRENT_ROW;
return this;
}

public Builder unboundedPreceding() {
if (precedingCol != null || precedingScalar != null) {
throw new IllegalStateException("preceding has already been set a different way");
}
this.isUnboundedPreceding = true;
this.precedingBoundsExtent = RangeExtentType.UNBOUNDED;
return this;
}

public Builder unboundedFollowing() {
if (followingCol != null || followingScalar != null) {
throw new IllegalStateException("following has already been set a different way");
}
this.isUnboundedFollowing = true;
this.followingBoundsExtent = RangeExtentType.UNBOUNDED;
return this;
}

Expand All @@ -270,7 +309,7 @@ public Builder preceding(Scalar preceding) {
if (preceding == null || !preceding.isValid()) {
throw new IllegalArgumentException("preceding cannot be null");
}
if (isUnboundedPreceding || precedingCol != null) {
if (precedingBoundsExtent != RangeExtentType.BOUNDED || precedingCol != null) {
throw new IllegalStateException("preceding has already been set a different way");
}
this.precedingScalar = preceding;
Expand All @@ -285,7 +324,7 @@ public Builder following(Scalar following) {
if (following == null || !following.isValid()) {
throw new IllegalArgumentException("following cannot be null");
}
if (isUnboundedFollowing || followingCol != null) {
if (followingBoundsExtent != RangeExtentType.BOUNDED || followingCol != null) {
throw new IllegalStateException("following has already been set a different way");
}
this.followingScalar = following;
Expand Down
49 changes: 32 additions & 17 deletions java/src/main/native/src/TableJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3221,8 +3221,8 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_rangeRollingWindowAggrega
JNIEnv *env, jclass, jlong j_input_table, jintArray j_keys, jintArray j_orderby_column_indices,
jbooleanArray j_is_orderby_ascending, jintArray j_aggregate_column_indices,
jlongArray j_agg_instances, jintArray j_min_periods, jlongArray j_preceding,
jlongArray j_following, jbooleanArray j_unbounded_preceding,
jbooleanArray j_unbounded_following, jboolean ignore_null_keys) {
jlongArray j_following, jintArray j_preceding_extent, jintArray j_following_extent,
jboolean ignore_null_keys) {

JNI_NULL_CHECK(env, j_input_table, "input table is null", NULL);
JNI_NULL_CHECK(env, j_keys, "input keys are null", NULL);
Expand All @@ -3246,8 +3246,8 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_rangeRollingWindowAggrega
cudf::jni::native_jintArray values{env, j_aggregate_column_indices};
cudf::jni::native_jpointerArray<cudf::aggregation> agg_instances(env, j_agg_instances);
cudf::jni::native_jintArray min_periods{env, j_min_periods};
cudf::jni::native_jbooleanArray unbounded_preceding{env, j_unbounded_preceding};
cudf::jni::native_jbooleanArray unbounded_following{env, j_unbounded_following};
cudf::jni::native_jintArray preceding_extent{env, j_preceding_extent};
cudf::jni::native_jintArray following_extent{env, j_following_extent};
cudf::jni::native_jpointerArray<cudf::scalar> preceding(env, j_preceding);
cudf::jni::native_jpointerArray<cudf::scalar> following(env, j_following);

Expand All @@ -3266,24 +3266,32 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_rangeRollingWindowAggrega
int agg_column_index = values[i];
cudf::column_view const &order_by_column = input_table->column(orderbys[i]);
cudf::data_type order_by_type = order_by_column.type();
cudf::data_type unbounded_type = order_by_type;

if (unbounded_preceding[i] || unbounded_following[i]) {
cudf::data_type duration_type = order_by_type;

// Range extents are defined as:
// a) 0 == CURRENT ROW
// b) 1 == BOUNDED
// c) 2 == UNBOUNDED
// Must set unbounded_type for only the BOUNDED case.
auto constexpr CURRENT_ROW = 0;
auto constexpr BOUNDED = 1;
auto constexpr UNBOUNDED = 2;
if (preceding_extent[i] != BOUNDED || following_extent[i] != BOUNDED) {
switch (order_by_type.id()) {
case cudf::type_id::TIMESTAMP_DAYS:
unbounded_type = cudf::data_type{cudf::type_id::DURATION_DAYS};
duration_type = cudf::data_type{cudf::type_id::DURATION_DAYS};
break;
case cudf::type_id::TIMESTAMP_SECONDS:
unbounded_type = cudf::data_type{cudf::type_id::DURATION_SECONDS};
duration_type = cudf::data_type{cudf::type_id::DURATION_SECONDS};
break;
case cudf::type_id::TIMESTAMP_MILLISECONDS:
unbounded_type = cudf::data_type{cudf::type_id::DURATION_MILLISECONDS};
duration_type = cudf::data_type{cudf::type_id::DURATION_MILLISECONDS};
break;
case cudf::type_id::TIMESTAMP_MICROSECONDS:
unbounded_type = cudf::data_type{cudf::type_id::DURATION_MICROSECONDS};
duration_type = cudf::data_type{cudf::type_id::DURATION_MICROSECONDS};
break;
case cudf::type_id::TIMESTAMP_NANOSECONDS:
unbounded_type = cudf::data_type{cudf::type_id::DURATION_NANOSECONDS};
duration_type = cudf::data_type{cudf::type_id::DURATION_NANOSECONDS};
break;
default: break;
}
Expand All @@ -3293,15 +3301,22 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_rangeRollingWindowAggrega
JNI_ARG_CHECK(env, agg != nullptr, "aggregation is not an instance of rolling_aggregation",
nullptr);

auto const make_window_bounds = [&](auto const &range_extent, auto const *p_scalar) {
if (range_extent == CURRENT_ROW) {
return cudf::range_window_bounds::current_row(duration_type);
} else if (range_extent == UNBOUNDED) {
return cudf::range_window_bounds::unbounded(duration_type);
} else {
return cudf::range_window_bounds::get(*p_scalar);
}
};

result_columns.emplace_back(cudf::grouped_range_rolling_window(
groupby_keys, order_by_column,
orderbys_ascending[i] ? cudf::order::ASCENDING : cudf::order::DESCENDING,
input_table->column(agg_column_index),
unbounded_preceding[i] ? cudf::range_window_bounds::unbounded(unbounded_type) :
cudf::range_window_bounds::get(*preceding[i]),
unbounded_following[i] ? cudf::range_window_bounds::unbounded(unbounded_type) :
cudf::range_window_bounds::get(*following[i]),
min_periods[i], *agg));
make_window_bounds(preceding_extent[i], preceding[i]),
make_window_bounds(following_extent[i], following[i]), min_periods[i], *agg));
}

auto result_table = std::make_unique<cudf::table>(std::move(result_columns));
Expand Down
Loading