Skip to content

Commit

Permalink
Added in JNI support for out of core sort algorithm (#7381)
Browse files Browse the repository at this point in the history
This makes changes to be able to support out of core sort on the Spark plugin. It adds sort order as an API so sorting that is not out of core does not need to gather columns that are only used for sort ordering.

It adds in versions of lower bound and upper bound that match more closely the APIs for sort so the sort order can be reused between them.

It updates the merge API to take an array of tables not just a list for simpler integration with Scala.

And makes a few minor changes to the sort order class so it can work better with Spark and debugging.

Authors:
  - Robert (Bobby) Evans (@revans2)

Approvers:
  - Jason Lowe (@jlowe)

URL: #7381
  • Loading branch information
revans2 authored Feb 16, 2021
1 parent a08ec0e commit 083eb2a
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 64 deletions.
143 changes: 118 additions & 25 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import ai.rapids.cudf.HostColumnVector.StructType;

import java.io.File;
import java.io.Serializable;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.nio.ByteBuffer;
Expand Down Expand Up @@ -466,6 +467,9 @@ private static native long[] timeRangeRollingWindowAggregate(long inputTable, in
int[] preceding, int[] following, boolean[] unboundedPreceding, boolean[] unboundedFollowing,
boolean ignoreNullKeys) throws CudfException;

private static native long sortOrder(long inputTable, long[] sortKeys, boolean[] isDescending,
boolean[] areNullsSmallest) throws CudfException;

private static native long[] orderBy(long inputTable, long[] sortKeys, boolean[] isDescending,
boolean[] areNullsSmallest) throws CudfException;

Expand Down Expand Up @@ -1247,7 +1251,8 @@ public Table repeat(ColumnVector counts, boolean checkCount) {
}

/**
* Given a sorted table return the lower bound.
* Find smallest indices in a sorted table where values should be inserted to maintain order.
* <pre>
* Example:
*
* Single column:
Expand All @@ -1265,14 +1270,11 @@ public Table repeat(ColumnVector counts, boolean checkCount) {
* { .7 },
* { 61 }}
* result = { 3 }
* NaNs in column values produce incorrect results.
* </pre>
* The input table and the values table need to be non-empty (row count > 0)
* The column data types of the tables' have to match in order.
* Strings and String categories do not work for this method. If the input table is
* unsorted the results are wrong. Types of columns can be of mixed data types.
* @param areNullsSmallest true if nulls are assumed smallest
* @param valueTable the table of values that need to be inserted
* @param descFlags indicates the ordering of the column(s), true if descending
* @param areNullsSmallest per column, true if nulls are assumed smallest
* @param valueTable the table of values to find insertion locations for
* @param descFlags per column indicates the ordering, true if descending.
* @return ColumnVector with lower bound indices for all rows in valueTable
*/
public ColumnVector lowerBound(boolean[] areNullsSmallest,
Expand All @@ -1283,7 +1285,34 @@ public ColumnVector lowerBound(boolean[] areNullsSmallest,
}

/**
* Find smallest indices in a sorted table where values should be inserted to maintain order.
* This is a convenience method. It pulls out the columns indicated by the args and sets up the
* ordering properly to call `lowerBound`.
* @param valueTable the table of values to find insertion locations for
* @param args the sort order used to sort this table.
* @return ColumnVector with lower bound indices for all rows in valueTable
*/
public ColumnVector lowerBound(Table valueTable, OrderByArg... args) {
boolean[] areNullsSmallest = new boolean[args.length];
boolean[] descFlags = new boolean[args.length];
ColumnVector[] inputColumns = new ColumnVector[args.length];
ColumnVector[] searchColumns = new ColumnVector[args.length];
for (int i = 0; i < args.length; i++) {
areNullsSmallest[i] = args[i].isNullSmallest;
descFlags[i] = args[i].isDescending;
inputColumns[i] = columns[args[i].index];
searchColumns[i] = valueTable.columns[args[i].index];
}
try (Table input = new Table(inputColumns);
Table search = new Table(searchColumns)) {
return input.lowerBound(areNullsSmallest, search, descFlags);
}
}

/**
* Find largest indices in a sorted table where values should be inserted to maintain order.
* Given a sorted table return the upper bound.
* <pre>
* Example:
*
* Single column:
Expand All @@ -1301,14 +1330,11 @@ public ColumnVector lowerBound(boolean[] areNullsSmallest,
* { .7 },
* { 61 }}
* result = { 5 }
* NaNs in column values produce incorrect results.
* </pre>
* The input table and the values table need to be non-empty (row count > 0)
* The column data types of the tables' have to match in order.
* Strings and String categories do not work for this method. If the input table is
* unsorted the results are wrong. Types of columns can be of mixed data types.
* @param areNullsSmallest true if nulls are assumed smallest
* @param valueTable the table of values that need to be inserted
* @param descFlags indicates the ordering of the column(s), true if descending
* @param areNullsSmallest per column, true if nulls are assumed smallest
* @param valueTable the table of values to find insertion locations for
* @param descFlags per column indicates the ordering, true if descending.
* @return ColumnVector with upper bound indices for all rows in valueTable
*/
public ColumnVector upperBound(boolean[] areNullsSmallest,
Expand All @@ -1318,6 +1344,31 @@ public ColumnVector upperBound(boolean[] areNullsSmallest,
descFlags, areNullsSmallest, true));
}

/**
* Find largest indices in a sorted table where values should be inserted to maintain order.
* This is a convenience method. It pulls out the columns indicated by the args and sets up the
* ordering properly to call `upperBound`.
* @param valueTable the table of values to find insertion locations for
* @param args the sort order used to sort this table.
* @return ColumnVector with upper bound indices for all rows in valueTable
*/
public ColumnVector upperBound(Table valueTable, OrderByArg... args) {
boolean[] areNullsSmallest = new boolean[args.length];
boolean[] descFlags = new boolean[args.length];
ColumnVector[] inputColumns = new ColumnVector[args.length];
ColumnVector[] searchColumns = new ColumnVector[args.length];
for (int i = 0; i < args.length; i++) {
areNullsSmallest[i] = args[i].isNullSmallest;
descFlags[i] = args[i].isDescending;
inputColumns[i] = columns[args[i].index];
searchColumns[i] = valueTable.columns[args[i].index];
}
try (Table input = new Table(inputColumns);
Table search = new Table(searchColumns)) {
return input.upperBound(areNullsSmallest, search, descFlags);
}
}

private void assertForBounds(Table valueTable) {
assert this.getRowCount() != 0 : "Input table cannot be empty";
assert valueTable.getRowCount() != 0 : "Value table cannot be empty";
Expand All @@ -1342,17 +1393,39 @@ public Table crossJoin(Table right) {
// TABLE MANIPULATION APIs
/////////////////////////////////////////////////////////////////////////////

/**
* Get back a gather map that can be used to sort the data. This allows you to sort by data
* that does not appear in the final result and not pay the cost of gathering the data that
* is only needed for sorting.
* @param args what order to sort the data by
* @return a gather map
*/
public ColumnVector sortOrder(OrderByArg... args) {
long[] sortKeys = new long[args.length];
boolean[] isDescending = new boolean[args.length];
boolean[] areNullsSmallest = new boolean[args.length];
for (int i = 0; i < args.length; i++) {
int index = args[i].index;
assert (index >= 0 && index < columns.length) :
"index is out of range 0 <= " + index + " < " + columns.length;
isDescending[i] = args[i].isDescending;
areNullsSmallest[i] = args[i].isNullSmallest;
sortKeys[i] = columns[index].getNativeView();
}

return new ColumnVector(sortOrder(nativeHandle, sortKeys, isDescending, areNullsSmallest));
}

/**
* Orders the table using the sortkeys returning a new allocated table. The caller is
* responsible for cleaning up
* the {@link ColumnVector} returned as part of the output {@link Table}
* <p>
* Example usage: orderBy(true, Table.asc(0), Table.desc(3)...);
* @param args - Suppliers to initialize sortKeys.
* @param args Suppliers to initialize sortKeys.
* @return Sorted Table
*/
public Table orderBy(OrderByArg... args) {
assert args.length <= columns.length;
long[] sortKeys = new long[args.length];
boolean[] isDescending = new boolean[args.length];
boolean[] areNullsSmallest = new boolean[args.length];
Expand All @@ -1377,13 +1450,13 @@ public Table orderBy(OrderByArg... args) {
* initially.
* @return a combined sorted table.
*/
public static Table merge(List<Table> tables, OrderByArg... args) {
assert !tables.isEmpty();
long[] tableHandles = new long[tables.size()];
Table first = tables.get(0);
public static Table merge(Table[] tables, OrderByArg... args) {
assert tables.length > 0;
long[] tableHandles = new long[tables.length];
Table first = tables[0];
assert args.length <= first.columns.length;
for (int i = 0; i < tables.size(); i++) {
Table t = tables.get(i);
for (int i = 0; i < tables.length; i++) {
Table t = tables[i];
assert t != null;
assert t.columns.length == first.columns.length;
tableHandles[i] = t.nativeHandle;
Expand All @@ -1394,7 +1467,7 @@ public static Table merge(List<Table> tables, OrderByArg... args) {
for (int i = 0; i < args.length; i++) {
int index = args[i].index;
assert (index >= 0 && index < first.columns.length) :
"index is out of range 0 <= " + index + " < " + first.columns.length;
"index is out of range 0 <= " + index + " < " + first.columns.length;
isDescending[i] = args[i].isDescending;
areNullsSmallest[i] = args[i].isNullSmallest;
sortKeyIndexes[i] = index;
Expand All @@ -1403,6 +1476,19 @@ public static Table merge(List<Table> tables, OrderByArg... args) {
return new Table(merge(tableHandles, sortKeyIndexes, isDescending, areNullsSmallest));
}

/**
* Merge multiple already sorted tables keeping the sort order the same.
* This is a more efficient version of concatenate followed by orderBy, but requires that
* the input already be sorted.
* @param tables the tables that should be merged.
* @param args the ordering of the tables. Should match how they were sorted
* initially.
* @return a combined sorted table.
*/
public static Table merge(List<Table> tables, OrderByArg... args) {
return merge(tables.toArray(new Table[tables.size()]), args);
}

public static OrderByArg asc(final int index) {
return new OrderByArg(index, false, false);
}
Expand Down Expand Up @@ -1852,7 +1938,7 @@ public static Table fromPackedTable(ByteBuffer metadata, DeviceMemoryBuffer data
// HELPER CLASSES
/////////////////////////////////////////////////////////////////////////////

public static final class OrderByArg {
public static final class OrderByArg implements Serializable {
final int index;
final boolean isDescending;
final boolean isNullSmallest;
Expand All @@ -1862,6 +1948,13 @@ public static final class OrderByArg {
this.isDescending = isDescending;
this.isNullSmallest = isNullSmallest;
}

@Override
public String toString() {
return "ORDER BY " + index +
(isDescending ? " DESC " : " ASC ") +
(isNullSmallest ? "NULL SMALLEST" : "NULL LARGEST");
}
}

/**
Expand Down
Loading

0 comments on commit 083eb2a

Please sign in to comment.