Skip to content

Commit

Permalink
JNI Aggregation Type Changes (#8919)
Browse files Browse the repository at this point in the history
CUDF is in the process of tagging aggregations with different classes to make it a compile error to use the wrong aggregation with the wrong API.  The first set of changes were for rolling windows and when I did the corresponding java changes I ended up using generics and interfaces to try and replicate the same thing.  This didn't work out as well as I had hoped and the code to use them ended up being more cumbersome thqn I wanted.

This patch adjusts it so we have truly separate classes for each type of aggregation. It adds more code here, but it makes the code that uses these cleaner.

This is a breaking change and I will be putting up a corresponding change in the Spark plugin to deal with this.

Authors:
  - Robert (Bobby) Evans (https://github.com/revans2)

Approvers:
  - Alessandro Bellina (https://github.com/abellina)
  - Jason Lowe (https://github.com/jlowe)

URL: #8919
  • Loading branch information
revans2 authored Aug 4, 2021
1 parent fdf47af commit b9820f1
Show file tree
Hide file tree
Showing 15 changed files with 1,478 additions and 523 deletions.
172 changes: 66 additions & 106 deletions java/src/main/java/ai/rapids/cudf/Aggregation.java

Large diffs are not rendered by default.

43 changes: 19 additions & 24 deletions java/src/main/java/ai/rapids/cudf/AggregationOverWindow.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
* An Aggregation instance that also holds a column number and window metadata so the aggregation
* can be done over a specific window.
*/
public class AggregationOverWindow<T extends Aggregation & RollingAggregation<T>>
extends AggregationOnColumn<T> {
public final class AggregationOverWindow {
private final RollingAggregationOnColumn wrapped;
protected final WindowOptions windowOptions;

AggregationOverWindow(T wrapped, int columnIndex, WindowOptions windowOptions) {
super(wrapped, columnIndex);
AggregationOverWindow(RollingAggregationOnColumn wrapped, WindowOptions windowOptions) {
this.wrapped = wrapped;
this.windowOptions = windowOptions;

if (windowOptions == null) {
Expand All @@ -43,23 +43,6 @@ public WindowOptions getWindowOptions() {
return windowOptions;
}

@Override
public AggregationOnColumn<T> onColumn(int columnIndex) {
if (columnIndex == getColumnIndex()) {
return this; // NOOP
} else {
return new AggregationOverWindow(this.wrapped, columnIndex, windowOptions);
}
}

@Override
public AggregationOverWindow<T> overWindow(WindowOptions windowOptions) {
if (this.windowOptions.equals(windowOptions)) {
return this;
}
return new AggregationOverWindow(wrapped, columnIndex, windowOptions);
}

@Override
public int hashCode() {
return 31 * super.hashCode() + windowOptions.hashCode();
Expand All @@ -69,10 +52,22 @@ public int hashCode() {
public boolean equals(Object other) {
if (other == this) {
return true;
} else if (other instanceof AggregationOnColumn) {
AggregationOnColumn o = (AggregationOnColumn) other;
return wrapped.equals(o.wrapped) && columnIndex == o.columnIndex;
} else if (other instanceof AggregationOverWindow) {
AggregationOverWindow o = (AggregationOverWindow) other;
return wrapped.equals(o.wrapped) && windowOptions.equals(o.windowOptions);
}
return false;
}

int getColumnIndex() {
return wrapped.getColumnIndex();
}

long createNativeInstance() {
return wrapped.createNativeInstance();
}

long getDefaultOutput() {
return wrapped.getDefaultOutput();
}
}
41 changes: 20 additions & 21 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -1135,15 +1135,15 @@ public Scalar sum() {
* of the specified type.
*/
public Scalar sum(DType outType) {
return reduce(Aggregation.sum(), outType);
return reduce(ReductionAggregation.sum(), outType);
}

/**
* Returns the minimum of all values in the column, returning a scalar
* of the same type as this column.
*/
public Scalar min() {
return reduce(Aggregation.min(), type);
return reduce(ReductionAggregation.min(), type);
}

/**
Expand All @@ -1160,15 +1160,15 @@ public Scalar min(DType outType) {
return tmp.min(outType);
}
}
return reduce(Aggregation.min(), outType);
return reduce(ReductionAggregation.min(), outType);
}

/**
* Returns the maximum of all values in the column, returning a scalar
* of the same type as this column.
*/
public Scalar max() {
return reduce(Aggregation.max(), type);
return reduce(ReductionAggregation.max(), type);
}

/**
Expand All @@ -1185,7 +1185,7 @@ public Scalar max(DType outType) {
return tmp.max(outType);
}
}
return reduce(Aggregation.max(), outType);
return reduce(ReductionAggregation.max(), outType);
}

/**
Expand All @@ -1201,7 +1201,7 @@ public Scalar product() {
* of the specified type.
*/
public Scalar product(DType outType) {
return reduce(Aggregation.product(), outType);
return reduce(ReductionAggregation.product(), outType);
}

/**
Expand All @@ -1217,7 +1217,7 @@ public Scalar sumOfSquares() {
* scalar of the specified type.
*/
public Scalar sumOfSquares(DType outType) {
return reduce(Aggregation.sumOfSquares(), outType);
return reduce(ReductionAggregation.sumOfSquares(), outType);
}

/**
Expand All @@ -1241,7 +1241,7 @@ public Scalar mean() {
* types are currently supported.
*/
public Scalar mean(DType outType) {
return reduce(Aggregation.mean(), outType);
return reduce(ReductionAggregation.mean(), outType);
}

/**
Expand All @@ -1265,7 +1265,7 @@ public Scalar variance() {
* types are currently supported.
*/
public Scalar variance(DType outType) {
return reduce(Aggregation.variance(), outType);
return reduce(ReductionAggregation.variance(), outType);
}

/**
Expand All @@ -1290,7 +1290,7 @@ public Scalar standardDeviation() {
* types are currently supported.
*/
public Scalar standardDeviation(DType outType) {
return reduce(Aggregation.standardDeviation(), outType);
return reduce(ReductionAggregation.standardDeviation(), outType);
}

/**
Expand All @@ -1309,7 +1309,7 @@ public Scalar any() {
* Null values are skipped.
*/
public Scalar any(DType outType) {
return reduce(Aggregation.any(), outType);
return reduce(ReductionAggregation.any(), outType);
}

/**
Expand All @@ -1330,7 +1330,7 @@ public Scalar all() {
*/
@Deprecated
public Scalar all(DType outType) {
return reduce(Aggregation.all(), outType);
return reduce(ReductionAggregation.all(), outType);
}

/**
Expand All @@ -1343,7 +1343,7 @@ public Scalar all(DType outType) {
* empty or the reduction operation fails then the
* {@link Scalar#isValid()} method of the result will return false.
*/
public Scalar reduce(Aggregation aggregation) {
public Scalar reduce(ReductionAggregation aggregation) {
return reduce(aggregation, type);
}

Expand All @@ -1360,7 +1360,7 @@ public Scalar reduce(Aggregation aggregation) {
* empty or the reduction operation fails then the
* {@link Scalar#isValid()} method of the result will return false.
*/
public Scalar reduce(Aggregation aggregation, DType outType) {
public Scalar reduce(ReductionAggregation aggregation, DType outType) {
long nativeId = aggregation.createNativeInstance();
try {
return new Scalar(outType, reduce(getNativeView(), nativeId, outType.typeId.getNativeId(), outType.getScale()));
Expand Down Expand Up @@ -1390,20 +1390,19 @@ public final ColumnVector quantile(QuantileMethod method, double[] quantiles) {
* @throws IllegalArgumentException if unsupported window specification * (i.e. other than {@link WindowOptions.FrameType#ROWS} is used.
*/
public final ColumnVector rollingWindow(RollingAggregation op, WindowOptions options) {
Aggregation agg = op.getBaseAggregation();
// Check that only row-based windows are used.
if (!options.getFrameType().equals(WindowOptions.FrameType.ROWS)) {
throw new IllegalArgumentException("Expected ROWS-based window specification. Unexpected window type: "
+ options.getFrameType());
}

long nativePtr = agg.createNativeInstance();
long nativePtr = op.createNativeInstance();
try {
Scalar p = options.getPrecedingScalar();
Scalar f = options.getFollowingScalar();
return new ColumnVector(
rollingWindow(this.getNativeView(),
agg.getDefaultOutput(),
op.getDefaultOutput(),
options.getMinPeriods(),
nativePtr,
p == null || !p.isValid() ? 0 : p.getInt(),
Expand All @@ -1420,7 +1419,7 @@ public final ColumnVector rollingWindow(RollingAggregation op, WindowOptions opt
* This is just a convenience method for an inclusive scan with a SUM aggregation.
*/
public final ColumnVector prefixSum() {
return scan(Aggregation.sum());
return scan(ScanAggregation.sum());
}

/**
Expand All @@ -1431,7 +1430,7 @@ public final ColumnVector prefixSum() {
* null policy too. Currently none of those aggregations are supported so
* it is undefined how they would interact with each other.
*/
public final ColumnVector scan(Aggregation aggregation, ScanType scanType, NullPolicy nullPolicy) {
public final ColumnVector scan(ScanAggregation aggregation, ScanType scanType, NullPolicy nullPolicy) {
long nativeId = aggregation.createNativeInstance();
try {
return new ColumnVector(scan(getNativeView(), nativeId,
Expand All @@ -1446,15 +1445,15 @@ public final ColumnVector scan(Aggregation aggregation, ScanType scanType, NullP
* @param aggregation the aggregation to perform
* @param scanType should the scan be inclusive, include the current row, or exclusive.
*/
public final ColumnVector scan(Aggregation aggregation, ScanType scanType) {
public final ColumnVector scan(ScanAggregation aggregation, ScanType scanType) {
return scan(aggregation, scanType, NullPolicy.EXCLUDE);
}

/**
* Computes an inclusive scan for a column that excludes nulls.
* @param aggregation the aggregation to perform
*/
public final ColumnVector scan(Aggregation aggregation) {
public final ColumnVector scan(ScanAggregation aggregation) {
return scan(aggregation, ScanType.INCLUSIVE, NullPolicy.EXCLUDE);
}

Expand Down
Loading

0 comments on commit b9820f1

Please sign in to comment.