Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JNI bindings for distinct_hash_join #15019

Merged
merged 9 commits into from
Feb 28, 2024
25 changes: 25 additions & 0 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,9 @@ private static native long[] leftHashJoinGatherMapsWithCount(long leftTable, lon
private static native long[] innerJoinGatherMaps(long leftKeys, long rightKeys,
boolean compareNullsEqual) throws CudfException;

private static native long[] innerDistinctJoinGatherMaps(long leftKeys, long rightKeys,
boolean compareNullsEqual) throws CudfException;

private static native long innerJoinRowCount(long table, long hashJoin) throws CudfException;

private static native long[] innerHashJoinGatherMaps(long table, long hashJoin) throws CudfException;
Expand Down Expand Up @@ -3150,6 +3153,28 @@ public GatherMap[] innerJoinGatherMaps(Table rightKeys, boolean compareNullsEqua
return buildJoinGatherMaps(gatherMapData);
}

/**
* Computes the gather maps that can be used to manifest the result of an inner equi-join between
* two tables where the right table is guaranteed to not contain any duplicated join keys. It is
* assumed this table instance holds the key columns from the left table, and the table argument
* represents the key columns from the right table. Two {@link GatherMap} instances will be
* returned that can be used to gather the left and right tables, respectively, to produce the
* result of the inner join.
* It is the responsibility of the caller to close the resulting gather map instances.
* @param rightKeys join key columns from the right table
jlowe marked this conversation as resolved.
Show resolved Hide resolved
* @param compareNullsEqual true if null key values should match otherwise false
* @return left and right table gather maps
*/
public GatherMap[] innerDistinctJoinGatherMaps(Table rightKeys, boolean compareNullsEqual) {
if (getNumberOfColumns() != rightKeys.getNumberOfColumns()) {
throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() +
jlowe marked this conversation as resolved.
Show resolved Hide resolved
"rightKeys: " + rightKeys.getNumberOfColumns());
}
long[] gatherMapData =
innerDistinctJoinGatherMaps(getNativeView(), rightKeys.getNativeView(), compareNullsEqual);
return buildJoinGatherMaps(gatherMapData);
}

/**
* Computes the number of rows resulting from an inner equi-join between two tables.
* @param otherHash hash table built from join key columns from the other table
Expand Down
26 changes: 24 additions & 2 deletions java/src/main/native/src/TableJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -702,9 +702,9 @@ jlongArray gather_maps_to_java(JNIEnv *env,
jlongArray gather_map_to_java(JNIEnv *env,
std::unique_ptr<rmm::device_uvector<cudf::size_type>> map) {
// release the underlying device buffer to Java
auto gather_map_buffer = std::make_unique<rmm::device_buffer>(map->release());
cudf::jni::native_jlongArray result(env, 3);
result[0] = static_cast<jlong>(gather_map_buffer->size());
result[0] = static_cast<jlong>(map->size() * sizeof(cudf::size_type));
auto gather_map_buffer = std::make_unique<rmm::device_buffer>(map->release());
result[1] = ptr_as_jlong(gather_map_buffer->data());
result[2] = release_as_jlong(gather_map_buffer);
return result.get_jArray();
Expand Down Expand Up @@ -2550,6 +2550,28 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_innerJoinGatherMaps(
});
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_innerDistinctJoinGatherMaps(
JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jboolean compare_nulls_equal) {
return cudf::jni::join_gather_maps(
env, j_left_keys, j_right_keys, compare_nulls_equal,
[](cudf::table_view const &left, cudf::table_view const &right, cudf::null_equality nulleq) {
auto has_nulls = cudf::has_nested_nulls(left) || cudf::has_nested_nulls(right)
? cudf::nullable_join::YES : cudf::nullable_join::NO;
std::pair<std::unique_ptr<rmm::device_uvector<cudf::size_type>>,
std::unique_ptr<rmm::device_uvector<cudf::size_type>>> maps;
if (cudf::detail::has_nested_columns(right)) {
cudf::distinct_hash_join<cudf::has_nested::YES> hash(right, left, has_nulls, nulleq);
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
maps = hash.inner_join();
} else {
cudf::distinct_hash_join<cudf::has_nested::NO> hash(right, left, has_nulls, nulleq);
maps = hash.inner_join();
}
// Unique join returns {right map, left map} but all the other joins
// return {left map, right map}. Swap here to make it consistent.
return std::make_pair(std::move(maps.second), std::move(maps.first));
});
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_innerJoinRowCount(JNIEnv *env, jclass,
jlong j_left_table,
jlong j_right_hash_join) {
Expand Down
111 changes: 110 additions & 1 deletion java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import com.google.common.base.Charsets;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.apache.avro.SchemaBuilder;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.parquet.hadoop.ParquetFileReader;
Expand Down Expand Up @@ -2085,6 +2084,116 @@ void testInnerJoinGatherMapsNulls() {
}
}

private void checkInnerDistinctJoin(Table leftKeys, Table rightKeys, Table expected,
boolean compareNullsEqual) {
GatherMap[] maps = leftKeys.innerDistinctJoinGatherMaps(rightKeys, compareNullsEqual);
try {
verifyJoinGatherMaps(maps, expected);
} finally {
for (GatherMap map : maps) {
map.close();
}
}
jlowe marked this conversation as resolved.
Show resolved Hide resolved
}

@Test
void testInnerDistinctJoinGatherMaps() {
try (Table leftKeys = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8, 6).build();
Table rightKeys = new Table.TestBuilder().column(6, 5, 9, 8, 10, 32).build();
Table expected = new Table.TestBuilder()
.column(2, 7, 8, 9, 10) // left
.column(2, 0, 1, 3, 0) // right
.build()) {
checkInnerDistinctJoin(leftKeys, rightKeys, expected, false);
}
}

@Test
void testInnerDistinctJoinGatherMapsWithNested() {
StructType structType = new StructType(false,
new BasicType(false, DType.STRING),
new BasicType(false, DType.INT32));
StructData[] leftData = new StructData[]{
new StructData("abc", 1),
new StructData("xyz", 1),
new StructData("abc", 2),
new StructData("xyz", 2),
new StructData("abc", 1),
new StructData("abc", 3),
new StructData("xyz", 3)
};
StructData[] rightData = new StructData[]{
new StructData("abc", 1),
new StructData("xyz", 4),
new StructData("xyz", 2),
new StructData("abc", -1),
};
try (Table leftKeys = new Table.TestBuilder().column(structType, leftData).build();
Table rightKeys = new Table.TestBuilder().column(structType, rightData).build();
Table expected = new Table.TestBuilder()
.column(0, 3, 4)
.column(0, 2, 0)
.build()) {
checkInnerDistinctJoin(leftKeys, rightKeys, expected, false);
}
}

@Test
void testInnerDistinctJoinGatherMapsNullsEqual() {
try (Table leftKeys = new Table.TestBuilder()
.column(2, 3, 9, 0, 1, 7, 4, null, null, 8)
.build();
Table rightKeys = new Table.TestBuilder()
.column(null, 9, 8, 10, 32)
.build();
Table expected = new Table.TestBuilder()
.column(2, 7, 8, 9) // left
.column(1, 0, 0, 2) // right
.build()) {
checkInnerDistinctJoin(leftKeys, rightKeys, expected, true);
}
}

@Test
void testInnerDistinctJoinGatherMapsWithNestedNullsEqual() {
StructType structType = new StructType(true,
new BasicType(true, DType.STRING),
new BasicType(true, DType.INT32));
StructData[] leftData = new StructData[]{
new StructData("abc", 1),
null,
new StructData("xyz", 1),
new StructData("abc", 2),
new StructData("xyz", null),
null,
new StructData("abc", 1),
new StructData("abc", 3),
new StructData("xyz", 3),
new StructData(null, null),
new StructData(null, 1)
};
StructData[] rightData = new StructData[]{
null,
new StructData("abc", 1),
new StructData("xyz", 4),
new StructData("xyz", 2),
new StructData(null, null),
new StructData(null, 2),
new StructData(null, 1),
new StructData("xyz", null),
new StructData("abc", null),
new StructData("abc", -1)
};
try (Table leftKeys = new Table.TestBuilder().column(structType, leftData).build();
Table rightKeys = new Table.TestBuilder().column(structType, rightData).build();
Table expected = new Table.TestBuilder()
.column(0, 1, 4, 5, 6, 9, 10)
.column(1, 0, 7, 0, 1, 4, 6)
.build()) {
checkInnerDistinctJoin(leftKeys, rightKeys, expected, true);
}
}

@Test
void testInnerHashJoinGatherMaps() {
try (Table leftKeys = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build();
Expand Down
Loading