Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update JNI for cudf::hash_join probe-time null equality parameter removal #10268

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 18 additions & 39 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -539,14 +539,11 @@ private static native long[] leftJoin(long leftTable, int[] leftJoinCols, long r
private static native long[] leftJoinGatherMaps(long leftKeys, long rightKeys,
boolean compareNullsEqual) throws CudfException;

private static native long leftJoinRowCount(long leftTable, long rightHashJoin,
boolean nullsEqual) throws CudfException;
private static native long leftJoinRowCount(long leftTable, long rightHashJoin) throws CudfException;

private static native long[] leftHashJoinGatherMaps(long leftTable, long rightHashJoin,
boolean nullsEqual) throws CudfException;
private static native long[] leftHashJoinGatherMaps(long leftTable, long rightHashJoin) throws CudfException;

private static native long[] leftHashJoinGatherMapsWithCount(long leftTable, long rightHashJoin,
boolean nullsEqual,
long outputRowCount) throws CudfException;

private static native long[] innerJoin(long leftTable, int[] leftJoinCols, long rightTable,
Expand All @@ -555,14 +552,11 @@ private static native long[] innerJoin(long leftTable, int[] leftJoinCols, long
private static native long[] innerJoinGatherMaps(long leftKeys, long rightKeys,
boolean compareNullsEqual) throws CudfException;

private static native long innerJoinRowCount(long table, long hashJoin,
boolean nullsEqual) throws CudfException;
private static native long innerJoinRowCount(long table, long hashJoin) throws CudfException;

private static native long[] innerHashJoinGatherMaps(long table, long hashJoin,
boolean nullsEqual) throws CudfException;
private static native long[] innerHashJoinGatherMaps(long table, long hashJoin) throws CudfException;

private static native long[] innerHashJoinGatherMapsWithCount(long table, long hashJoin,
boolean nullsEqual,
long outputRowCount) throws CudfException;

private static native long[] fullJoin(long leftTable, int[] leftJoinCols, long rightTable,
Expand All @@ -571,14 +565,11 @@ private static native long[] fullJoin(long leftTable, int[] leftJoinCols, long r
private static native long[] fullJoinGatherMaps(long leftKeys, long rightKeys,
boolean compareNullsEqual) throws CudfException;

private static native long fullJoinRowCount(long leftTable, long rightHashJoin,
boolean nullsEqual) throws CudfException;
private static native long fullJoinRowCount(long leftTable, long rightHashJoin) throws CudfException;

private static native long[] fullHashJoinGatherMaps(long leftTable, long rightHashJoin,
boolean nullsEqual) throws CudfException;
private static native long[] fullHashJoinGatherMaps(long leftTable, long rightHashJoin) throws CudfException;

private static native long[] fullHashJoinGatherMapsWithCount(long leftTable, long rightHashJoin,
boolean nullsEqual,
long outputRowCount) throws CudfException;

private static native long[] leftSemiJoin(long leftTable, int[] leftJoinCols, long rightTable,
Expand Down Expand Up @@ -2318,8 +2309,7 @@ public long leftJoinRowCount(HashJoin rightHash) {
throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() +
"rightKeys: " + rightHash.getNumberOfColumns());
}
return leftJoinRowCount(getNativeView(), rightHash.getNativeView(),
rightHash.getCompareNulls());
return leftJoinRowCount(getNativeView(), rightHash.getNativeView());
}

/**
Expand All @@ -2337,9 +2327,7 @@ public GatherMap[] leftJoinGatherMaps(HashJoin rightHash) {
throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() +
"rightKeys: " + rightHash.getNumberOfColumns());
}
long[] gatherMapData =
leftHashJoinGatherMaps(getNativeView(), rightHash.getNativeView(),
rightHash.getCompareNulls());
long[] gatherMapData = leftHashJoinGatherMaps(getNativeView(), rightHash.getNativeView());
return buildJoinGatherMaps(gatherMapData);
}

Expand All @@ -2363,9 +2351,8 @@ public GatherMap[] leftJoinGatherMaps(HashJoin rightHash, long outputRowCount) {
throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() +
"rightKeys: " + rightHash.getNumberOfColumns());
}
long[] gatherMapData =
leftHashJoinGatherMapsWithCount(getNativeView(), rightHash.getNativeView(),
rightHash.getCompareNulls(), outputRowCount);
long[] gatherMapData = leftHashJoinGatherMapsWithCount(getNativeView(),
rightHash.getNativeView(), outputRowCount);
return buildJoinGatherMaps(gatherMapData);
}

Expand Down Expand Up @@ -2545,8 +2532,7 @@ public long innerJoinRowCount(HashJoin otherHash) {
throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() +
"otherKeys: " + otherHash.getNumberOfColumns());
}
return innerJoinRowCount(getNativeView(), otherHash.getNativeView(),
otherHash.getCompareNulls());
return innerJoinRowCount(getNativeView(), otherHash.getNativeView());
}

/**
Expand All @@ -2564,9 +2550,7 @@ public GatherMap[] innerJoinGatherMaps(HashJoin rightHash) {
throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() +
"rightKeys: " + rightHash.getNumberOfColumns());
}
long[] gatherMapData =
innerHashJoinGatherMaps(getNativeView(), rightHash.getNativeView(),
rightHash.getCompareNulls());
long[] gatherMapData = innerHashJoinGatherMaps(getNativeView(), rightHash.getNativeView());
return buildJoinGatherMaps(gatherMapData);
}

Expand All @@ -2590,9 +2574,8 @@ public GatherMap[] innerJoinGatherMaps(HashJoin rightHash, long outputRowCount)
throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() +
"rightKeys: " + rightHash.getNumberOfColumns());
}
long[] gatherMapData =
innerHashJoinGatherMapsWithCount(getNativeView(), rightHash.getNativeView(),
rightHash.getCompareNulls(), outputRowCount);
long[] gatherMapData = innerHashJoinGatherMapsWithCount(getNativeView(),
rightHash.getNativeView(), outputRowCount);
return buildJoinGatherMaps(gatherMapData);
}

Expand Down Expand Up @@ -2778,8 +2761,7 @@ public long fullJoinRowCount(HashJoin rightHash) {
throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() +
"rightKeys: " + rightHash.getNumberOfColumns());
}
return fullJoinRowCount(getNativeView(), rightHash.getNativeView(),
rightHash.getCompareNulls());
return fullJoinRowCount(getNativeView(), rightHash.getNativeView());
}

/**
Expand All @@ -2797,9 +2779,7 @@ public GatherMap[] fullJoinGatherMaps(HashJoin rightHash) {
throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() +
"rightKeys: " + rightHash.getNumberOfColumns());
}
long[] gatherMapData =
fullHashJoinGatherMaps(getNativeView(), rightHash.getNativeView(),
rightHash.getCompareNulls());
long[] gatherMapData = fullHashJoinGatherMaps(getNativeView(), rightHash.getNativeView());
return buildJoinGatherMaps(gatherMapData);
}

Expand All @@ -2823,9 +2803,8 @@ public GatherMap[] fullJoinGatherMaps(HashJoin rightHash, long outputRowCount) {
throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() +
"rightKeys: " + rightHash.getNumberOfColumns());
}
long[] gatherMapData =
fullHashJoinGatherMapsWithCount(getNativeView(), rightHash.getNativeView(),
rightHash.getCompareNulls(), outputRowCount);
long[] gatherMapData = fullHashJoinGatherMapsWithCount(getNativeView(),
rightHash.getNativeView(), outputRowCount);
return buildJoinGatherMaps(gatherMapData);
}

Expand Down
92 changes: 38 additions & 54 deletions java/src/main/native/src/TableJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -812,15 +812,14 @@ jlongArray join_gather_maps(JNIEnv *env, jlong j_left_keys, jlong j_right_keys,
// a hash table built from the join's right table.
template <typename T>
jlongArray hash_join_gather_maps(JNIEnv *env, jlong j_left_keys, jlong j_right_hash_join,
jboolean compare_nulls_equal, T join_func) {
T join_func) {
JNI_NULL_CHECK(env, j_left_keys, "left table is null", NULL);
JNI_NULL_CHECK(env, j_right_hash_join, "hash join is null", NULL);
try {
cudf::jni::auto_set_device(env);
auto left_keys = reinterpret_cast<cudf::table_view const *>(j_left_keys);
auto hash_join = reinterpret_cast<cudf::hash_join const *>(j_right_hash_join);
auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL;
return gather_maps_to_java(env, join_func(*left_keys, *hash_join, nulleq));
return gather_maps_to_java(env, join_func(*left_keys, *hash_join));
}
CATCH_STD(env, NULL);
}
Expand Down Expand Up @@ -2172,41 +2171,36 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftJoinGatherMaps(

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_leftJoinRowCount(JNIEnv *env, jclass,
jlong j_left_table,
jlong j_right_hash_join,
jboolean compare_nulls_equal) {
jlong j_right_hash_join) {
JNI_NULL_CHECK(env, j_left_table, "left table is null", 0);
JNI_NULL_CHECK(env, j_right_hash_join, "right hash join is null", 0);
try {
cudf::jni::auto_set_device(env);
auto left_table = reinterpret_cast<cudf::table_view const *>(j_left_table);
auto hash_join = reinterpret_cast<cudf::hash_join const *>(j_right_hash_join);
auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL;
auto row_count = hash_join->left_join_size(*left_table, nulleq);
auto row_count = hash_join->left_join_size(*left_table);
return static_cast<jlong>(row_count);
}
CATCH_STD(env, 0);
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftHashJoinGatherMaps(
JNIEnv *env, jclass, jlong j_left_table, jlong j_right_hash_join,
jboolean compare_nulls_equal) {
JNIEnv *env, jclass, jlong j_left_table, jlong j_right_hash_join) {
return cudf::jni::hash_join_gather_maps(
env, j_left_table, j_right_hash_join, compare_nulls_equal,
[](cudf::table_view const &left, cudf::hash_join const &hash, cudf::null_equality nulleq) {
return hash.left_join(left, nulleq);
env, j_left_table, j_right_hash_join,
[](cudf::table_view const &left, cudf::hash_join const &hash) {
return hash.left_join(left);
});
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftHashJoinGatherMapsWithCount(
JNIEnv *env, jclass, jlong j_left_table, jlong j_right_hash_join, jboolean compare_nulls_equal,
jlong j_output_row_count) {
JNIEnv *env, jclass, jlong j_left_table, jlong j_right_hash_join, jlong j_output_row_count) {
auto output_row_count = static_cast<std::size_t>(j_output_row_count);
return cudf::jni::hash_join_gather_maps(env, j_left_table, j_right_hash_join, compare_nulls_equal,
[output_row_count](cudf::table_view const &left,
cudf::hash_join const &hash,
cudf::null_equality nulleq) {
return hash.left_join(left, nulleq, output_row_count);
});
return cudf::jni::hash_join_gather_maps(
env, j_left_table, j_right_hash_join,
[output_row_count](cudf::table_view const &left, cudf::hash_join const &hash) {
return hash.left_join(left, output_row_count);
});
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_conditionalLeftJoinRowCount(JNIEnv *env, jclass,
Expand Down Expand Up @@ -2305,41 +2299,36 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_innerJoinGatherMaps(

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_innerJoinRowCount(JNIEnv *env, jclass,
jlong j_left_table,
jlong j_right_hash_join,
jboolean compare_nulls_equal) {
jlong j_right_hash_join) {
JNI_NULL_CHECK(env, j_left_table, "left table is null", 0);
JNI_NULL_CHECK(env, j_right_hash_join, "right hash join is null", 0);
try {
cudf::jni::auto_set_device(env);
auto left_table = reinterpret_cast<cudf::table_view const *>(j_left_table);
auto hash_join = reinterpret_cast<cudf::hash_join const *>(j_right_hash_join);
auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL;
auto row_count = hash_join->inner_join_size(*left_table, nulleq);
auto row_count = hash_join->inner_join_size(*left_table);
return static_cast<jlong>(row_count);
}
CATCH_STD(env, 0);
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_innerHashJoinGatherMaps(
JNIEnv *env, jclass, jlong j_left_table, jlong j_right_hash_join,
jboolean compare_nulls_equal) {
JNIEnv *env, jclass, jlong j_left_table, jlong j_right_hash_join) {
return cudf::jni::hash_join_gather_maps(
env, j_left_table, j_right_hash_join, compare_nulls_equal,
[](cudf::table_view const &left, cudf::hash_join const &hash, cudf::null_equality nulleq) {
return hash.inner_join(left, nulleq);
env, j_left_table, j_right_hash_join,
[](cudf::table_view const &left, cudf::hash_join const &hash) {
return hash.inner_join(left);
});
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_innerHashJoinGatherMapsWithCount(
JNIEnv *env, jclass, jlong j_left_table, jlong j_right_hash_join, jboolean compare_nulls_equal,
jlong j_output_row_count) {
JNIEnv *env, jclass, jlong j_left_table, jlong j_right_hash_join, jlong j_output_row_count) {
auto output_row_count = static_cast<std::size_t>(j_output_row_count);
return cudf::jni::hash_join_gather_maps(env, j_left_table, j_right_hash_join, compare_nulls_equal,
[output_row_count](cudf::table_view const &left,
cudf::hash_join const &hash,
cudf::null_equality nulleq) {
return hash.inner_join(left, nulleq, output_row_count);
});
return cudf::jni::hash_join_gather_maps(
env, j_left_table, j_right_hash_join,
[output_row_count](cudf::table_view const &left, cudf::hash_join const &hash) {
return hash.inner_join(left, output_row_count);
});
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_conditionalInnerJoinRowCount(JNIEnv *env, jclass,
Expand Down Expand Up @@ -2438,41 +2427,36 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_fullJoinGatherMaps(

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_fullJoinRowCount(JNIEnv *env, jclass,
jlong j_left_table,
jlong j_right_hash_join,
jboolean compare_nulls_equal) {
jlong j_right_hash_join) {
JNI_NULL_CHECK(env, j_left_table, "left table is null", 0);
JNI_NULL_CHECK(env, j_right_hash_join, "right hash join is null", 0);
try {
cudf::jni::auto_set_device(env);
auto left_table = reinterpret_cast<cudf::table_view const *>(j_left_table);
auto hash_join = reinterpret_cast<cudf::hash_join const *>(j_right_hash_join);
auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL;
auto row_count = hash_join->full_join_size(*left_table, nulleq);
auto row_count = hash_join->full_join_size(*left_table);
return static_cast<jlong>(row_count);
}
CATCH_STD(env, 0);
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_fullHashJoinGatherMaps(
JNIEnv *env, jclass, jlong j_left_table, jlong j_right_hash_join,
jboolean compare_nulls_equal) {
JNIEnv *env, jclass, jlong j_left_table, jlong j_right_hash_join) {
return cudf::jni::hash_join_gather_maps(
env, j_left_table, j_right_hash_join, compare_nulls_equal,
[](cudf::table_view const &left, cudf::hash_join const &hash, cudf::null_equality nulleq) {
return hash.full_join(left, nulleq);
env, j_left_table, j_right_hash_join,
[](cudf::table_view const &left, cudf::hash_join const &hash) {
return hash.full_join(left);
});
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_fullHashJoinGatherMapsWithCount(
JNIEnv *env, jclass, jlong j_left_table, jlong j_right_hash_join, jboolean compare_nulls_equal,
jlong j_output_row_count) {
JNIEnv *env, jclass, jlong j_left_table, jlong j_right_hash_join, jlong j_output_row_count) {
auto output_row_count = static_cast<std::size_t>(j_output_row_count);
return cudf::jni::hash_join_gather_maps(env, j_left_table, j_right_hash_join, compare_nulls_equal,
[output_row_count](cudf::table_view const &left,
cudf::hash_join const &hash,
cudf::null_equality nulleq) {
return hash.full_join(left, nulleq, output_row_count);
});
return cudf::jni::hash_join_gather_maps(
env, j_left_table, j_right_hash_join,
[output_row_count](cudf::table_view const &left, cudf::hash_join const &hash) {
return hash.full_join(left, output_row_count);
});
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalFullJoinGatherMaps(
Expand Down