diff --git a/java/src/main/java/ai/rapids/cudf/HashType.java b/java/src/main/java/ai/rapids/cudf/HashType.java index b521bc5c42c..eb31edd8222 100644 --- a/java/src/main/java/ai/rapids/cudf/HashType.java +++ b/java/src/main/java/ai/rapids/cudf/HashType.java @@ -1,6 +1,6 @@ /* * - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,8 +22,8 @@ * Hash algorithm identifiers, mirroring native enum cudf::hash_id */ public enum HashType { - // TODO IDENTITY(0), - // TODO MURMUR3(1), + IDENTITY(0), + MURMUR3(1), HASH_MD5(2), HASH_SERIAL_MURMUR3(3), HASH_SPARK_MURMUR3(4); diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index fcc23777d69..7385b55d0df 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -185,6 +185,7 @@ public long getDeviceMemorySize() { private static native long[] hashPartition(long inputTable, int[] columnsToHash, + int hashTypeId, int numberOfPartitions, int[] outputOffsets) throws CudfException; @@ -2587,15 +2588,31 @@ public Table leftAntiJoin(TableOperation rightJoinIndices) { } /** - * Hash partition a table into the specified number of partitions. + * Hash partition a table into the specified number of partitions. Uses the default MURMUR3 + * hashing. * @param numberOfPartitions - number of partitions to use * @return - {@link PartitionedTable} - Table that exposes a limited functionality of the * {@link Table} class */ public PartitionedTable hashPartition(int numberOfPartitions) { + return hashPartition(HashType.MURMUR3, numberOfPartitions); + } + + /** + * Hash partition a table into the specified number of partitions. + * @param type the type of hash to use. Depending on the type of hash different restrictions + * on the hash column(s) may exist. Not all hash functions are guaranteed to work + * besides IDENTITY and MURMUR3. + * @param numberOfPartitions - number of partitions to use + * @return {@link PartitionedTable} - Table that exposes a limited functionality of the + * {@link Table} class + */ + public PartitionedTable hashPartition(HashType type, int numberOfPartitions) { int[] partitionOffsets = new int[numberOfPartitions]; - return new PartitionedTable(new Table(Table.hashPartition(operation.table.nativeHandle, + return new PartitionedTable(new Table(Table.hashPartition( + operation.table.nativeHandle, operation.indices, + type.nativeId, partitionOffsets.length, partitionOffsets)), partitionOffsets); } diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index e051f68be4e..4548156055a 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -1616,6 +1617,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_concatenate(JNIEnv *env, JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_hashPartition(JNIEnv *env, jclass, jlong input_table, jintArray columns_to_hash, + jint hash_function, jint number_of_partitions, jintArray output_offsets) { @@ -1626,6 +1628,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_hashPartition(JNIEnv *env try { cudf::jni::auto_set_device(env); + cudf::hash_id hash_func = static_cast(hash_function); cudf::table_view *n_input_table = reinterpret_cast(input_table); cudf::jni::native_jintArray n_columns_to_hash(env, columns_to_hash); cudf::jni::native_jintArray n_output_offsets(env, output_offsets); @@ -1638,7 +1641,10 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_hashPartition(JNIEnv *env } std::pair, std::vector> result = - cudf::hash_partition(*n_input_table, columns_to_hash_vec, number_of_partitions); + cudf::hash_partition(*n_input_table, + columns_to_hash_vec, + number_of_partitions, + hash_func); for (size_t i = 0; i < result.second.size(); i++) { n_output_offsets[i] = result.second[i]; diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 88196a4112a..626f7828012 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -1742,7 +1742,7 @@ void testPartStability() { final int PARTS = 5; int expectedPart = -1; try (Table start = new Table.TestBuilder().column(0).build(); - PartitionedTable out = start.onColumns(0).partition(PARTS)) { + PartitionedTable out = start.onColumns(0).hashPartition(PARTS)) { // Lets figure out what partitions this is a part of. int[] parts = out.getPartitions(); for (int i = 0; i < parts.length; i++) { @@ -1755,7 +1755,7 @@ void testPartStability() { for (int numEntries = 1; numEntries < COUNT; numEntries++) { try (ColumnVector data = ColumnVector.build(DType.INT32, numEntries, Range.appendInts(0, numEntries)); Table t = new Table(data); - PartitionedTable out = t.onColumns(0).partition(PARTS); + PartitionedTable out = t.onColumns(0).hashPartition(PARTS); HostColumnVector tmp = out.getColumn(0).copyToHost()) { // Now we need to get the range out for the partition we expect int[] parts = out.getPartitions(); @@ -1774,7 +1774,7 @@ void testPartStability() { } @Test - void testPartition() { + void testIdentityHashPartition() { final int count = 1024 * 1024; try (ColumnVector aIn = ColumnVector.build(DType.INT64, count, Range.appendLongs(count)); ColumnVector bIn = ColumnVector.build(DType.INT32, count, (b) -> { @@ -1793,7 +1793,57 @@ void testPartition() { expected.add(i); } try (Table input = new Table(new ColumnVector[]{aIn, bIn, cIn}); - PartitionedTable output = input.onColumns(0).partition(5)) { + PartitionedTable output = input.onColumns(0).hashPartition(HashType.IDENTITY, 5)) { + int[] parts = output.getPartitions(); + assertEquals(5, parts.length); + assertEquals(0, parts[0]); + int previous = 0; + long rows = 0; + for (int i = 1; i < parts.length; i++) { + assertTrue(parts[i] >= previous); + rows += parts[i] - previous; + previous = parts[i]; + } + assertTrue(rows <= count); + try (HostColumnVector aOut = output.getColumn(0).copyToHost(); + HostColumnVector bOut = output.getColumn(1).copyToHost(); + HostColumnVector cOut = output.getColumn(2).copyToHost()) { + + for (int i = 0; i < count; i++) { + long fromA = aOut.getLong(i); + long fromB = bOut.getInt(i); + String fromC = cOut.getJavaString(i); + assertTrue(expected.remove(fromA)); + assertEquals(fromA / 2, fromB); + assertEquals(String.valueOf(fromA), fromC, "At Index " + i); + } + assertTrue(expected.isEmpty()); + } + } + } + } + + @Test + void testHashPartition() { + final int count = 1024 * 1024; + try (ColumnVector aIn = ColumnVector.build(DType.INT64, count, Range.appendLongs(count)); + ColumnVector bIn = ColumnVector.build(DType.INT32, count, (b) -> { + for (int i = 0; i < count; i++) { + b.append(i / 2); + } + }); + ColumnVector cIn = ColumnVector.build(DType.STRING, count, (b) -> { + for (int i = 0; i < count; i++) { + b.appendUTF8String(String.valueOf(i).getBytes()); + } + })) { + + HashSet expected = new HashSet<>(); + for (long i = 0; i < count; i++) { + expected.add(i); + } + try (Table input = new Table(new ColumnVector[]{aIn, bIn, cIn}); + PartitionedTable output = input.onColumns(0).hashPartition(5)) { int[] parts = output.getPartitions(); assertEquals(5, parts.length); assertEquals(0, parts[0]);