diff --git a/cpp/benchmarks/join/join.cu b/cpp/benchmarks/join/join.cu index 55a1e524479..f21356aff02 100644 --- a/cpp/benchmarks/join/join.cu +++ b/cpp/benchmarks/join/join.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -49,7 +49,7 @@ void nvbench_inner_join(nvbench::state& state, cudf::null_equality compare_nulls, rmm::cuda_stream_view stream) { cudf::hash_join hj_obj(left_input.select(left_on), compare_nulls, stream); - return hj_obj.inner_join(right_input.select(right_on), compare_nulls, std::nullopt, stream); + return hj_obj.inner_join(right_input.select(right_on), std::nullopt, stream); }; BM_join(state, join); @@ -71,7 +71,7 @@ void nvbench_left_join(nvbench::state& state, cudf::null_equality compare_nulls, rmm::cuda_stream_view stream) { cudf::hash_join hj_obj(left_input.select(left_on), compare_nulls, stream); - return hj_obj.left_join(right_input.select(right_on), compare_nulls, std::nullopt, stream); + return hj_obj.left_join(right_input.select(right_on), std::nullopt, stream); }; BM_join(state, join); @@ -93,7 +93,7 @@ void nvbench_full_join(nvbench::state& state, cudf::null_equality compare_nulls, rmm::cuda_stream_view stream) { cudf::hash_join hj_obj(left_input.select(left_on), compare_nulls, stream); - return hj_obj.full_join(right_input.select(right_on), compare_nulls, std::nullopt, stream); + return hj_obj.full_join(right_input.select(right_on), std::nullopt, stream); }; BM_join(state, join); diff --git a/cpp/include/cudf/join.hpp b/cpp/include/cudf/join.hpp index f6efea5f2bb..d56f8f0e904 100644 --- a/cpp/include/cudf/join.hpp +++ b/cpp/include/cudf/join.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -530,7 +530,6 @@ class hash_join { * provided `output_size` is smaller than the actual output size. * * @param probe The probe table, from which the tuples are probed. - * @param compare_nulls Controls whether null join-key values should match or not. * @param output_size Optional value which allows users to specify the exact output size. * @param stream CUDA stream used for device memory operations and kernel launches * @param mr Device memory resource used to allocate the returned table and columns' device @@ -543,7 +542,6 @@ class hash_join { std::pair>, std::unique_ptr>> inner_join(cudf::table_view const& probe, - null_equality compare_nulls = null_equality::EQUAL, std::optional output_size = {}, rmm::cuda_stream_view stream = rmm::cuda_stream_default, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) const; @@ -554,7 +552,6 @@ class hash_join { * provided `output_size` is smaller than the actual output size. * * @param probe The probe table, from which the tuples are probed. - * @param compare_nulls Controls whether null join-key values should match or not. * @param output_size Optional value which allows users to specify the exact output size. * @param stream CUDA stream used for device memory operations and kernel launches * @param mr Device memory resource used to allocate the returned table and columns' device @@ -567,7 +564,6 @@ class hash_join { std::pair>, std::unique_ptr>> left_join(cudf::table_view const& probe, - null_equality compare_nulls = null_equality::EQUAL, std::optional output_size = {}, rmm::cuda_stream_view stream = rmm::cuda_stream_default, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) const; @@ -578,7 +574,6 @@ class hash_join { * provided `output_size` is smaller than the actual output size. * * @param probe The probe table, from which the tuples are probed. - * @param compare_nulls Controls whether null join-key values should match or not. * @param output_size Optional value which allows users to specify the exact output size. * @param stream CUDA stream used for device memory operations and kernel launches * @param mr Device memory resource used to allocate the returned table and columns' device @@ -591,7 +586,6 @@ class hash_join { std::pair>, std::unique_ptr>> full_join(cudf::table_view const& probe, - null_equality compare_nulls = null_equality::EQUAL, std::optional output_size = {}, rmm::cuda_stream_view stream = rmm::cuda_stream_default, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) const; @@ -601,39 +595,32 @@ class hash_join { * probe table. * * @param probe The probe table, from which the tuples are probed. - * @param compare_nulls Controls whether null join-key values should match or not. * @param stream CUDA stream used for device memory operations and kernel launches * * @return The exact number of output when performing an inner join between two tables with * `build` and `probe` as the the join keys . */ [[nodiscard]] std::size_t inner_join_size( - cudf::table_view const& probe, - null_equality compare_nulls = null_equality::EQUAL, - rmm::cuda_stream_view stream = rmm::cuda_stream_default) const; + cudf::table_view const& probe, rmm::cuda_stream_view stream = rmm::cuda_stream_default) const; /** * Returns the exact number of matches (rows) when performing a left join with the specified probe * table. * * @param probe The probe table, from which the tuples are probed. - * @param compare_nulls Controls whether null join-key values should match or not. * @param stream CUDA stream used for device memory operations and kernel launches * * @return The exact number of output when performing a left join between two tables with `build` * and `probe` as the the join keys . */ [[nodiscard]] std::size_t left_join_size( - cudf::table_view const& probe, - null_equality compare_nulls = null_equality::EQUAL, - rmm::cuda_stream_view stream = rmm::cuda_stream_default) const; + cudf::table_view const& probe, rmm::cuda_stream_view stream = rmm::cuda_stream_default) const; /** * Returns the exact number of matches (rows) when performing a full join with the specified probe * table. * * @param probe The probe table, from which the tuples are probed. - * @param compare_nulls Controls whether null join-key values should match or not. * @param stream CUDA stream used for device memory operations and kernel launches * @param mr Device memory resource used to allocate the intermediate table and columns' device * memory. @@ -643,7 +630,6 @@ class hash_join { */ std::size_t full_join_size( cudf::table_view const& probe, - null_equality compare_nulls = null_equality::EQUAL, rmm::cuda_stream_view stream = rmm::cuda_stream_default, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) const; diff --git a/cpp/src/join/hash_join.cu b/cpp/src/join/hash_join.cu index 7590c93f0c3..b89bcabf23e 100644 --- a/cpp/src/join/hash_join.cu +++ b/cpp/src/join/hash_join.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -140,8 +140,8 @@ probe_join_hash_table(cudf::table_device_view build_table, std::size_t get_full_join_size(cudf::table_device_view build_table, cudf::table_device_view probe_table, multimap_type const& hash_table, - bool has_nulls, - null_equality compare_nulls, + bool const has_nulls, + null_equality const compare_nulls, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { @@ -235,6 +235,7 @@ hash_join::hash_join_impl::hash_join_impl(cudf::table_view const& build, null_equality compare_nulls, rmm::cuda_stream_view stream) : _is_empty{build.num_rows() == 0}, + _nulls_equal{compare_nulls}, _hash_table{compute_hash_table_size(build.num_rows()), std::numeric_limits::max(), cudf::detail::JoinNoneValue, @@ -253,50 +254,43 @@ hash_join::hash_join_impl::hash_join_impl(cudf::table_view const& build, if (_is_empty) { return; } - build_join_hash_table(_build, _hash_table, compare_nulls, stream); + cudf::detail::build_join_hash_table(_build, _hash_table, _nulls_equal, stream); } std::pair>, std::unique_ptr>> hash_join::hash_join_impl::inner_join(cudf::table_view const& probe, - null_equality compare_nulls, std::optional output_size, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) const { CUDF_FUNC_RANGE(); - return compute_hash_join( - probe, compare_nulls, output_size, stream, mr); + return compute_hash_join(probe, output_size, stream, mr); } std::pair>, std::unique_ptr>> hash_join::hash_join_impl::left_join(cudf::table_view const& probe, - null_equality compare_nulls, std::optional output_size, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) const { CUDF_FUNC_RANGE(); - return compute_hash_join( - probe, compare_nulls, output_size, stream, mr); + return compute_hash_join(probe, output_size, stream, mr); } std::pair>, std::unique_ptr>> hash_join::hash_join_impl::full_join(cudf::table_view const& probe, - null_equality compare_nulls, std::optional output_size, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) const { CUDF_FUNC_RANGE(); - return compute_hash_join( - probe, compare_nulls, output_size, stream, mr); + return compute_hash_join(probe, output_size, stream, mr); } std::size_t hash_join::hash_join_impl::inner_join_size(cudf::table_view const& probe, - null_equality compare_nulls, rmm::cuda_stream_view stream) const { CUDF_FUNC_RANGE(); @@ -316,12 +310,11 @@ std::size_t hash_join::hash_join_impl::inner_join_size(cudf::table_view const& p *flattened_probe_table_ptr, _hash_table, cudf::has_nulls(flattened_probe_table) | cudf::has_nulls(_build), - compare_nulls, + _nulls_equal, stream); } std::size_t hash_join::hash_join_impl::left_join_size(cudf::table_view const& probe, - null_equality compare_nulls, rmm::cuda_stream_view stream) const { CUDF_FUNC_RANGE(); @@ -341,12 +334,11 @@ std::size_t hash_join::hash_join_impl::left_join_size(cudf::table_view const& pr *flattened_probe_table_ptr, _hash_table, cudf::has_nulls(flattened_probe_table) | cudf::has_nulls(_build), - compare_nulls, + _nulls_equal, stream); } std::size_t hash_join::hash_join_impl::full_join_size(cudf::table_view const& probe, - null_equality compare_nulls, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) const { @@ -362,20 +354,20 @@ std::size_t hash_join::hash_join_impl::full_join_size(cudf::table_view const& pr auto build_table_ptr = cudf::table_device_view::create(_build, stream); auto flattened_probe_table_ptr = cudf::table_device_view::create(flattened_probe_table, stream); - return get_full_join_size(*build_table_ptr, - *flattened_probe_table_ptr, - _hash_table, - cudf::has_nulls(flattened_probe_table) | cudf::has_nulls(_build), - compare_nulls, - stream, - mr); + return cudf::detail::get_full_join_size( + *build_table_ptr, + *flattened_probe_table_ptr, + _hash_table, + cudf::has_nulls(flattened_probe_table) | cudf::has_nulls(_build), + _nulls_equal, + stream, + mr); } template std::pair>, std::unique_ptr>> hash_join::hash_join_impl::compute_hash_join(cudf::table_view const& probe, - null_equality compare_nulls, std::optional output_size, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) const @@ -403,42 +395,40 @@ hash_join::hash_join_impl::compute_hash_join(cudf::table_view const& probe, [](const auto& b, const auto& p) { return b.type() == p.type(); }), "Mismatch in joining column data types"); - return probe_join_indices( - flattened_probe_table, compare_nulls, output_size, stream, mr); + return probe_join_indices(flattened_probe_table, output_size, stream, mr); } template std::pair>, std::unique_ptr>> -hash_join::hash_join_impl::probe_join_indices(cudf::table_view const& probe, - null_equality compare_nulls, +hash_join::hash_join_impl::probe_join_indices(cudf::table_view const& probe_table, std::optional output_size, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) const { // Trivial left join case - exit early if (_is_empty and JoinKind != cudf::detail::join_kind::INNER_JOIN) { - return get_trivial_left_join_indices(probe, stream, mr); + return get_trivial_left_join_indices(probe_table, stream, mr); } CUDF_EXPECTS(!_is_empty, "Hash table of hash join is null."); auto build_table_ptr = cudf::table_device_view::create(_build, stream); - auto probe_table_ptr = cudf::table_device_view::create(probe, stream); - - auto join_indices = - cudf::detail::probe_join_hash_table(*build_table_ptr, - *probe_table_ptr, - _hash_table, - cudf::has_nulls(probe) | cudf::has_nulls(_build), - compare_nulls, - output_size, - stream, - mr); + auto probe_table_ptr = cudf::table_device_view::create(probe_table, stream); + + auto join_indices = cudf::detail::probe_join_hash_table( + *build_table_ptr, + *probe_table_ptr, + _hash_table, + cudf::has_nulls(probe_table) | cudf::has_nulls(_build), + _nulls_equal, + output_size, + stream, + mr); if constexpr (JoinKind == cudf::detail::join_kind::FULL_JOIN) { auto complement_indices = detail::get_left_join_indices_complement( - join_indices.second, probe.num_rows(), _build.num_rows(), stream, mr); + join_indices.second, probe_table.num_rows(), _build.num_rows(), stream, mr); join_indices = detail::concatenate_vector_pairs(join_indices, complement_indices, stream); } return join_indices; diff --git a/cpp/src/join/hash_join.cuh b/cpp/src/join/hash_join.cuh index 21bfd8120f7..9c44aeebd59 100644 --- a/cpp/src/join/hash_join.cuh +++ b/cpp/src/join/hash_join.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -89,7 +89,7 @@ class make_pair_function { * @param probe_table The left hand table * @param hash_table A hash table built on the build table that maps the index * of every row to the hash value of that row. - * @param compare_nulls Controls whether null join-key values should match or not. + * @param nulls_equal Flag to denote nulls are equal or not. * @param stream CUDA stream used for device memory operations and kernel launches * * @return The exact size of the output of the join operation @@ -98,8 +98,8 @@ template std::size_t compute_join_output_size(table_device_view build_table, table_device_view probe_table, multimap_type const& hash_table, - bool has_nulls, - null_equality compare_nulls, + bool const has_nulls, + cudf::null_equality const nulls_equal, rmm::cuda_stream_view stream) { const size_type build_table_num_rows{build_table.num_rows()}; @@ -121,7 +121,7 @@ std::size_t compute_join_output_size(table_device_view build_table, } auto const probe_nulls = cudf::nullate::DYNAMIC{has_nulls}; - pair_equality equality{probe_table, build_table, probe_nulls, compare_nulls}; + pair_equality equality{probe_table, build_table, probe_nulls, nulls_equal}; row_hash hash_probe{probe_nulls, probe_table}; auto const empty_key_sentinel = hash_table.get_empty_key_sentinel(); @@ -152,14 +152,14 @@ std::unique_ptr combine_table_pair(std::unique_ptr&& l * * @param build Table of columns used to build join hash. * @param hash_table Build hash table. - * @param compare_nulls Controls whether null join-key values should match or not. + * @param nulls_equal Flag to denote nulls are equal or not. * @param stream CUDA stream used for device memory operations and kernel launches. * */ template void build_join_hash_table(cudf::table_view const& build, MultimapType& hash_table, - null_equality compare_nulls, + null_equality const nulls_equal, rmm::cuda_stream_view stream) { auto build_table_ptr = cudf::table_device_view::create(build, stream); @@ -174,7 +174,7 @@ void build_join_hash_table(cudf::table_view const& build, auto iter = cudf::detail::make_counting_transform_iterator(0, pair_func); size_type const build_table_num_rows{build_table_ptr->num_rows()}; - if ((compare_nulls == null_equality::EQUAL) or (not nullable(build))) { + if (nulls_equal == cudf::null_equality::EQUAL or (not nullable(build))) { hash_table.insert(iter, iter + build_table_num_rows, stream.value()); } else { thrust::counting_iterator stencil(0); @@ -197,7 +197,8 @@ struct hash_join::hash_join_impl { hash_join_impl& operator=(hash_join_impl&&) = delete; private: - bool _is_empty; + bool const _is_empty; + cudf::null_equality const _nulls_equal; cudf::table_view _build; std::vector> _created_null_columns; cudf::structs::detail::flattened_table _flattened_build_table; @@ -221,7 +222,6 @@ struct hash_join::hash_join_impl { std::pair>, std::unique_ptr>> inner_join(cudf::table_view const& probe, - null_equality compare_nulls, std::optional output_size, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) const; @@ -229,7 +229,6 @@ struct hash_join::hash_join_impl { std::pair>, std::unique_ptr>> left_join(cudf::table_view const& probe, - null_equality compare_nulls, std::optional output_size, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) const; @@ -237,21 +236,17 @@ struct hash_join::hash_join_impl { std::pair>, std::unique_ptr>> full_join(cudf::table_view const& probe, - null_equality compare_nulls, std::optional output_size, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) const; [[nodiscard]] std::size_t inner_join_size(cudf::table_view const& probe, - null_equality compare_nulls, rmm::cuda_stream_view stream) const; [[nodiscard]] std::size_t left_join_size(cudf::table_view const& probe, - null_equality compare_nulls, rmm::cuda_stream_view stream) const; std::size_t full_join_size(cudf::table_view const& probe, - null_equality compare_nulls, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) const; @@ -260,7 +255,6 @@ struct hash_join::hash_join_impl { std::pair>, std::unique_ptr>> compute_hash_join(cudf::table_view const& probe, - null_equality compare_nulls, std::optional output_size, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) const; @@ -276,7 +270,6 @@ struct hash_join::hash_join_impl { * @tparam JoinKind The type of join to be performed. * * @param probe_table Table of probe side columns to join. - * @param compare_nulls Controls whether null join-key values should match or not. * @param output_size Optional value which allows users to specify the exact output size. * @param stream CUDA stream used for device memory operations and kernel launches. * @param mr Device memory resource used to allocate the returned vectors. @@ -286,8 +279,7 @@ struct hash_join::hash_join_impl { template std::pair>, std::unique_ptr>> - probe_join_indices(cudf::table_view const& probe, - null_equality compare_nulls, + probe_join_indices(cudf::table_view const& probe_table, std::optional output_size, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) const; diff --git a/cpp/src/join/join.cu b/cpp/src/join/join.cu index db79075d864..ef9e7867a2d 100644 --- a/cpp/src/join/join.cu +++ b/cpp/src/join/join.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -51,11 +51,11 @@ inner_join(table_view const& left_input, // build the hash map from the smaller table. if (right.num_rows() > left.num_rows()) { cudf::hash_join hj_obj(left, compare_nulls, stream); - auto result = hj_obj.inner_join(right, compare_nulls, std::nullopt, stream, mr); + auto result = hj_obj.inner_join(right, std::nullopt, stream, mr); return std::make_pair(std::move(result.second), std::move(result.first)); } else { cudf::hash_join hj_obj(right, compare_nulls, stream); - return hj_obj.inner_join(left, compare_nulls, std::nullopt, stream, mr); + return hj_obj.inner_join(left, std::nullopt, stream, mr); } } @@ -113,7 +113,7 @@ left_join(table_view const& left_input, table_view const right = matched.second.back(); cudf::hash_join hj_obj(right, compare_nulls, stream); - return hj_obj.left_join(left, compare_nulls, std::nullopt, stream, mr); + return hj_obj.left_join(left, std::nullopt, stream, mr); } std::unique_ptr left_join(table_view const& left_input, @@ -176,7 +176,7 @@ full_join(table_view const& left_input, table_view const right = matched.second.back(); cudf::hash_join hj_obj(right, compare_nulls, stream); - return hj_obj.full_join(left, compare_nulls, std::nullopt, stream, mr); + return hj_obj.full_join(left, std::nullopt, stream, mr); } std::unique_ptr
full_join(table_view const& left_input, @@ -234,56 +234,50 @@ hash_join::hash_join(cudf::table_view const& build, std::pair>, std::unique_ptr>> hash_join::inner_join(cudf::table_view const& probe, - null_equality compare_nulls, std::optional output_size, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) const { - return impl->inner_join(probe, compare_nulls, output_size, stream, mr); + return impl->inner_join(probe, output_size, stream, mr); } std::pair>, std::unique_ptr>> hash_join::left_join(cudf::table_view const& probe, - null_equality compare_nulls, std::optional output_size, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) const { - return impl->left_join(probe, compare_nulls, output_size, stream, mr); + return impl->left_join(probe, output_size, stream, mr); } std::pair>, std::unique_ptr>> hash_join::full_join(cudf::table_view const& probe, - null_equality compare_nulls, std::optional output_size, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) const { - return impl->full_join(probe, compare_nulls, output_size, stream, mr); + return impl->full_join(probe, output_size, stream, mr); } std::size_t hash_join::inner_join_size(cudf::table_view const& probe, - null_equality compare_nulls, rmm::cuda_stream_view stream) const { - return impl->inner_join_size(probe, compare_nulls, stream); + return impl->inner_join_size(probe, stream); } std::size_t hash_join::left_join_size(cudf::table_view const& probe, - null_equality compare_nulls, rmm::cuda_stream_view stream) const { - return impl->left_join_size(probe, compare_nulls, stream); + return impl->left_join_size(probe, stream); } std::size_t hash_join::full_join_size(cudf::table_view const& probe, - null_equality compare_nulls, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) const { - return impl->full_join_size(probe, compare_nulls, stream, mr); + return impl->full_join_size(probe, stream, mr); } // external APIs diff --git a/cpp/tests/join/join_tests.cpp b/cpp/tests/join/join_tests.cpp index e6ae709f009..57041e448a2 100644 --- a/cpp/tests/join/join_tests.cpp +++ b/cpp/tests/join/join_tests.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -1004,7 +1004,7 @@ TEST_F(JoinTest, EmptyRightTableInnerJoin) std::size_t const size_gold = 0; EXPECT_EQ(output_size, size_gold); - auto result = hash_join.inner_join(t0, cudf::null_equality::EQUAL, optional_size); + auto result = hash_join.inner_join(t0, optional_size); column_wrapper col_gold_0{}; column_wrapper col_gold_1{}; auto const [sorted_gold, sorted_result] = gather_maps_as_tables(col_gold_0, col_gold_1, result); @@ -1043,7 +1043,7 @@ TEST_F(JoinTest, EmptyRightTableLeftJoin) std::size_t const size_gold = 5; EXPECT_EQ(output_size, size_gold); - auto result = hash_join.left_join(t0, cudf::null_equality::EQUAL, optional_size); + auto result = hash_join.left_join(t0, optional_size); column_wrapper col_gold_0{{0, 1, 2, 3, 4}}; column_wrapper col_gold_1{{NoneValue, NoneValue, NoneValue, NoneValue, NoneValue}}; auto const [sorted_gold, sorted_result] = gather_maps_as_tables(col_gold_0, col_gold_1, result); @@ -1082,7 +1082,7 @@ TEST_F(JoinTest, EmptyRightTableFullJoin) std::size_t const size_gold = 5; EXPECT_EQ(output_size, size_gold); - auto result = hash_join.full_join(t0, cudf::null_equality::EQUAL, optional_size); + auto result = hash_join.full_join(t0, optional_size); column_wrapper col_gold_0{{0, 1, 2, 3, 4}}; column_wrapper col_gold_1{{NoneValue, NoneValue, NoneValue, NoneValue, NoneValue}}; auto const [sorted_gold, sorted_result] = gather_maps_as_tables(col_gold_0, col_gold_1, result); @@ -1310,7 +1310,7 @@ TEST_F(JoinTest, HashJoinSequentialProbes) std::size_t const size_gold = 9; EXPECT_EQ(output_size, size_gold); - auto result = hash_join.full_join(t0, cudf::null_equality::EQUAL, optional_size); + auto result = hash_join.full_join(t0, optional_size); column_wrapper col_gold_0{{NoneValue, NoneValue, NoneValue, NoneValue, 4, 0, 1, 2, 3}}; column_wrapper col_gold_1{{0, 1, 2, 3, 4, NoneValue, NoneValue, NoneValue, NoneValue}}; auto const [sorted_gold, sorted_result] = gather_maps_as_tables(col_gold_0, col_gold_1, result); @@ -1330,7 +1330,7 @@ TEST_F(JoinTest, HashJoinSequentialProbes) std::size_t const size_gold = 5; EXPECT_EQ(output_size, size_gold); - auto result = hash_join.left_join(t0, cudf::null_equality::EQUAL, optional_size); + auto result = hash_join.left_join(t0, optional_size); column_wrapper col_gold_0{{0, 1, 2, 3, 4}}; column_wrapper col_gold_1{{NoneValue, NoneValue, NoneValue, NoneValue, 4}}; auto const [sorted_gold, sorted_result] = gather_maps_as_tables(col_gold_0, col_gold_1, result); @@ -1350,7 +1350,7 @@ TEST_F(JoinTest, HashJoinSequentialProbes) std::size_t const size_gold = 3; EXPECT_EQ(output_size, size_gold); - auto result = hash_join.inner_join(t0, cudf::null_equality::EQUAL, optional_size); + auto result = hash_join.inner_join(t0, optional_size); column_wrapper col_gold_0{{2, 4, 0}}; column_wrapper col_gold_1{{1, 1, 4}}; auto const [sorted_gold, sorted_result] = gather_maps_as_tables(col_gold_0, col_gold_1, result); @@ -1390,7 +1390,7 @@ TEST_F(JoinTest, HashJoinWithStructsAndNulls) { auto output_size = hash_join.left_join_size(t0); EXPECT_EQ(5, output_size); - auto result = hash_join.left_join(t0, cudf::null_equality::EQUAL, output_size); + auto result = hash_join.left_join(t0, output_size); column_wrapper col_gold_0{{0, 1, 2, 3, 4}}; column_wrapper col_gold_1{{0, NoneValue, 2, NoneValue, NoneValue}}; auto const [sorted_gold, sorted_result] = gather_maps_as_tables(col_gold_0, col_gold_1, result); @@ -1400,7 +1400,7 @@ TEST_F(JoinTest, HashJoinWithStructsAndNulls) { auto output_size = hash_join.inner_join_size(t0); EXPECT_EQ(2, output_size); - auto result = hash_join.inner_join(t0, cudf::null_equality::EQUAL, output_size); + auto result = hash_join.inner_join(t0, output_size); column_wrapper col_gold_0{{0, 2}}; column_wrapper col_gold_1{{0, 2}}; auto const [sorted_gold, sorted_result] = gather_maps_as_tables(col_gold_0, col_gold_1, result); @@ -1410,7 +1410,7 @@ TEST_F(JoinTest, HashJoinWithStructsAndNulls) { auto output_size = hash_join.full_join_size(t0); EXPECT_EQ(8, output_size); - auto result = hash_join.full_join(t0, cudf::null_equality::EQUAL, output_size); + auto result = hash_join.full_join(t0, output_size); column_wrapper col_gold_0{{NoneValue, NoneValue, NoneValue, 0, 1, 2, 3, 4}}; column_wrapper col_gold_1{{1, 3, 4, 0, NoneValue, 2, NoneValue, NoneValue}}; auto const [sorted_gold, sorted_result] = gather_maps_as_tables(col_gold_0, col_gold_1, result); diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index bb0321d0a16..17e10933b65 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -539,14 +539,11 @@ private static native long[] leftJoin(long leftTable, int[] leftJoinCols, long r private static native long[] leftJoinGatherMaps(long leftKeys, long rightKeys, boolean compareNullsEqual) throws CudfException; - private static native long leftJoinRowCount(long leftTable, long rightHashJoin, - boolean nullsEqual) throws CudfException; + private static native long leftJoinRowCount(long leftTable, long rightHashJoin) throws CudfException; - private static native long[] leftHashJoinGatherMaps(long leftTable, long rightHashJoin, - boolean nullsEqual) throws CudfException; + private static native long[] leftHashJoinGatherMaps(long leftTable, long rightHashJoin) throws CudfException; private static native long[] leftHashJoinGatherMapsWithCount(long leftTable, long rightHashJoin, - boolean nullsEqual, long outputRowCount) throws CudfException; private static native long[] innerJoin(long leftTable, int[] leftJoinCols, long rightTable, @@ -555,14 +552,11 @@ private static native long[] innerJoin(long leftTable, int[] leftJoinCols, long private static native long[] innerJoinGatherMaps(long leftKeys, long rightKeys, boolean compareNullsEqual) throws CudfException; - private static native long innerJoinRowCount(long table, long hashJoin, - boolean nullsEqual) throws CudfException; + private static native long innerJoinRowCount(long table, long hashJoin) throws CudfException; - private static native long[] innerHashJoinGatherMaps(long table, long hashJoin, - boolean nullsEqual) throws CudfException; + private static native long[] innerHashJoinGatherMaps(long table, long hashJoin) throws CudfException; private static native long[] innerHashJoinGatherMapsWithCount(long table, long hashJoin, - boolean nullsEqual, long outputRowCount) throws CudfException; private static native long[] fullJoin(long leftTable, int[] leftJoinCols, long rightTable, @@ -571,14 +565,11 @@ private static native long[] fullJoin(long leftTable, int[] leftJoinCols, long r private static native long[] fullJoinGatherMaps(long leftKeys, long rightKeys, boolean compareNullsEqual) throws CudfException; - private static native long fullJoinRowCount(long leftTable, long rightHashJoin, - boolean nullsEqual) throws CudfException; + private static native long fullJoinRowCount(long leftTable, long rightHashJoin) throws CudfException; - private static native long[] fullHashJoinGatherMaps(long leftTable, long rightHashJoin, - boolean nullsEqual) throws CudfException; + private static native long[] fullHashJoinGatherMaps(long leftTable, long rightHashJoin) throws CudfException; private static native long[] fullHashJoinGatherMapsWithCount(long leftTable, long rightHashJoin, - boolean nullsEqual, long outputRowCount) throws CudfException; private static native long[] leftSemiJoin(long leftTable, int[] leftJoinCols, long rightTable, @@ -2318,8 +2309,7 @@ public long leftJoinRowCount(HashJoin rightHash) { throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + "rightKeys: " + rightHash.getNumberOfColumns()); } - return leftJoinRowCount(getNativeView(), rightHash.getNativeView(), - rightHash.getCompareNulls()); + return leftJoinRowCount(getNativeView(), rightHash.getNativeView()); } /** @@ -2337,9 +2327,7 @@ public GatherMap[] leftJoinGatherMaps(HashJoin rightHash) { throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + "rightKeys: " + rightHash.getNumberOfColumns()); } - long[] gatherMapData = - leftHashJoinGatherMaps(getNativeView(), rightHash.getNativeView(), - rightHash.getCompareNulls()); + long[] gatherMapData = leftHashJoinGatherMaps(getNativeView(), rightHash.getNativeView()); return buildJoinGatherMaps(gatherMapData); } @@ -2363,9 +2351,8 @@ public GatherMap[] leftJoinGatherMaps(HashJoin rightHash, long outputRowCount) { throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + "rightKeys: " + rightHash.getNumberOfColumns()); } - long[] gatherMapData = - leftHashJoinGatherMapsWithCount(getNativeView(), rightHash.getNativeView(), - rightHash.getCompareNulls(), outputRowCount); + long[] gatherMapData = leftHashJoinGatherMapsWithCount(getNativeView(), + rightHash.getNativeView(), outputRowCount); return buildJoinGatherMaps(gatherMapData); } @@ -2545,8 +2532,7 @@ public long innerJoinRowCount(HashJoin otherHash) { throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + "otherKeys: " + otherHash.getNumberOfColumns()); } - return innerJoinRowCount(getNativeView(), otherHash.getNativeView(), - otherHash.getCompareNulls()); + return innerJoinRowCount(getNativeView(), otherHash.getNativeView()); } /** @@ -2564,9 +2550,7 @@ public GatherMap[] innerJoinGatherMaps(HashJoin rightHash) { throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + "rightKeys: " + rightHash.getNumberOfColumns()); } - long[] gatherMapData = - innerHashJoinGatherMaps(getNativeView(), rightHash.getNativeView(), - rightHash.getCompareNulls()); + long[] gatherMapData = innerHashJoinGatherMaps(getNativeView(), rightHash.getNativeView()); return buildJoinGatherMaps(gatherMapData); } @@ -2590,9 +2574,8 @@ public GatherMap[] innerJoinGatherMaps(HashJoin rightHash, long outputRowCount) throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + "rightKeys: " + rightHash.getNumberOfColumns()); } - long[] gatherMapData = - innerHashJoinGatherMapsWithCount(getNativeView(), rightHash.getNativeView(), - rightHash.getCompareNulls(), outputRowCount); + long[] gatherMapData = innerHashJoinGatherMapsWithCount(getNativeView(), + rightHash.getNativeView(), outputRowCount); return buildJoinGatherMaps(gatherMapData); } @@ -2778,8 +2761,7 @@ public long fullJoinRowCount(HashJoin rightHash) { throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + "rightKeys: " + rightHash.getNumberOfColumns()); } - return fullJoinRowCount(getNativeView(), rightHash.getNativeView(), - rightHash.getCompareNulls()); + return fullJoinRowCount(getNativeView(), rightHash.getNativeView()); } /** @@ -2797,9 +2779,7 @@ public GatherMap[] fullJoinGatherMaps(HashJoin rightHash) { throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + "rightKeys: " + rightHash.getNumberOfColumns()); } - long[] gatherMapData = - fullHashJoinGatherMaps(getNativeView(), rightHash.getNativeView(), - rightHash.getCompareNulls()); + long[] gatherMapData = fullHashJoinGatherMaps(getNativeView(), rightHash.getNativeView()); return buildJoinGatherMaps(gatherMapData); } @@ -2823,9 +2803,8 @@ public GatherMap[] fullJoinGatherMaps(HashJoin rightHash, long outputRowCount) { throw new IllegalArgumentException("column count mismatch, this: " + getNumberOfColumns() + "rightKeys: " + rightHash.getNumberOfColumns()); } - long[] gatherMapData = - fullHashJoinGatherMapsWithCount(getNativeView(), rightHash.getNativeView(), - rightHash.getCompareNulls(), outputRowCount); + long[] gatherMapData = fullHashJoinGatherMapsWithCount(getNativeView(), + rightHash.getNativeView(), outputRowCount); return buildJoinGatherMaps(gatherMapData); } diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index aeac1856db0..eac76222475 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -812,15 +812,14 @@ jlongArray join_gather_maps(JNIEnv *env, jlong j_left_keys, jlong j_right_keys, // a hash table built from the join's right table. template jlongArray hash_join_gather_maps(JNIEnv *env, jlong j_left_keys, jlong j_right_hash_join, - jboolean compare_nulls_equal, T join_func) { + T join_func) { JNI_NULL_CHECK(env, j_left_keys, "left table is null", NULL); JNI_NULL_CHECK(env, j_right_hash_join, "hash join is null", NULL); try { cudf::jni::auto_set_device(env); auto left_keys = reinterpret_cast(j_left_keys); auto hash_join = reinterpret_cast(j_right_hash_join); - auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; - return gather_maps_to_java(env, join_func(*left_keys, *hash_join, nulleq)); + return gather_maps_to_java(env, join_func(*left_keys, *hash_join)); } CATCH_STD(env, NULL); } @@ -2172,41 +2171,36 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftJoinGatherMaps( JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_leftJoinRowCount(JNIEnv *env, jclass, jlong j_left_table, - jlong j_right_hash_join, - jboolean compare_nulls_equal) { + jlong j_right_hash_join) { JNI_NULL_CHECK(env, j_left_table, "left table is null", 0); JNI_NULL_CHECK(env, j_right_hash_join, "right hash join is null", 0); try { cudf::jni::auto_set_device(env); auto left_table = reinterpret_cast(j_left_table); auto hash_join = reinterpret_cast(j_right_hash_join); - auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; - auto row_count = hash_join->left_join_size(*left_table, nulleq); + auto row_count = hash_join->left_join_size(*left_table); return static_cast(row_count); } CATCH_STD(env, 0); } JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftHashJoinGatherMaps( - JNIEnv *env, jclass, jlong j_left_table, jlong j_right_hash_join, - jboolean compare_nulls_equal) { + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_hash_join) { return cudf::jni::hash_join_gather_maps( - env, j_left_table, j_right_hash_join, compare_nulls_equal, - [](cudf::table_view const &left, cudf::hash_join const &hash, cudf::null_equality nulleq) { - return hash.left_join(left, nulleq); + env, j_left_table, j_right_hash_join, + [](cudf::table_view const &left, cudf::hash_join const &hash) { + return hash.left_join(left); }); } JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftHashJoinGatherMapsWithCount( - JNIEnv *env, jclass, jlong j_left_table, jlong j_right_hash_join, jboolean compare_nulls_equal, - jlong j_output_row_count) { + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_hash_join, jlong j_output_row_count) { auto output_row_count = static_cast(j_output_row_count); - return cudf::jni::hash_join_gather_maps(env, j_left_table, j_right_hash_join, compare_nulls_equal, - [output_row_count](cudf::table_view const &left, - cudf::hash_join const &hash, - cudf::null_equality nulleq) { - return hash.left_join(left, nulleq, output_row_count); - }); + return cudf::jni::hash_join_gather_maps( + env, j_left_table, j_right_hash_join, + [output_row_count](cudf::table_view const &left, cudf::hash_join const &hash) { + return hash.left_join(left, output_row_count); + }); } JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_conditionalLeftJoinRowCount(JNIEnv *env, jclass, @@ -2305,41 +2299,36 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_innerJoinGatherMaps( JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_innerJoinRowCount(JNIEnv *env, jclass, jlong j_left_table, - jlong j_right_hash_join, - jboolean compare_nulls_equal) { + jlong j_right_hash_join) { JNI_NULL_CHECK(env, j_left_table, "left table is null", 0); JNI_NULL_CHECK(env, j_right_hash_join, "right hash join is null", 0); try { cudf::jni::auto_set_device(env); auto left_table = reinterpret_cast(j_left_table); auto hash_join = reinterpret_cast(j_right_hash_join); - auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; - auto row_count = hash_join->inner_join_size(*left_table, nulleq); + auto row_count = hash_join->inner_join_size(*left_table); return static_cast(row_count); } CATCH_STD(env, 0); } JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_innerHashJoinGatherMaps( - JNIEnv *env, jclass, jlong j_left_table, jlong j_right_hash_join, - jboolean compare_nulls_equal) { + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_hash_join) { return cudf::jni::hash_join_gather_maps( - env, j_left_table, j_right_hash_join, compare_nulls_equal, - [](cudf::table_view const &left, cudf::hash_join const &hash, cudf::null_equality nulleq) { - return hash.inner_join(left, nulleq); + env, j_left_table, j_right_hash_join, + [](cudf::table_view const &left, cudf::hash_join const &hash) { + return hash.inner_join(left); }); } JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_innerHashJoinGatherMapsWithCount( - JNIEnv *env, jclass, jlong j_left_table, jlong j_right_hash_join, jboolean compare_nulls_equal, - jlong j_output_row_count) { + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_hash_join, jlong j_output_row_count) { auto output_row_count = static_cast(j_output_row_count); - return cudf::jni::hash_join_gather_maps(env, j_left_table, j_right_hash_join, compare_nulls_equal, - [output_row_count](cudf::table_view const &left, - cudf::hash_join const &hash, - cudf::null_equality nulleq) { - return hash.inner_join(left, nulleq, output_row_count); - }); + return cudf::jni::hash_join_gather_maps( + env, j_left_table, j_right_hash_join, + [output_row_count](cudf::table_view const &left, cudf::hash_join const &hash) { + return hash.inner_join(left, output_row_count); + }); } JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_conditionalInnerJoinRowCount(JNIEnv *env, jclass, @@ -2438,41 +2427,36 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_fullJoinGatherMaps( JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_fullJoinRowCount(JNIEnv *env, jclass, jlong j_left_table, - jlong j_right_hash_join, - jboolean compare_nulls_equal) { + jlong j_right_hash_join) { JNI_NULL_CHECK(env, j_left_table, "left table is null", 0); JNI_NULL_CHECK(env, j_right_hash_join, "right hash join is null", 0); try { cudf::jni::auto_set_device(env); auto left_table = reinterpret_cast(j_left_table); auto hash_join = reinterpret_cast(j_right_hash_join); - auto nulleq = compare_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; - auto row_count = hash_join->full_join_size(*left_table, nulleq); + auto row_count = hash_join->full_join_size(*left_table); return static_cast(row_count); } CATCH_STD(env, 0); } JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_fullHashJoinGatherMaps( - JNIEnv *env, jclass, jlong j_left_table, jlong j_right_hash_join, - jboolean compare_nulls_equal) { + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_hash_join) { return cudf::jni::hash_join_gather_maps( - env, j_left_table, j_right_hash_join, compare_nulls_equal, - [](cudf::table_view const &left, cudf::hash_join const &hash, cudf::null_equality nulleq) { - return hash.full_join(left, nulleq); + env, j_left_table, j_right_hash_join, + [](cudf::table_view const &left, cudf::hash_join const &hash) { + return hash.full_join(left); }); } JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_fullHashJoinGatherMapsWithCount( - JNIEnv *env, jclass, jlong j_left_table, jlong j_right_hash_join, jboolean compare_nulls_equal, - jlong j_output_row_count) { + JNIEnv *env, jclass, jlong j_left_table, jlong j_right_hash_join, jlong j_output_row_count) { auto output_row_count = static_cast(j_output_row_count); - return cudf::jni::hash_join_gather_maps(env, j_left_table, j_right_hash_join, compare_nulls_equal, - [output_row_count](cudf::table_view const &left, - cudf::hash_join const &hash, - cudf::null_equality nulleq) { - return hash.full_join(left, nulleq, output_row_count); - }); + return cudf::jni::hash_join_gather_maps( + env, j_left_table, j_right_hash_join, + [output_row_count](cudf::table_view const &left, cudf::hash_join const &hash) { + return hash.full_join(left, output_row_count); + }); } JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_conditionalFullJoinGatherMaps(