From bde4f2a99f1e4d39a2b014e1e19d8abaca812788 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Wed, 4 Aug 2021 11:29:06 -0500 Subject: [PATCH 1/2] Java bindings for conditional join output sizes --- java/src/main/java/ai/rapids/cudf/Table.java | 237 +++++++++++++++++- java/src/main/native/src/TableJni.cpp | 120 +++++++++ .../test/java/ai/rapids/cudf/TableTest.java | 226 ++++++++++++++++- 3 files changed, 563 insertions(+), 20 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index 861c6485a5c..5386c1404b9 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -524,26 +524,67 @@ private static native long[] leftAntiJoin(long leftTable, int[] leftJoinCols, lo private static native long[] leftAntiJoinGatherMap(long leftKeys, long rightKeys, boolean compareNullsEqual) throws CudfException; + private static native long conditionalLeftJoinRowCount(long leftTable, long rightTable, + long condition, + boolean compareNullsEqual) throws CudfException; + private static native long[] conditionalLeftJoinGatherMaps(long leftTable, long rightTable, long condition, boolean compareNullsEqual) throws CudfException; + private static native long[] conditionalLeftJoinGatherMapsWithCount(long leftTable, long rightTable, + long condition, + boolean compareNullsEqual, + long rowCount) throws CudfException; + + private static native long conditionalInnerJoinRowCount(long leftTable, long rightTable, + long condition, + boolean compareNullsEqual) throws CudfException; + private static native long[] conditionalInnerJoinGatherMaps(long leftTable, long rightTable, long condition, boolean compareNullsEqual) throws CudfException; + private static native long[] conditionalInnerJoinGatherMapsWithCount(long leftTable, long rightTable, + long condition, + boolean compareNullsEqual, + long rowCount) throws CudfException; + private static native long[] conditionalFullJoinGatherMaps(long leftTable, long rightTable, long condition, boolean compareNullsEqual) throws CudfException; + private static native long[] conditionalFullJoinGatherMapsWithCount(long leftTable, long rightTable, + long condition, + boolean compareNullsEqual, + long rowCount) throws CudfException; + + private static native long conditionalLeftSemiJoinRowCount(long leftTable, long rightTable, + long condition, + boolean compareNullsEqual) throws CudfException; + private static native long[] conditionalLeftSemiJoinGatherMap(long leftTable, long rightTable, long condition, boolean compareNullsEqual) throws CudfException; + private static native long[] conditionalLeftSemiJoinGatherMapWithCount(long leftTable, long rightTable, + long condition, + boolean compareNullsEqual, + long rowCount) throws CudfException; + + private static native long conditionalLeftAntiJoinRowCount(long leftTable, long rightTable, + long condition, + boolean compareNullsEqual) throws CudfException; + private static native long[] conditionalLeftAntiJoinGatherMap(long leftTable, long rightTable, long condition, boolean compareNullsEqual) throws CudfException; + private static native long[] conditionalLeftAntiJoinGatherMapWithCount(long leftTable, long rightTable, + long condition, + boolean compareNullsEqual, + long rowCount) throws CudfException; + private static native long[] crossJoin(long leftTable, long rightTable) throws CudfException; private static native long[] concatenate(long[] cudfTablePointers) throws CudfException; @@ -1990,6 +2031,21 @@ public GatherMap[] leftJoinGatherMaps(Table rightKeys, boolean compareNullsEqual return buildJoinGatherMaps(gatherMapData); } + /** + * Computes the number of rows from the result of a left join between two tables when a + * conditional expression is true. It is assumed this table instance holds the columns from + * the left table, and the table argument represents the columns from the right table. + * @param rightTable the right side table of the join in the join + * @param condition conditional expression to evaluate during the join + * @param compareNullsEqual true if null key values should match otherwise false + * @return row count for the join result + */ + public long conditionalLeftJoinRowCount(Table rightTable, CompiledExpression condition, + boolean compareNullsEqual) { + return conditionalLeftJoinRowCount(getNativeView(), rightTable.getNativeView(), + condition.getNativeHandle(), compareNullsEqual); + } + /** * Computes the gather maps that can be used to manifest the result of a left join between * two tables when a conditional expression is true. It is assumed this table instance holds @@ -2002,14 +2058,42 @@ public GatherMap[] leftJoinGatherMaps(Table rightKeys, boolean compareNullsEqual * @param compareNullsEqual true if null key values should match otherwise false * @return left and right table gather maps */ - public GatherMap[] leftJoinGatherMaps(Table rightTable, CompiledExpression condition, - boolean compareNullsEqual) { + public GatherMap[] conditionalLeftJoinGatherMaps(Table rightTable, + CompiledExpression condition, + boolean compareNullsEqual) { long[] gatherMapData = conditionalLeftJoinGatherMaps(getNativeView(), rightTable.getNativeView(), condition.getNativeHandle(), compareNullsEqual); return buildJoinGatherMaps(gatherMapData); } + /** + * Computes the gather maps that can be used to manifest the result of a left join between + * two tables when a conditional expression is true. It is assumed this table instance holds + * the columns from the left table, and the table argument represents the columns from the + * right table. Two {@link GatherMap} instances will be returned that can be used to gather + * the left and right tables, respectively, to produce the result of the left join. + * It is the responsibility of the caller to close the resulting gather map instances. + * This interface allows passing an output row count that was previously computed from + * {@link #conditionalLeftJoinRowCount(Table, CompiledExpression, boolean)}. + * WARNING: Passing a row count that is smaller than the actual row count will result + * in undefined behavior. + * @param rightTable the right side table of the join in the join + * @param condition conditional expression to evaluate during the join + * @param compareNullsEqual true if null key values should match otherwise false + * @param outputRowCount number of output rows in the join result + * @return left and right table gather maps + */ + public GatherMap[] conditionalLeftJoinGatherMaps(Table rightTable, + CompiledExpression condition, + boolean compareNullsEqual, + long outputRowCount) { + long[] gatherMapData = + conditionalLeftJoinGatherMapsWithCount(getNativeView(), rightTable.getNativeView(), + condition.getNativeHandle(), compareNullsEqual, outputRowCount); + return buildJoinGatherMaps(gatherMapData); + } + /** * Computes the gather maps that can be used to manifest the result of an inner equi-join between * two tables. It is assumed this table instance holds the key columns from the left table, and @@ -2031,6 +2115,22 @@ public GatherMap[] innerJoinGatherMaps(Table rightKeys, boolean compareNullsEqua return buildJoinGatherMaps(gatherMapData); } + /** + * Computes the number of rows from the result of an inner join between two tables when a + * conditional expression is true. It is assumed this table instance holds the columns from + * the left table, and the table argument represents the columns from the right table. + * @param rightTable the right side table of the join in the join + * @param condition conditional expression to evaluate during the join + * @param compareNullsEqual true if null key values should match otherwise false + * @return row count for the join result + */ + public long conditionalInnerJoinRowCount(Table rightTable, + CompiledExpression condition, + boolean compareNullsEqual) { + return conditionalInnerJoinRowCount(getNativeView(), rightTable.getNativeView(), + condition.getNativeHandle(), compareNullsEqual); + } + /** * Computes the gather maps that can be used to manifest the result of an inner join between * two tables when a conditional expression is true. It is assumed this table instance holds @@ -2043,14 +2143,42 @@ public GatherMap[] innerJoinGatherMaps(Table rightKeys, boolean compareNullsEqua * @param compareNullsEqual true if null key values should match otherwise false * @return left and right table gather maps */ - public GatherMap[] innerJoinGatherMaps(Table rightTable, CompiledExpression condition, - boolean compareNullsEqual) { + public GatherMap[] conditionalInnerJoinGatherMaps(Table rightTable, + CompiledExpression condition, + boolean compareNullsEqual) { long[] gatherMapData = conditionalInnerJoinGatherMaps(getNativeView(), rightTable.getNativeView(), condition.getNativeHandle(), compareNullsEqual); return buildJoinGatherMaps(gatherMapData); } + /** + * Computes the gather maps that can be used to manifest the result of an inner join between + * two tables when a conditional expression is true. It is assumed this table instance holds + * the columns from the left table, and the table argument represents the columns from the + * right table. Two {@link GatherMap} instances will be returned that can be used to gather + * the left and right tables, respectively, to produce the result of the inner join. + * It is the responsibility of the caller to close the resulting gather map instances. + * This interface allows passing an output row count that was previously computed from + * {@link #conditionalInnerJoinRowCount(Table, CompiledExpression, boolean)}. + * WARNING: Passing a row count that is smaller than the actual row count will result + * in undefined behavior. + * @param rightTable the right side table of the join in the join + * @param condition conditional expression to evaluate during the join + * @param compareNullsEqual true if null key values should match otherwise false + * @param outputRowCount number of output rows in the join result + * @return left and right table gather maps + */ + public GatherMap[] conditionalInnerJoinGatherMaps(Table rightTable, + CompiledExpression condition, + boolean compareNullsEqual, + long outputRowCount) { + long[] gatherMapData = + conditionalInnerJoinGatherMapsWithCount(getNativeView(), rightTable.getNativeView(), + condition.getNativeHandle(), compareNullsEqual, outputRowCount); + return buildJoinGatherMaps(gatherMapData); + } + /** * Computes the gather maps that can be used to manifest the result of an full equi-join between * two tables. It is assumed this table instance holds the key columns from the left table, and @@ -2084,8 +2212,9 @@ public GatherMap[] fullJoinGatherMaps(Table rightKeys, boolean compareNullsEqual * @param compareNullsEqual true if null key values should match otherwise false * @return left and right table gather maps */ - public GatherMap[] fullJoinGatherMaps(Table rightTable, CompiledExpression condition, - boolean compareNullsEqual) { + public GatherMap[] conditionalFullJoinGatherMaps(Table rightTable, + CompiledExpression condition, + boolean compareNullsEqual) { long[] gatherMapData = conditionalFullJoinGatherMaps(getNativeView(), rightTable.getNativeView(), condition.getNativeHandle(), compareNullsEqual); @@ -2120,6 +2249,22 @@ public GatherMap leftSemiJoinGatherMap(Table rightKeys, boolean compareNullsEqua return buildSemiJoinGatherMap(gatherMapData); } + /** + * Computes the number of rows from the result of a left semi join between two tables when a + * conditional expression is true. It is assumed this table instance holds the columns from + * the left table, and the table argument represents the columns from the right table. + * @param rightTable the right side table of the join in the join + * @param condition conditional expression to evaluate during the join + * @param compareNullsEqual true if null key values should match otherwise false + * @return row count for the join result + */ + public long conditionalLeftSemiJoinRowCount(Table rightTable, + CompiledExpression condition, + boolean compareNullsEqual) { + return conditionalLeftSemiJoinRowCount(getNativeView(), rightTable.getNativeView(), + condition.getNativeHandle(), compareNullsEqual); + } + /** * Computes the gather map that can be used to manifest the result of a left semi join between * two tables when a conditional expression is true. It is assumed this table instance holds @@ -2132,14 +2277,42 @@ public GatherMap leftSemiJoinGatherMap(Table rightKeys, boolean compareNullsEqua * @param compareNullsEqual true if null key values should match otherwise false * @return left table gather map */ - public GatherMap leftSemiJoinGatherMap(Table rightTable, CompiledExpression condition, - boolean compareNullsEqual) { + public GatherMap conditionalLeftSemiJoinGatherMap(Table rightTable, + CompiledExpression condition, + boolean compareNullsEqual) { long[] gatherMapData = conditionalLeftSemiJoinGatherMap(getNativeView(), rightTable.getNativeView(), condition.getNativeHandle(), compareNullsEqual); return buildSemiJoinGatherMap(gatherMapData); } + /** + * Computes the gather map that can be used to manifest the result of a left semi join between + * two tables when a conditional expression is true. It is assumed this table instance holds + * the columns from the left table, and the table argument represents the columns from the + * right table. The {@link GatherMap} instance returned can be used to gather the left table + * to produce the result of the left semi join. + * It is the responsibility of the caller to close the resulting gather map instance. + * This interface allows passing an output row count that was previously computed from + * {@link #conditionalLeftSemiJoinRowCount(Table, CompiledExpression, boolean)}. + * WARNING: Passing a row count that is smaller than the actual row count will result + * in undefined behavior. + * @param rightTable the right side table of the join + * @param condition conditional expression to evaluate during the join + * @param compareNullsEqual true if null key values should match otherwise false + * @param outputRowCount number of output rows in the join result + * @return left table gather map + */ + public GatherMap conditionalLeftSemiJoinGatherMap(Table rightTable, + CompiledExpression condition, + boolean compareNullsEqual, + long outputRowCount) { + long[] gatherMapData = + conditionalLeftSemiJoinGatherMapWithCount(getNativeView(), rightTable.getNativeView(), + condition.getNativeHandle(), compareNullsEqual, outputRowCount); + return buildSemiJoinGatherMap(gatherMapData); + } + /** * Computes the gather map that can be used to manifest the result of a left anti-join between * two tables. It is assumed this table instance holds the key columns from the left table, and @@ -2161,6 +2334,22 @@ public GatherMap leftAntiJoinGatherMap(Table rightKeys, boolean compareNullsEqua return buildSemiJoinGatherMap(gatherMapData); } + /** + * Computes the number of rows from the result of a left anti join between two tables when a + * conditional expression is true. It is assumed this table instance holds the columns from + * the left table, and the table argument represents the columns from the right table. + * @param rightTable the right side table of the join in the join + * @param condition conditional expression to evaluate during the join + * @param compareNullsEqual true if null key values should match otherwise false + * @return row count for the join result + */ + public long conditionalLeftAntiJoinRowCount(Table rightTable, + CompiledExpression condition, + boolean compareNullsEqual) { + return conditionalLeftAntiJoinRowCount(getNativeView(), rightTable.getNativeView(), + condition.getNativeHandle(), compareNullsEqual); + } + /** * Computes the gather map that can be used to manifest the result of a left anti join between * two tables when a conditional expression is true. It is assumed this table instance holds @@ -2173,14 +2362,42 @@ public GatherMap leftAntiJoinGatherMap(Table rightKeys, boolean compareNullsEqua * @param compareNullsEqual true if null key values should match otherwise false * @return left table gather map */ - public GatherMap leftAntiJoinGatherMap(Table rightTable, CompiledExpression condition, - boolean compareNullsEqual) { + public GatherMap conditionalLeftAntiJoinGatherMap(Table rightTable, + CompiledExpression condition, + boolean compareNullsEqual) { long[] gatherMapData = conditionalLeftAntiJoinGatherMap(getNativeView(), rightTable.getNativeView(), condition.getNativeHandle(), compareNullsEqual); return buildSemiJoinGatherMap(gatherMapData); } + /** + * Computes the gather map that can be used to manifest the result of a left anti join between + * two tables when a conditional expression is true. It is assumed this table instance holds + * the columns from the left table, and the table argument represents the columns from the + * right table. The {@link GatherMap} instance returned can be used to gather the left table + * to produce the result of the left anti join. + * It is the responsibility of the caller to close the resulting gather map instance. + * This interface allows passing an output row count that was previously computed from + * {@link #conditionalLeftAntiJoinRowCount(Table, CompiledExpression, boolean)}. + * WARNING: Passing a row count that is smaller than the actual row count will result + * in undefined behavior. + * @param rightTable the right side table of the join + * @param condition conditional expression to evaluate during the join + * @param compareNullsEqual true if null key values should match otherwise false + * @param outputRowCount number of output rows in the join result + * @return left table gather map + */ + public GatherMap conditionalLeftAntiJoinGatherMap(Table rightTable, + CompiledExpression condition, + boolean compareNullsEqual, + long outputRowCount) { + long[] gatherMapData = + conditionalLeftAntiJoinGatherMapWithCount(getNativeView(), rightTable.getNativeView(), + condition.getNativeHandle(), compareNullsEqual, outputRowCount); + return buildSemiJoinGatherMap(gatherMapData); + } + /** * Convert this table of columns into a row major format that is useful for interacting with other * systems that do row major processing of the data. Currently only fixed-width column types are diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index c092450da1c..bac71cb95ed 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -1924,6 +1924,24 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftJoinGatherMaps( }); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_conditionalLeftJoinRowCount( + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, + jboolean compare_nulls_equal) { + JNI_NULL_CHECK(env, j_left_table, "left_table is null", 0); + JNI_NULL_CHECK(env, j_right_table, "right_table is null", 0); + JNI_NULL_CHECK(env, j_condition, "condition is null", 0); + try { + cudf::jni::auto_set_device(env); + auto left_table = reinterpret_cast(j_left_table); + auto right_table = reinterpret_cast(j_right_table); + auto condition = reinterpret_cast(j_condition); + auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; + auto row_count = cudf::conditional_left_join_size(*left_table, *right_table, condition->get_top_expression(), nulleq); + return static_cast(row_count); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalLeftJoinGatherMaps( JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, jboolean compare_nulls_equal) { @@ -1935,6 +1953,18 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalLeftJoinGather }); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalLeftJoinGatherMapsWithCount( + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, + jboolean compare_nulls_equal, jlong j_row_count) { + auto row_count = static_cast(j_row_count); + return cudf::jni::cond_join_gather_maps( + env, j_left_table, j_right_table, j_condition, compare_nulls_equal, + [row_count](cudf::table_view const &left, cudf::table_view const &right, + cudf::ast::expression const &cond_expr, cudf::null_equality nulleq) { + return cudf::conditional_left_join(left, right, cond_expr, nulleq, row_count); + }); +} + JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_innerJoinGatherMaps( JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jboolean compare_nulls_equal) { return cudf::jni::join_gather_maps( @@ -1944,6 +1974,24 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_innerJoinGatherMaps( }); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_conditionalInnerJoinRowCount( + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, + jboolean compare_nulls_equal) { + JNI_NULL_CHECK(env, j_left_table, "left_table is null", 0); + JNI_NULL_CHECK(env, j_right_table, "right_table is null", 0); + JNI_NULL_CHECK(env, j_condition, "condition is null", 0); + try { + cudf::jni::auto_set_device(env); + auto left_table = reinterpret_cast(j_left_table); + auto right_table = reinterpret_cast(j_right_table); + auto condition = reinterpret_cast(j_condition); + auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; + auto row_count = cudf::conditional_inner_join_size(*left_table, *right_table, condition->get_top_expression(), nulleq); + return static_cast(row_count); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalInnerJoinGatherMaps( JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, jboolean compare_nulls_equal) { @@ -1955,6 +2003,18 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalInnerJoinGathe }); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalInnerJoinGatherMapsWithCount( + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, + jboolean compare_nulls_equal, jlong j_row_count) { + auto row_count = static_cast(j_row_count); + return cudf::jni::cond_join_gather_maps( + env, j_left_table, j_right_table, j_condition, compare_nulls_equal, + [row_count](cudf::table_view const &left, cudf::table_view const &right, + cudf::ast::expression const &cond_expr, cudf::null_equality nulleq) { + return cudf::conditional_inner_join(left, right, cond_expr, nulleq, row_count); + }); +} + JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_fullJoinGatherMaps( JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jboolean compare_nulls_equal) { return cudf::jni::join_gather_maps( @@ -1984,6 +2044,24 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftSemiJoinGatherMap( }); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_conditionalLeftSemiJoinRowCount( + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, + jboolean compare_nulls_equal) { + JNI_NULL_CHECK(env, j_left_table, "left_table is null", 0); + JNI_NULL_CHECK(env, j_right_table, "right_table is null", 0); + JNI_NULL_CHECK(env, j_condition, "condition is null", 0); + try { + cudf::jni::auto_set_device(env); + auto left_table = reinterpret_cast(j_left_table); + auto right_table = reinterpret_cast(j_right_table); + auto condition = reinterpret_cast(j_condition); + auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; + auto row_count = cudf::conditional_left_semi_join_size(*left_table, *right_table, condition->get_top_expression(), nulleq); + return static_cast(row_count); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalLeftSemiJoinGatherMap( JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, jboolean compare_nulls_equal) { @@ -1995,6 +2073,18 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalLeftSemiJoinGa }); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalLeftSemiJoinGatherMapWithCount( + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, + jboolean compare_nulls_equal, jlong j_row_count) { + auto row_count = static_cast(j_row_count); + return cudf::jni::cond_join_gather_single_map( + env, j_left_table, j_right_table, j_condition, compare_nulls_equal, + [row_count](cudf::table_view const &left, cudf::table_view const &right, + cudf::ast::expression const &cond_expr, cudf::null_equality nulleq) { + return cudf::conditional_left_semi_join(left, right, cond_expr, nulleq, row_count); + }); +} + JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftAntiJoinGatherMap( JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jboolean compare_nulls_equal) { return cudf::jni::join_gather_single_map( @@ -2004,6 +2094,24 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftAntiJoinGatherMap( }); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_conditionalLeftAntiJoinRowCount( + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, + jboolean compare_nulls_equal) { + JNI_NULL_CHECK(env, j_left_table, "left_table is null", 0); + JNI_NULL_CHECK(env, j_right_table, "right_table is null", 0); + JNI_NULL_CHECK(env, j_condition, "condition is null", 0); + try { + cudf::jni::auto_set_device(env); + auto left_table = reinterpret_cast(j_left_table); + auto right_table = reinterpret_cast(j_right_table); + auto condition = reinterpret_cast(j_condition); + auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; + auto row_count = cudf::conditional_left_anti_join_size(*left_table, *right_table, condition->get_top_expression(), nulleq); + return static_cast(row_count); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalLeftAntiJoinGatherMap( JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, jboolean compare_nulls_equal) { @@ -2015,6 +2123,18 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalLeftAntiJoinGa }); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalLeftAntiJoinGatherMapWithCount( + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, + jboolean compare_nulls_equal, jlong j_row_count) { + auto row_count = static_cast(j_row_count); + return cudf::jni::cond_join_gather_single_map( + env, j_left_table, j_right_table, j_condition, compare_nulls_equal, + [row_count](cudf::table_view const &left, cudf::table_view const &right, + cudf::ast::expression const &cond_expr, cudf::null_equality nulleq) { + return cudf::conditional_left_anti_join(left, right, cond_expr, nulleq, row_count); + }); +} + JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_crossJoin(JNIEnv *env, jclass, jlong left_table, jlong right_table) { diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 6b347897f82..127855aa403 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -1504,7 +1504,7 @@ void testConditionalLeftJoinGatherMaps() { .column(inv, inv, 0, 1, 3, inv, inv, 0, 1, inv, 1, inv, 0, 1) .build(); CompiledExpression condition = expr.compile()) { - GatherMap[] maps = left.leftJoinGatherMaps(right, condition, false); + GatherMap[] maps = left.conditionalLeftJoinGatherMaps(right, condition, false); try { verifyJoinGatherMaps(maps, expected); } finally { @@ -1532,7 +1532,65 @@ void testConditionalLeftJoinGatherMapsNulls() { .column(inv, inv, 2, inv, inv, inv, inv, 0, 1, 0, 1, 3) // right .build(); CompiledExpression condition = expr.compile()) { - GatherMap[] maps = left.leftJoinGatherMaps(right, condition, true); + GatherMap[] maps = left.conditionalLeftJoinGatherMaps(right, condition, true); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testConditionalLeftJoinGatherMapsWithCount() { + final int inv = Integer.MIN_VALUE; + BinaryExpression expr = new BinaryExpression(BinaryOperator.GREATER, + new ColumnReference(0, TableReference.LEFT), + new ColumnReference(0, TableReference.RIGHT)); + try (Table left = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); + Table right = new Table.TestBuilder() + .column(6, 5, 9, 8, 10, 32) + .column(0, 1, 2, 3, 4, 5).build(); + Table expected = new Table.TestBuilder() + .column( 0, 1, 2, 2, 2, 3, 4, 5, 5, 6, 7, 8, 9, 9) + .column(inv, inv, 0, 1, 3, inv, inv, 0, 1, inv, 1, inv, 0, 1) + .build(); + CompiledExpression condition = expr.compile()) { + long rowCount = left.conditionalLeftJoinRowCount(right, condition, false); + assertEquals(expected.getRowCount(), rowCount); + GatherMap[] maps = left.conditionalLeftJoinGatherMaps(right, condition, false, rowCount); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testConditionalLeftJoinGatherMapsNullsWithCount() { + final int inv = Integer.MIN_VALUE; + BinaryExpression expr = new BinaryExpression(BinaryOperator.EQUAL, + new ColumnReference(0, TableReference.LEFT), + new ColumnReference(0, TableReference.RIGHT)); + try (Table left = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, null, null, 8) + .build(); + Table right = new Table.TestBuilder() + .column(null, null, 9, 8, 10, 32) + .build(); + Table expected = new Table.TestBuilder() + .column( 0, 1, 2, 3, 4, 5, 6, 7, 7, 8, 8, 9) // left + .column(inv, inv, 2, inv, inv, inv, inv, 0, 1, 0, 1, 3) // right + .build(); + CompiledExpression condition = expr.compile()) { + long rowCount = left.conditionalLeftJoinRowCount(right, condition, true); + assertEquals(expected.getRowCount(), rowCount); + GatherMap[] maps = left.conditionalLeftJoinGatherMaps(right, condition, true, rowCount); try { verifyJoinGatherMaps(maps, expected); } finally { @@ -1599,7 +1657,7 @@ void testConditionalInnerJoinGatherMaps() { .column(0, 1, 3, 0, 1, 1, 0, 1) .build(); CompiledExpression condition = expr.compile()) { - GatherMap[] maps = left.innerJoinGatherMaps(right, condition, false); + GatherMap[] maps = left.conditionalInnerJoinGatherMaps(right, condition, false); try { verifyJoinGatherMaps(maps, expected); } finally { @@ -1626,7 +1684,63 @@ void testConditionalInnerJoinGatherMapsNulls() { .column(2, 0, 1, 0, 1, 3) // right .build(); CompiledExpression condition = expr.compile()) { - GatherMap[] maps = left.innerJoinGatherMaps(right, condition, true); + GatherMap[] maps = left.conditionalInnerJoinGatherMaps(right, condition, true); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testConditionalInnerJoinGatherMapsWithCount() { + BinaryExpression expr = new BinaryExpression(BinaryOperator.GREATER, + new ColumnReference(0, TableReference.LEFT), + new ColumnReference(0, TableReference.RIGHT)); + try (Table left = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); + Table right = new Table.TestBuilder() + .column(6, 5, 9, 8, 10, 32) + .column(0, 1, 2, 3, 4, 5).build(); + Table expected = new Table.TestBuilder() + .column(2, 2, 2, 5, 5, 7, 9, 9) + .column(0, 1, 3, 0, 1, 1, 0, 1) + .build(); + CompiledExpression condition = expr.compile()) { + long rowCount = left.conditionalInnerJoinRowCount(right, condition, false); + assertEquals(expected.getRowCount(), rowCount); + GatherMap[] maps = left.conditionalInnerJoinGatherMaps(right, condition, false, rowCount); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testConditionalInnerJoinGatherMapsNullsWithCount() { + BinaryExpression expr = new BinaryExpression(BinaryOperator.EQUAL, + new ColumnReference(0, TableReference.LEFT), + new ColumnReference(0, TableReference.RIGHT)); + try (Table left = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, null, null, 8) + .build(); + Table right = new Table.TestBuilder() + .column(null, null, 9, 8, 10, 32) + .build(); + Table expected = new Table.TestBuilder() + .column(2, 7, 7, 8, 8, 9) // left + .column(2, 0, 1, 0, 1, 3) // right + .build(); + CompiledExpression condition = expr.compile()) { + long rowCount = left.conditionalInnerJoinRowCount(right, condition, true); + assertEquals(expected.getRowCount(), rowCount); + GatherMap[] maps = left.conditionalInnerJoinGatherMaps(right, condition, true, rowCount); try { verifyJoinGatherMaps(maps, expected); } finally { @@ -1696,7 +1810,7 @@ void testConditionalFullJoinGatherMaps() { .column( 2, 4, 5, inv, inv, 0, 1, 3, inv, inv, 0, 1, inv, 1, inv, 0, 1) .build(); CompiledExpression condition = expr.compile()) { - GatherMap[] maps = left.fullJoinGatherMaps(right, condition, false); + GatherMap[] maps = left.conditionalFullJoinGatherMaps(right, condition, false); try { verifyJoinGatherMaps(maps, expected); } finally { @@ -1724,7 +1838,7 @@ void testConditionalFullJoinGatherMapsNulls() { .column( 4, 5, inv, inv, 2, inv, inv, inv, inv, 0, 1, 0, 1, 3) // right .build(); CompiledExpression condition = expr.compile()) { - GatherMap[] maps = left.fullJoinGatherMaps(right, condition, true); + GatherMap[] maps = left.conditionalFullJoinGatherMaps(right, condition, true); try { verifyJoinGatherMaps(maps, expected); } finally { @@ -1776,7 +1890,7 @@ void testConditionalLeftSemiJoinGatherMap() { .column(2, 5, 7, 9) // left .build(); CompiledExpression condition = expr.compile(); - GatherMap map = left.leftSemiJoinGatherMap(right, condition, false)) { + GatherMap map = left.conditionalLeftSemiJoinGatherMap(right, condition, false)) { verifySemiJoinGatherMap(map, expected); } } @@ -1796,11 +1910,57 @@ void testConditionalLeftSemiJoinGatherMapNulls() { .column(2, 7, 8, 9) // left .build(); CompiledExpression condition = expr.compile(); - GatherMap map = left.leftSemiJoinGatherMap(right, condition, true)) { + GatherMap map = left.conditionalLeftSemiJoinGatherMap(right, condition, true)) { verifySemiJoinGatherMap(map, expected); } } + @Test + void testConditionalLeftSemiJoinGatherMapWithCount() { + BinaryExpression expr = new BinaryExpression(BinaryOperator.GREATER, + new ColumnReference(0, TableReference.LEFT), + new ColumnReference(0, TableReference.RIGHT)); + try (Table left = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); + Table right = new Table.TestBuilder() + .column(6, 5, 9, 8, 10, 32) + .column(0, 1, 2, 3, 4, 5).build(); + Table expected = new Table.TestBuilder() + .column(2, 5, 7, 9) // left + .build(); + CompiledExpression condition = expr.compile()) { + long rowCount = left.conditionalLeftSemiJoinRowCount(right, condition, false); + assertEquals(expected.getRowCount(), rowCount); + try (GatherMap map = + left.conditionalLeftSemiJoinGatherMap(right, condition, false, rowCount)) { + verifySemiJoinGatherMap(map, expected); + } + } + } + + @Test + void testConditionalLeftSemiJoinGatherMapNullsWithCount() { + BinaryExpression expr = new BinaryExpression(BinaryOperator.EQUAL, + new ColumnReference(0, TableReference.LEFT), + new ColumnReference(0, TableReference.RIGHT)); + try (Table left = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, null, null, 8) + .build(); + Table right = new Table.TestBuilder() + .column(null, null, 9, 8, 10, 32) + .build(); + Table expected = new Table.TestBuilder() + .column(2, 7, 8, 9) // left + .build(); + CompiledExpression condition = expr.compile()) { + long rowCount = left.conditionalLeftSemiJoinRowCount(right, condition, true); + assertEquals(expected.getRowCount(), rowCount); + try (GatherMap map = + left.conditionalLeftSemiJoinGatherMap(right, condition, true, rowCount)) { + verifySemiJoinGatherMap(map, expected); + } + } + } + @Test void testAntiSemiJoinGatherMap() { try (Table leftKeys = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); @@ -1842,7 +2002,7 @@ void testConditionalLeftAntiJoinGatherMap() { .column(0, 1, 3, 4, 6, 8) // left .build(); CompiledExpression condition = expr.compile(); - GatherMap map = left.leftAntiJoinGatherMap(right, condition, false)) { + GatherMap map = left.conditionalLeftAntiJoinGatherMap(right, condition, false)) { verifySemiJoinGatherMap(map, expected); } } @@ -1862,11 +2022,57 @@ void testConditionalAntiSemiJoinGatherMapNulls() { .column(0, 1, 3, 4, 5, 6) // left .build(); CompiledExpression condition = expr.compile(); - GatherMap map = left.leftAntiJoinGatherMap(right, condition, true)) { + GatherMap map = left.conditionalLeftAntiJoinGatherMap(right, condition, true)) { verifySemiJoinGatherMap(map, expected); } } + @Test + void testConditionalLeftAntiJoinGatherMapWithCount() { + BinaryExpression expr = new BinaryExpression(BinaryOperator.GREATER, + new ColumnReference(0, TableReference.LEFT), + new ColumnReference(0, TableReference.RIGHT)); + try (Table left = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); + Table right = new Table.TestBuilder() + .column(6, 5, 9, 8, 10, 32) + .column(0, 1, 2, 3, 4, 5).build(); + Table expected = new Table.TestBuilder() + .column(0, 1, 3, 4, 6, 8) // left + .build(); + CompiledExpression condition = expr.compile()) { + long rowCount = left.conditionalLeftAntiJoinRowCount(right, condition, false); + assertEquals(expected.getRowCount(), rowCount); + try (GatherMap map = + left.conditionalLeftAntiJoinGatherMap(right, condition, false, rowCount)) { + verifySemiJoinGatherMap(map, expected); + } + } + } + + @Test + void testConditionalAntiSemiJoinGatherMapNullsWithCount() { + BinaryExpression expr = new BinaryExpression(BinaryOperator.EQUAL, + new ColumnReference(0, TableReference.LEFT), + new ColumnReference(0, TableReference.RIGHT)); + try (Table left = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, null, null, 8) + .build(); + Table right = new Table.TestBuilder() + .column(null, null, 9, 8, 10, 32) + .build(); + Table expected = new Table.TestBuilder() + .column(0, 1, 3, 4, 5, 6) // left + .build(); + CompiledExpression condition = expr.compile()) { + long rowCount = left.conditionalLeftAntiJoinRowCount(right, condition, true); + assertEquals(expected.getRowCount(), rowCount); + try (GatherMap map = + left.conditionalLeftAntiJoinGatherMap(right, condition, true, rowCount)) { + verifySemiJoinGatherMap(map, expected); + } + } + } + @Test void testBoundsNulls() { boolean[] descFlags = new boolean[1]; From 0cfc4131944fc9b99b661b165aa8547fae7780e5 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Mon, 16 Aug 2021 10:36:34 -0500 Subject: [PATCH 2/2] fix JNI code style --- java/src/main/native/src/TableJni.cpp | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index bac71cb95ed..c9649c2d9ee 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -1936,7 +1936,8 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_conditionalLeftJoinRowCount( auto right_table = reinterpret_cast(j_right_table); auto condition = reinterpret_cast(j_condition); auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; - auto row_count = cudf::conditional_left_join_size(*left_table, *right_table, condition->get_top_expression(), nulleq); + auto row_count = cudf::conditional_left_join_size(*left_table, *right_table, + condition->get_top_expression(), nulleq); return static_cast(row_count); } CATCH_STD(env, 0); @@ -1960,7 +1961,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalLeftJoinGather return cudf::jni::cond_join_gather_maps( env, j_left_table, j_right_table, j_condition, compare_nulls_equal, [row_count](cudf::table_view const &left, cudf::table_view const &right, - cudf::ast::expression const &cond_expr, cudf::null_equality nulleq) { + cudf::ast::expression const &cond_expr, cudf::null_equality nulleq) { return cudf::conditional_left_join(left, right, cond_expr, nulleq, row_count); }); } @@ -1986,7 +1987,8 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_conditionalInnerJoinRowCount( auto right_table = reinterpret_cast(j_right_table); auto condition = reinterpret_cast(j_condition); auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; - auto row_count = cudf::conditional_inner_join_size(*left_table, *right_table, condition->get_top_expression(), nulleq); + auto row_count = cudf::conditional_inner_join_size(*left_table, *right_table, + condition->get_top_expression(), nulleq); return static_cast(row_count); } CATCH_STD(env, 0); @@ -2010,7 +2012,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalInnerJoinGathe return cudf::jni::cond_join_gather_maps( env, j_left_table, j_right_table, j_condition, compare_nulls_equal, [row_count](cudf::table_view const &left, cudf::table_view const &right, - cudf::ast::expression const &cond_expr, cudf::null_equality nulleq) { + cudf::ast::expression const &cond_expr, cudf::null_equality nulleq) { return cudf::conditional_inner_join(left, right, cond_expr, nulleq, row_count); }); } @@ -2056,7 +2058,8 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_conditionalLeftSemiJoinRowCoun auto right_table = reinterpret_cast(j_right_table); auto condition = reinterpret_cast(j_condition); auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; - auto row_count = cudf::conditional_left_semi_join_size(*left_table, *right_table, condition->get_top_expression(), nulleq); + auto row_count = cudf::conditional_left_semi_join_size(*left_table, *right_table, + condition->get_top_expression(), nulleq); return static_cast(row_count); } CATCH_STD(env, 0); @@ -2080,7 +2083,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalLeftSemiJoinGa return cudf::jni::cond_join_gather_single_map( env, j_left_table, j_right_table, j_condition, compare_nulls_equal, [row_count](cudf::table_view const &left, cudf::table_view const &right, - cudf::ast::expression const &cond_expr, cudf::null_equality nulleq) { + cudf::ast::expression const &cond_expr, cudf::null_equality nulleq) { return cudf::conditional_left_semi_join(left, right, cond_expr, nulleq, row_count); }); } @@ -2106,7 +2109,8 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_conditionalLeftAntiJoinRowCoun auto right_table = reinterpret_cast(j_right_table); auto condition = reinterpret_cast(j_condition); auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; - auto row_count = cudf::conditional_left_anti_join_size(*left_table, *right_table, condition->get_top_expression(), nulleq); + auto row_count = cudf::conditional_left_anti_join_size(*left_table, *right_table, + condition->get_top_expression(), nulleq); return static_cast(row_count); } CATCH_STD(env, 0); @@ -2130,7 +2134,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalLeftAntiJoinGa return cudf::jni::cond_join_gather_single_map( env, j_left_table, j_right_table, j_condition, compare_nulls_equal, [row_count](cudf::table_view const &left, cudf::table_view const &right, - cudf::ast::expression const &cond_expr, cudf::null_equality nulleq) { + cudf::ast::expression const &cond_expr, cudf::null_equality nulleq) { return cudf::conditional_left_anti_join(left, right, cond_expr, nulleq, row_count); }); }