Skip to content

Commit

Permalink
Add JNI support for IDENTITY hash partitioning (#7626)
Browse files Browse the repository at this point in the history
This adds in support for identity hash partitioning in JNI.

Authors:
  - Robert (Bobby) Evans (@revans2)

Approvers:
  - Jason Lowe (@jlowe)

URL: #7626
  • Loading branch information
revans2 authored Mar 17, 2021
1 parent 9c6e1ba commit 0b766c5
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 10 deletions.
6 changes: 3 additions & 3 deletions java/src/main/java/ai/rapids/cudf/HashType.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -22,8 +22,8 @@
* Hash algorithm identifiers, mirroring native enum cudf::hash_id
*/
public enum HashType {
// TODO IDENTITY(0),
// TODO MURMUR3(1),
IDENTITY(0),
MURMUR3(1),
HASH_MD5(2),
HASH_SERIAL_MURMUR3(3),
HASH_SPARK_MURMUR3(4);
Expand Down
21 changes: 19 additions & 2 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ public long getDeviceMemorySize() {

private static native long[] hashPartition(long inputTable,
int[] columnsToHash,
int hashTypeId,
int numberOfPartitions,
int[] outputOffsets) throws CudfException;

Expand Down Expand Up @@ -2587,15 +2588,31 @@ public Table leftAntiJoin(TableOperation rightJoinIndices) {
}

/**
* Hash partition a table into the specified number of partitions.
* Hash partition a table into the specified number of partitions. Uses the default MURMUR3
* hashing.
* @param numberOfPartitions - number of partitions to use
* @return - {@link PartitionedTable} - Table that exposes a limited functionality of the
* {@link Table} class
*/
public PartitionedTable hashPartition(int numberOfPartitions) {
return hashPartition(HashType.MURMUR3, numberOfPartitions);
}

/**
* Hash partition a table into the specified number of partitions.
* @param type the type of hash to use. Depending on the type of hash different restrictions
* on the hash column(s) may exist. Not all hash functions are guaranteed to work
* besides IDENTITY and MURMUR3.
* @param numberOfPartitions - number of partitions to use
* @return {@link PartitionedTable} - Table that exposes a limited functionality of the
* {@link Table} class
*/
public PartitionedTable hashPartition(HashType type, int numberOfPartitions) {
int[] partitionOffsets = new int[numberOfPartitions];
return new PartitionedTable(new Table(Table.hashPartition(operation.table.nativeHandle,
return new PartitionedTable(new Table(Table.hashPartition(
operation.table.nativeHandle,
operation.indices,
type.nativeId,
partitionOffsets.length,
partitionOffsets)), partitionOffsets);
}
Expand Down
8 changes: 7 additions & 1 deletion java/src/main/native/src/TableJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <cudf/io/orc.hpp>
#include <cudf/io/parquet.hpp>
#include <cudf/join.hpp>
#include <cudf/lists/explode.hpp>
#include <cudf/merge.hpp>
#include <cudf/partitioning.hpp>
#include <cudf/reshape.hpp>
Expand Down Expand Up @@ -1616,6 +1617,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_concatenate(JNIEnv *env,
JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_hashPartition(JNIEnv *env, jclass,
jlong input_table,
jintArray columns_to_hash,
jint hash_function,
jint number_of_partitions,
jintArray output_offsets) {

Expand All @@ -1626,6 +1628,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_hashPartition(JNIEnv *env

try {
cudf::jni::auto_set_device(env);
cudf::hash_id hash_func = static_cast<cudf::hash_id>(hash_function);
cudf::table_view *n_input_table = reinterpret_cast<cudf::table_view *>(input_table);
cudf::jni::native_jintArray n_columns_to_hash(env, columns_to_hash);
cudf::jni::native_jintArray n_output_offsets(env, output_offsets);
Expand All @@ -1638,7 +1641,10 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_hashPartition(JNIEnv *env
}

std::pair<std::unique_ptr<cudf::table>, std::vector<cudf::size_type>> result =
cudf::hash_partition(*n_input_table, columns_to_hash_vec, number_of_partitions);
cudf::hash_partition(*n_input_table,
columns_to_hash_vec,
number_of_partitions,
hash_func);

for (size_t i = 0; i < result.second.size(); i++) {
n_output_offsets[i] = result.second[i];
Expand Down
58 changes: 54 additions & 4 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1742,7 +1742,7 @@ void testPartStability() {
final int PARTS = 5;
int expectedPart = -1;
try (Table start = new Table.TestBuilder().column(0).build();
PartitionedTable out = start.onColumns(0).partition(PARTS)) {
PartitionedTable out = start.onColumns(0).hashPartition(PARTS)) {
// Lets figure out what partitions this is a part of.
int[] parts = out.getPartitions();
for (int i = 0; i < parts.length; i++) {
Expand All @@ -1755,7 +1755,7 @@ void testPartStability() {
for (int numEntries = 1; numEntries < COUNT; numEntries++) {
try (ColumnVector data = ColumnVector.build(DType.INT32, numEntries, Range.appendInts(0, numEntries));
Table t = new Table(data);
PartitionedTable out = t.onColumns(0).partition(PARTS);
PartitionedTable out = t.onColumns(0).hashPartition(PARTS);
HostColumnVector tmp = out.getColumn(0).copyToHost()) {
// Now we need to get the range out for the partition we expect
int[] parts = out.getPartitions();
Expand All @@ -1774,7 +1774,7 @@ void testPartStability() {
}

@Test
void testPartition() {
void testIdentityHashPartition() {
final int count = 1024 * 1024;
try (ColumnVector aIn = ColumnVector.build(DType.INT64, count, Range.appendLongs(count));
ColumnVector bIn = ColumnVector.build(DType.INT32, count, (b) -> {
Expand All @@ -1793,7 +1793,57 @@ void testPartition() {
expected.add(i);
}
try (Table input = new Table(new ColumnVector[]{aIn, bIn, cIn});
PartitionedTable output = input.onColumns(0).partition(5)) {
PartitionedTable output = input.onColumns(0).hashPartition(HashType.IDENTITY, 5)) {
int[] parts = output.getPartitions();
assertEquals(5, parts.length);
assertEquals(0, parts[0]);
int previous = 0;
long rows = 0;
for (int i = 1; i < parts.length; i++) {
assertTrue(parts[i] >= previous);
rows += parts[i] - previous;
previous = parts[i];
}
assertTrue(rows <= count);
try (HostColumnVector aOut = output.getColumn(0).copyToHost();
HostColumnVector bOut = output.getColumn(1).copyToHost();
HostColumnVector cOut = output.getColumn(2).copyToHost()) {

for (int i = 0; i < count; i++) {
long fromA = aOut.getLong(i);
long fromB = bOut.getInt(i);
String fromC = cOut.getJavaString(i);
assertTrue(expected.remove(fromA));
assertEquals(fromA / 2, fromB);
assertEquals(String.valueOf(fromA), fromC, "At Index " + i);
}
assertTrue(expected.isEmpty());
}
}
}
}

@Test
void testHashPartition() {
final int count = 1024 * 1024;
try (ColumnVector aIn = ColumnVector.build(DType.INT64, count, Range.appendLongs(count));
ColumnVector bIn = ColumnVector.build(DType.INT32, count, (b) -> {
for (int i = 0; i < count; i++) {
b.append(i / 2);
}
});
ColumnVector cIn = ColumnVector.build(DType.STRING, count, (b) -> {
for (int i = 0; i < count; i++) {
b.appendUTF8String(String.valueOf(i).getBytes());
}
})) {

HashSet<Long> expected = new HashSet<>();
for (long i = 0; i < count; i++) {
expected.add(i);
}
try (Table input = new Table(new ColumnVector[]{aIn, bIn, cIn});
PartitionedTable output = input.onColumns(0).hashPartition(5)) {
int[] parts = output.getPartitions();
assertEquals(5, parts.length);
assertEquals(0, parts[0]);
Expand Down

0 comments on commit 0b766c5

Please sign in to comment.