diff --git a/java/src/main/java/ai/rapids/cudf/MixedJoinSize.java b/java/src/main/java/ai/rapids/cudf/MixedJoinSize.java new file mode 100644 index 00000000000..811f0b9a0b0 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/MixedJoinSize.java @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.rapids.cudf; + +/** This class tracks size information associated with a mixed table join. */ +public final class MixedJoinSize implements AutoCloseable { + private final long outputRowCount; + // This is in flux, avoid exposing publicly until the dust settles. + private ColumnVector matches; + + MixedJoinSize(long outputRowCount, ColumnVector matches) { + this.outputRowCount = outputRowCount; + this.matches = matches; + } + + /** Return the number of output rows that would be generated from the mixed join */ + public long getOutputRowCount() { + return outputRowCount; + } + + ColumnVector getMatches() { + return matches; + } + + @Override + public synchronized void close() { + matches.close(); + } +} diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index dcd7953fa2e..a021ded4588 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -640,6 +640,36 @@ private static native long[] conditionalLeftAntiJoinGatherMapWithCount(long left long condition, long rowCount) throws CudfException; + private static native long[] mixedLeftJoinSize(long leftKeysTable, long rightKeysTable, + long leftConditionTable, long rightConditionTable, + long condition, boolean compareNullsEqual); + + private static native long[] mixedLeftJoinGatherMaps(long leftKeysTable, long rightKeysTable, + long leftConditionTable, long rightConditionTable, + long condition, boolean compareNullsEqual); + + private static native long[] mixedLeftJoinGatherMapsWithSize(long leftKeysTable, long rightKeysTable, + long leftConditionTable, long rightConditionTable, + long condition, boolean compareNullsEqual, + long outputRowCount, long matchesColumnView); + + private static native long[] mixedInnerJoinSize(long leftKeysTable, long rightKeysTable, + long leftConditionTable, long rightConditionTable, + long condition, boolean compareNullsEqual); + + private static native long[] mixedInnerJoinGatherMaps(long leftKeysTable, long rightKeysTable, + long leftConditionTable, long rightConditionTable, + long condition, boolean compareNullsEqual); + + private static native long[] mixedInnerJoinGatherMapsWithSize(long leftKeysTable, long rightKeysTable, + long leftConditionTable, long rightConditionTable, + long condition, boolean compareNullsEqual, + long outputRowCount, long matchesColumnView); + + private static native long[] mixedFullJoinGatherMaps(long leftKeysTable, long rightKeysTable, + long leftConditionTable, long rightConditionTable, + long condition, boolean compareNullsEqual); + private static native long[] crossJoin(long leftTable, long rightTable) throws CudfException; private static native long[] concatenate(long[] cudfTablePointers) throws CudfException; @@ -2221,7 +2251,7 @@ public static Table scatter(Scalar[] source, ColumnView scatterMap, Table target target.getNativeView(), checkBounds)); } - private GatherMap[] buildJoinGatherMaps(long[] gatherMapData) { + private static GatherMap[] buildJoinGatherMaps(long[] gatherMapData) { long bufferSize = gatherMapData[0]; long leftAddr = gatherMapData[1]; long leftHandle = gatherMapData[2]; @@ -2374,6 +2404,94 @@ public GatherMap[] conditionalLeftJoinGatherMaps(Table rightTable, return buildJoinGatherMaps(gatherMapData); } + /** + * Computes output size information for a left join between two tables using a mix of equality + * and inequality conditions. The entire join condition is assumed to be a logical AND of the + * equality condition and inequality condition. + * NOTE: It is the responsibility of the caller to close the resulting size information object + * or native resources can be leaked! + * @param leftKeys the left table's key columns for the equality condition + * @param rightKeys the right table's key columns for the equality condition + * @param leftConditional the left table's columns needed to evaluate the inequality condition + * @param rightConditional the right table's columns needed to evaluate the inequality condition + * @param condition the inequality condition of the join + * @param nullEquality whether nulls should compare as equal + * @return size information for the join + */ + public static MixedJoinSize mixedLeftJoinSize(Table leftKeys, Table rightKeys, + Table leftConditional, Table rightConditional, + CompiledExpression condition, + NullEquality nullEquality) { + long[] mixedSizeInfo = mixedLeftJoinSize( + leftKeys.getNativeView(), rightKeys.getNativeView(), + leftConditional.getNativeView(), rightConditional.getNativeView(), + condition.getNativeHandle(), nullEquality == NullEquality.EQUAL); + assert mixedSizeInfo.length == 2; + long outputRowCount = mixedSizeInfo[0]; + long matchesColumnHandle = mixedSizeInfo[1]; + return new MixedJoinSize(outputRowCount, new ColumnVector(matchesColumnHandle)); + } + + /** + * Computes the gather maps that can be used to manifest the result of a left join between + * two tables using a mix of equality and inequality conditions. The entire join condition is + * assumed to be a logical AND of the equality condition and inequality condition. + * Two {@link GatherMap} instances will be returned that can be used to gather + * the left and right tables, respectively, to produce the result of the left join. + * It is the responsibility of the caller to close the resulting gather map instances. + * @param leftKeys the left table's key columns for the equality condition + * @param rightKeys the right table's key columns for the equality condition + * @param leftConditional the left table's columns needed to evaluate the inequality condition + * @param rightConditional the right table's columns needed to evaluate the inequality condition + * @param condition the inequality condition of the join + * @param nullEquality whether nulls should compare as equal + * @return left and right table gather maps + */ + public static GatherMap[] mixedLeftJoinGatherMaps(Table leftKeys, Table rightKeys, + Table leftConditional, Table rightConditional, + CompiledExpression condition, + NullEquality nullEquality) { + long[] gatherMapData = mixedLeftJoinGatherMaps( + leftKeys.getNativeView(), rightKeys.getNativeView(), + leftConditional.getNativeView(), rightConditional.getNativeView(), + condition.getNativeHandle(), + nullEquality == NullEquality.EQUAL); + return buildJoinGatherMaps(gatherMapData); + } + + /** + * Computes the gather maps that can be used to manifest the result of a left join between + * two tables using a mix of equality and inequality conditions. The entire join condition is + * assumed to be a logical AND of the equality condition and inequality condition. + * Two {@link GatherMap} instances will be returned that can be used to gather + * the left and right tables, respectively, to produce the result of the left join. + * It is the responsibility of the caller to close the resulting gather map instances. + * This interface allows passing the size result from + * {@link #mixedLeftJoinSize(Table, Table, Table, Table, CompiledExpression, NullEquality)} + * when the output size was computed previously. + * @param leftKeys the left table's key columns for the equality condition + * @param rightKeys the right table's key columns for the equality condition + * @param leftConditional the left table's columns needed to evaluate the inequality condition + * @param rightConditional the right table's columns needed to evaluate the inequality condition + * @param condition the inequality condition of the join + * @param nullEquality whether nulls should compare as equal + * @param joinSize mixed join size result + * @return left and right table gather maps + */ + public static GatherMap[] mixedLeftJoinGatherMaps(Table leftKeys, Table rightKeys, + Table leftConditional, Table rightConditional, + CompiledExpression condition, + NullEquality nullEquality, + MixedJoinSize joinSize) { + long[] gatherMapData = mixedLeftJoinGatherMapsWithSize( + leftKeys.getNativeView(), rightKeys.getNativeView(), + leftConditional.getNativeView(), rightConditional.getNativeView(), + condition.getNativeHandle(), + nullEquality == NullEquality.EQUAL, + joinSize.getOutputRowCount(), joinSize.getMatches().getNativeView()); + return buildJoinGatherMaps(gatherMapData); + } + /** * Computes the gather maps that can be used to manifest the result of an inner equi-join between * two tables. It is assumed this table instance holds the key columns from the left table, and @@ -2514,6 +2632,94 @@ public GatherMap[] conditionalInnerJoinGatherMaps(Table rightTable, return buildJoinGatherMaps(gatherMapData); } + /** + * Computes output size information for an inner join between two tables using a mix of equality + * and inequality conditions. The entire join condition is assumed to be a logical AND of the + * equality condition and inequality condition. + * NOTE: It is the responsibility of the caller to close the resulting size information object + * or native resources can be leaked! + * @param leftKeys the left table's key columns for the equality condition + * @param rightKeys the right table's key columns for the equality condition + * @param leftConditional the left table's columns needed to evaluate the inequality condition + * @param rightConditional the right table's columns needed to evaluate the inequality condition + * @param condition the inequality condition of the join + * @param nullEquality whether nulls should compare as equal + * @return size information for the join + */ + public static MixedJoinSize mixedInnerJoinSize(Table leftKeys, Table rightKeys, + Table leftConditional, Table rightConditional, + CompiledExpression condition, + NullEquality nullEquality) { + long[] mixedSizeInfo = mixedInnerJoinSize( + leftKeys.getNativeView(), rightKeys.getNativeView(), + leftConditional.getNativeView(), rightConditional.getNativeView(), + condition.getNativeHandle(), nullEquality == NullEquality.EQUAL); + assert mixedSizeInfo.length == 2; + long outputRowCount = mixedSizeInfo[0]; + long matchesColumnHandle = mixedSizeInfo[1]; + return new MixedJoinSize(outputRowCount, new ColumnVector(matchesColumnHandle)); + } + + /** + * Computes the gather maps that can be used to manifest the result of an inner join between + * two tables using a mix of equality and inequality conditions. The entire join condition is + * assumed to be a logical AND of the equality condition and inequality condition. + * Two {@link GatherMap} instances will be returned that can be used to gather + * the left and right tables, respectively, to produce the result of the inner join. + * It is the responsibility of the caller to close the resulting gather map instances. + * @param leftKeys the left table's key columns for the equality condition + * @param rightKeys the right table's key columns for the equality condition + * @param leftConditional the left table's columns needed to evaluate the inequality condition + * @param rightConditional the right table's columns needed to evaluate the inequality condition + * @param condition the inequality condition of the join + * @param nullEquality whether nulls should compare as equal + * @return left and right table gather maps + */ + public static GatherMap[] mixedInnerJoinGatherMaps(Table leftKeys, Table rightKeys, + Table leftConditional, Table rightConditional, + CompiledExpression condition, + NullEquality nullEquality) { + long[] gatherMapData = mixedInnerJoinGatherMaps( + leftKeys.getNativeView(), rightKeys.getNativeView(), + leftConditional.getNativeView(), rightConditional.getNativeView(), + condition.getNativeHandle(), + nullEquality == NullEquality.EQUAL); + return buildJoinGatherMaps(gatherMapData); + } + + /** + * Computes the gather maps that can be used to manifest the result of an inner join between + * two tables using a mix of equality and inequality conditions. The entire join condition is + * assumed to be a logical AND of the equality condition and inequality condition. + * Two {@link GatherMap} instances will be returned that can be used to gather + * the left and right tables, respectively, to produce the result of the inner join. + * It is the responsibility of the caller to close the resulting gather map instances. + * This interface allows passing the size result from + * {@link #mixedInnerJoinSize(Table, Table, Table, Table, CompiledExpression, NullEquality)} + * when the output size was computed previously. + * @param leftKeys the left table's key columns for the equality condition + * @param rightKeys the right table's key columns for the equality condition + * @param leftConditional the left table's columns needed to evaluate the inequality condition + * @param rightConditional the right table's columns needed to evaluate the inequality condition + * @param condition the inequality condition of the join + * @param nullEquality whether nulls should compare as equal + * @param joinSize mixed join size result + * @return left and right table gather maps + */ + public static GatherMap[] mixedInnerJoinGatherMaps(Table leftKeys, Table rightKeys, + Table leftConditional, Table rightConditional, + CompiledExpression condition, + NullEquality nullEquality, + MixedJoinSize joinSize) { + long[] gatherMapData = mixedInnerJoinGatherMapsWithSize( + leftKeys.getNativeView(), rightKeys.getNativeView(), + leftConditional.getNativeView(), rightConditional.getNativeView(), + condition.getNativeHandle(), + nullEquality == NullEquality.EQUAL, + joinSize.getOutputRowCount(), joinSize.getMatches().getNativeView()); + return buildJoinGatherMaps(gatherMapData); + } + /** * Computes the gather maps that can be used to manifest the result of an full equi-join between * two tables. It is assumed this table instance holds the key columns from the left table, and @@ -2620,6 +2826,33 @@ public GatherMap[] conditionalFullJoinGatherMaps(Table rightTable, return buildJoinGatherMaps(gatherMapData); } + /** + * Computes the gather maps that can be used to manifest the result of a full join between + * two tables using a mix of equality and inequality conditions. The entire join condition is + * assumed to be a logical AND of the equality condition and inequality condition. + * Two {@link GatherMap} instances will be returned that can be used to gather + * the left and right tables, respectively, to produce the result of the full join. + * It is the responsibility of the caller to close the resulting gather map instances. + * @param leftKeys the left table's key columns for the equality condition + * @param rightKeys the right table's key columns for the equality condition + * @param leftConditional the left table's columns needed to evaluate the inequality condition + * @param rightConditional the right table's columns needed to evaluate the inequality condition + * @param condition the inequality condition of the join + * @param nullEquality whether nulls should compare as equal + * @return left and right table gather maps + */ + public static GatherMap[] mixedFullJoinGatherMaps(Table leftKeys, Table rightKeys, + Table leftConditional, Table rightConditional, + CompiledExpression condition, + NullEquality nullEquality) { + long[] gatherMapData = mixedFullJoinGatherMaps( + leftKeys.getNativeView(), rightKeys.getNativeView(), + leftConditional.getNativeView(), rightConditional.getNativeView(), + condition.getNativeHandle(), + nullEquality == NullEquality.EQUAL); + return buildJoinGatherMaps(gatherMapData); + } + private GatherMap buildSemiJoinGatherMap(long[] gatherMapData) { long bufferSize = gatherMapData[0]; long leftAddr = gatherMapData[1]; diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index 828d163fe07..03faf9be021 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -41,6 +41,7 @@ #include #include #include +#include #include #include "cudf_jni_apis.hpp" @@ -886,6 +887,76 @@ jlongArray cond_join_gather_single_map(JNIEnv *env, jlong j_left_table, jlong j_ CATCH_STD(env, NULL); } +template +jlongArray mixed_join_size(JNIEnv *env, jlong j_left_keys, jlong j_right_keys, + jlong j_left_condition, jlong j_right_condition, jlong j_condition, + jboolean j_nulls_equal, T join_size_func) { + JNI_NULL_CHECK(env, j_left_keys, "left keys table is null", 0); + JNI_NULL_CHECK(env, j_right_keys, "right keys table is null", 0); + JNI_NULL_CHECK(env, j_left_condition, "left condition table is null", 0); + JNI_NULL_CHECK(env, j_right_condition, "right condition table is null", 0); + JNI_NULL_CHECK(env, j_condition, "condition is null", 0); + try { + cudf::jni::auto_set_device(env); + auto const left_keys = reinterpret_cast(j_left_keys); + auto const right_keys = reinterpret_cast(j_right_keys); + auto const left_condition = reinterpret_cast(j_left_condition); + auto const right_condition = reinterpret_cast(j_right_condition); + auto const condition = reinterpret_cast(j_condition); + auto const nulls_equal = + j_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; + std::pair>> join_size_info = + join_size_func(*left_keys, *right_keys, *left_condition, *right_condition, + condition->get_top_expression(), nulls_equal); + if (join_size_info.second->size() > std::numeric_limits::max()) { + throw std::runtime_error("Too many values in device buffer to convert into a column"); + } + auto col_size = join_size_info.second->size(); + auto col_data = join_size_info.second->release(); + auto col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, col_size, + std::move(col_data), rmm::device_buffer{}, 0); + cudf::jni::native_jlongArray result(env, 2); + result[0] = static_cast(join_size_info.first); + result[1] = reinterpret_cast(col.release()); + return result.get_jArray(); + } + CATCH_STD(env, NULL); +} + +template +jlongArray mixed_join_gather_maps(JNIEnv *env, jlong j_left_keys, jlong j_right_keys, + jlong j_left_condition, jlong j_right_condition, + jlong j_condition, jboolean j_nulls_equal, T join_func) { + JNI_NULL_CHECK(env, j_left_keys, "left keys table is null", 0); + JNI_NULL_CHECK(env, j_right_keys, "right keys table is null", 0); + JNI_NULL_CHECK(env, j_left_condition, "left condition table is null", 0); + JNI_NULL_CHECK(env, j_right_condition, "right condition table is null", 0); + JNI_NULL_CHECK(env, j_condition, "condition is null", 0); + try { + cudf::jni::auto_set_device(env); + auto const left_keys = reinterpret_cast(j_left_keys); + auto const right_keys = reinterpret_cast(j_right_keys); + auto const left_condition = reinterpret_cast(j_left_condition); + auto const right_condition = reinterpret_cast(j_right_condition); + auto const condition = reinterpret_cast(j_condition); + auto const nulls_equal = + j_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; + return gather_maps_to_java(env, + join_func(*left_keys, *right_keys, *left_condition, *right_condition, + condition->get_top_expression(), nulls_equal)); + } + CATCH_STD(env, NULL); +} + +std::pair> +get_mixed_size_info(JNIEnv *env, jlong j_output_row_count, jlong j_matches_view) { + auto const row_count = static_cast(j_output_row_count); + auto const matches = reinterpret_cast(j_matches_view); + return std::pair>( + row_count, cudf::device_span(matches->template data(), + matches->size())); +} + // Returns a table view containing only the columns at the specified indices cudf::table_view const get_keys_table(cudf::table_view const *t, native_jintArray const &key_indices) { @@ -2227,6 +2298,50 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalLeftJoinGather }); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_mixedLeftJoinSize( + JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jlong j_left_condition, + jlong j_right_condition, jlong j_condition, jboolean j_nulls_equal) { + return cudf::jni::mixed_join_size( + env, j_left_keys, j_right_keys, j_left_condition, j_right_condition, j_condition, + j_nulls_equal, + [](cudf::table_view const &left_keys, cudf::table_view const &right_keys, + cudf::table_view const &left_condition, cudf::table_view const &right_condition, + cudf::ast::expression const &condition, cudf::null_equality nulls_equal) { + return cudf::mixed_left_join_size(left_keys, right_keys, left_condition, right_condition, + condition, nulls_equal); + }); +} + +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_mixedLeftJoinGatherMaps( + JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jlong j_left_condition, + jlong j_right_condition, jlong j_condition, jboolean j_nulls_equal) { + return cudf::jni::mixed_join_gather_maps( + env, j_left_keys, j_right_keys, j_left_condition, j_right_condition, j_condition, + j_nulls_equal, + [](cudf::table_view const &left_keys, cudf::table_view const &right_keys, + cudf::table_view const &left_condition, cudf::table_view const &right_condition, + cudf::ast::expression const &condition, cudf::null_equality nulls_equal) { + return cudf::mixed_left_join(left_keys, right_keys, left_condition, right_condition, + condition, nulls_equal); + }); +} + +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_mixedLeftJoinGatherMapsWithSize( + JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jlong j_left_condition, + jlong j_right_condition, jlong j_condition, jboolean j_nulls_equal, jlong j_output_row_count, + jlong j_matches_view) { + auto size_info = cudf::jni::get_mixed_size_info(env, j_output_row_count, j_matches_view); + return cudf::jni::mixed_join_gather_maps( + env, j_left_keys, j_right_keys, j_left_condition, j_right_condition, j_condition, + j_nulls_equal, + [&size_info](cudf::table_view const &left_keys, cudf::table_view const &right_keys, + cudf::table_view const &left_condition, cudf::table_view const &right_condition, + cudf::ast::expression const &condition, cudf::null_equality nulls_equal) { + return cudf::mixed_left_join(left_keys, right_keys, left_condition, right_condition, + condition, nulls_equal, size_info); + }); +} + JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_innerJoinGatherMaps( JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jboolean compare_nulls_equal) { return cudf::jni::join_gather_maps( @@ -2316,6 +2431,50 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalInnerJoinGathe }); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_mixedInnerJoinSize( + JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jlong j_left_condition, + jlong j_right_condition, jlong j_condition, jboolean j_nulls_equal) { + return cudf::jni::mixed_join_size( + env, j_left_keys, j_right_keys, j_left_condition, j_right_condition, j_condition, + j_nulls_equal, + [](cudf::table_view const &left_keys, cudf::table_view const &right_keys, + cudf::table_view const &left_condition, cudf::table_view const &right_condition, + cudf::ast::expression const &condition, cudf::null_equality nulls_equal) { + return cudf::mixed_inner_join_size(left_keys, right_keys, left_condition, right_condition, + condition, nulls_equal); + }); +} + +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_mixedInnerJoinGatherMaps( + JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jlong j_left_condition, + jlong j_right_condition, jlong j_condition, jboolean j_nulls_equal) { + return cudf::jni::mixed_join_gather_maps( + env, j_left_keys, j_right_keys, j_left_condition, j_right_condition, j_condition, + j_nulls_equal, + [](cudf::table_view const &left_keys, cudf::table_view const &right_keys, + cudf::table_view const &left_condition, cudf::table_view const &right_condition, + cudf::ast::expression const &condition, cudf::null_equality nulls_equal) { + return cudf::mixed_inner_join(left_keys, right_keys, left_condition, right_condition, + condition, nulls_equal); + }); +} + +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_mixedInnerJoinGatherMapsWithSize( + JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jlong j_left_condition, + jlong j_right_condition, jlong j_condition, jboolean j_nulls_equal, jlong j_output_row_count, + jlong j_matches_view) { + auto size_info = cudf::jni::get_mixed_size_info(env, j_output_row_count, j_matches_view); + return cudf::jni::mixed_join_gather_maps( + env, j_left_keys, j_right_keys, j_left_condition, j_right_condition, j_condition, + j_nulls_equal, + [&size_info](cudf::table_view const &left_keys, cudf::table_view const &right_keys, + cudf::table_view const &left_condition, cudf::table_view const &right_condition, + cudf::ast::expression const &condition, cudf::null_equality nulls_equal) { + return cudf::mixed_inner_join(left_keys, right_keys, left_condition, right_condition, + condition, nulls_equal, size_info); + }); +} + JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_fullJoinGatherMaps( JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jboolean compare_nulls_equal) { return cudf::jni::join_gather_maps( @@ -2374,6 +2533,20 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalFullJoinGather }); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_mixedFullJoinGatherMaps( + JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jlong j_left_condition, + jlong j_right_condition, jlong j_condition, jboolean j_nulls_equal) { + return cudf::jni::mixed_join_gather_maps( + env, j_left_keys, j_right_keys, j_left_condition, j_right_condition, j_condition, + j_nulls_equal, + [](cudf::table_view const &left_keys, cudf::table_view const &right_keys, + cudf::table_view const &left_condition, cudf::table_view const &right_condition, + cudf::ast::expression const &condition, cudf::null_equality nulls_equal) { + return cudf::mixed_full_join(left_keys, right_keys, left_condition, right_condition, + condition, nulls_equal); + }); +} + JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftSemiJoinGatherMap( JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jboolean compare_nulls_equal) { return cudf::jni::join_gather_single_map( diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 7fe69d2d7fc..8e074a5e4ff 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -1578,6 +1578,144 @@ void testConditionalLeftJoinGatherMapsNullsWithCount() { } } + @Test + void testMixedLeftJoinGatherMaps() { + final int inv = Integer.MIN_VALUE; + BinaryOperation expr = new BinaryOperation(BinaryOperator.GREATER, + new ColumnReference(1, TableReference.LEFT), + new ColumnReference(1, TableReference.RIGHT)); + try (CompiledExpression condition = expr.compile(); + Table left = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8) + .column(1, 2, 3, 4, 5, 6, 7, 8, 9, 0) + .build(); + Table leftKeys = new Table(left.getColumn(0)); + Table right = new Table.TestBuilder() + .column(6, 5, 9, 8, 10, 32) + .column(0, 1, 2, 3, 4, 5) + .column(7, 8, 9, 0, 1, 2).build(); + Table rightKeys = new Table(right.getColumn(0)); + Table expected = new Table.TestBuilder() + .column( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + .column(inv, inv, 2, inv, inv, inv, inv, 0, 1, inv) + .build()) { + GatherMap[] maps = Table.mixedLeftJoinGatherMaps(leftKeys, rightKeys, left, right, condition, + NullEquality.UNEQUAL); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testMixedLeftJoinGatherMapsNulls() { + final int inv = Integer.MIN_VALUE; + BinaryOperation expr = new BinaryOperation(BinaryOperator.GREATER, + new ColumnReference(1, TableReference.LEFT), + new ColumnReference(1, TableReference.RIGHT)); + try (CompiledExpression condition = expr.compile(); + Table left = new Table.TestBuilder() + .column(null, 3, 9, 0, 1, 7, 4, null, 5, 8) + .column( 1, 2, 3, 4, 5, 6, 7, 8, 9, 0) + .build(); + Table leftKeys = new Table(left.getColumn(0)); + Table right = new Table.TestBuilder() + .column(null, 5, null, 8, 10, 32) + .column( 0, 1, 2, 3, 4, 5) + .column( 7, 8, 9, 0, 1, 2).build(); + Table rightKeys = new Table(right.getColumn(0)); + Table expected = new Table.TestBuilder() + .column(0, 1, 2, 3, 4, 5, 6, 7, 7, 8, 9) + .column(0, inv, inv, inv, inv, inv, inv, 0, 2, 1, inv) + .build()) { + GatherMap[] maps = Table.mixedLeftJoinGatherMaps(leftKeys, rightKeys, left, right, condition, + NullEquality.EQUAL); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testMixedLeftJoinGatherMapsWithSize() { + final int inv = Integer.MIN_VALUE; + BinaryOperation expr = new BinaryOperation(BinaryOperator.GREATER, + new ColumnReference(1, TableReference.LEFT), + new ColumnReference(1, TableReference.RIGHT)); + try (CompiledExpression condition = expr.compile(); + Table left = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8) + .column(1, 2, 3, 4, 5, 6, 7, 8, 9, 0) + .build(); + Table leftKeys = new Table(left.getColumn(0)); + Table right = new Table.TestBuilder() + .column(6, 5, 9, 8, 10, 32) + .column(0, 1, 2, 3, 4, 5) + .column(7, 8, 9, 0, 1, 2).build(); + Table rightKeys = new Table(right.getColumn(0)); + Table expected = new Table.TestBuilder() + .column( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + .column(inv, inv, 2, inv, inv, inv, inv, 0, 1, inv) + .build(); + MixedJoinSize sizeInfo = Table.mixedLeftJoinSize(leftKeys, rightKeys, left, right, + condition, NullEquality.UNEQUAL)) { + assertEquals(expected.getRowCount(), sizeInfo.getOutputRowCount()); + GatherMap[] maps = Table.mixedLeftJoinGatherMaps(leftKeys, rightKeys, left, right, condition, + NullEquality.UNEQUAL, sizeInfo); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testMixedLeftJoinGatherMapsNullsWithSize() { + final int inv = Integer.MIN_VALUE; + BinaryOperation expr = new BinaryOperation(BinaryOperator.GREATER, + new ColumnReference(1, TableReference.LEFT), + new ColumnReference(1, TableReference.RIGHT)); + try (CompiledExpression condition = expr.compile(); + Table left = new Table.TestBuilder() + .column(null, 3, 9, 0, 1, 7, 4, null, 5, 8) + .column( 1, 2, 3, 4, 5, 6, 7, 8, 9, 0) + .build(); + Table leftKeys = new Table(left.getColumn(0)); + Table right = new Table.TestBuilder() + .column(null, 5, null, 8, 10, 32) + .column( 0, 1, 2, 3, 4, 5) + .column( 7, 8, 9, 0, 1, 2).build(); + Table rightKeys = new Table(right.getColumn(0)); + Table expected = new Table.TestBuilder() + .column(0, 1, 2, 3, 4, 5, 6, 7, 7, 8, 9) + .column(0, inv, inv, inv, inv, inv, inv, 0, 2, 1, inv) + .build(); + MixedJoinSize sizeInfo = Table.mixedLeftJoinSize(leftKeys, rightKeys, left, right, + condition, NullEquality.EQUAL)) { + assertEquals(expected.getRowCount(), sizeInfo.getOutputRowCount()); + GatherMap[] maps = Table.mixedLeftJoinGatherMaps(leftKeys, rightKeys, left, right, condition, + NullEquality.EQUAL, sizeInfo); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + @Test void testInnerJoinGatherMaps() { try (Table leftKeys = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); @@ -1848,6 +1986,140 @@ void testConditionalInnerJoinGatherMapsNullsWithCount() { } } + @Test + void testMixedInnerJoinGatherMaps() { + BinaryOperation expr = new BinaryOperation(BinaryOperator.GREATER, + new ColumnReference(1, TableReference.LEFT), + new ColumnReference(1, TableReference.RIGHT)); + try (CompiledExpression condition = expr.compile(); + Table left = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8) + .column(1, 2, 3, 4, 5, 6, 7, 8, 9, 0) + .build(); + Table leftKeys = new Table(left.getColumn(0)); + Table right = new Table.TestBuilder() + .column(6, 5, 9, 8, 10, 32) + .column(0, 1, 2, 3, 4, 5) + .column(7, 8, 9, 0, 1, 2).build(); + Table rightKeys = new Table(right.getColumn(0)); + Table expected = new Table.TestBuilder() + .column(2, 7, 8) + .column(2, 0, 1) + .build()) { + GatherMap[] maps = Table.mixedInnerJoinGatherMaps(leftKeys, rightKeys, left, right, condition, + NullEquality.UNEQUAL); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testMixedInnerJoinGatherMapsNulls() { + BinaryOperation expr = new BinaryOperation(BinaryOperator.GREATER, + new ColumnReference(1, TableReference.LEFT), + new ColumnReference(1, TableReference.RIGHT)); + try (CompiledExpression condition = expr.compile(); + Table left = new Table.TestBuilder() + .column(null, 3, 9, 0, 1, 7, 4, null, 5, 8) + .column( 1, 2, 3, 4, 5, 6, 7, 8, 9, 0) + .build(); + Table leftKeys = new Table(left.getColumn(0)); + Table right = new Table.TestBuilder() + .column(null, 5, null, 8, 10, 32) + .column( 0, 1, 2, 3, 4, 5) + .column( 7, 8, 9, 0, 1, 2).build(); + Table rightKeys = new Table(right.getColumn(0)); + Table expected = new Table.TestBuilder() + .column(0, 7, 7, 8) + .column(0, 0, 2, 1) + .build()) { + GatherMap[] maps = Table.mixedInnerJoinGatherMaps(leftKeys, rightKeys, left, right, condition, + NullEquality.EQUAL); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testMixedInnerJoinGatherMapsWithSize() { + BinaryOperation expr = new BinaryOperation(BinaryOperator.GREATER, + new ColumnReference(1, TableReference.LEFT), + new ColumnReference(1, TableReference.RIGHT)); + try (CompiledExpression condition = expr.compile(); + Table left = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8) + .column(1, 2, 3, 4, 5, 6, 7, 8, 9, 0) + .build(); + Table leftKeys = new Table(left.getColumn(0)); + Table right = new Table.TestBuilder() + .column(6, 5, 9, 8, 10, 32) + .column(0, 1, 2, 3, 4, 5) + .column(7, 8, 9, 0, 1, 2).build(); + Table rightKeys = new Table(right.getColumn(0)); + Table expected = new Table.TestBuilder() + .column(2, 7, 8) + .column(2, 0, 1) + .build(); + MixedJoinSize sizeInfo = Table.mixedInnerJoinSize(leftKeys, rightKeys, left, right, + condition, NullEquality.UNEQUAL)) { + assertEquals(expected.getRowCount(), sizeInfo.getOutputRowCount()); + GatherMap[] maps = Table.mixedInnerJoinGatherMaps(leftKeys, rightKeys, left, right, condition, + NullEquality.UNEQUAL, sizeInfo); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testMixedInnerJoinGatherMapsNullsWithSize() { + BinaryOperation expr = new BinaryOperation(BinaryOperator.GREATER, + new ColumnReference(1, TableReference.LEFT), + new ColumnReference(1, TableReference.RIGHT)); + try (CompiledExpression condition = expr.compile(); + Table left = new Table.TestBuilder() + .column(null, 3, 9, 0, 1, 7, 4, null, 5, 8) + .column( 1, 2, 3, 4, 5, 6, 7, 8, 9, 0) + .build(); + Table leftKeys = new Table(left.getColumn(0)); + Table right = new Table.TestBuilder() + .column(null, 5, null, 8, 10, 32) + .column( 0, 1, 2, 3, 4, 5) + .column( 7, 8, 9, 0, 1, 2).build(); + Table rightKeys = new Table(right.getColumn(0)); + Table expected = new Table.TestBuilder() + .column(0, 7, 7, 8) + .column(0, 0, 2, 1) + .build(); + MixedJoinSize sizeInfo = Table.mixedInnerJoinSize(leftKeys, rightKeys, left, right, + condition, NullEquality.EQUAL)) { + assertEquals(expected.getRowCount(), sizeInfo.getOutputRowCount()); + GatherMap[] maps = Table.mixedInnerJoinGatherMaps(leftKeys, rightKeys, left, right, condition, + NullEquality.EQUAL, sizeInfo); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + @Test void testFullJoinGatherMaps() { final int inv = Integer.MIN_VALUE; @@ -2042,6 +2314,72 @@ void testConditionalFullJoinGatherMapsNulls() { } } + @Test + void testMixedFullJoinGatherMaps() { + final int inv = Integer.MIN_VALUE; + BinaryOperation expr = new BinaryOperation(BinaryOperator.GREATER, + new ColumnReference(1, TableReference.LEFT), + new ColumnReference(1, TableReference.RIGHT)); + try (CompiledExpression condition = expr.compile(); + Table left = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8) + .column(1, 2, 3, 4, 5, 6, 7, 8, 9, 0) + .build(); + Table leftKeys = new Table(left.getColumn(0)); + Table right = new Table.TestBuilder() + .column(6, 5, 9, 8, 10, 32) + .column(0, 1, 2, 3, 4, 5) + .column(7, 8, 9, 0, 1, 2).build(); + Table rightKeys = new Table(right.getColumn(0)); + Table expected = new Table.TestBuilder() + .column(inv, inv, inv, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + .column( 3, 4, 5, inv, inv, 2, inv, inv, inv, inv, 0, 1, inv) + .build()) { + GatherMap[] maps = Table.mixedFullJoinGatherMaps(leftKeys, rightKeys, left, right, condition, + NullEquality.UNEQUAL); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testMixedFullJoinGatherMapsNulls() { + final int inv = Integer.MIN_VALUE; + BinaryOperation expr = new BinaryOperation(BinaryOperator.GREATER, + new ColumnReference(1, TableReference.LEFT), + new ColumnReference(1, TableReference.RIGHT)); + try (CompiledExpression condition = expr.compile(); + Table left = new Table.TestBuilder() + .column(null, 3, 9, 0, 1, 7, 4, null, 5, 8) + .column( 1, 2, 3, 4, 5, 6, 7, 8, 9, 0) + .build(); + Table leftKeys = new Table(left.getColumn(0)); + Table right = new Table.TestBuilder() + .column(null, 5, null, 8, 10, 32) + .column( 0, 1, 2, 3, 4, 5) + .column( 7, 8, 9, 0, 1, 2).build(); + Table rightKeys = new Table(right.getColumn(0)); + Table expected = new Table.TestBuilder() + .column(inv, inv, inv, 0, 1, 2, 3, 4, 5, 6, 7, 7, 8, 9) + .column( 3, 4, 5, 0, inv, inv, inv, inv, inv, inv, 0, 2, 1, inv) + .build()) { + GatherMap[] maps = Table.mixedFullJoinGatherMaps(leftKeys, rightKeys, left, right, condition, + NullEquality.EQUAL); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + @Test void testLeftSemiJoinGatherMap() { try (Table leftKeys = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build();