diff --git a/ci/gpu/build.sh b/ci/gpu/build.sh index 8ebc85e5736..326fc2f1119 100755 --- a/ci/gpu/build.sh +++ b/ci/gpu/build.sh @@ -74,7 +74,7 @@ conda config --show-sources conda list --show-channel-urls gpuci_logger "Install dependencies" -gpuci_conda_retry install -y \ +gpuci_mamba_retry install -y \ "cudatoolkit=$CUDA_REL" \ "rapids-build-env=$MINOR_VERSION.*" \ "rapids-notebook-env=$MINOR_VERSION.*" \ @@ -83,8 +83,8 @@ gpuci_conda_retry install -y \ "ucx-py=0.21.*" # https://docs.rapids.ai/maintainers/depmgmt/ -# gpuci_conda_retry remove --force rapids-build-env rapids-notebook-env -# gpuci_conda_retry install -y "your-pkg=1.0.0" +# gpuci_mamba_retry remove --force rapids-build-env rapids-notebook-env +# gpuci_mamba_retry install -y "your-pkg=1.0.0" gpuci_logger "Check compiler versions" diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 887470d29cf..c82826b8c60 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -377,7 +377,9 @@ ConfigureTest(STRINGS_TEST ################################################################################################### # - structs test ---------------------------------------------------------------------------------- -ConfigureTest(STRUCTS_TEST structs/structs_column_tests.cu) +ConfigureTest(STRUCTS_TEST + structs/structs_column_tests.cpp + ) ################################################################################################### # - nvtext test ----------------------------------------------------------------------------------- diff --git a/cpp/tests/structs/structs_column_tests.cu b/cpp/tests/structs/structs_column_tests.cpp similarity index 95% rename from cpp/tests/structs/structs_column_tests.cu rename to cpp/tests/structs/structs_column_tests.cpp index e1438c33044..548284d6c87 100644 --- a/cpp/tests/structs/structs_column_tests.cu +++ b/cpp/tests/structs/structs_column_tests.cpp @@ -192,6 +192,10 @@ TYPED_TEST(TypedStructColumnWrapperTest, TestStructsContainingLists) auto struct_col = cudf::test::structs_column_wrapper{{names_col, lists_col}, {1, 1, 1, 1, 0, 0}}.release(); + EXPECT_EQ(struct_col->size(), num_rows); + EXPECT_EQ(struct_col->view().child(0).size(), num_rows); + EXPECT_EQ(struct_col->view().child(1).size(), num_rows); + // Check that the last two rows are null for all members. // For `Name` member, indices 4 and 5 are null. @@ -200,15 +204,9 @@ TYPED_TEST(TypedStructColumnWrapperTest, TestStructsContainingLists) return i < 4; })}.release(); - cudf::test::expect_columns_equivalent(struct_col->view().child(0), expected_names_col->view()); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(struct_col->view().child(0), expected_names_col->view()); // For the `List` member, indices 4, 5 should be null. - // FIXME: The way list columns are currently compared is not ideal for testing - // structs' list members. Rather than comparing for equivalence, - // column_comparator_impl currently checks that list's data (child) - // and offsets match perfectly. - // This causes two "equivalent lists" to compare unequal, if the data columns - // have different values at an index where the value is null. auto expected_last_two_lists_col = cudf::test::lists_column_wrapper{ { {1, 2, 3}, @@ -218,14 +216,11 @@ TYPED_TEST(TypedStructColumnWrapperTest, TestStructsContainingLists) {7, 8}, // Null. {9} // Null. }, - cudf::detail::make_counting_transform_iterator(0, [](auto i) { - return i == 0; - })}.release(); + cudf::detail::make_counting_transform_iterator( + 0, [](auto i) { return i < 4; })}.release(); - // FIXME: Uncomment after list comparison is fixed. - // cudf::test::expect_columns_equivalent( - // struct_col->view().child(1), - // expected_last_two_lists_col->view()); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(struct_col->view().child(1), + expected_last_two_lists_col->view()); } TYPED_TEST(TypedStructColumnWrapperTest, StructOfStructs) @@ -255,6 +250,10 @@ TYPED_TEST(TypedStructColumnWrapperTest, StructOfStructs) auto struct_2 = cudf::test::structs_column_wrapper{{is_human_col, struct_1}, {0, 1, 1, 1, 1, 1}}.release(); + EXPECT_EQ(struct_2->size(), num_rows); + EXPECT_EQ(struct_2->view().child(0).size(), num_rows); + EXPECT_EQ(struct_2->view().child(1).size(), num_rows); + // Verify that the child/grandchild columns are as expected. auto expected_names_col = cudf::test::strings_column_wrapper( @@ -327,6 +326,10 @@ TYPED_TEST(TypedStructColumnWrapperTest, TestNullMaskPropagationForNonNullStruct } .release(); + EXPECT_EQ(struct_2->size(), num_rows); + EXPECT_EQ(struct_2->view().child(0).size(), num_rows); + EXPECT_EQ(struct_2->view().child(1).size(), num_rows); + // Verify that the child/grandchild columns are as expected. // Top-struct has 1 null (at index 0). @@ -387,9 +390,9 @@ TYPED_TEST(TypedStructColumnWrapperTest, StructsWithMembersWithDifferentRowCount TYPED_TEST(TypedStructColumnWrapperTest, TestListsOfStructs) { - // Test structs with two members: + // Test list containing structs with two members // 1. Name: String - // 2. List: List + // 2. Age: TypeParam std::initializer_list names = {"Samuel Vimes", "Carrot Ironfoundersson", @@ -398,7 +401,7 @@ TYPED_TEST(TypedStructColumnWrapperTest, TestListsOfStructs) "Detritus", "Mr Slant"}; - auto num_rows{std::distance(names.begin(), names.end())}; + auto num_struct_rows{std::distance(names.begin(), names.end())}; // `Name` column has all valid values. auto names_col = cudf::test::strings_column_wrapper{names.begin(), names.end()}; @@ -410,6 +413,9 @@ TYPED_TEST(TypedStructColumnWrapperTest, TestListsOfStructs) auto struct_col = cudf::test::structs_column_wrapper({names_col, ages_col}, {1, 1, 1, 0, 0, 1}).release(); + EXPECT_EQ(struct_col->size(), num_struct_rows); + EXPECT_EQ(struct_col->view().child(0).size(), num_struct_rows); + auto expected_unchanged_struct_col = cudf::column(*struct_col); auto list_offsets_column = diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index 627a2a36e9e..96a9b608f06 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -23,6 +23,7 @@ import ai.rapids.cudf.HostColumnVector.ListType; import ai.rapids.cudf.HostColumnVector.StructData; import ai.rapids.cudf.HostColumnVector.StructType; +import ai.rapids.cudf.ast.CompiledExpression; import java.io.File; import java.math.BigDecimal; @@ -523,6 +524,26 @@ private static native long[] leftAntiJoin(long leftTable, int[] leftJoinCols, lo private static native long[] leftAntiJoinGatherMap(long leftKeys, long rightKeys, boolean compareNullsEqual) throws CudfException; + private static native long[] conditionalLeftJoinGatherMaps(long leftTable, long rightTable, + long condition, + boolean compareNullsEqual) throws CudfException; + + private static native long[] conditionalInnerJoinGatherMaps(long leftTable, long rightTable, + long condition, + boolean compareNullsEqual) throws CudfException; + + private static native long[] conditionalFullJoinGatherMaps(long leftTable, long rightTable, + long condition, + boolean compareNullsEqual) throws CudfException; + + private static native long[] conditionalLeftSemiJoinGatherMap(long leftTable, long rightTable, + long condition, + boolean compareNullsEqual) throws CudfException; + + private static native long[] conditionalLeftAntiJoinGatherMap(long leftTable, long rightTable, + long condition, + boolean compareNullsEqual) throws CudfException; + private static native long[] crossJoin(long leftTable, long rightTable) throws CudfException; private static native long[] concatenate(long[] cudfTablePointers) throws CudfException; @@ -1969,6 +1990,30 @@ public GatherMap[] leftJoinGatherMaps(Table rightKeys, boolean compareNullsEqual return buildJoinGatherMaps(gatherMapData); } + /** + * Computes the gather maps that can be used to manifest the result of a left join between + * two tables when a conditional expression is true. It is assumed this table instance holds + * the columns from the left table, and the table argument represents the columns from the + * right table. Two {@link GatherMap} instances will be returned that can be used to gather + * the left and right tables, respectively, to produce the result of the left join. + * It is the responsibility of the caller to close the resulting gather map instances. + * @param rightTable the right side table of the join in the join + * @param condition conditional expression to evaluate during the join + * @param compareNullsEqual true if null key values should match otherwise false + * @return left and right table gather maps + */ + public GatherMap[] leftJoinGatherMaps(Table rightTable, CompiledExpression condition, + boolean compareNullsEqual) { + if (getNumberOfColumns() != rightTable.getNumberOfColumns()) { + throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + + "rightKeys: " + rightTable.getNumberOfColumns()); + } + long[] gatherMapData = + conditionalLeftJoinGatherMaps(getNativeView(), rightTable.getNativeView(), + condition.getNativeHandle(), compareNullsEqual); + return buildJoinGatherMaps(gatherMapData); + } + /** * Computes the gather maps that can be used to manifest the result of an inner equi-join between * two tables. It is assumed this table instance holds the key columns from the left table, and @@ -1990,6 +2035,30 @@ public GatherMap[] innerJoinGatherMaps(Table rightKeys, boolean compareNullsEqua return buildJoinGatherMaps(gatherMapData); } + /** + * Computes the gather maps that can be used to manifest the result of an inner join between + * two tables when a conditional expression is true. It is assumed this table instance holds + * the columns from the left table, and the table argument represents the columns from the + * right table. Two {@link GatherMap} instances will be returned that can be used to gather + * the left and right tables, respectively, to produce the result of the inner join. + * It is the responsibility of the caller to close the resulting gather map instances. + * @param rightTable the right side table of the join + * @param condition conditional expression to evaluate during the join + * @param compareNullsEqual true if null key values should match otherwise false + * @return left and right table gather maps + */ + public GatherMap[] innerJoinGatherMaps(Table rightTable, CompiledExpression condition, + boolean compareNullsEqual) { + if (getNumberOfColumns() != rightTable.getNumberOfColumns()) { + throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + + "rightKeys: " + rightTable.getNumberOfColumns()); + } + long[] gatherMapData = + conditionalInnerJoinGatherMaps(getNativeView(), rightTable.getNativeView(), + condition.getNativeHandle(), compareNullsEqual); + return buildJoinGatherMaps(gatherMapData); + } + /** * Computes the gather maps that can be used to manifest the result of an full equi-join between * two tables. It is assumed this table instance holds the key columns from the left table, and @@ -2011,6 +2080,30 @@ public GatherMap[] fullJoinGatherMaps(Table rightKeys, boolean compareNullsEqual return buildJoinGatherMaps(gatherMapData); } + /** + * Computes the gather maps that can be used to manifest the result of a full join between + * two tables when a conditional expression is true. It is assumed this table instance holds + * the columns from the left table, and the table argument represents the columns from the + * right table. Two {@link GatherMap} instances will be returned that can be used to gather + * the left and right tables, respectively, to produce the result of the full join. + * It is the responsibility of the caller to close the resulting gather map instances. + * @param rightTable the right side table of the join + * @param condition conditional expression to evaluate during the join + * @param compareNullsEqual true if null key values should match otherwise false + * @return left and right table gather maps + */ + public GatherMap[] fullJoinGatherMaps(Table rightTable, CompiledExpression condition, + boolean compareNullsEqual) { + if (getNumberOfColumns() != rightTable.getNumberOfColumns()) { + throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + + "rightKeys: " + rightTable.getNumberOfColumns()); + } + long[] gatherMapData = + conditionalFullJoinGatherMaps(getNativeView(), rightTable.getNativeView(), + condition.getNativeHandle(), compareNullsEqual); + return buildJoinGatherMaps(gatherMapData); + } + private GatherMap buildSemiJoinGatherMap(long[] gatherMapData) { long bufferSize = gatherMapData[0]; long leftAddr = gatherMapData[1]; @@ -2039,6 +2132,30 @@ public GatherMap leftSemiJoinGatherMap(Table rightKeys, boolean compareNullsEqua return buildSemiJoinGatherMap(gatherMapData); } + /** + * Computes the gather map that can be used to manifest the result of a left semi join between + * two tables when a conditional expression is true. It is assumed this table instance holds + * the columns from the left table, and the table argument represents the columns from the + * right table. The {@link GatherMap} instance returned can be used to gather the left table + * to produce the result of the left semi join. + * It is the responsibility of the caller to close the resulting gather map instance. + * @param rightTable the right side table of the join + * @param condition conditional expression to evaluate during the join + * @param compareNullsEqual true if null key values should match otherwise false + * @return left table gather map + */ + public GatherMap leftSemiJoinGatherMap(Table rightTable, CompiledExpression condition, + boolean compareNullsEqual) { + if (getNumberOfColumns() != rightTable.getNumberOfColumns()) { + throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + + "rightKeys: " + rightTable.getNumberOfColumns()); + } + long[] gatherMapData = + conditionalLeftSemiJoinGatherMap(getNativeView(), rightTable.getNativeView(), + condition.getNativeHandle(), compareNullsEqual); + return buildSemiJoinGatherMap(gatherMapData); + } + /** * Computes the gather map that can be used to manifest the result of a left anti-join between * two tables. It is assumed this table instance holds the key columns from the left table, and @@ -2060,6 +2177,30 @@ public GatherMap leftAntiJoinGatherMap(Table rightKeys, boolean compareNullsEqua return buildSemiJoinGatherMap(gatherMapData); } + /** + * Computes the gather map that can be used to manifest the result of a left anti join between + * two tables when a conditional expression is true. It is assumed this table instance holds + * the columns from the left table, and the table argument represents the columns from the + * right table. The {@link GatherMap} instance returned can be used to gather the left table + * to produce the result of the left anti join. + * It is the responsibility of the caller to close the resulting gather map instance. + * @param rightTable the right side table of the join + * @param condition conditional expression to evaluate during the join + * @param compareNullsEqual true if null key values should match otherwise false + * @return left table gather map + */ + public GatherMap leftAntiJoinGatherMap(Table rightTable, CompiledExpression condition, + boolean compareNullsEqual) { + if (getNumberOfColumns() != rightTable.getNumberOfColumns()) { + throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + + "rightKeys: " + rightTable.getNumberOfColumns()); + } + long[] gatherMapData = + conditionalLeftAntiJoinGatherMap(getNativeView(), rightTable.getNativeView(), + condition.getNativeHandle(), compareNullsEqual); + return buildSemiJoinGatherMap(gatherMapData); + } + /** * Convert this table of columns into a row major format that is useful for interacting with other * systems that do row major processing of the data. Currently only fixed-width column types are diff --git a/java/src/main/java/ai/rapids/cudf/ast/CompiledExpression.java b/java/src/main/java/ai/rapids/cudf/ast/CompiledExpression.java index 0d2a0052e29..0949b09cbb0 100644 --- a/java/src/main/java/ai/rapids/cudf/ast/CompiledExpression.java +++ b/java/src/main/java/ai/rapids/cudf/ast/CompiledExpression.java @@ -94,6 +94,11 @@ public synchronized void close() { isClosed = true; } + /** Returns the native address of a compiled expression. Intended for internal cudf use only. */ + public long getNativeHandle() { + return cleaner.nativeHandle; + } + private static native long compile(byte[] serializedExpression); private static native long computeColumn(long astHandle, long tableHandle); private static native void destroy(long handle); diff --git a/java/src/main/native/src/CompiledExpression.cpp b/java/src/main/native/src/CompiledExpression.cpp index a28160b32a3..31f3184f107 100644 --- a/java/src/main/native/src/CompiledExpression.cpp +++ b/java/src/main/native/src/CompiledExpression.cpp @@ -26,59 +26,7 @@ #include #include "cudf_jni_apis.hpp" - -namespace cudf { -namespace jni { -namespace ast { - -/** - * A class to capture all of the resources associated with a compiled AST expression. - * AST nodes do not own their child nodes, so every node in the expression tree - * must be explicitly tracked in order to free the underlying resources for each node. - * - * This should be cleaned up a bit after the libcudf AST refactoring in - * https://github.com/rapidsai/cudf/pull/8815 when a virtual destructor is added to the - * base AST node type. Then we do not have to track every AST node type separately. - */ -class compiled_expr { - /** All literal nodes within the expression tree */ - std::vector> literals; - - /** All column reference nodes within the expression tree */ - std::vector> column_refs; - - /** All expression nodes within the expression tree */ - std::vector> expressions; - - /** GPU scalar instances that correspond to literal nodes */ - std::vector> scalars; - -public: - cudf::ast::literal &add_literal(std::unique_ptr literal_ptr, - std::unique_ptr scalar_ptr) { - literals.push_back(std::move(literal_ptr)); - scalars.push_back(std::move(scalar_ptr)); - return *literals.back(); - } - - cudf::ast::column_reference & - add_column_ref(std::unique_ptr ref_ptr) { - column_refs.push_back(std::move(ref_ptr)); - return *column_refs.back(); - } - - cudf::ast::expression &add_expression(std::unique_ptr expr_ptr) { - expressions.push_back(std::move(expr_ptr)); - return *expressions.back(); - } - - /** Return the expression node at the top of the tree */ - cudf::ast::expression &get_top_expression() const { return *expressions.back(); } -}; - -} // namespace ast -} // namespace jni -} // namespace cudf +#include "jni_compiled_expr.hpp" namespace { diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index 78e1bf88a1c..c092450da1c 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -43,6 +43,7 @@ #include "cudf_jni_apis.hpp" #include "dtype_utils.hpp" +#include "jni_compiled_expr.hpp" #include "jni_utils.hpp" #include "row_conversion.hpp" @@ -744,7 +745,7 @@ bool valid_window_parameters(native_jintArray const &values, values.size() == preceding.size() && values.size() == following.size(); } -// Generate gather maps needed to manifest the result of a join between two tables. +// Generate gather maps needed to manifest the result of an equi-join between two tables. // The resulting Java long array contains the following at each index: // 0: Size of each gather map in bytes // 1: Device address of the gather map for the left table @@ -779,7 +780,44 @@ jlongArray join_gather_maps(JNIEnv *env, jlong j_left_keys, jlong j_right_keys, CATCH_STD(env, NULL); } -// Generate gather maps needed to manifest the result of a join between two tables. +// Generate gather maps needed to manifest the result of a conditional join between two tables. +// The resulting Java long array contains the following at each index: +// 0: Size of each gather map in bytes +// 1: Device address of the gather map for the left table +// 2: Host address of the rmm::device_buffer instance that owns the left gather map data +// 3: Device address of the gather map for the right table +// 4: Host address of the rmm::device_buffer instance that owns the right gather map data +template +jlongArray cond_join_gather_maps(JNIEnv *env, jlong j_left_table, jlong j_right_table, + jlong j_condition, jboolean compare_nulls_equal, T join_func) { + JNI_NULL_CHECK(env, j_left_table, "left_table is null", NULL); + JNI_NULL_CHECK(env, j_right_table, "right_table is null", NULL); + JNI_NULL_CHECK(env, j_condition, "condition is null", NULL); + try { + cudf::jni::auto_set_device(env); + auto left_table = reinterpret_cast(j_left_table); + auto right_table = reinterpret_cast(j_right_table); + auto condition = reinterpret_cast(j_condition); + auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; + std::pair>, + std::unique_ptr>> + join_maps = join_func(*left_table, *right_table, condition->get_top_expression(), nulleq); + + // release the underlying device buffer to Java + auto left_map_buffer = std::make_unique(join_maps.first->release()); + auto right_map_buffer = std::make_unique(join_maps.second->release()); + cudf::jni::native_jlongArray result(env, 5); + result[0] = static_cast(left_map_buffer->size()); + result[1] = reinterpret_cast(left_map_buffer->data()); + result[2] = reinterpret_cast(left_map_buffer.release()); + result[3] = reinterpret_cast(right_map_buffer->data()); + result[4] = reinterpret_cast(right_map_buffer.release()); + return result.get_jArray(); + } + CATCH_STD(env, NULL); +} + +// Generate a gather map needed to manifest the result of a semi/anti join between two tables. // The resulting Java long array contains the following at each index: // 0: Size of the gather map in bytes // 1: Device address of the gather map @@ -808,6 +846,39 @@ jlongArray join_gather_single_map(JNIEnv *env, jlong j_left_keys, jlong j_right_ CATCH_STD(env, NULL); } +// Generate a gather map needed to manifest the result of a conditional semi/anti join +// between two tables. +// The resulting Java long array contains the following at each index: +// 0: Size of the gather map in bytes +// 1: Device address of the gather map +// 2: Host address of the rmm::device_buffer instance that owns the gather map data +template +jlongArray cond_join_gather_single_map(JNIEnv *env, jlong j_left_table, jlong j_right_table, + jlong j_condition, jboolean compare_nulls_equal, + T join_func) { + JNI_NULL_CHECK(env, j_left_table, "left_table is null", NULL); + JNI_NULL_CHECK(env, j_right_table, "right_table is null", NULL); + JNI_NULL_CHECK(env, j_condition, "condition is null", NULL); + try { + cudf::jni::auto_set_device(env); + auto left_table = reinterpret_cast(j_left_table); + auto right_table = reinterpret_cast(j_right_table); + auto condition = reinterpret_cast(j_condition); + auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; + std::unique_ptr> join_map = + join_func(*left_table, *right_table, condition->get_top_expression(), nulleq); + + // release the underlying device buffer to Java + auto gather_map_buffer = std::make_unique(join_map->release()); + cudf::jni::native_jlongArray result(env, 3); + result[0] = static_cast(gather_map_buffer->size()); + result[1] = reinterpret_cast(gather_map_buffer->data()); + result[2] = reinterpret_cast(gather_map_buffer.release()); + return result.get_jArray(); + } + CATCH_STD(env, NULL); +} + // Returns a table view containing only the columns at the specified indices cudf::table_view const get_keys_table(cudf::table_view const *t, native_jintArray const &key_indices) { @@ -1853,6 +1924,17 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftJoinGatherMaps( }); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalLeftJoinGatherMaps( + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, + jboolean compare_nulls_equal) { + return cudf::jni::cond_join_gather_maps( + env, j_left_table, j_right_table, j_condition, compare_nulls_equal, + [](cudf::table_view const &left, cudf::table_view const &right, + cudf::ast::expression const &cond_expr, cudf::null_equality nulleq) { + return cudf::conditional_left_join(left, right, cond_expr, nulleq); + }); +} + JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_innerJoinGatherMaps( JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jboolean compare_nulls_equal) { return cudf::jni::join_gather_maps( @@ -1862,6 +1944,17 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_innerJoinGatherMaps( }); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalInnerJoinGatherMaps( + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, + jboolean compare_nulls_equal) { + return cudf::jni::cond_join_gather_maps( + env, j_left_table, j_right_table, j_condition, compare_nulls_equal, + [](cudf::table_view const &left, cudf::table_view const &right, + cudf::ast::expression const &cond_expr, cudf::null_equality nulleq) { + return cudf::conditional_inner_join(left, right, cond_expr, nulleq); + }); +} + JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_fullJoinGatherMaps( JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jboolean compare_nulls_equal) { return cudf::jni::join_gather_maps( @@ -1871,6 +1964,17 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_fullJoinGatherMaps( }); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalFullJoinGatherMaps( + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, + jboolean compare_nulls_equal) { + return cudf::jni::cond_join_gather_maps( + env, j_left_table, j_right_table, j_condition, compare_nulls_equal, + [](cudf::table_view const &left, cudf::table_view const &right, + cudf::ast::expression const &cond_expr, cudf::null_equality nulleq) { + return cudf::conditional_full_join(left, right, cond_expr, nulleq); + }); +} + JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftSemiJoinGatherMap( JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jboolean compare_nulls_equal) { return cudf::jni::join_gather_single_map( @@ -1880,6 +1984,17 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftSemiJoinGatherMap( }); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalLeftSemiJoinGatherMap( + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, + jboolean compare_nulls_equal) { + return cudf::jni::cond_join_gather_single_map( + env, j_left_table, j_right_table, j_condition, compare_nulls_equal, + [](cudf::table_view const &left, cudf::table_view const &right, + cudf::ast::expression const &cond_expr, cudf::null_equality nulleq) { + return cudf::conditional_left_semi_join(left, right, cond_expr, nulleq); + }); +} + JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftAntiJoinGatherMap( JNIEnv *env, jclass, jlong j_left_keys, jlong j_right_keys, jboolean compare_nulls_equal) { return cudf::jni::join_gather_single_map( @@ -1889,6 +2004,17 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftAntiJoinGatherMap( }); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalLeftAntiJoinGatherMap( + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_table, jlong j_condition, + jboolean compare_nulls_equal) { + return cudf::jni::cond_join_gather_single_map( + env, j_left_table, j_right_table, j_condition, compare_nulls_equal, + [](cudf::table_view const &left, cudf::table_view const &right, + cudf::ast::expression const &cond_expr, cudf::null_equality nulleq) { + return cudf::conditional_left_anti_join(left, right, cond_expr, nulleq); + }); +} + JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_crossJoin(JNIEnv *env, jclass, jlong left_table, jlong right_table) { diff --git a/java/src/main/native/src/jni_compiled_expr.hpp b/java/src/main/native/src/jni_compiled_expr.hpp new file mode 100644 index 00000000000..e42e5a37fba --- /dev/null +++ b/java/src/main/native/src/jni_compiled_expr.hpp @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace cudf { +namespace jni { +namespace ast { + +/** + * A class to capture all of the resources associated with a compiled AST expression. + * AST nodes do not own their child nodes, so every node in the expression tree + * must be explicitly tracked in order to free the underlying resources for each node. + * + * This should be cleaned up a bit after the libcudf AST refactoring in + * https://github.com/rapidsai/cudf/pull/8815 when a virtual destructor is added to the + * base AST node type. Then we do not have to track every AST node type separately. + */ +class compiled_expr { + /** All literal nodes within the expression tree */ + std::vector> literals; + + /** All column reference nodes within the expression tree */ + std::vector> column_refs; + + /** All expression nodes within the expression tree */ + std::vector> expressions; + + /** GPU scalar instances that correspond to literal nodes */ + std::vector> scalars; + +public: + cudf::ast::literal &add_literal(std::unique_ptr literal_ptr, + std::unique_ptr scalar_ptr) { + literals.push_back(std::move(literal_ptr)); + scalars.push_back(std::move(scalar_ptr)); + return *literals.back(); + } + + cudf::ast::column_reference & + add_column_ref(std::unique_ptr ref_ptr) { + column_refs.push_back(std::move(ref_ptr)); + return *column_refs.back(); + } + + cudf::ast::expression &add_expression(std::unique_ptr expr_ptr) { + expressions.push_back(std::move(expr_ptr)); + return *expressions.back(); + } + + /** Return the expression node at the top of the tree */ + cudf::ast::expression &get_top_expression() const { return *expressions.back(); } +}; + +} // namespace ast +} // namespace jni +} // namespace cudf diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 7507bc8a286..360f3c04f5b 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -25,6 +25,11 @@ import ai.rapids.cudf.HostColumnVector.StructData; import ai.rapids.cudf.HostColumnVector.StructType; +import ai.rapids.cudf.ast.BinaryExpression; +import ai.rapids.cudf.ast.BinaryOperator; +import ai.rapids.cudf.ast.ColumnReference; +import ai.rapids.cudf.ast.CompiledExpression; +import ai.rapids.cudf.ast.TableReference; import org.junit.jupiter.api.Test; import java.io.ByteArrayInputStream; @@ -1484,6 +1489,58 @@ void testLeftJoinGatherMapsNulls() { } } + @Test + void testConditionalLeftJoinGatherMaps() { + final int inv = Integer.MIN_VALUE; + BinaryExpression expr = new BinaryExpression(BinaryOperator.GREATER, + new ColumnReference(0, TableReference.LEFT), + new ColumnReference(0, TableReference.RIGHT)); + try (Table left = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); + Table right = new Table.TestBuilder().column(6, 5, 9, 8, 10, 32).build(); + Table expected = new Table.TestBuilder() + .column( 0, 1, 2, 2, 2, 3, 4, 5, 5, 6, 7, 8, 9, 9) + .column(inv, inv, 0, 1, 3, inv, inv, 0, 1, inv, 1, inv, 0, 1) + .build(); + CompiledExpression condition = expr.compile()) { + GatherMap[] maps = left.leftJoinGatherMaps(right, condition, false); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testConditionalLeftJoinGatherMapsNulls() { + final int inv = Integer.MIN_VALUE; + BinaryExpression expr = new BinaryExpression(BinaryOperator.EQUAL, + new ColumnReference(0, TableReference.LEFT), + new ColumnReference(0, TableReference.RIGHT)); + try (Table left = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, null, null, 8) + .build(); + Table right = new Table.TestBuilder() + .column(null, null, 9, 8, 10, 32) + .build(); + Table expected = new Table.TestBuilder() + .column( 0, 1, 2, 3, 4, 5, 6, 7, 7, 8, 8, 9) // left + .column(inv, inv, 2, inv, inv, inv, inv, 0, 1, 0, 1, 3) // right + .build(); + CompiledExpression condition = expr.compile()) { + GatherMap[] maps = left.leftJoinGatherMaps(right, condition, true); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + @Test void testInnerJoinGatherMaps() { try (Table leftKeys = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); @@ -1526,6 +1583,56 @@ void testInnerJoinGatherMapsNulls() { } } + @Test + void testConditionalInnerJoinGatherMaps() { + BinaryExpression expr = new BinaryExpression(BinaryOperator.GREATER, + new ColumnReference(0, TableReference.LEFT), + new ColumnReference(0, TableReference.RIGHT)); + try (Table left = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); + Table right = new Table.TestBuilder().column(6, 5, 9, 8, 10, 32).build(); + Table expected = new Table.TestBuilder() + .column(2, 2, 2, 5, 5, 7, 9, 9) + .column(0, 1, 3, 0, 1, 1, 0, 1) + .build(); + CompiledExpression condition = expr.compile()) { + GatherMap[] maps = left.innerJoinGatherMaps(right, condition, false); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testConditionalInnerJoinGatherMapsNulls() { + BinaryExpression expr = new BinaryExpression(BinaryOperator.EQUAL, + new ColumnReference(0, TableReference.LEFT), + new ColumnReference(0, TableReference.RIGHT)); + try (Table left = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, null, null, 8) + .build(); + Table right = new Table.TestBuilder() + .column(null, null, 9, 8, 10, 32) + .build(); + Table expected = new Table.TestBuilder() + .column(2, 7, 7, 8, 8, 9) // left + .column(2, 0, 1, 0, 1, 3) // right + .build(); + CompiledExpression condition = expr.compile()) { + GatherMap[] maps = left.innerJoinGatherMaps(right, condition, true); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + @Test void testFullJoinGatherMaps() { final int inv = Integer.MIN_VALUE; @@ -1570,6 +1677,58 @@ void testFullJoinGatherMapsNulls() { } } + @Test + void testConditionalFullJoinGatherMaps() { + final int inv = Integer.MIN_VALUE; + BinaryExpression expr = new BinaryExpression(BinaryOperator.GREATER, + new ColumnReference(0, TableReference.LEFT), + new ColumnReference(0, TableReference.RIGHT)); + try (Table left = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); + Table right = new Table.TestBuilder().column(6, 5, 9, 8, 10, 32).build(); + Table expected = new Table.TestBuilder() + .column(inv, inv, inv, 0, 1, 2, 2, 2, 3, 4, 5, 5, 6, 7, 8, 9, 9) + .column( 2, 4, 5, inv, inv, 0, 1, 3, inv, inv, 0, 1, inv, 1, inv, 0, 1) + .build(); + CompiledExpression condition = expr.compile()) { + GatherMap[] maps = left.fullJoinGatherMaps(right, condition, false); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + + @Test + void testConditionalFullJoinGatherMapsNulls() { + final int inv = Integer.MIN_VALUE; + BinaryExpression expr = new BinaryExpression(BinaryOperator.EQUAL, + new ColumnReference(0, TableReference.LEFT), + new ColumnReference(0, TableReference.RIGHT)); + try (Table left = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, null, null, 8) + .build(); + Table right = new Table.TestBuilder() + .column(null, null, 9, 8, 10, 32) + .build(); + Table expected = new Table.TestBuilder() + .column(inv, inv, 0, 1, 2, 3, 4, 5, 6, 7, 7, 8, 8, 9) // left + .column( 4, 5, inv, inv, 2, inv, inv, inv, inv, 0, 1, 0, 1, 3) // right + .build(); + CompiledExpression condition = expr.compile()) { + GatherMap[] maps = left.fullJoinGatherMaps(right, condition, true); + try { + verifyJoinGatherMaps(maps, expected); + } finally { + for (GatherMap map : maps) { + map.close(); + } + } + } + } + @Test void testLeftSemiJoinGatherMap() { try (Table leftKeys = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); @@ -1598,6 +1757,42 @@ void testLeftSemiJoinGatherMapNulls() { } } + @Test + void testConditionalLeftSemiJoinGatherMap() { + BinaryExpression expr = new BinaryExpression(BinaryOperator.GREATER, + new ColumnReference(0, TableReference.LEFT), + new ColumnReference(0, TableReference.RIGHT)); + try (Table left = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); + Table right = new Table.TestBuilder().column(6, 5, 9, 8, 10, 32).build(); + Table expected = new Table.TestBuilder() + .column(2, 5, 7, 9) // left + .build(); + CompiledExpression condition = expr.compile(); + GatherMap map = left.leftSemiJoinGatherMap(right, condition, false)) { + verifySemiJoinGatherMap(map, expected); + } + } + + @Test + void testConditionalLeftSemiJoinGatherMapNulls() { + BinaryExpression expr = new BinaryExpression(BinaryOperator.EQUAL, + new ColumnReference(0, TableReference.LEFT), + new ColumnReference(0, TableReference.RIGHT)); + try (Table left = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, null, null, 8) + .build(); + Table right = new Table.TestBuilder() + .column(null, null, 9, 8, 10, 32) + .build(); + Table expected = new Table.TestBuilder() + .column(2, 7, 8, 9) // left + .build(); + CompiledExpression condition = expr.compile(); + GatherMap map = left.leftSemiJoinGatherMap(right, condition, true)) { + verifySemiJoinGatherMap(map, expected); + } + } + @Test void testAntiSemiJoinGatherMap() { try (Table leftKeys = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); @@ -1626,6 +1821,42 @@ void testAntiSemiJoinGatherMapNulls() { } } + @Test + void testConditionalLeftAntiJoinGatherMap() { + BinaryExpression expr = new BinaryExpression(BinaryOperator.GREATER, + new ColumnReference(0, TableReference.LEFT), + new ColumnReference(0, TableReference.RIGHT)); + try (Table left = new Table.TestBuilder().column(2, 3, 9, 0, 1, 7, 4, 6, 5, 8).build(); + Table right = new Table.TestBuilder().column(6, 5, 9, 8, 10, 32).build(); + Table expected = new Table.TestBuilder() + .column(0, 1, 3, 4, 6, 8) // left + .build(); + CompiledExpression condition = expr.compile(); + GatherMap map = left.leftAntiJoinGatherMap(right, condition, false)) { + verifySemiJoinGatherMap(map, expected); + } + } + + @Test + void testConditionalAntiSemiJoinGatherMapNulls() { + BinaryExpression expr = new BinaryExpression(BinaryOperator.EQUAL, + new ColumnReference(0, TableReference.LEFT), + new ColumnReference(0, TableReference.RIGHT)); + try (Table left = new Table.TestBuilder() + .column(2, 3, 9, 0, 1, 7, 4, null, null, 8) + .build(); + Table right = new Table.TestBuilder() + .column(null, null, 9, 8, 10, 32) + .build(); + Table expected = new Table.TestBuilder() + .column(0, 1, 3, 4, 5, 6) // left + .build(); + CompiledExpression condition = expr.compile(); + GatherMap map = left.leftAntiJoinGatherMap(right, condition, true)) { + verifySemiJoinGatherMap(map, expected); + } + } + @Test void testBoundsNulls() { boolean[] descFlags = new boolean[1]; diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index d449d52927e..a5e49b026f3 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -70,6 +70,7 @@ get_time_unit, min_unsigned_type, np_to_pa_dtype, + pandas_dtypes_to_cudf_dtypes, ) from cudf.utils.utils import mask_dtype @@ -877,6 +878,10 @@ def can_cast_safely(self, to_dtype: Dtype) -> bool: raise NotImplementedError() def astype(self, dtype: Dtype, **kwargs) -> ColumnBase: + if is_categorical_dtype(dtype): + return self.as_categorical_column(dtype, **kwargs) + + dtype = pandas_dtypes_to_cudf_dtypes.get(dtype, dtype) if _is_non_decimal_numeric_dtype(dtype): return self.as_numerical_column(dtype, **kwargs) elif is_categorical_dtype(dtype): diff --git a/python/cudf/cudf/tests/test_dataframe.py b/python/cudf/cudf/tests/test_dataframe.py index ecd31afd9e8..9acf6783095 100644 --- a/python/cudf/cudf/tests/test_dataframe.py +++ b/python/cudf/cudf/tests/test_dataframe.py @@ -3633,6 +3633,23 @@ def test_one_row_head(): assert_eq(head_pdf, head_gdf) +@pytest.mark.parametrize("dtype", ALL_TYPES) +@pytest.mark.parametrize( + "np_dtype,pd_dtype", + [ + tuple(item) + for item in cudf.utils.dtypes.cudf_dtypes_to_pandas_dtypes.items() + ], +) +def test_series_astype_pandas_nullable(dtype, np_dtype, pd_dtype): + source = cudf.Series([0, 1, None], dtype=dtype) + + expect = source.astype(np_dtype) + got = source.astype(pd_dtype) + + assert_eq(expect, got) + + @pytest.mark.parametrize("dtype", NUMERIC_TYPES) @pytest.mark.parametrize("as_dtype", NUMERIC_TYPES) def test_series_astype_numeric_to_numeric(dtype, as_dtype): diff --git a/python/cudf/cudf/tests/test_joining.py b/python/cudf/cudf/tests/test_joining.py index 7b56f864272..4ae7c40ead8 100644 --- a/python/cudf/cudf/tests/test_joining.py +++ b/python/cudf/cudf/tests/test_joining.py @@ -1529,7 +1529,7 @@ def test_categorical_typecast_inner_one_cat(dtype): data = np.array([1, 2, 3], dtype=dtype) left = make_categorical_dataframe(data) - right = left.astype(left["key"].dtype.categories) + right = left.astype(left["key"].dtype.categories.dtype) result = left.merge(right, on="key", how="inner") assert result["key"].dtype == left["key"].dtype.categories.dtype @@ -1541,7 +1541,7 @@ def test_categorical_typecast_left_one_cat(dtype): data = np.array([1, 2, 3], dtype=dtype) left = make_categorical_dataframe(data) - right = left.astype(left["key"].dtype.categories) + right = left.astype(left["key"].dtype.categories.dtype) result = left.merge(right, on="key", how="left") assert result["key"].dtype == left["key"].dtype @@ -1553,7 +1553,7 @@ def test_categorical_typecast_outer_one_cat(dtype): data = np.array([1, 2, 3], dtype=dtype) left = make_categorical_dataframe(data) - right = left.astype(left["key"].dtype.categories) + right = left.astype(left["key"].dtype.categories.dtype) result = left.merge(right, on="key", how="outer") assert result["key"].dtype == left["key"].dtype.categories.dtype