Skip to content

Commit

Permalink
Add in JNI APIs for scan, replace_nulls, group_by.scan, and group_by.…
Browse files Browse the repository at this point in the history
…replace_nulls (#8503)

To be able to do a running window test prototype I added in APIs for `scan`, `group_by.scan`, and `group_by.replace_nulls`. I also added a version of `replace_nulls` that java was missing. It is still not decided exactly how we are going to support running windows, but I thought I should get these in in case we do want to use them.

Authors:
  - Robert (Bobby) Evans (https://github.com/revans2)

Approvers:
  - Jason Lowe (https://github.com/jlowe)

URL: #8503
  • Loading branch information
revans2 authored Jun 24, 2021
1 parent 1a58f45 commit 11e021d
Show file tree
Hide file tree
Showing 15 changed files with 629 additions and 135 deletions.
56 changes: 48 additions & 8 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,10 @@ public final ColumnVector replaceNulls(ColumnView replacements) {
return new ColumnVector(replaceNullsColumn(getNativeView(), replacements.getNativeView()));
}

public final ColumnVector replaceNulls(ReplacePolicy policy) {
return new ColumnVector(replaceNullsPolicy(getNativeView(), policy.isPreceding));
}

/**
* For a BOOL8 vector, computes a vector whose rows are selected from two other vectors
* based on the boolean value of this vector in the corresponding row.
Expand Down Expand Up @@ -1384,17 +1388,50 @@ public final ColumnVector rollingWindow(RollingAggregation op, WindowOptions opt
}

/**
* Compute the cumulative sum/prefix sum of the values in this column.
* This is similar to a rolling window SUM with unbounded preceding and none following.
* Input values 1, 2, 3
* Output values 1, 3, 6
* This currently only works for long values that are not nullable as this is currently a
* very simple implementation. It may be expanded in the future if needed.
* Compute the prefix sum (aka cumulative sum) of the values in this column.
* This is just a convenience method for an inclusive scan with a SUM aggregation.
*/
public final ColumnVector prefixSum() {
return new ColumnVector(prefixSum(getNativeView()));
return scan(Aggregation.sum());
}

/**
* Computes a scan for a column. This is very similar to a running window on the column.
* @param aggregation the aggregation to perform
* @param scanType should the scan be inclusive, include the current row, or exclusive.
* @param nullPolicy how should nulls be treated. Note that some aggregations also include a
* null policy too. Currently none of those aggregations are supported so
* it is undefined how they would interact with each other.
*/
public final ColumnVector scan(Aggregation aggregation, ScanType scanType, NullPolicy nullPolicy) {
long nativeId = aggregation.createNativeInstance();
try {
return new ColumnVector(scan(getNativeView(), nativeId,
scanType.isInclusive, nullPolicy.includeNulls));
} finally {
Aggregation.close(nativeId);
}
}

/**
* Computes a scan for a column that excludes nulls.
* @param aggregation the aggregation to perform
* @param scanType should the scan be inclusive, include the current row, or exclusive.
*/
public final ColumnVector scan(Aggregation aggregation, ScanType scanType) {
return scan(aggregation, scanType, NullPolicy.EXCLUDE);
}

/**
* Computes an inclusive scan for a column that excludes nulls.
* @param aggregation the aggregation to perform
*/
public final ColumnVector scan(Aggregation aggregation) {
return scan(aggregation, ScanType.INCLUSIVE, NullPolicy.EXCLUDE);
}



/////////////////////////////////////////////////////////////////////////////
// LOGICAL
/////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -3217,7 +3254,8 @@ private static native long rollingWindow(
long preceding_col,
long following_col);

private static native long prefixSum(long viewHandle) throws CudfException;
private static native long scan(long viewHandle, long aggregation,
boolean isInclusive, boolean includeNulls) throws CudfException;

private static native long nansToNulls(long viewHandle) throws CudfException;

Expand All @@ -3227,6 +3265,8 @@ private static native long rollingWindow(

private static native long replaceNullsColumn(long viewHandle, long replaceViewHandle) throws CudfException;

private static native long replaceNullsPolicy(long nativeView, boolean isPreceding) throws CudfException;

private static native long ifElseVV(long predVec, long trueVec, long falseVec) throws CudfException;

private static native long ifElseVS(long predVec, long trueVec, long falseScalar) throws CudfException;
Expand Down
12 changes: 10 additions & 2 deletions java/src/main/java/ai/rapids/cudf/NaNEquality.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,19 @@

package ai.rapids.cudf;

/*
* This is analogous to the native 'nan_equality'.
/**
* How should NaNs be compared in an operation. In floating point there are multiple
* different binary representations for NaN.
*/
public enum NaNEquality {
/**
* No NaN representation is considered equal to any NaN representation, even for the
* exact same representation.
*/
UNEQUAL(false),
/**
* All representations of NaN are considered to be equal.
*/
ALL_EQUAL(true);

NaNEquality(boolean nansEqual) {
Expand Down
4 changes: 2 additions & 2 deletions java/src/main/java/ai/rapids/cudf/NullEquality.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

package ai.rapids.cudf;

/*
* This is analogous to the native 'null_equality'.
/**
* How should nulls be compared in an operation.
*/
public enum NullEquality {
UNEQUAL(false),
Expand Down
4 changes: 2 additions & 2 deletions java/src/main/java/ai/rapids/cudf/NullPolicy.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

package ai.rapids.cudf;

/*
* This is analogous to the native 'null_policy'.
/**
* Specify whether to include nulls or exclude nulls in an operation.
*/
public enum NullPolicy {
EXCLUDE(false),
Expand Down
46 changes: 46 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ReplacePolicy.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package ai.rapids.cudf;

/**
* Policy to specify the position of replacement values relative to null rows.
*/
public enum ReplacePolicy {
/**
* The replacement value is the first non-null value preceding the null row.
*/
PRECEDING(true),
/**
* The replacement value is the first non-null value following the null row.
*/
FOLLOWING(false);

ReplacePolicy(boolean isPreceding) {
this.isPreceding = isPreceding;
}

final boolean isPreceding;

/**
* Indicate which column the replacement should happen on.
*/
public ReplacePolicyWithColumn onColumn(int columnNumber) {
return new ReplacePolicyWithColumn(columnNumber, this);
}
}
46 changes: 46 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ReplacePolicyWithColumn.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package ai.rapids.cudf;

/**
* A replacement policy for a specific column
*/
public class ReplacePolicyWithColumn {
final int column;
final ReplacePolicy policy;

ReplacePolicyWithColumn(int column, ReplacePolicy policy) {
this.column = column;
this.policy = policy;
}

@Override
public boolean equals(Object other) {
if (!(other instanceof ReplacePolicyWithColumn)) {
return false;
}
ReplacePolicyWithColumn ro = (ReplacePolicyWithColumn)other;
return this.column == ro.column && this.policy.equals(ro.policy);
}

@Override
public int hashCode() {
return 31 * column + policy.hashCode();
}
}
39 changes: 39 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ScanType.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package ai.rapids.cudf;

/**
* Scan operation type.
*/
public enum ScanType {
/**
* Include the current row in the scan.
*/
INCLUSIVE(true),
/**
* Exclude the current row from the scan.
*/
EXCLUSIVE(false);

ScanType(boolean isInclusive) {
this.isInclusive = isInclusive;
}

final boolean isInclusive;
}
110 changes: 102 additions & 8 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,7 @@
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.*;

/**
* Class to represent a collection of ColumnVectors and operations that can be performed on them
Expand Down Expand Up @@ -464,6 +457,17 @@ private static native long[] groupByAggregate(long inputTable, int[] keyIndices,
boolean keySorted, boolean[] keysDescending,
boolean[] keysNullSmallest) throws CudfException;

private static native long[] groupByScan(long inputTable, int[] keyIndices, int[] aggColumnsIndices,
long[] aggInstances, boolean ignoreNullKeys,
boolean keySorted, boolean[] keysDescending,
boolean[] keysNullSmallest) throws CudfException;

private static native long[] groupByReplaceNulls(long inputTable, int[] keyIndices,
int[] replaceColumnsIndices,
boolean[] isPreceding, boolean ignoreNullKeys,
boolean keySorted, boolean[] keysDescending,
boolean[] keysNullSmallest) throws CudfException;

private static native long[] rollingWindowAggregate(
long inputTable,
int[] keyIndices,
Expand Down Expand Up @@ -2663,6 +2667,96 @@ public Table aggregateWindowsOverRanges(AggregationOverWindow... windowAggregate
}
}

public Table scan(AggregationOnColumn... aggregates) {
assert aggregates != null;

// To improve performance and memory we want to remove duplicate operations
// and also group the operations by column so hopefully cudf can do multiple aggregations
// in a single pass.

// Use a tree map to make debugging simpler (columns are all in the same order)
TreeMap<Integer, ColumnOps> groupedOps = new TreeMap<>();
// Total number of operations that will need to be done.
int keysLength = operation.indices.length;
int totalOps = 0;
for (int outputIndex = 0; outputIndex < aggregates.length; outputIndex++) {
AggregationOnColumn agg = aggregates[outputIndex];
ColumnOps ops = groupedOps.computeIfAbsent(agg.getColumnIndex(), (idx) -> new ColumnOps());
totalOps += ops.add(agg, outputIndex + keysLength);
}
int[] aggColumnIndexes = new int[totalOps];
long[] aggOperationInstances = new long[totalOps];
try {
int opIndex = 0;
for (Map.Entry<Integer, ColumnOps> entry: groupedOps.entrySet()) {
int columnIndex = entry.getKey();
for (Aggregation operation: entry.getValue().operations()) {
aggColumnIndexes[opIndex] = columnIndex;
aggOperationInstances[opIndex] = operation.createNativeInstance();
opIndex++;
}
}
assert opIndex == totalOps : opIndex + " == " + totalOps;

try (Table aggregate = new Table(groupByScan(
operation.table.nativeHandle,
operation.indices,
aggColumnIndexes,
aggOperationInstances,
groupByOptions.getIgnoreNullKeys(),
groupByOptions.getKeySorted(),
groupByOptions.getKeysDescending(),
groupByOptions.getKeysNullSmallest()))) {
// prepare the final table
ColumnVector[] finalCols = new ColumnVector[keysLength + aggregates.length];

// get the key columns
for (int aggIndex = 0; aggIndex < keysLength; aggIndex++) {
finalCols[aggIndex] = aggregate.getColumn(aggIndex);
}

int inputColumn = keysLength;
// Now get the aggregation columns
for (ColumnOps ops: groupedOps.values()) {
for (List<Integer> indices: ops.outputIndices()) {
for (int outIndex: indices) {
finalCols[outIndex] = aggregate.getColumn(inputColumn);
}
inputColumn++;
}
}
return new Table(finalCols);
}
} finally {
Aggregation.close(aggOperationInstances);
}
}

public Table replaceNulls(ReplacePolicyWithColumn... replacements) {
assert replacements != null;

// TODO in the future perhaps to improve performance and memory we want to
// remove duplicate operations.

boolean[] isPreceding = new boolean[replacements.length];
int [] columnIndexes = new int[replacements.length];

for (int index = 0; index < replacements.length; index++) {
isPreceding[index] = replacements[index].policy.isPreceding;
columnIndexes[index] = replacements[index].column;
}

return new Table(groupByReplaceNulls(
operation.table.nativeHandle,
operation.indices,
columnIndexes,
isPreceding,
groupByOptions.getIgnoreNullKeys(),
groupByOptions.getKeySorted(),
groupByOptions.getKeysDescending(),
groupByOptions.getKeysNullSmallest()));
}

/**
* Splits the groups in a single table into separate tables according to the grouping keys.
* Each split table represents a single group.
Expand Down
Loading

0 comments on commit 11e021d

Please sign in to comment.