Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JNI for cudf::drop_duplicates #9841

Merged
merged 8 commits into from
Dec 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 60 additions & 32 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,10 @@ private static native long[] conditionalLeftAntiJoinGatherMapWithCount(long left

private static native long[] filter(long input, long mask);

private static native long[] dropDuplicates(long nativeHandle, int[] keyColumns,
boolean keepFirst, boolean nullsEqual,
boolean nullsBefore) throws CudfException;

private static native long[] gather(long tableHandle, long gatherView, boolean checkBounds);

private static native long[] convertToRows(long nativeHandle);
Expand Down Expand Up @@ -1820,6 +1824,30 @@ public Table filter(ColumnView mask) {
return new Table(filter(nativeHandle, mask.getNativeView()));
}

/**
* Copy rows of the current table to an output table such that duplicate rows in the key columns
* are ignored (i.e., only one row from the duplicate ones will be copied). These keys columns are
* a subset of the current table columns and their indices are specified by an input array.
*
* Currently, the output table is sorted by key columns, using stable sort. However, this is not
* guaranteed in the future.
*
* @param keyColumns Array of indices representing key columns from the current table.
* @param keepFirst If it is true, the first row with a duplicated key will be copied. Otherwise,
* copy the last row with a duplicated key.
* @param nullsEqual Flag to denote whether nulls are treated as equal when comparing rows of the
* key columns to check for uniqueness.
* @param nullsBefore Flag to specify whether nulls in the key columns will appear before or
* after non-null elements when sorting the table.
*
* @return Table with unique keys.
*/
public Table dropDuplicates(int[] keyColumns, boolean keepFirst, boolean nullsEqual,
boolean nullsBefore) {
assert keyColumns.length >= 1 : "Input keyColumns must contain indices of at least one column";
return new Table(dropDuplicates(nativeHandle, keyColumns, keepFirst, nullsEqual, nullsBefore));
}

/**
* Split a table at given boundaries, but the result of each split has memory that is laid out
* in a contiguous range of memory. This allows for us to optimize copying the data in a single
Expand Down Expand Up @@ -3005,27 +3033,27 @@ public Table aggregate(GroupByAggregationOnColumn... aggregates) {
}

/**
* Computes row-based window aggregation functions on the Table/projection,
* Computes row-based window aggregation functions on the Table/projection,
* based on windows specified in the argument.
*
*
* This method enables queries such as the following SQL:
*
* SELECT user_id,
* MAX(sales_amt) OVER(PARTITION BY user_id ORDER BY date
*
* SELECT user_id,
* MAX(sales_amt) OVER(PARTITION BY user_id ORDER BY date
* ROWS BETWEEN 1 PRECEDING and 1 FOLLOWING)
* FROM my_sales_table WHERE ...
*
*
* Each window-aggregation is represented by a different {@link AggregationOverWindow} argument,
* indicating:
* 1. the {@link Aggregation.Kind},
* 2. the number of rows preceding and following the current row, within a window,
* 3. the minimum number of observations within the defined window
*
*
* This method returns a {@link Table} instance, with one result column for each specified
* window aggregation.
*
*
* In this example, for the following input:
*
*
* [ // user_id, sales_amt
* { "user1", 10 },
* { "user2", 20 },
Expand All @@ -3037,19 +3065,19 @@ public Table aggregate(GroupByAggregationOnColumn... aggregates) {
* { "user1", 60 },
* { "user2", 40 }
* ]
*
* Partitioning (grouping) by `user_id` yields the following `sales_amt` vector
*
* Partitioning (grouping) by `user_id` yields the following `sales_amt` vector
* (with 2 groups, one for each distinct `user_id`):
*
*
* [ 10, 20, 10, 50, 60, 20, 30, 80, 40 ]
* <-------user1-------->|<------user2------->
*
*
* The SUM aggregation is applied with 1 preceding and 1 following
* row, with a minimum of 1 period. The aggregation window is thus 3 rows wide,
* yielding the following column:
*
*
* [ 30, 40, 80, 120, 110, 50, 130, 150, 120 ]
*
*
* @param windowAggregates the window-aggregations to be performed
* @return Table instance, with each column containing the result of each aggregation.
* @throws IllegalArgumentException if the window arguments are not of type
Expand All @@ -3068,7 +3096,7 @@ public Table aggregateWindows(AggregationOverWindow... windowAggregates) {
for (int outputIndex = 0; outputIndex < windowAggregates.length; outputIndex++) {
AggregationOverWindow agg = windowAggregates[outputIndex];
if (agg.getWindowOptions().getFrameType() != WindowOptions.FrameType.ROWS) {
throw new IllegalArgumentException("Expected ROWS-based window specification. Unexpected window type: "
throw new IllegalArgumentException("Expected ROWS-based window specification. Unexpected window type: "
+ agg.getWindowOptions().getFrameType());
}
ColumnWindowOps ops = groupedOps.computeIfAbsent(agg.getColumnIndex(), (idx) -> new ColumnWindowOps());
Expand Down Expand Up @@ -3129,27 +3157,27 @@ public Table aggregateWindows(AggregationOverWindow... windowAggregates) {
/**
* Computes range-based window aggregation functions on the Table/projection,
* based on windows specified in the argument.
*
*
* This method enables queries such as the following SQL:
*
* SELECT user_id,
* MAX(sales_amt) OVER(PARTITION BY user_id ORDER BY date
*
* SELECT user_id,
* MAX(sales_amt) OVER(PARTITION BY user_id ORDER BY date
* RANGE BETWEEN INTERVAL 1 DAY PRECEDING and CURRENT ROW)
* FROM my_sales_table WHERE ...
*
*
* Each window-aggregation is represented by a different {@link AggregationOverWindow} argument,
* indicating:
* 1. the {@link Aggregation.Kind},
* 2. the index for the timestamp column to base the window definitions on
* 2. the number of DAYS preceding and following the current row's date, to consider in the window
* 3. the minimum number of observations within the defined window
*
*
* This method returns a {@link Table} instance, with one result column for each specified
* window aggregation.
*
*
* In this example, for the following input:
*
* [ // user, sales_amt, YYYYMMDD (date)
*
* [ // user, sales_amt, YYYYMMDD (date)
* { "user1", 10, 20200101 },
* { "user2", 20, 20200101 },
* { "user1", 20, 20200102 },
Expand All @@ -3160,19 +3188,19 @@ public Table aggregateWindows(AggregationOverWindow... windowAggregates) {
* { "user1", 60, 20200107 },
* { "user2", 40, 20200104 }
* ]
*
* Partitioning (grouping) by `user_id`, and ordering by `date` yields the following `sales_amt` vector
*
* Partitioning (grouping) by `user_id`, and ordering by `date` yields the following `sales_amt` vector
* (with 2 groups, one for each distinct `user_id`):
*
*
* Date :(202001-) [ 01, 02, 03, 07, 07, 01, 01, 02, 04 ]
* Input: [ 10, 20, 10, 50, 60, 20, 30, 80, 40 ]
* <-------user1-------->|<---------user2--------->
*
* The SUM aggregation is applied, with 1 day preceding, and 1 day following, with a minimum of 1 period.
*
* The SUM aggregation is applied, with 1 day preceding, and 1 day following, with a minimum of 1 period.
* The aggregation window is thus 3 *days* wide, yielding the following output column:
*
*
* Results: [ 30, 40, 30, 110, 110, 130, 130, 130, 40 ]
*
*
* @param windowAggregates the window-aggregations to be performed
* @return Table instance, with each column containing the result of each aggregation.
* @throws IllegalArgumentException if the window arguments are not of type
Expand Down
26 changes: 26 additions & 0 deletions java/src/main/native/src/TableJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2676,6 +2676,32 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_filter(JNIEnv *env, jclas
CATCH_STD(env, 0);
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_dropDuplicates(
JNIEnv *env, jclass, jlong input_jtable, jintArray key_columns, jboolean keep_first,
jboolean nulls_equal, jboolean nulls_before) {
JNI_NULL_CHECK(env, input_jtable, "input table is null", 0);
JNI_NULL_CHECK(env, key_columns, "input key_columns is null", 0);
try {
cudf::jni::auto_set_device(env);
auto const input = reinterpret_cast<cudf::table_view const *>(input_jtable);

static_assert(sizeof(jint) == sizeof(cudf::size_type), "Integer types mismatched.");
auto const native_keys_indices = cudf::jni::native_jintArray(env, key_columns);
auto const keys_indices =
std::vector<cudf::size_type>(native_keys_indices.begin(), native_keys_indices.end());

auto result = cudf::drop_duplicates(
*input, keys_indices,
keep_first ? cudf::duplicate_keep_option::KEEP_FIRST :
cudf::duplicate_keep_option::KEEP_LAST,
nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL,
nulls_before ? cudf::null_order::BEFORE : cudf::null_order::AFTER,
rmm::mr::get_current_device_resource());
return cudf::jni::convert_table_for_return(env, result);
}
CATCH_STD(env, 0);
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_gather(JNIEnv *env, jclass, jlong j_input,
jlong j_map, jboolean check_bounds) {
JNI_NULL_CHECK(env, j_input, "input table is null", 0);
Expand Down
26 changes: 26 additions & 0 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -6592,6 +6592,32 @@ void testTableBasedFilter() {
}
}

@Test
void testDropDuplicates() {
int[] keyColumns = new int[]{ 1 };

try (ColumnVector col1 = ColumnVector.fromBoxedInts(5, null, 3, 5, 8, 1);
ColumnVector col2 = ColumnVector.fromBoxedInts(20, null, null, 19, 21, 19);
Table input = new Table(col1, col2)) {

// Keep the first duplicate element.
try (Table result = input.dropDuplicates(keyColumns, true, true, true);
ColumnVector expectedCol1 = ColumnVector.fromBoxedInts(null, 5, 5, 8);
ColumnVector expectedCol2 = ColumnVector.fromBoxedInts(null, 19, 20, 21);
Table expected = new Table(expectedCol1, expectedCol2)) {
assertTablesAreEqual(expected, result);
}

// Keep the last duplicate element.
try (Table result = input.dropDuplicates(keyColumns, false, true, true);
ColumnVector expectedCol1 = ColumnVector.fromBoxedInts(3, 1, 5, 8);
ColumnVector expectedCol2 = ColumnVector.fromBoxedInts(null, 19, 20, 21);
Table expected = new Table(expectedCol1, expectedCol2)) {
assertTablesAreEqual(expected, result);
}
}
}

private enum Columns {
BOOL("BOOL"),
INT("INT"),
Expand Down