Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JNI Aggregation Type Changes #8919

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 66 additions & 106 deletions java/src/main/java/ai/rapids/cudf/Aggregation.java

Large diffs are not rendered by default.

43 changes: 19 additions & 24 deletions java/src/main/java/ai/rapids/cudf/AggregationOverWindow.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
* An Aggregation instance that also holds a column number and window metadata so the aggregation
* can be done over a specific window.
*/
public class AggregationOverWindow<T extends Aggregation & RollingAggregation<T>>
extends AggregationOnColumn<T> {
public final class AggregationOverWindow {
private final RollingAggregationOnColumn wrapped;
protected final WindowOptions windowOptions;

AggregationOverWindow(T wrapped, int columnIndex, WindowOptions windowOptions) {
super(wrapped, columnIndex);
AggregationOverWindow(RollingAggregationOnColumn wrapped, WindowOptions windowOptions) {
this.wrapped = wrapped;
this.windowOptions = windowOptions;

if (windowOptions == null) {
Expand All @@ -43,23 +43,6 @@ public WindowOptions getWindowOptions() {
return windowOptions;
}

@Override
public AggregationOnColumn<T> onColumn(int columnIndex) {
if (columnIndex == getColumnIndex()) {
return this; // NOOP
} else {
return new AggregationOverWindow(this.wrapped, columnIndex, windowOptions);
}
}

@Override
public AggregationOverWindow<T> overWindow(WindowOptions windowOptions) {
if (this.windowOptions.equals(windowOptions)) {
return this;
}
return new AggregationOverWindow(wrapped, columnIndex, windowOptions);
}

@Override
public int hashCode() {
return 31 * super.hashCode() + windowOptions.hashCode();
Expand All @@ -69,10 +52,22 @@ public int hashCode() {
public boolean equals(Object other) {
if (other == this) {
return true;
} else if (other instanceof AggregationOnColumn) {
AggregationOnColumn o = (AggregationOnColumn) other;
return wrapped.equals(o.wrapped) && columnIndex == o.columnIndex;
} else if (other instanceof AggregationOverWindow) {
AggregationOverWindow o = (AggregationOverWindow) other;
return wrapped.equals(o.wrapped) && windowOptions.equals(o.windowOptions);
}
return false;
}

int getColumnIndex() {
return wrapped.getColumnIndex();
}

long createNativeInstance() {
return wrapped.createNativeInstance();
}

long getDefaultOutput() {
return wrapped.getDefaultOutput();
}
}
41 changes: 20 additions & 21 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -1135,15 +1135,15 @@ public Scalar sum() {
* of the specified type.
*/
public Scalar sum(DType outType) {
return reduce(Aggregation.sum(), outType);
return reduce(ReductionAggregation.sum(), outType);
}

/**
* Returns the minimum of all values in the column, returning a scalar
* of the same type as this column.
*/
public Scalar min() {
return reduce(Aggregation.min(), type);
return reduce(ReductionAggregation.min(), type);
}

/**
Expand All @@ -1160,15 +1160,15 @@ public Scalar min(DType outType) {
return tmp.min(outType);
}
}
return reduce(Aggregation.min(), outType);
return reduce(ReductionAggregation.min(), outType);
}

/**
* Returns the maximum of all values in the column, returning a scalar
* of the same type as this column.
*/
public Scalar max() {
return reduce(Aggregation.max(), type);
return reduce(ReductionAggregation.max(), type);
}

/**
Expand All @@ -1185,7 +1185,7 @@ public Scalar max(DType outType) {
return tmp.max(outType);
}
}
return reduce(Aggregation.max(), outType);
return reduce(ReductionAggregation.max(), outType);
}

/**
Expand All @@ -1201,7 +1201,7 @@ public Scalar product() {
* of the specified type.
*/
public Scalar product(DType outType) {
return reduce(Aggregation.product(), outType);
return reduce(ReductionAggregation.product(), outType);
}

/**
Expand All @@ -1217,7 +1217,7 @@ public Scalar sumOfSquares() {
* scalar of the specified type.
*/
public Scalar sumOfSquares(DType outType) {
return reduce(Aggregation.sumOfSquares(), outType);
return reduce(ReductionAggregation.sumOfSquares(), outType);
}

/**
Expand All @@ -1241,7 +1241,7 @@ public Scalar mean() {
* types are currently supported.
*/
public Scalar mean(DType outType) {
return reduce(Aggregation.mean(), outType);
return reduce(ReductionAggregation.mean(), outType);
}

/**
Expand All @@ -1265,7 +1265,7 @@ public Scalar variance() {
* types are currently supported.
*/
public Scalar variance(DType outType) {
return reduce(Aggregation.variance(), outType);
return reduce(ReductionAggregation.variance(), outType);
}

/**
Expand All @@ -1290,7 +1290,7 @@ public Scalar standardDeviation() {
* types are currently supported.
*/
public Scalar standardDeviation(DType outType) {
return reduce(Aggregation.standardDeviation(), outType);
return reduce(ReductionAggregation.standardDeviation(), outType);
}

/**
Expand All @@ -1309,7 +1309,7 @@ public Scalar any() {
* Null values are skipped.
*/
public Scalar any(DType outType) {
return reduce(Aggregation.any(), outType);
return reduce(ReductionAggregation.any(), outType);
}

/**
Expand All @@ -1330,7 +1330,7 @@ public Scalar all() {
*/
@Deprecated
public Scalar all(DType outType) {
return reduce(Aggregation.all(), outType);
return reduce(ReductionAggregation.all(), outType);
}

/**
Expand All @@ -1343,7 +1343,7 @@ public Scalar all(DType outType) {
* empty or the reduction operation fails then the
* {@link Scalar#isValid()} method of the result will return false.
*/
public Scalar reduce(Aggregation aggregation) {
public Scalar reduce(ReductionAggregation aggregation) {
return reduce(aggregation, type);
}

Expand All @@ -1360,7 +1360,7 @@ public Scalar reduce(Aggregation aggregation) {
* empty or the reduction operation fails then the
* {@link Scalar#isValid()} method of the result will return false.
*/
public Scalar reduce(Aggregation aggregation, DType outType) {
public Scalar reduce(ReductionAggregation aggregation, DType outType) {
long nativeId = aggregation.createNativeInstance();
try {
return new Scalar(outType, reduce(getNativeView(), nativeId, outType.typeId.getNativeId(), outType.getScale()));
Expand Down Expand Up @@ -1390,20 +1390,19 @@ public final ColumnVector quantile(QuantileMethod method, double[] quantiles) {
* @throws IllegalArgumentException if unsupported window specification * (i.e. other than {@link WindowOptions.FrameType#ROWS} is used.
*/
public final ColumnVector rollingWindow(RollingAggregation op, WindowOptions options) {
Aggregation agg = op.getBaseAggregation();
// Check that only row-based windows are used.
if (!options.getFrameType().equals(WindowOptions.FrameType.ROWS)) {
throw new IllegalArgumentException("Expected ROWS-based window specification. Unexpected window type: "
+ options.getFrameType());
}

long nativePtr = agg.createNativeInstance();
long nativePtr = op.createNativeInstance();
try {
Scalar p = options.getPrecedingScalar();
Scalar f = options.getFollowingScalar();
return new ColumnVector(
rollingWindow(this.getNativeView(),
agg.getDefaultOutput(),
op.getDefaultOutput(),
options.getMinPeriods(),
nativePtr,
p == null || !p.isValid() ? 0 : p.getInt(),
Expand All @@ -1420,7 +1419,7 @@ public final ColumnVector rollingWindow(RollingAggregation op, WindowOptions opt
* This is just a convenience method for an inclusive scan with a SUM aggregation.
*/
public final ColumnVector prefixSum() {
return scan(Aggregation.sum());
return scan(ScanAggregation.sum());
}

/**
Expand All @@ -1431,7 +1430,7 @@ public final ColumnVector prefixSum() {
* null policy too. Currently none of those aggregations are supported so
* it is undefined how they would interact with each other.
*/
public final ColumnVector scan(Aggregation aggregation, ScanType scanType, NullPolicy nullPolicy) {
public final ColumnVector scan(ScanAggregation aggregation, ScanType scanType, NullPolicy nullPolicy) {
long nativeId = aggregation.createNativeInstance();
try {
return new ColumnVector(scan(getNativeView(), nativeId,
Expand All @@ -1446,15 +1445,15 @@ public final ColumnVector scan(Aggregation aggregation, ScanType scanType, NullP
* @param aggregation the aggregation to perform
* @param scanType should the scan be inclusive, include the current row, or exclusive.
*/
public final ColumnVector scan(Aggregation aggregation, ScanType scanType) {
public final ColumnVector scan(ScanAggregation aggregation, ScanType scanType) {
return scan(aggregation, scanType, NullPolicy.EXCLUDE);
}

/**
* Computes an inclusive scan for a column that excludes nulls.
* @param aggregation the aggregation to perform
*/
public final ColumnVector scan(Aggregation aggregation) {
public final ColumnVector scan(ScanAggregation aggregation) {
return scan(aggregation, ScanType.INCLUSIVE, NullPolicy.EXCLUDE);
}

Expand Down
Loading