Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add in JNI APIs for scan, replace_nulls, group_by.scan, and group_by.replace_nulls [skip ci] #8503

Merged
merged 7 commits into from
Jun 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 48 additions & 8 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,10 @@ public final ColumnVector replaceNulls(ColumnView replacements) {
return new ColumnVector(replaceNullsColumn(getNativeView(), replacements.getNativeView()));
}

public final ColumnVector replaceNulls(ReplacePolicy policy) {
return new ColumnVector(replaceNullsPolicy(getNativeView(), policy.isPreceding));
}

/**
* For a BOOL8 vector, computes a vector whose rows are selected from two other vectors
* based on the boolean value of this vector in the corresponding row.
Expand Down Expand Up @@ -1384,17 +1388,50 @@ public final ColumnVector rollingWindow(RollingAggregation op, WindowOptions opt
}

/**
* Compute the cumulative sum/prefix sum of the values in this column.
* This is similar to a rolling window SUM with unbounded preceding and none following.
* Input values 1, 2, 3
* Output values 1, 3, 6
* This currently only works for long values that are not nullable as this is currently a
* very simple implementation. It may be expanded in the future if needed.
* Compute the prefix sum (aka cumulative sum) of the values in this column.
* This is just a convenience method for an inclusive scan with a SUM aggregation.
*/
public final ColumnVector prefixSum() {
return new ColumnVector(prefixSum(getNativeView()));
return scan(Aggregation.sum());
}

/**
* Computes a scan for a column. This is very similar to a running window on the column.
* @param aggregation the aggregation to perform
* @param scanType should the scan be inclusive, include the current row, or exclusive.
* @param nullPolicy how should nulls be treated. Note that some aggregations also include a
* null policy too. Currently none of those aggregations are supported so
* it is undefined how they would interact with each other.
*/
public final ColumnVector scan(Aggregation aggregation, ScanType scanType, NullPolicy nullPolicy) {
long nativeId = aggregation.createNativeInstance();
try {
return new ColumnVector(scan(getNativeView(), nativeId,
scanType.isInclusive, nullPolicy.includeNulls));
} finally {
Aggregation.close(nativeId);
}
}

/**
* Computes a scan for a column that excludes nulls.
* @param aggregation the aggregation to perform
* @param scanType should the scan be inclusive, include the current row, or exclusive.
*/
public final ColumnVector scan(Aggregation aggregation, ScanType scanType) {
return scan(aggregation, scanType, NullPolicy.EXCLUDE);
}

/**
* Computes an inclusive scan for a column that excludes nulls.
* @param aggregation the aggregation to perform
*/
public final ColumnVector scan(Aggregation aggregation) {
return scan(aggregation, ScanType.INCLUSIVE, NullPolicy.EXCLUDE);
}



/////////////////////////////////////////////////////////////////////////////
// LOGICAL
/////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -3217,7 +3254,8 @@ private static native long rollingWindow(
long preceding_col,
long following_col);

private static native long prefixSum(long viewHandle) throws CudfException;
private static native long scan(long viewHandle, long aggregation,
boolean isInclusive, boolean includeNulls) throws CudfException;

private static native long nansToNulls(long viewHandle) throws CudfException;

Expand All @@ -3227,6 +3265,8 @@ private static native long rollingWindow(

private static native long replaceNullsColumn(long viewHandle, long replaceViewHandle) throws CudfException;

private static native long replaceNullsPolicy(long nativeView, boolean isPreceding) throws CudfException;

private static native long ifElseVV(long predVec, long trueVec, long falseVec) throws CudfException;

private static native long ifElseVS(long predVec, long trueVec, long falseScalar) throws CudfException;
Expand Down
12 changes: 10 additions & 2 deletions java/src/main/java/ai/rapids/cudf/NaNEquality.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,19 @@

package ai.rapids.cudf;

/*
* This is analogous to the native 'nan_equality'.
/**
* How should NaNs be compared in an operation. In floating point there are multiple
* different binary representations for NaN.
*/
public enum NaNEquality {
/**
* No NaN representation is considered equal to any NaN representation, even for the
* exact same representation.
*/
UNEQUAL(false),
/**
* All representations of NaN are considered to be equal.
*/
ALL_EQUAL(true);

NaNEquality(boolean nansEqual) {
Expand Down
4 changes: 2 additions & 2 deletions java/src/main/java/ai/rapids/cudf/NullEquality.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

package ai.rapids.cudf;

/*
* This is analogous to the native 'null_equality'.
/**
* How should nulls be compared in an operation.
*/
public enum NullEquality {
UNEQUAL(false),
Expand Down
4 changes: 2 additions & 2 deletions java/src/main/java/ai/rapids/cudf/NullPolicy.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

package ai.rapids.cudf;

/*
* This is analogous to the native 'null_policy'.
/**
* Specify whether to include nulls or exclude nulls in an operation.
*/
public enum NullPolicy {
EXCLUDE(false),
Expand Down
46 changes: 46 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ReplacePolicy.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package ai.rapids.cudf;

/**
* Policy to specify the position of replacement values relative to null rows.
*/
public enum ReplacePolicy {
/**
* The replacement value is the first non-null value preceding the null row.
*/
PRECEDING(true),
/**
* The replacement value is the first non-null value following the null row.
*/
FOLLOWING(false);

ReplacePolicy(boolean isPreceding) {
this.isPreceding = isPreceding;
}

final boolean isPreceding;

/**
* Indicate which column the replacement should happen on.
*/
public ReplacePolicyWithColumn onColumn(int columnNumber) {
return new ReplacePolicyWithColumn(columnNumber, this);
}
}
46 changes: 46 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ReplacePolicyWithColumn.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package ai.rapids.cudf;

/**
* A replacement policy for a specific column
*/
public class ReplacePolicyWithColumn {
final int column;
final ReplacePolicy policy;

ReplacePolicyWithColumn(int column, ReplacePolicy policy) {
this.column = column;
this.policy = policy;
}

@Override
public boolean equals(Object other) {
if (!(other instanceof ReplacePolicyWithColumn)) {
return false;
}
ReplacePolicyWithColumn ro = (ReplacePolicyWithColumn)other;
return this.column == ro.column && this.policy.equals(ro.policy);
}

@Override
public int hashCode() {
return 31 * column + policy.hashCode();
}
}
39 changes: 39 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ScanType.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package ai.rapids.cudf;

/**
* Scan operation type.
*/
public enum ScanType {
/**
* Include the current row in the scan.
*/
INCLUSIVE(true),
/**
* Exclude the current row from the scan.
*/
EXCLUSIVE(false);

ScanType(boolean isInclusive) {
this.isInclusive = isInclusive;
}

final boolean isInclusive;
}
110 changes: 102 additions & 8 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,7 @@
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.*;

/**
* Class to represent a collection of ColumnVectors and operations that can be performed on them
Expand Down Expand Up @@ -464,6 +457,17 @@ private static native long[] groupByAggregate(long inputTable, int[] keyIndices,
boolean keySorted, boolean[] keysDescending,
boolean[] keysNullSmallest) throws CudfException;

private static native long[] groupByScan(long inputTable, int[] keyIndices, int[] aggColumnsIndices,
long[] aggInstances, boolean ignoreNullKeys,
boolean keySorted, boolean[] keysDescending,
boolean[] keysNullSmallest) throws CudfException;

private static native long[] groupByReplaceNulls(long inputTable, int[] keyIndices,
int[] replaceColumnsIndices,
boolean[] isPreceding, boolean ignoreNullKeys,
boolean keySorted, boolean[] keysDescending,
boolean[] keysNullSmallest) throws CudfException;

private static native long[] rollingWindowAggregate(
long inputTable,
int[] keyIndices,
Expand Down Expand Up @@ -2663,6 +2667,96 @@ public Table aggregateWindowsOverRanges(AggregationOverWindow... windowAggregate
}
}

public Table scan(AggregationOnColumn... aggregates) {
assert aggregates != null;

// To improve performance and memory we want to remove duplicate operations
// and also group the operations by column so hopefully cudf can do multiple aggregations
// in a single pass.

// Use a tree map to make debugging simpler (columns are all in the same order)
TreeMap<Integer, ColumnOps> groupedOps = new TreeMap<>();
// Total number of operations that will need to be done.
int keysLength = operation.indices.length;
int totalOps = 0;
for (int outputIndex = 0; outputIndex < aggregates.length; outputIndex++) {
AggregationOnColumn agg = aggregates[outputIndex];
ColumnOps ops = groupedOps.computeIfAbsent(agg.getColumnIndex(), (idx) -> new ColumnOps());
totalOps += ops.add(agg, outputIndex + keysLength);
}
int[] aggColumnIndexes = new int[totalOps];
long[] aggOperationInstances = new long[totalOps];
try {
int opIndex = 0;
for (Map.Entry<Integer, ColumnOps> entry: groupedOps.entrySet()) {
int columnIndex = entry.getKey();
for (Aggregation operation: entry.getValue().operations()) {
aggColumnIndexes[opIndex] = columnIndex;
aggOperationInstances[opIndex] = operation.createNativeInstance();
opIndex++;
}
}
assert opIndex == totalOps : opIndex + " == " + totalOps;

try (Table aggregate = new Table(groupByScan(
operation.table.nativeHandle,
operation.indices,
aggColumnIndexes,
aggOperationInstances,
groupByOptions.getIgnoreNullKeys(),
groupByOptions.getKeySorted(),
groupByOptions.getKeysDescending(),
groupByOptions.getKeysNullSmallest()))) {
// prepare the final table
ColumnVector[] finalCols = new ColumnVector[keysLength + aggregates.length];

// get the key columns
for (int aggIndex = 0; aggIndex < keysLength; aggIndex++) {
finalCols[aggIndex] = aggregate.getColumn(aggIndex);
}

int inputColumn = keysLength;
// Now get the aggregation columns
for (ColumnOps ops: groupedOps.values()) {
for (List<Integer> indices: ops.outputIndices()) {
for (int outIndex: indices) {
finalCols[outIndex] = aggregate.getColumn(inputColumn);
}
inputColumn++;
}
}
return new Table(finalCols);
}
} finally {
Aggregation.close(aggOperationInstances);
}
}

public Table replaceNulls(ReplacePolicyWithColumn... replacements) {
assert replacements != null;

// TODO in the future perhaps to improve performance and memory we want to
// remove duplicate operations.

boolean[] isPreceding = new boolean[replacements.length];
int [] columnIndexes = new int[replacements.length];

for (int index = 0; index < replacements.length; index++) {
isPreceding[index] = replacements[index].policy.isPreceding;
columnIndexes[index] = replacements[index].column;
}

return new Table(groupByReplaceNulls(
operation.table.nativeHandle,
operation.indices,
columnIndexes,
isPreceding,
groupByOptions.getIgnoreNullKeys(),
groupByOptions.getKeySorted(),
groupByOptions.getKeysDescending(),
groupByOptions.getKeysNullSmallest()));
}

/**
* Splits the groups in a single table into separate tables according to the grouping keys.
* Each split table represents a single group.
Expand Down
Loading