Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add in JNI support for table partition [skip ci] #7637

Merged
merged 2 commits into from
Mar 18, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ public long getDeviceMemorySize() {

private static native ContiguousTable[] contiguousSplit(long inputTable, int[] indices);

private static native long[] partition(long inputTable, long partitionView,
int numberOfPartitions, int[] outputOffsets);

private static native long[] hashPartition(long inputTable,
int[] columnsToHash,
int hashTypeId,
Expand Down Expand Up @@ -1257,6 +1260,24 @@ public Table repeat(ColumnVector counts, boolean checkCount) {
return new Table(repeatColumnCount(this.nativeHandle, counts.getNativeView(), checkCount));
}

/**
* Partition this table using the mapping in partitionMap. partitionMap must be an integer
* column. The number of rows in partitionMap must be the same as this table. Each row
* in the map will indicate which partition the rows in the table belong to.
* @param partitionMap the partitions for each row.
* @param numberOfPartitions number of partitions
* @return {@link PartitionedTable} Table that exposes a limited functionality of the
* {@link Table} class
*/
public PartitionedTable partition(ColumnView partitionMap, int numberOfPartitions) {
int[] partitionOffsets = new int[numberOfPartitions];
return new PartitionedTable(new Table(partition(
getNativeView(),
partitionMap.getNativeView(),
partitionOffsets.length,
partitionOffsets)), partitionOffsets);
}

/**
* Find smallest indices in a sorted table where values should be inserted to maintain order.
* <pre>
Expand Down
33 changes: 33 additions & 0 deletions java/src/main/native/src/TableJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1614,6 +1614,39 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_concatenate(JNIEnv *env,
CATCH_STD(env, NULL);
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_partition(JNIEnv *env, jclass,
jlong input_table,
jlong partition_column,
jint number_of_partitions,
jintArray output_offsets) {

JNI_NULL_CHECK(env, input_table, "input table is null", NULL);
JNI_NULL_CHECK(env, partition_column, "partition_column is null", NULL);
JNI_NULL_CHECK(env, output_offsets, "output_offsets is null", NULL);
JNI_ARG_CHECK(env, number_of_partitions > 0, "number_of_partitions is zero", NULL);

try {
cudf::jni::auto_set_device(env);
cudf::table_view *n_input_table = reinterpret_cast<cudf::table_view *>(input_table);
cudf::column_view *n_part_column = reinterpret_cast<cudf::column_view *>(partition_column);
cudf::jni::native_jintArray n_output_offsets(env, output_offsets);

auto result = cudf::partition(*n_input_table,
*n_part_column,
number_of_partitions);

for (size_t i = 0; i < result.second.size() - 1; i++) {
// for what ever reason partition returns the length of the result at then
// end and hash partition/round robing do not, so skip the last entry for
revans2 marked this conversation as resolved.
Show resolved Hide resolved
// consistency
n_output_offsets[i] = result.second[i];
}

return cudf::jni::convert_table_for_return(env, result.first);
}
CATCH_STD(env, NULL);
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_hashPartition(JNIEnv *env, jclass,
jlong input_table,
jintArray columns_to_hash,
Expand Down
17 changes: 17 additions & 0 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1773,6 +1773,23 @@ void testPartStability() {
}
}

@Test
void testPartition() {
try (Table t = new Table.TestBuilder()
.column(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
.build();
ColumnVector parts = ColumnVector
.fromInts(1, 2, 1, 2, 1, 2, 1, 2, 1, 2);
PartitionedTable pt = t.partition(parts, 3);
Table expected = new Table.TestBuilder()
.column(1, 3, 5, 7, 9, 2, 4, 6, 8, 10)
.build()) {
int[] partCutoffs = pt.getPartitions();
assertArrayEquals(new int[]{0, 0, 5}, partCutoffs);
assertTablesAreEqual(expected, pt.getTable());
}
}

@Test
void testIdentityHashPartition() {
final int count = 1024 * 1024;
Expand Down