Skip to content

Commit

Permalink
Add in JNI support for table partition (#7637)
Browse files Browse the repository at this point in the history
This adds in support for partition. Which will partition a table based off of a partition map.

Authors:
  - Robert (Bobby) Evans (@revans2)

Approvers:
  - Jason Lowe (@jlowe)

URL: #7637
  • Loading branch information
revans2 authored Mar 18, 2021
1 parent 873955e commit d6cc694
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 0 deletions.
21 changes: 21 additions & 0 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ public long getDeviceMemorySize() {

private static native ContiguousTable[] contiguousSplit(long inputTable, int[] indices);

private static native long[] partition(long inputTable, long partitionView,
int numberOfPartitions, int[] outputOffsets);

private static native long[] hashPartition(long inputTable,
int[] columnsToHash,
int hashTypeId,
Expand Down Expand Up @@ -1257,6 +1260,24 @@ public Table repeat(ColumnVector counts, boolean checkCount) {
return new Table(repeatColumnCount(this.nativeHandle, counts.getNativeView(), checkCount));
}

/**
* Partition this table using the mapping in partitionMap. partitionMap must be an integer
* column. The number of rows in partitionMap must be the same as this table. Each row
* in the map will indicate which partition the rows in the table belong to.
* @param partitionMap the partitions for each row.
* @param numberOfPartitions number of partitions
* @return {@link PartitionedTable} Table that exposes a limited functionality of the
* {@link Table} class
*/
public PartitionedTable partition(ColumnView partitionMap, int numberOfPartitions) {
int[] partitionOffsets = new int[numberOfPartitions];
return new PartitionedTable(new Table(partition(
getNativeView(),
partitionMap.getNativeView(),
partitionOffsets.length,
partitionOffsets)), partitionOffsets);
}

/**
* Find smallest indices in a sorted table where values should be inserted to maintain order.
* <pre>
Expand Down
33 changes: 33 additions & 0 deletions java/src/main/native/src/TableJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1614,6 +1614,39 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_concatenate(JNIEnv *env,
CATCH_STD(env, NULL);
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_partition(JNIEnv *env, jclass,
jlong input_table,
jlong partition_column,
jint number_of_partitions,
jintArray output_offsets) {

JNI_NULL_CHECK(env, input_table, "input table is null", NULL);
JNI_NULL_CHECK(env, partition_column, "partition_column is null", NULL);
JNI_NULL_CHECK(env, output_offsets, "output_offsets is null", NULL);
JNI_ARG_CHECK(env, number_of_partitions > 0, "number_of_partitions is zero", NULL);

try {
cudf::jni::auto_set_device(env);
cudf::table_view *n_input_table = reinterpret_cast<cudf::table_view *>(input_table);
cudf::column_view *n_part_column = reinterpret_cast<cudf::column_view *>(partition_column);
cudf::jni::native_jintArray n_output_offsets(env, output_offsets);

auto result = cudf::partition(*n_input_table,
*n_part_column,
number_of_partitions);

for (size_t i = 0; i < result.second.size() - 1; i++) {
// for what ever reason partition returns the length of the result at then
// end and hash partition/round robin do not, so skip the last entry for
// consistency
n_output_offsets[i] = result.second[i];
}

return cudf::jni::convert_table_for_return(env, result.first);
}
CATCH_STD(env, NULL);
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_hashPartition(JNIEnv *env, jclass,
jlong input_table,
jintArray columns_to_hash,
Expand Down
17 changes: 17 additions & 0 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1773,6 +1773,23 @@ void testPartStability() {
}
}

@Test
void testPartition() {
try (Table t = new Table.TestBuilder()
.column(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
.build();
ColumnVector parts = ColumnVector
.fromInts(1, 2, 1, 2, 1, 2, 1, 2, 1, 2);
PartitionedTable pt = t.partition(parts, 3);
Table expected = new Table.TestBuilder()
.column(1, 3, 5, 7, 9, 2, 4, 6, 8, 10)
.build()) {
int[] partCutoffs = pt.getPartitions();
assertArrayEquals(new int[]{0, 0, 5}, partCutoffs);
assertTablesAreEqual(expected, pt.getTable());
}
}

@Test
void testIdentityHashPartition() {
final int count = 1024 * 1024;
Expand Down

0 comments on commit d6cc694

Please sign in to comment.