From b9555040ad3ca582077e9ca2bd6021a2adb62b51 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Tue, 27 Jul 2021 14:21:17 -0500 Subject: [PATCH 1/2] Add Java bindings for conditional joins --- java/src/main/java/ai/rapids/cudf/Table.java | 141 +++++++++++ .../rapids/cudf/ast/CompiledExpression.java | 5 + .../main/native/src/CompiledExpression.cpp | 54 +--- java/src/main/native/src/TableJni.cpp | 130 +++++++++- .../src/main/native/src/jni_compiled_expr.hpp | 72 ++++++ .../test/java/ai/rapids/cudf/TableTest.java | 231 ++++++++++++++++++ 6 files changed, 578 insertions(+), 55 deletions(-) create mode 100644 java/src/main/native/src/jni_compiled_expr.hpp diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index 627a2a36e9e..96a9b608f06 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -23,6 +23,7 @@ import ai.rapids.cudf.HostColumnVector.ListType; import ai.rapids.cudf.HostColumnVector.StructData; import ai.rapids.cudf.HostColumnVector.StructType; +import ai.rapids.cudf.ast.CompiledExpression; import java.io.File; import java.math.BigDecimal; @@ -523,6 +524,26 @@ private static native long[] leftAntiJoin(long leftTable, int[] leftJoinCols, lo private static native long[] leftAntiJoinGatherMap(long leftKeys, long rightKeys, boolean compareNullsEqual) throws CudfException; + private static native long[] conditionalLeftJoinGatherMaps(long leftTable, long rightTable, + long condition, + boolean compareNullsEqual) throws CudfException; + + private static native long[] conditionalInnerJoinGatherMaps(long leftTable, long rightTable, + long condition, + boolean compareNullsEqual) throws CudfException; + + private static native long[] conditionalFullJoinGatherMaps(long leftTable, long rightTable, + long condition, + boolean compareNullsEqual) throws CudfException; + + private static native long[] conditionalLeftSemiJoinGatherMap(long leftTable, long rightTable, + long condition, + boolean compareNullsEqual) throws CudfException; + + private static native long[] conditionalLeftAntiJoinGatherMap(long leftTable, long rightTable, + long condition, + boolean compareNullsEqual) throws CudfException; + private static native long[] crossJoin(long leftTable, long rightTable) throws CudfException; private static native long[] concatenate(long[] cudfTablePointers) throws CudfException; @@ -1969,6 +1990,30 @@ public GatherMap[] leftJoinGatherMaps(Table rightKeys, boolean compareNullsEqual return buildJoinGatherMaps(gatherMapData); } + /** + * Computes the gather maps that can be used to manifest the result of a left join between + * two tables when a conditional expression is true. It is assumed this table instance holds + * the columns from the left table, and the table argument represents the columns from the + * right table. Two {@link GatherMap} instances will be returned that can be used to gather + * the left and right tables, respectively, to produce the result of the left join. + * It is the responsibility of the caller to close the resulting gather map instances. + * @param rightTable the right side table of the join in the join + * @param condition conditional expression to evaluate during the join + * @param compareNullsEqual true if null key values should match otherwise false + * @return left and right table gather maps + */ + public GatherMap[] leftJoinGatherMaps(Table rightTable, CompiledExpression condition, + boolean compareNullsEqual) { + if (getNumberOfColumns() != rightTable.getNumberOfColumns()) { + throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + + "rightKeys: " + rightTable.getNumberOfColumns()); + } + long[] gatherMapData = + conditionalLeftJoinGatherMaps(getNativeView(), rightTable.getNativeView(), + condition.getNativeHandle(), compareNullsEqual); + return buildJoinGatherMaps(gatherMapData); + } + /** * Computes the gather maps that can be used to manifest the result of an inner equi-join between * two tables. It is assumed this table instance holds the key columns from the left table, and @@ -1990,6 +2035,30 @@ public GatherMap[] innerJoinGatherMaps(Table rightKeys, boolean compareNullsEqua return buildJoinGatherMaps(gatherMapData); } + /** + * Computes the gather maps that can be used to manifest the result of an inner join between + * two tables when a conditional expression is true. It is assumed this table instance holds + * the columns from the left table, and the table argument represents the columns from the + * right table. Two {@link GatherMap} instances will be returned that can be used to gather + * the left and right tables, respectively, to produce the result of the inner join. + * It is the responsibility of the caller to close the resulting gather map instances. + * @param rightTable the right side table of the join + * @param condition conditional expression to evaluate during the join + * @param compareNullsEqual true if null key values should match otherwise false + * @return left and right table gather maps + */ + public GatherMap[] innerJoinGatherMaps(Table rightTable, CompiledExpression condition, + boolean compareNullsEqual) { + if (getNumberOfColumns() != rightTable.getNumberOfColumns()) { + throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + + "rightKeys: " + rightTable.getNumberOfColumns()); + } + long[] gatherMapData = + conditionalInnerJoinGatherMaps(getNativeView(), rightTable.getNativeView(), + condition.getNativeHandle(), compareNullsEqual); + return buildJoinGatherMaps(gatherMapData); + } + /** * Computes the gather maps that can be used to manifest the result of an full equi-join between * two tables. It is assumed this table instance holds the key columns from the left table, and @@ -2011,6 +2080,30 @@ public GatherMap[] fullJoinGatherMaps(Table rightKeys, boolean compareNullsEqual return buildJoinGatherMaps(gatherMapData); } + /** + * Computes the gather maps that can be used to manifest the result of a full join between + * two tables when a conditional expression is true. It is assumed this table instance holds + * the columns from the left table, and the table argument represents the columns from the + * right table. Two {@link GatherMap} instances will be returned that can be used to gather + * the left and right tables, respectively, to produce the result of the full join. + * It is the responsibility of the caller to close the resulting gather map instances. + * @param rightTable the right side table of the join + * @param condition conditional expression to evaluate during the join + * @param compareNullsEqual true if null key values should match otherwise false + * @return left and right table gather maps + */ + public GatherMap[] fullJoinGatherMaps(Table rightTable, CompiledExpression condition, + boolean compareNullsEqual) { + if (getNumberOfColumns() != rightTable.getNumberOfColumns()) { + throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + + "rightKeys: " + rightTable.getNumberOfColumns()); + } + long[] gatherMapData = + conditionalFullJoinGatherMaps(getNativeView(), rightTable.getNativeView(), + condition.getNativeHandle(), compareNullsEqual); + return buildJoinGatherMaps(gatherMapData); + } + private GatherMap buildSemiJoinGatherMap(long[] gatherMapData) { long bufferSize = gatherMapData[0]; long leftAddr = gatherMapData[1]; @@ -2039,6 +2132,30 @@ public GatherMap leftSemiJoinGatherMap(Table rightKeys, boolean compareNullsEqua return buildSemiJoinGatherMap(gatherMapData); } + /** + * Computes the gather map that can be used to manifest the result of a left semi join between + * two tables when a conditional expression is true. It is assumed this table instance holds + * the columns from the left table, and the table argument represents the columns from the + * right table. The {@link GatherMap} instance returned can be used to gather the left table + * to produce the result of the left semi join. + * It is the responsibility of the caller to close the resulting gather map instance. + * @param rightTable the right side table of the join + * @param condition conditional expression to evaluate during the join + * @param compareNullsEqual true if null key values should match otherwise false + * @return left table gather map + */ + public GatherMap leftSemiJoinGatherMap(Table rightTable, CompiledExpression condition, + boolean compareNullsEqual) { + if (getNumberOfColumns() != rightTable.getNumberOfColumns()) { + throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + + "rightKeys: " + rightTable.getNumberOfColumns()); + } + long[] gatherMapData = + conditionalLeftSemiJoinGatherMap(getNativeView(), rightTable.getNativeView(), + condition.getNativeHandle(), compareNullsEqual); + return buildSemiJoinGatherMap(gatherMapData); + } + /** * Computes the gather map that can be used to manifest the result of a left anti-join between * two tables. It is assumed this table instance holds the key columns from the left table, and @@ -2060,6 +2177,30 @@ public GatherMap leftAntiJoinGatherMap(Table rightKeys, boolean compareNullsEqua return buildSemiJoinGatherMap(gatherMapData); } + /** + * Computes the gather map that can be used to manifest the result of a left anti join between + * two tables when a conditional expression is true. It is assumed this table instance holds + * the columns from the left table, and the table argument represents the columns from the + * right table. The {@link GatherMap} instance returned can be used to gather the left table + * to produce the result of the left anti join. + * It is the responsibility of the caller to close the resulting gather map instance. + * @param rightTable the right side table of the join + * @param condition conditional expression to evaluate during the join + * @param compareNullsEqual true if null key values should match otherwise false + * @return left table gather map + */ + public GatherMap leftAntiJoinGatherMap(Table rightTable, CompiledExpression condition, + boolean compareNullsEqual) { + if (getNumberOfColumns() != rightTable.getNumberOfColumns()) { + throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + + "rightKeys: " + rightTable.getNumberOfColumns()); + } + long[] gatherMapData = + conditionalLeftAntiJoinGatherMap(getNativeView(), rightTable.getNativeView(), + condition.getNativeHandle(), compareNullsEqual); + return buildSemiJoinGatherMap(gatherMapData); + } + /** * Convert this table of columns into a row major format that is useful for interacting with other * systems that do row major processing of the data. Currently only fixed-width column types are diff --git a/java/src/main/java/ai/rapids/cudf/ast/CompiledExpression.java b/java/src/main/java/ai/rapids/cudf/ast/CompiledExpression.java index 0d2a0052e29..0949b09cbb0 100644 --- a/java/src/main/java/ai/rapids/cudf/ast/CompiledExpression.java +++ b/java/src/main/java/ai/rapids/cudf/ast/CompiledExpression.java @@ -94,6 +94,11 @@ public synchronized void close() { isClosed = true; } + /** Returns the native address of a compiled expression. Intended for internal cudf use only. */ + public long getNativeHandle() { + return cleaner.nativeHandle; + } + private static native long compile(byte[] serializedExpression); private static native long computeColumn(long astHandle, long tableHandle); private static native void destroy(long handle); diff --git a/java/src/main/native/src/CompiledExpression.cpp b/java/src/main/native/src/CompiledExpression.cpp index a28160b32a3..31f3184f107 100644 --- a/java/src/main/native/src/CompiledExpression.cpp +++ b/java/src/main/native/src/CompiledExpression.cpp @@ -26,59 +26,7 @@ #include #include "cudf_jni_apis.hpp" - -namespace cudf { -namespace jni { -namespace ast { - -/** - * A class to capture all of the resources associated with a compiled AST expression. - * AST nodes do not own their child nodes, so every node in the expression tree - * must be explicitly tracked in order to free the underlying resources for each node. - * - * This should be cleaned up a bit after the libcudf AST refactoring in - * https://github.com/rapidsai/cudf/pull/8815 when a virtual destructor is added to the - * base AST node type. Then we do not have to track every AST node type separately. - */ -class compiled_expr { - /** All literal nodes within the expression tree */ - std::vector> literals; - - /** All column reference nodes within the expression tree */ - std::vector> column_refs; - - /** All expression nodes within the expression tree */ - std::vector> expressions; - - /** GPU scalar instances that correspond to literal nodes */ - std::vector> scalars; - -public: - cudf::ast::literal &add_literal(std::unique_ptr literal_ptr, - std::unique_ptr scalar_ptr) { - literals.push_back(std::move(literal_ptr)); - scalars.push_back(std::move(scalar_ptr)); - return *literals.back(); - } - - cudf::ast::column_reference & - add_column_ref(std::unique_ptr ref_ptr) { - column_refs.push_back(std::move(ref_ptr)); - return *column_refs.back(); - } - - cudf::ast::expression &add_expression(std::unique_ptr expr_ptr) { - expressions.push_back(std::move(expr_ptr)); - return *expressions.back(); - } - - /** Return the expression node at the top of the tree */ - cudf::ast::expression &get_top_expression() const { return *expressions.back(); } -}; - -} // namespace ast -} // namespace jni -} // namespace cudf +#include "jni_compiled_expr.hpp" namespace { diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index 78e1bf88a1c..c092450da1c 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -43,6 +43,7 @@ #include "cudf_jni_apis.hpp" #include "dtype_utils.hpp" +#include "jni_compiled_expr.hpp" #include "jni_utils.hpp" #include "row_conversion.hpp" @@ -744,7 +745,7 @@ bool valid_window_parameters(native_jintArray const &values, values.size() == preceding.size() && values.size() == following.size(); } -// Generate gather maps needed to manifest the result of a join between two tables. +// Generate gather maps needed to manifest the result of an equi-join between two tables. // The resulting Java long array contains the following at each index: // 0: Size of each gather map in bytes // 1: Device address of the gather map for the left table @@ -779,7 +780,44 @@ jlongArray join_gather_maps(JNIEnv *env, jlong j_left_keys, jlong j_right_keys, CATCH_STD(env, NULL); } -// Generate gather maps needed to manifest the result of a join between two tables. +// Generate gather maps needed to manifest the result of a conditional join between two tables. +// The resulting Java long array contains the following at each index: +// 0: Size of each gather map in bytes +// 1: Device address of the gather map for the left table +// 2: Host address of the rmm::device_buffer instance that owns the left gather map data +// 3: Device address of the gather map for the right table +// 4: Host address of the rmm::device_buffer instance that owns the right gather map data +template +jlongArray cond_join_gather_maps(JNIEnv *env, jlong j_left_table, jlong j_right_table, + jlong j_condition, jboolean compare_nulls_equal, T join_func) { + JNI_NULL_CHECK(env, j_left_table, "left_table is null", NULL); + JNI_NULL_CHECK(env, j_right_table, "right_table is null", NULL); + JNI_NULL_CHECK(env, j_condition, "condition is null", NULL); + try { + cudf::jni::auto_set_device(env); + auto left_table = reinterpret_cast(j_left_table); + auto right_table = reinterpret_cast(j_right_table); + auto condition = reinterpret_cast(j_condition); + auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; + std::pair>, + std::unique_ptr>> + join_maps = join_func(*left_table, *right_table, condition->get_top_expression(), nulleq); + + // release the underlying device buffer to Java + auto left_map_buffer = std::make_unique(join_maps.first->release()); + auto right_map_buffer = std::make_unique(join_maps.second->release()); + cudf::jni::native_jlongArray result(env, 5); + result[0] = static_cast(left_map_buffer->size()); + result[1] = reinterpret_cast(left_map_buffer->data()); + result[2] = reinterpret_cast(left_map_buffer.release()); + result[3] = reinterpret_cast(right_map_buffer->data()); + result[4] = reinterpret_cast(right_map_buffer.release()); + return result.get_jArray(); + } + CATCH_STD(env, NULL); +} + +// Generate a gather map needed to manifest the result of a semi/anti join between two tables. // The resulting Java long array contains the following at each index: // 0: Size of the gather map in bytes // 1: Device address of the gather map @@ -808,6 +846,39 @@ jlongArray join_gather_single_map(JNIEnv *env, jlong j_left_keys, jlong j_right_ CATCH_STD(env, NULL); } +// Generate a gather map needed to manifest the result of a conditional semi/anti join +// between two tables. +// The resulting Java long array contains the following at each index: +// 0: Size of the gather map in bytes +// 1: Device address of the gather map +// 2: Host address of the rmm::device_buffer instance that owns the gather map data +template +jlongArray cond_join_gather_single_map(JNIEnv *env, jlong j_left_table, jlong j_right_table, + jlong j_condition, jboolean compare_nulls_equal, + T join_func) { + JNI_NULL_CHECK(env, j_left_table, "left_table is null", NULL); + JNI_NULL_CHECK(env, j_right_table, "right_table is null", NULL); + JNI_NULL_CHECK(env, j_condition, "condition is null", NULL); + try { + cudf::jni::auto_set_device(env); + auto left_table = reinterpret_cast(j_left_table); + auto right_table = reinterpret_cast(j_right_table); + auto condition = reinterpret_cast(j_condition); + auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; + std::unique_ptr> join_map = + join_func(*left_table, *right_table, condition->get_top_expression(), nulleq); + + // release the underlying device buffer to Java + auto gather_map_buffer = std::make_unique(join_map->release()); + cudf::jni::native_jlongArray result(env, 3); + result[0] = static_cast(gather_map_buffer->size()); + result[1] = reinterpret_cast(gather_map_buffer->data()); + result[2] = reinterpret_cast(gather_map_buffer.release()); + return result.get_jArray(); + } + CATCH_STD(env, NULL); +} + // Returns a table view containing only the columns at the specified indices cudf::table_view const get_keys_table(cudf::table_view const *t, native_jintArray const &key_indices) { @@ -1853,6 +1924,17 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftJoinGatherMaps( }); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalLeftJoinGatherMaps( + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, + jboolean compare_nulls_equal) { + return cudf::jni::cond_join_gather_maps( + env, j_left_table, j_right_table, j_condition, compare_nulls_equal, + [](cudf::table_view const &left, cudf::table_view const &right, + cudf::ast::expression const &cond_expr, cudf::null_equality nulleq) { + return cudf::conditional_left_join(left, right, cond_expr, nulleq); + }); +} + JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_innerJoinGatherMaps( JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jboolean compare_nulls_equal) { return cudf::jni::join_gather_maps( @@ -1862,6 +1944,17 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_innerJoinGatherMaps( }); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalInnerJoinGatherMaps( + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, + jboolean compare_nulls_equal) { + return cudf::jni::cond_join_gather_maps( + env, j_left_table, j_right_table, j_condition, compare_nulls_equal, + [](cudf::table_view const &left, cudf::table_view const &right, + cudf::ast::expression const &cond_expr, cudf::null_equality nulleq) { + return cudf::conditional_inner_join(left, right, cond_expr, nulleq); + }); +} + JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_fullJoinGatherMaps( JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jboolean compare_nulls_equal) { return cudf::jni::join_gather_maps( @@ -1871,6 +1964,17 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_fullJoinGatherMaps( }); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalFullJoinGatherMaps( + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, + jboolean compare_nulls_equal) { + return cudf::jni::cond_join_gather_maps( + env, j_left_table, j_right_table, j_condition, compare_nulls_equal, + [](cudf::table_view const &left, cudf::table_view const &right, + cudf::ast::expression const &cond_expr, cudf::null_equality nulleq) { + return cudf::conditional_full_join(left, right, cond_expr, nulleq); + }); +} + JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftSemiJoinGatherMap( JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jboolean compare_nulls_equal) { return cudf::jni::join_gather_single_map( @@ -1880,6 +1984,17 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftSemiJoinGatherMap( }); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalLeftSemiJoinGatherMap( + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, + jboolean compare_nulls_equal) { + return cudf::jni::cond_join_gather_single_map( + env, j_left_table, j_right_table, j_condition, compare_nulls_equal, + [](cudf::table_view const &left, cudf::table_view const &right, + cudf::ast::expression const &cond_expr, cudf::null_equality nulleq) { + return cudf::conditional_left_semi_join(left, right, cond_expr, nulleq); + }); +} + JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftAntiJoinGatherMap( JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jboolean compare_nulls_equal) { return cudf::jni::join_gather_single_map( @@ -1889,6 +2004,17 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftAntiJoinGatherMap( }); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalLeftAntiJoinGatherMap( + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, + jboolean compare_nulls_equal) { + return cudf::jni::cond_join_gather_single_map( + env, j_left_table, j_right_table, j_condition, compare_nulls_equal, + [](cudf::table_view const &left, cudf::table_view const &right, + cudf::ast::expression const &cond_expr, cudf::null_equality nulleq) { + return cudf::conditional_left_anti_join(left, right, cond_expr, nulleq); + }); +} + JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_crossJoin(JNIEnv *env, jclass, jlong left_table, jlong right_table) { diff --git a/java/src/main/native/src/jni_compiled_expr.hpp b/java/src/main/native/src/jni_compiled_expr.hpp new file mode 100644 index 00000000000..e42e5a37fba --- /dev/null +++ b/java/src/main/native/src/jni_compiled_expr.hpp @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace cudf { +namespace jni { +namespace ast { + +/** + * A class to capture all of the resources associated with a compiled AST expression. + * AST nodes do not own their child nodes, so every node in the expression tree + * must be explicitly tracked in order to free the underlying resources for each node. + * + * This should be cleaned up a bit after the libcudf AST refactoring in + * https://github.com/rapidsai/cudf/pull/8815 when a virtual destructor is added to the + * base AST node type. Then we do not have to track every AST node type separately. + */ +class compiled_expr { + /** All literal nodes within the expression tree */ + std::vector> literals; + + /** All column reference nodes within the expression tree */ + std::vector> column_refs; + + /** All expression nodes within the expression tree */ + std::vector> expressions; + + /** GPU scalar instances that correspond to literal nodes */ + std::vector> scalars; + +public: + cudf::ast::literal &add_literal(std::unique_ptr literal_ptr, + std::unique_ptr scalar_ptr) { + literals.push_back(std::move(literal_ptr)); + scalars.push_back(std::move(scalar_ptr)); + return *literals.back(); + } + + cudf::ast::column_reference & + add_column_ref(std::unique_ptr ref_ptr) { + column_refs.push_back(std::move(ref_ptr)); + return *column_refs.back(); + } + + cudf::ast::expression &add_expression(std::unique_ptr expr_ptr) { + expressions.push_back(std::move(expr_ptr)); + return *expressions.back(); + } + + /** Return the expression node at the top of the tree */ + cudf::ast::expression &get_top_expression() const { return *expressions.back(); } +}; + +} // namespace ast +} // namespace jni +} // namespace cudf diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 7507bc8a286..89820ee482c 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -25,6 +25,11 @@ import ai.rapids.cudf.HostColumnVector.StructData; import ai.rapids.cudf.HostColumnVector.StructType; +import ai.rapids.cudf.ast.BinaryExpression; +import ai.rapids.cudf.ast.BinaryOperator; +import ai.rapids.cudf.ast.ColumnReference; +import ai.rapids.cudf.ast.CompiledExpression; +import ai.rapids.cudf.ast.TableReference; import org.junit.jupiter.api.Test; import java.io.ByteArrayInputStream; @@ -1484,6 +1489,58 @@ void testLeftJoinGatherMapsNulls() { } } + @Test + void testConditionalLeftJoinGatherMaps() { + final int inv = Integer.MIN_VALUE; + BinaryExpression expr = new BinaryExpression(BinaryOperator.GREATER, + new ColumnReference(0, TableReference.LEFT), + new ColumnReference(0, TableReference.RIGHT)); + try (Table left = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); + Table right = new Table.TestBuilder().column(6, 5, 9, 8, 10, 32).build(); + Table expected = new Table.TestBuilder() + .column( 0, 1, 2, 2, 2, 3, 4, 5, 5, 6, 7, 8, 9, 9) + .column(inv, inv, 0, 1, 3, inv, inv, 0, 1, inv, 1, inv, 0, 1) + .build(); + CompiledExpression condition = expr.compile()) { + GatherMap[] maps = left.leftJoinGatherMaps(right, condition, false); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testConditionalLeftJoinGatherMapsNulls() { + final int inv = Integer.MIN_VALUE; + BinaryExpression expr = new BinaryExpression(BinaryOperator.EQUAL, + new ColumnReference(0, TableReference.LEFT), + new ColumnReference(0, TableReference.RIGHT)); + try (Table left = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, null, null, 8) + .build(); + Table right = new Table.TestBuilder() + .column(null, null, 9, 8, 10, 32) + .build(); + Table expected = new Table.TestBuilder() + .column( 0, 1, 2, 3, 4, 5, 6, 7, 7, 8, 8, 9) // left + .column(inv, inv, 2, inv, inv, inv, inv, 0, 1, 0, 1, 3) // right + .build(); + CompiledExpression condition = expr.compile()) { + GatherMap[] maps = left.leftJoinGatherMaps(right, condition, true); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + @Test void testInnerJoinGatherMaps() { try (Table leftKeys = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); @@ -1526,6 +1583,56 @@ void testInnerJoinGatherMapsNulls() { } } + @Test + void testConditionalInnerJoinGatherMaps() { + BinaryExpression expr = new BinaryExpression(BinaryOperator.GREATER, + new ColumnReference(0, TableReference.LEFT), + new ColumnReference(0, TableReference.RIGHT)); + try (Table left = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); + Table right = new Table.TestBuilder().column(6, 5, 9, 8, 10, 32).build(); + Table expected = new Table.TestBuilder() + .column(2, 2, 2, 5, 5, 7, 9, 9) + .column(0, 1, 3, 0, 1, 1, 0, 1) + .build(); + CompiledExpression condition = expr.compile()) { + GatherMap[] maps = left.innerJoinGatherMaps(right, condition, false); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testConditionalInnerJoinGatherMapsNulls() { + BinaryExpression expr = new BinaryExpression(BinaryOperator.EQUAL, + new ColumnReference(0, TableReference.LEFT), + new ColumnReference(0, TableReference.RIGHT)); + try (Table left = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, null, null, 8) + .build(); + Table right = new Table.TestBuilder() + .column(null, null, 9, 8, 10, 32) + .build(); + Table expected = new Table.TestBuilder() + .column(2, 7, 7, 8, 8, 9) // left + .column(2, 0, 1, 0, 1, 3) // right + .build(); + CompiledExpression condition = expr.compile()) { + GatherMap[] maps = left.innerJoinGatherMaps(right, condition, true); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + @Test void testFullJoinGatherMaps() { final int inv = Integer.MIN_VALUE; @@ -1570,6 +1677,58 @@ void testFullJoinGatherMapsNulls() { } } + @Test + void testConditionalFullJoinGatherMaps() { + final int inv = Integer.MIN_VALUE; + BinaryExpression expr = new BinaryExpression(BinaryOperator.GREATER, + new ColumnReference(0, TableReference.LEFT), + new ColumnReference(0, TableReference.RIGHT)); + try (Table left = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); + Table right = new Table.TestBuilder().column(6, 5, 9, 8, 10, 32).build(); + Table expected = new Table.TestBuilder() + .column( 0, 1, 2, 2, 2, 3, 4, 5, 5, 6, 7, 8, 9, 9, inv, inv, inv) + .column(inv, inv, 0, 1, 3, inv, inv, 0, 1, inv, 1, inv, 0, 1, 2, 4, 5) + .build(); + CompiledExpression condition = expr.compile()) { + GatherMap[] maps = left.fullJoinGatherMaps(right, condition, false); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testConditionalFullJoinGatherMapsNulls() { + final int inv = Integer.MIN_VALUE; + BinaryExpression expr = new BinaryExpression(BinaryOperator.EQUAL, + new ColumnReference(0, TableReference.LEFT), + new ColumnReference(0, TableReference.RIGHT)); + try (Table left = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, null, null, 8) + .build(); + Table right = new Table.TestBuilder() + .column(null, null, 9, 8, 10, 32) + .build(); + Table expected = new Table.TestBuilder() + .column(inv, inv, 0, 1, 2, 3, 4, 5, 6, 7, 7, 8, 8, 9) // left + .column( 4, 5, inv, inv, 2, inv, inv, inv, inv, 0, 1, 0, 1, 3) // right + .build(); + CompiledExpression condition = expr.compile()) { + GatherMap[] maps = left.fullJoinGatherMaps(right, condition, true); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + @Test void testLeftSemiJoinGatherMap() { try (Table leftKeys = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); @@ -1598,6 +1757,42 @@ void testLeftSemiJoinGatherMapNulls() { } } + @Test + void testConditionalLeftSemiJoinGatherMap() { + BinaryExpression expr = new BinaryExpression(BinaryOperator.GREATER, + new ColumnReference(0, TableReference.LEFT), + new ColumnReference(0, TableReference.RIGHT)); + try (Table left = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); + Table right = new Table.TestBuilder().column(6, 5, 9, 8, 10, 32).build(); + Table expected = new Table.TestBuilder() + .column(2, 5, 7, 9) // left + .build(); + CompiledExpression condition = expr.compile(); + GatherMap map = left.leftSemiJoinGatherMap(right, condition, false)) { + verifySemiJoinGatherMap(map, expected); + } + } + + @Test + void testConditionalLeftSemiJoinGatherMapNulls() { + BinaryExpression expr = new BinaryExpression(BinaryOperator.EQUAL, + new ColumnReference(0, TableReference.LEFT), + new ColumnReference(0, TableReference.RIGHT)); + try (Table left = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, null, null, 8) + .build(); + Table right = new Table.TestBuilder() + .column(null, null, 9, 8, 10, 32) + .build(); + Table expected = new Table.TestBuilder() + .column(2, 7, 8, 9) // left + .build(); + CompiledExpression condition = expr.compile(); + GatherMap map = left.leftSemiJoinGatherMap(right, condition, true)) { + verifySemiJoinGatherMap(map, expected); + } + } + @Test void testAntiSemiJoinGatherMap() { try (Table leftKeys = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); @@ -1626,6 +1821,42 @@ void testAntiSemiJoinGatherMapNulls() { } } + @Test + void testConditionalLeftAntiJoinGatherMap() { + BinaryExpression expr = new BinaryExpression(BinaryOperator.GREATER, + new ColumnReference(0, TableReference.LEFT), + new ColumnReference(0, TableReference.RIGHT)); + try (Table left = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); + Table right = new Table.TestBuilder().column(6, 5, 9, 8, 10, 32).build(); + Table expected = new Table.TestBuilder() + .column(0, 1, 3, 4, 6, 8) // left + .build(); + CompiledExpression condition = expr.compile(); + GatherMap map = left.leftSemiJoinGatherMap(right, condition, false)) { + verifySemiJoinGatherMap(map, expected); + } + } + + @Test + void testConditionalAntiSemiJoinGatherMapNulls() { + BinaryExpression expr = new BinaryExpression(BinaryOperator.EQUAL, + new ColumnReference(0, TableReference.LEFT), + new ColumnReference(0, TableReference.RIGHT)); + try (Table left = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, null, null, 8) + .build(); + Table right = new Table.TestBuilder() + .column(null, null, 9, 8, 10, 32) + .build(); + Table expected = new Table.TestBuilder() + .column(0, 1, 3, 4, 5, 6) // left + .build(); + CompiledExpression condition = expr.compile(); + GatherMap map = left.leftAntiJoinGatherMap(right, condition, true)) { + verifySemiJoinGatherMap(map, expected); + } + } + @Test void testBoundsNulls() { boolean[] descFlags = new boolean[1]; From 437472caf9a7c0821f874a58294182f67ca9c3d2 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Thu, 29 Jul 2021 10:29:48 -0500 Subject: [PATCH 2/2] Fix tests --- java/src/test/java/ai/rapids/cudf/TableTest.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 89820ee482c..360f3c04f5b 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -1686,8 +1686,8 @@ void testConditionalFullJoinGatherMaps() { try (Table left = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); Table right = new Table.TestBuilder().column(6, 5, 9, 8, 10, 32).build(); Table expected = new Table.TestBuilder() - .column( 0, 1, 2, 2, 2, 3, 4, 5, 5, 6, 7, 8, 9, 9, inv, inv, inv) - .column(inv, inv, 0, 1, 3, inv, inv, 0, 1, inv, 1, inv, 0, 1, 2, 4, 5) + .column(inv, inv, inv, 0, 1, 2, 2, 2, 3, 4, 5, 5, 6, 7, 8, 9, 9) + .column( 2, 4, 5, inv, inv, 0, 1, 3, inv, inv, 0, 1, inv, 1, inv, 0, 1) .build(); CompiledExpression condition = expr.compile()) { GatherMap[] maps = left.fullJoinGatherMaps(right, condition, false); @@ -1832,7 +1832,7 @@ void testConditionalLeftAntiJoinGatherMap() { .column(0, 1, 3, 4, 6, 8) // left .build(); CompiledExpression condition = expr.compile(); - GatherMap map = left.leftSemiJoinGatherMap(right, condition, false)) { + GatherMap map = left.leftAntiJoinGatherMap(right, condition, false)) { verifySemiJoinGatherMap(map, expected); } }