From 083eb2aefc3fc6fe3b68d6ddba1e248054c501cf Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Tue, 16 Feb 2021 09:03:17 -0600 Subject: [PATCH] Added in JNI support for out of core sort algorithm (#7381) This makes changes to be able to support out of core sort on the Spark plugin. It adds sort order as an API so sorting that is not out of core does not need to gather columns that are only used for sort ordering. It adds in versions of lower bound and upper bound that match more closely the APIs for sort so the sort order can be reused between them. It updates the merge API to take an array of tables not just a list for simpler integration with Scala. And makes a few minor changes to the sort order class so it can work better with Spark and debugging. Authors: - Robert (Bobby) Evans (@revans2) Approvers: - Jason Lowe (@jlowe) URL: https://github.com/rapidsai/cudf/pull/7381 --- java/src/main/java/ai/rapids/cudf/Table.java | 143 +++++++++++++++--- java/src/main/native/src/TableJni.cpp | 120 ++++++++++----- .../test/java/ai/rapids/cudf/TableTest.java | 18 +++ 3 files changed, 217 insertions(+), 64 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index 0637ae6de1e..474a9da53bf 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -25,6 +25,7 @@ import ai.rapids.cudf.HostColumnVector.StructType; import java.io.File; +import java.io.Serializable; import java.math.BigDecimal; import java.math.RoundingMode; import java.nio.ByteBuffer; @@ -466,6 +467,9 @@ private static native long[] timeRangeRollingWindowAggregate(long inputTable, in int[] preceding, int[] following, boolean[] unboundedPreceding, boolean[] unboundedFollowing, boolean ignoreNullKeys) throws CudfException; + private static native long sortOrder(long inputTable, long[] sortKeys, boolean[] isDescending, + boolean[] areNullsSmallest) throws CudfException; + private static native long[] orderBy(long inputTable, long[] sortKeys, boolean[] isDescending, boolean[] areNullsSmallest) throws CudfException; @@ -1247,7 +1251,8 @@ public Table repeat(ColumnVector counts, boolean checkCount) { } /** - * Given a sorted table return the lower bound. + * Find smallest indices in a sorted table where values should be inserted to maintain order. + *
    * Example:
    *
    *  Single column:
@@ -1265,14 +1270,11 @@ public Table repeat(ColumnVector counts, boolean checkCount) {
    *                      { .7 },
    *                      { 61 }}
    *   result          = {  3 }
-   * NaNs in column values produce incorrect results.
+   * 
* The input table and the values table need to be non-empty (row count > 0) - * The column data types of the tables' have to match in order. - * Strings and String categories do not work for this method. If the input table is - * unsorted the results are wrong. Types of columns can be of mixed data types. - * @param areNullsSmallest true if nulls are assumed smallest - * @param valueTable the table of values that need to be inserted - * @param descFlags indicates the ordering of the column(s), true if descending + * @param areNullsSmallest per column, true if nulls are assumed smallest + * @param valueTable the table of values to find insertion locations for + * @param descFlags per column indicates the ordering, true if descending. * @return ColumnVector with lower bound indices for all rows in valueTable */ public ColumnVector lowerBound(boolean[] areNullsSmallest, @@ -1283,7 +1285,34 @@ public ColumnVector lowerBound(boolean[] areNullsSmallest, } /** + * Find smallest indices in a sorted table where values should be inserted to maintain order. + * This is a convenience method. It pulls out the columns indicated by the args and sets up the + * ordering properly to call `lowerBound`. + * @param valueTable the table of values to find insertion locations for + * @param args the sort order used to sort this table. + * @return ColumnVector with lower bound indices for all rows in valueTable + */ + public ColumnVector lowerBound(Table valueTable, OrderByArg... args) { + boolean[] areNullsSmallest = new boolean[args.length]; + boolean[] descFlags = new boolean[args.length]; + ColumnVector[] inputColumns = new ColumnVector[args.length]; + ColumnVector[] searchColumns = new ColumnVector[args.length]; + for (int i = 0; i < args.length; i++) { + areNullsSmallest[i] = args[i].isNullSmallest; + descFlags[i] = args[i].isDescending; + inputColumns[i] = columns[args[i].index]; + searchColumns[i] = valueTable.columns[args[i].index]; + } + try (Table input = new Table(inputColumns); + Table search = new Table(searchColumns)) { + return input.lowerBound(areNullsSmallest, search, descFlags); + } + } + + /** + * Find largest indices in a sorted table where values should be inserted to maintain order. * Given a sorted table return the upper bound. + *
    * Example:
    *
    *  Single column:
@@ -1301,14 +1330,11 @@ public ColumnVector lowerBound(boolean[] areNullsSmallest,
    *                      { .7 },
    *                      { 61 }}
    *   result          = {  5 }
-   * NaNs in column values produce incorrect results.
+   * 
* The input table and the values table need to be non-empty (row count > 0) - * The column data types of the tables' have to match in order. - * Strings and String categories do not work for this method. If the input table is - * unsorted the results are wrong. Types of columns can be of mixed data types. - * @param areNullsSmallest true if nulls are assumed smallest - * @param valueTable the table of values that need to be inserted - * @param descFlags indicates the ordering of the column(s), true if descending + * @param areNullsSmallest per column, true if nulls are assumed smallest + * @param valueTable the table of values to find insertion locations for + * @param descFlags per column indicates the ordering, true if descending. * @return ColumnVector with upper bound indices for all rows in valueTable */ public ColumnVector upperBound(boolean[] areNullsSmallest, @@ -1318,6 +1344,31 @@ public ColumnVector upperBound(boolean[] areNullsSmallest, descFlags, areNullsSmallest, true)); } + /** + * Find largest indices in a sorted table where values should be inserted to maintain order. + * This is a convenience method. It pulls out the columns indicated by the args and sets up the + * ordering properly to call `upperBound`. + * @param valueTable the table of values to find insertion locations for + * @param args the sort order used to sort this table. + * @return ColumnVector with upper bound indices for all rows in valueTable + */ + public ColumnVector upperBound(Table valueTable, OrderByArg... args) { + boolean[] areNullsSmallest = new boolean[args.length]; + boolean[] descFlags = new boolean[args.length]; + ColumnVector[] inputColumns = new ColumnVector[args.length]; + ColumnVector[] searchColumns = new ColumnVector[args.length]; + for (int i = 0; i < args.length; i++) { + areNullsSmallest[i] = args[i].isNullSmallest; + descFlags[i] = args[i].isDescending; + inputColumns[i] = columns[args[i].index]; + searchColumns[i] = valueTable.columns[args[i].index]; + } + try (Table input = new Table(inputColumns); + Table search = new Table(searchColumns)) { + return input.upperBound(areNullsSmallest, search, descFlags); + } + } + private void assertForBounds(Table valueTable) { assert this.getRowCount() != 0 : "Input table cannot be empty"; assert valueTable.getRowCount() != 0 : "Value table cannot be empty"; @@ -1342,17 +1393,39 @@ public Table crossJoin(Table right) { // TABLE MANIPULATION APIs ///////////////////////////////////////////////////////////////////////////// + /** + * Get back a gather map that can be used to sort the data. This allows you to sort by data + * that does not appear in the final result and not pay the cost of gathering the data that + * is only needed for sorting. + * @param args what order to sort the data by + * @return a gather map + */ + public ColumnVector sortOrder(OrderByArg... args) { + long[] sortKeys = new long[args.length]; + boolean[] isDescending = new boolean[args.length]; + boolean[] areNullsSmallest = new boolean[args.length]; + for (int i = 0; i < args.length; i++) { + int index = args[i].index; + assert (index >= 0 && index < columns.length) : + "index is out of range 0 <= " + index + " < " + columns.length; + isDescending[i] = args[i].isDescending; + areNullsSmallest[i] = args[i].isNullSmallest; + sortKeys[i] = columns[index].getNativeView(); + } + + return new ColumnVector(sortOrder(nativeHandle, sortKeys, isDescending, areNullsSmallest)); + } + /** * Orders the table using the sortkeys returning a new allocated table. The caller is * responsible for cleaning up * the {@link ColumnVector} returned as part of the output {@link Table} *

* Example usage: orderBy(true, Table.asc(0), Table.desc(3)...); - * @param args - Suppliers to initialize sortKeys. + * @param args Suppliers to initialize sortKeys. * @return Sorted Table */ public Table orderBy(OrderByArg... args) { - assert args.length <= columns.length; long[] sortKeys = new long[args.length]; boolean[] isDescending = new boolean[args.length]; boolean[] areNullsSmallest = new boolean[args.length]; @@ -1377,13 +1450,13 @@ public Table orderBy(OrderByArg... args) { * initially. * @return a combined sorted table. */ - public static Table merge(List tables, OrderByArg... args) { - assert !tables.isEmpty(); - long[] tableHandles = new long[tables.size()]; - Table first = tables.get(0); + public static Table merge(Table[] tables, OrderByArg... args) { + assert tables.length > 0; + long[] tableHandles = new long[tables.length]; + Table first = tables[0]; assert args.length <= first.columns.length; - for (int i = 0; i < tables.size(); i++) { - Table t = tables.get(i); + for (int i = 0; i < tables.length; i++) { + Table t = tables[i]; assert t != null; assert t.columns.length == first.columns.length; tableHandles[i] = t.nativeHandle; @@ -1394,7 +1467,7 @@ public static Table merge(List
tables, OrderByArg... args) { for (int i = 0; i < args.length; i++) { int index = args[i].index; assert (index >= 0 && index < first.columns.length) : - "index is out of range 0 <= " + index + " < " + first.columns.length; + "index is out of range 0 <= " + index + " < " + first.columns.length; isDescending[i] = args[i].isDescending; areNullsSmallest[i] = args[i].isNullSmallest; sortKeyIndexes[i] = index; @@ -1403,6 +1476,19 @@ public static Table merge(List
tables, OrderByArg... args) { return new Table(merge(tableHandles, sortKeyIndexes, isDescending, areNullsSmallest)); } + /** + * Merge multiple already sorted tables keeping the sort order the same. + * This is a more efficient version of concatenate followed by orderBy, but requires that + * the input already be sorted. + * @param tables the tables that should be merged. + * @param args the ordering of the tables. Should match how they were sorted + * initially. + * @return a combined sorted table. + */ + public static Table merge(List
tables, OrderByArg... args) { + return merge(tables.toArray(new Table[tables.size()]), args); + } + public static OrderByArg asc(final int index) { return new OrderByArg(index, false, false); } @@ -1852,7 +1938,7 @@ public static Table fromPackedTable(ByteBuffer metadata, DeviceMemoryBuffer data // HELPER CLASSES ///////////////////////////////////////////////////////////////////////////// - public static final class OrderByArg { + public static final class OrderByArg implements Serializable { final int index; final boolean isDescending; final boolean isNullSmallest; @@ -1862,6 +1948,13 @@ public static final class OrderByArg { this.isDescending = isDescending; this.isNullSmallest = isNullSmallest; } + + @Override + public String toString() { + return "ORDER BY " + index + + (isDescending ? " DESC " : " ASC ") + + (isNullSmallest ? "NULL SMALLEST" : "NULL LARGEST"); + } } /** diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index 30222452804..96b6d1d9a74 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -578,7 +578,7 @@ bool valid_window_parameters(native_jintArray const &values, native_jintArray co extern "C" { JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_createCudfTableView(JNIEnv *env, - jclass class_object, + jclass, jlongArray j_cudf_columns) { JNI_NULL_CHECK(env, j_cudf_columns, "columns are null", 0); @@ -596,7 +596,7 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_createCudfTableView(JNIEnv *en CATCH_STD(env, 0); } -JNIEXPORT void JNICALL Java_ai_rapids_cudf_Table_deleteCudfTable(JNIEnv *env, jclass class_object, +JNIEXPORT void JNICALL Java_ai_rapids_cudf_Table_deleteCudfTable(JNIEnv *env, jclass, jlong j_cudf_table_view) { JNI_NULL_CHECK(env, j_cudf_table_view, "table view handle is null", ); try { @@ -638,7 +638,58 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_columnViewsFromPacked(JNI CATCH_STD(env, nullptr); } -JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_orderBy(JNIEnv *env, jclass j_class_object, +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_sortOrder(JNIEnv *env, jclass, + jlong j_input_table, + jlongArray j_sort_keys_columns, + jbooleanArray j_is_descending, + jbooleanArray j_are_nulls_smallest) { + + // input validations & verifications + JNI_NULL_CHECK(env, j_input_table, "input table is null", 0); + JNI_NULL_CHECK(env, j_sort_keys_columns, "sort keys columns is null", 0); + JNI_NULL_CHECK(env, j_is_descending, "sort order array is null", 0); + JNI_NULL_CHECK(env, j_are_nulls_smallest, "null order array is null", 0); + + try { + cudf::jni::auto_set_device(env); + cudf::jni::native_jpointerArray n_sort_keys_columns(env, + j_sort_keys_columns); + jsize num_columns = n_sort_keys_columns.size(); + const cudf::jni::native_jbooleanArray n_is_descending(env, j_is_descending); + jsize num_columns_is_desc = n_is_descending.size(); + + JNI_ARG_CHECK(env, num_columns_is_desc == num_columns, + "columns and is_descending lengths don't match", 0); + + const cudf::jni::native_jbooleanArray n_are_nulls_smallest(env, j_are_nulls_smallest); + jsize num_columns_null_smallest = n_are_nulls_smallest.size(); + + JNI_ARG_CHECK(env, num_columns_null_smallest == num_columns, + "columns and is_descending lengths don't match", 0); + + std::vector order(n_is_descending.size()); + for (int i = 0; i < n_is_descending.size(); i++) { + order[i] = n_is_descending[i] ? cudf::order::DESCENDING : cudf::order::ASCENDING; + } + std::vector null_order(n_are_nulls_smallest.size()); + for (int i = 0; i < n_are_nulls_smallest.size(); i++) { + null_order[i] = n_are_nulls_smallest[i] ? cudf::null_order::BEFORE : cudf::null_order::AFTER; + } + + std::vector columns(num_columns); + for (int i = 0; i < num_columns; i++) { + columns[i] = *n_sort_keys_columns[i]; + } + cudf::table_view keys(columns); + + auto sorted_col = cudf::sorted_order(keys, order, null_order); + return reinterpret_cast(sorted_col.release()); + } + CATCH_STD(env, 0); +} + + +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_orderBy(JNIEnv *env, jclass, jlong j_input_table, jlongArray j_sort_keys_columns, jbooleanArray j_is_descending, @@ -646,7 +697,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_orderBy(JNIEnv *env, jcla // input validations & verifications JNI_NULL_CHECK(env, j_input_table, "input table is null", NULL); - JNI_NULL_CHECK(env, j_sort_keys_columns, "input table is null", NULL); + JNI_NULL_CHECK(env, j_sort_keys_columns, "sort keys columns is null", NULL); JNI_NULL_CHECK(env, j_is_descending, "sort order array is null", NULL); JNI_NULL_CHECK(env, j_are_nulls_smallest, "null order array is null", NULL); @@ -658,18 +709,14 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_orderBy(JNIEnv *env, jcla const cudf::jni::native_jbooleanArray n_is_descending(env, j_is_descending); jsize num_columns_is_desc = n_is_descending.size(); - if (num_columns_is_desc != num_columns) { - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", - "columns and is_descending lengths don't match", NULL); - } + JNI_ARG_CHECK(env, num_columns_is_desc == num_columns, + "columns and is_descending lengths don't match", 0); const cudf::jni::native_jbooleanArray n_are_nulls_smallest(env, j_are_nulls_smallest); jsize num_columns_null_smallest = n_are_nulls_smallest.size(); - if (num_columns_null_smallest != num_columns) { - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", - "columns and areNullsSmallest lengths don't match", NULL); - } + JNI_ARG_CHECK(env, num_columns_null_smallest == num_columns, + "columns and areNullsSmallest lengths don't match", 0); std::vector order(n_is_descending.size()); for (int i = 0; i < n_is_descending.size(); i++) { @@ -680,10 +727,9 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_orderBy(JNIEnv *env, jcla null_order[i] = n_are_nulls_smallest[i] ? cudf::null_order::BEFORE : cudf::null_order::AFTER; } - std::vector columns; - columns.reserve(num_columns); + std::vector columns(num_columns); for (int i = 0; i < num_columns; i++) { - columns.push_back(*n_sort_keys_columns[i]); + columns[i] = *n_sort_keys_columns[i]; } cudf::table_view keys(columns); @@ -696,7 +742,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_orderBy(JNIEnv *env, jcla CATCH_STD(env, NULL); } -JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_merge(JNIEnv *env, jclass j_class_object, +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_merge(JNIEnv *env, jclass, jlongArray j_table_handles, jintArray j_sort_key_indexes, jbooleanArray j_is_descending, @@ -717,18 +763,14 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_merge(JNIEnv *env, jclass const cudf::jni::native_jbooleanArray n_is_descending(env, j_is_descending); jsize num_columns_is_desc = n_is_descending.size(); - if (num_columns_is_desc != num_columns) { - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", - "columns and is_descending lengths don't match", NULL); - } + JNI_ARG_CHECK(env, num_columns_is_desc == num_columns, + "columns and is_descending lengths don't match", NULL); const cudf::jni::native_jbooleanArray n_are_nulls_smallest(env, j_are_nulls_smallest); jsize num_columns_null_smallest = n_are_nulls_smallest.size(); - if (num_columns_null_smallest != num_columns) { - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", - "columns and areNullsSmallest lengths don't match", NULL); - } + JNI_ARG_CHECK(env, num_columns_null_smallest == num_columns, + "columns and areNullsSmallest lengths don't match", NULL); std::vector indexes(n_sort_key_indexes.size()); for (int i = 0; i < n_sort_key_indexes.size(); i++) { @@ -757,7 +799,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_merge(JNIEnv *env, jclass } JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readCSV( - JNIEnv *env, jclass j_class_object, jobjectArray col_names, jobjectArray data_types, + JNIEnv *env, jclass, jobjectArray col_names, jobjectArray data_types, jobjectArray filter_col_names, jstring inputfilepath, jlong buffer, jlong buffer_length, jint header_row, jbyte delim, jbyte quote, jbyte comment, jobjectArray null_values, jobjectArray true_values, jobjectArray false_values) { @@ -819,7 +861,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readCSV( } JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readParquet( - JNIEnv *env, jclass j_class_object, jobjectArray filter_col_names, jstring inputfilepath, + JNIEnv *env, jclass, jobjectArray filter_col_names, jstring inputfilepath, jlong buffer, jlong buffer_length, jint unit, jboolean strict_decimal_types) { bool read_buffer = true; if (buffer == 0) { @@ -998,7 +1040,7 @@ JNIEXPORT void JNICALL Java_ai_rapids_cudf_Table_writeParquetEnd(JNIEnv *env, jc } JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readORC( - JNIEnv *env, jclass j_class_object, jobjectArray filter_col_names, jstring inputfilepath, + JNIEnv *env, jclass, jobjectArray filter_col_names, jstring inputfilepath, jlong buffer, jlong buffer_length, jboolean usingNumPyTypes, jint unit) { bool read_buffer = true; if (buffer == 0) { @@ -1352,7 +1394,7 @@ JNIEXPORT void JNICALL Java_ai_rapids_cudf_Table_readArrowIPCEnd(JNIEnv *env, jc } JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftJoin( - JNIEnv *env, jclass clazz, jlong left_table, jintArray left_col_join_indices, jlong right_table, + JNIEnv *env, jclass, jlong left_table, jintArray left_col_join_indices, jlong right_table, jintArray right_col_join_indices, jboolean compare_nulls_equal) { JNI_NULL_CHECK(env, left_table, "left_table is null", NULL); JNI_NULL_CHECK(env, left_col_join_indices, "left_col_join_indices is null", NULL); @@ -1388,7 +1430,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftJoin( } JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_innerJoin( - JNIEnv *env, jclass clazz, jlong left_table, jintArray left_col_join_indices, jlong right_table, + JNIEnv *env, jclass, jlong left_table, jintArray left_col_join_indices, jlong right_table, jintArray right_col_join_indices, jboolean compare_nulls_equal) { JNI_NULL_CHECK(env, left_table, "left_table is null", NULL); JNI_NULL_CHECK(env, left_col_join_indices, "left_col_join_indices is null", NULL); @@ -1424,7 +1466,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_innerJoin( } JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_fullJoin( - JNIEnv *env, jclass clazz, jlong left_table, jintArray left_col_join_indices, jlong right_table, + JNIEnv *env, jclass, jlong left_table, jintArray left_col_join_indices, jlong right_table, jintArray right_col_join_indices, jboolean compare_nulls_equal) { JNI_NULL_CHECK(env, left_table, "left_table is null", NULL); JNI_NULL_CHECK(env, left_col_join_indices, "left_col_join_indices is null", NULL); @@ -1525,7 +1567,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftAntiJoin( CATCH_STD(env, NULL); } -JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_crossJoin(JNIEnv *env, jclass clazz, +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_crossJoin(JNIEnv *env, jclass, jlong left_table, jlong right_table) { JNI_NULL_CHECK(env, left_table, "left_table is null", NULL); @@ -1556,7 +1598,7 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_interleaveColumns(JNIEnv *env, CATCH_STD(env, 0); } -JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_concatenate(JNIEnv *env, jclass clazz, +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_concatenate(JNIEnv *env, jclass, jlongArray table_handles) { JNI_NULL_CHECK(env, table_handles, "input tables are null", NULL); try { @@ -1578,7 +1620,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_concatenate(JNIEnv *env, CATCH_STD(env, NULL); } -JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_hashPartition(JNIEnv *env, jclass clazz, +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_hashPartition(JNIEnv *env, jclass, jlong input_table, jintArray columns_to_hash, jint number_of_partitions, @@ -1641,7 +1683,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_roundRobinPartition( } JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_groupByAggregate( - JNIEnv *env, jclass clazz, jlong input_table, jintArray keys, + JNIEnv *env, jclass, jlong input_table, jintArray keys, jintArray aggregate_column_indices, jlongArray agg_instances, jboolean ignore_null_keys) { JNI_NULL_CHECK(env, input_table, "input table is null", NULL); JNI_NULL_CHECK(env, keys, "input keys are null", NULL); @@ -1726,7 +1768,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_gather(JNIEnv *env, jclas CATCH_STD(env, 0); } -JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_convertToRows(JNIEnv *env, jclass clazz, +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_convertToRows(JNIEnv *env, jclass, jlong input_table) { JNI_NULL_CHECK(env, input_table, "input table is null", 0); @@ -1744,7 +1786,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_convertToRows(JNIEnv *env CATCH_STD(env, 0); } -JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_convertFromRows(JNIEnv *env, jclass clazz, +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_convertFromRows(JNIEnv *env, jclass, jlong input_column, jintArray types, jintArray scale) { @@ -1835,7 +1877,7 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_bound(JNIEnv *env, jclass, jlo CATCH_STD(env, 0); } -JNIEXPORT jobjectArray JNICALL Java_ai_rapids_cudf_Table_contiguousSplit(JNIEnv *env, jclass clazz, +JNIEXPORT jobjectArray JNICALL Java_ai_rapids_cudf_Table_contiguousSplit(JNIEnv *env, jclass, jlong input_table, jintArray split_indices) { JNI_NULL_CHECK(env, input_table, "native handle is null", 0); @@ -1862,7 +1904,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_rapids_cudf_Table_contiguousSplit(JNIEnv } JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_rollingWindowAggregate( - JNIEnv *env, jclass clazz, jlong j_input_table, jintArray j_keys, jlongArray j_default_output, + JNIEnv *env, jclass, jlong j_input_table, jintArray j_keys, jlongArray j_default_output, jintArray j_aggregate_column_indices, jlongArray j_agg_instances, jintArray j_min_periods, jintArray j_preceding, jintArray j_following, jboolean ignore_null_keys) { @@ -1918,7 +1960,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_rollingWindowAggregate( } JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_timeRangeRollingWindowAggregate( - JNIEnv *env, jclass clazz, jlong j_input_table, jintArray j_keys, + JNIEnv *env, jclass, jlong j_input_table, jintArray j_keys, jintArray j_timestamp_column_indices, jbooleanArray j_is_timestamp_ascending, jintArray j_aggregate_column_indices, jlongArray j_agg_instances, jintArray j_min_periods, jintArray j_preceding, jintArray j_following, diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 86f6bec9eef..f3f8925d3d0 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -393,6 +393,24 @@ void testOrderByAD() { } } + @Test + void testSortOrderSimple() { + try (Table table = new Table.TestBuilder() + .column(5, 3, 3, 1, 1) + .column(5, 3, 4, 1, 2) + .column(1, 3, 5, 7, 9) + .build(); + Table expected = new Table.TestBuilder() + .column(1, 1, 3, 3, 5) + .column(2, 1, 4, 3, 5) + .column(9, 7, 5, 3, 1) + .build(); + ColumnVector gatherMap = table.sortOrder(Table.asc(0), Table.desc(1)); + Table sortedTable = table.gather(gatherMap)) { + assertTablesAreEqual(expected, sortedTable); + } + } + @Test void testOrderByDD() { try (Table table = new Table.TestBuilder()