diff --git a/cpp/src/join/hash_join.cu b/cpp/src/join/hash_join.cu index ee62008b90f..c6f842c6c55 100644 --- a/cpp/src/join/hash_join.cu +++ b/cpp/src/join/hash_join.cu @@ -81,7 +81,7 @@ void build_join_hash_table(cudf::table_view const& build, CUDF_EXPECTS(0 != build_table_ptr->num_columns(), "Selected build dataset is empty"); CUDF_EXPECTS(0 != build_table_ptr->num_rows(), "Build side table has no rows"); - row_hash hash_build{nullate::YES{}, *build_table_ptr}; + row_hash hash_build{nullate::DYNAMIC{cudf::has_nulls(build)}, *build_table_ptr}; auto const empty_key_sentinel = hash_table.get_empty_key_sentinel(); make_pair_function pair_func{hash_build, empty_key_sentinel}; @@ -123,6 +123,7 @@ std::pair>, probe_join_hash_table(cudf::table_device_view build_table, cudf::table_device_view probe_table, multimap_type const& hash_table, + bool has_nulls, null_equality compare_nulls, std::optional output_size, rmm::cuda_stream_view stream, @@ -133,10 +134,10 @@ probe_join_hash_table(cudf::table_device_view build_table, ? cudf::detail::join_kind::LEFT_JOIN : JoinKind; - std::size_t const join_size = output_size - ? *output_size - : compute_join_output_size( - build_table, probe_table, hash_table, compare_nulls, stream); + std::size_t const join_size = + output_size ? *output_size + : compute_join_output_size( + build_table, probe_table, hash_table, has_nulls, compare_nulls, stream); // If output size is zero, return immediately if (join_size == 0) { @@ -147,9 +148,10 @@ probe_join_hash_table(cudf::table_device_view build_table, auto left_indices = std::make_unique>(join_size, stream, mr); auto right_indices = std::make_unique>(join_size, stream, mr); - pair_equality equality{probe_table, build_table, compare_nulls}; + auto const probe_nulls = cudf::nullate::DYNAMIC{has_nulls}; + pair_equality equality{probe_table, build_table, probe_nulls, compare_nulls}; - row_hash hash_probe{nullate::YES{}, probe_table}; + row_hash hash_probe{probe_nulls, probe_table}; auto const empty_key_sentinel = hash_table.get_empty_key_sentinel(); make_pair_function pair_func{hash_probe, empty_key_sentinel}; @@ -197,12 +199,13 @@ probe_join_hash_table(cudf::table_device_view build_table, std::size_t get_full_join_size(cudf::table_device_view build_table, cudf::table_device_view probe_table, multimap_type const& hash_table, + bool has_nulls, null_equality compare_nulls, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { std::size_t join_size = compute_join_output_size( - build_table, probe_table, hash_table, compare_nulls, stream); + build_table, probe_table, hash_table, has_nulls, compare_nulls, stream); // If output size is zero, return immediately if (join_size == 0) { return join_size; } @@ -212,9 +215,10 @@ std::size_t get_full_join_size(cudf::table_device_view build_table, auto left_indices = std::make_unique>(join_size, stream, mr); auto right_indices = std::make_unique>(join_size, stream, mr); - pair_equality equality{probe_table, build_table, compare_nulls}; + auto const probe_nulls = cudf::nullate::DYNAMIC{has_nulls}; + pair_equality equality{probe_table, build_table, probe_nulls, compare_nulls}; - row_hash hash_probe{nullate::YES{}, probe_table}; + row_hash hash_probe{probe_nulls, probe_table}; auto const empty_key_sentinel = hash_table.get_empty_key_sentinel(); make_pair_function pair_func{hash_probe, empty_key_sentinel}; @@ -367,7 +371,12 @@ std::size_t hash_join::hash_join_impl::inner_join_size(cudf::table_view const& p auto flattened_probe_table_ptr = cudf::table_device_view::create(flattened_probe_table, stream); return cudf::detail::compute_join_output_size( - *build_table_ptr, *flattened_probe_table_ptr, _hash_table, compare_nulls, stream); + *build_table_ptr, + *flattened_probe_table_ptr, + _hash_table, + cudf::has_nulls(flattened_probe_table) | cudf::has_nulls(_build), + compare_nulls, + stream); } std::size_t hash_join::hash_join_impl::left_join_size(cudf::table_view const& probe, @@ -387,7 +396,12 @@ std::size_t hash_join::hash_join_impl::left_join_size(cudf::table_view const& pr auto flattened_probe_table_ptr = cudf::table_device_view::create(flattened_probe_table, stream); return cudf::detail::compute_join_output_size( - *build_table_ptr, *flattened_probe_table_ptr, _hash_table, compare_nulls, stream); + *build_table_ptr, + *flattened_probe_table_ptr, + _hash_table, + cudf::has_nulls(flattened_probe_table) | cudf::has_nulls(_build), + compare_nulls, + stream); } std::size_t hash_join::hash_join_impl::full_join_size(cudf::table_view const& probe, @@ -407,8 +421,13 @@ std::size_t hash_join::hash_join_impl::full_join_size(cudf::table_view const& pr auto build_table_ptr = cudf::table_device_view::create(_build, stream); auto flattened_probe_table_ptr = cudf::table_device_view::create(flattened_probe_table, stream); - return get_full_join_size( - *build_table_ptr, *flattened_probe_table_ptr, _hash_table, compare_nulls, stream, mr); + return get_full_join_size(*build_table_ptr, + *flattened_probe_table_ptr, + _hash_table, + cudf::has_nulls(flattened_probe_table) | cudf::has_nulls(_build), + compare_nulls, + stream, + mr); } template @@ -466,8 +485,15 @@ hash_join::hash_join_impl::probe_join_indices(cudf::table_view const& probe, auto build_table_ptr = cudf::table_device_view::create(_build, stream); auto probe_table_ptr = cudf::table_device_view::create(probe, stream); - auto join_indices = cudf::detail::probe_join_hash_table( - *build_table_ptr, *probe_table_ptr, _hash_table, compare_nulls, output_size, stream, mr); + auto join_indices = + cudf::detail::probe_join_hash_table(*build_table_ptr, + *probe_table_ptr, + _hash_table, + cudf::has_nulls(probe) | cudf::has_nulls(_build), + compare_nulls, + output_size, + stream, + mr); if constexpr (JoinKind == cudf::detail::join_kind::FULL_JOIN) { auto complement_indices = detail::get_left_join_indices_complement( diff --git a/cpp/src/join/hash_join.cuh b/cpp/src/join/hash_join.cuh index 976b0c81ead..5a042f65aad 100644 --- a/cpp/src/join/hash_join.cuh +++ b/cpp/src/join/hash_join.cuh @@ -96,6 +96,7 @@ template std::size_t compute_join_output_size(table_device_view build_table, table_device_view probe_table, multimap_type const& hash_table, + bool has_nulls, null_equality compare_nulls, rmm::cuda_stream_view stream) { @@ -117,9 +118,10 @@ std::size_t compute_join_output_size(table_device_view build_table, } } - pair_equality equality{probe_table, build_table, compare_nulls}; + auto const probe_nulls = cudf::nullate::DYNAMIC{has_nulls}; + pair_equality equality{probe_table, build_table, probe_nulls, compare_nulls}; - row_hash hash_probe{nullate::YES{}, probe_table}; + row_hash hash_probe{probe_nulls, probe_table}; auto const empty_key_sentinel = hash_table.get_empty_key_sentinel(); make_pair_function pair_func{hash_probe, empty_key_sentinel}; diff --git a/cpp/src/join/join_common_utils.cuh b/cpp/src/join/join_common_utils.cuh index 4b33772dd69..39a9f19c0ee 100644 --- a/cpp/src/join/join_common_utils.cuh +++ b/cpp/src/join/join_common_utils.cuh @@ -34,8 +34,9 @@ class pair_equality { public: pair_equality(table_device_view lhs, table_device_view rhs, + nullate::DYNAMIC has_nulls, null_equality nulls_are_equal = null_equality::EQUAL) - : _check_row_equality{cudf::nullate::YES{}, lhs, rhs, nulls_are_equal} + : _check_row_equality{has_nulls, lhs, rhs, nulls_are_equal} { } diff --git a/cpp/src/join/join_common_utils.hpp b/cpp/src/join/join_common_utils.hpp index c4692a50fec..9a7540bcd33 100644 --- a/cpp/src/join/join_common_utils.hpp +++ b/cpp/src/join/join_common_utils.hpp @@ -51,9 +51,9 @@ using multimap_type = hash_table_allocator_type, cuco::double_hashing>; -using row_hash = cudf::row_hasher; +using row_hash = cudf::row_hasher; -using row_equality = cudf::row_equality_comparator; +using row_equality = cudf::row_equality_comparator; enum class join_kind { INNER_JOIN, LEFT_JOIN, FULL_JOIN, LEFT_SEMI_JOIN, LEFT_ANTI_JOIN }; diff --git a/cpp/src/join/semi_join.cu b/cpp/src/join/semi_join.cu index 3d27c5740f4..e781472e025 100644 --- a/cpp/src/join/semi_join.cu +++ b/cpp/src/join/semi_join.cu @@ -77,13 +77,15 @@ std::unique_ptr> left_semi_anti_join( // Create hash table containing all keys found in right table auto right_rows_d = table_device_view::create(right_flattened_keys, stream); size_t const hash_table_size = compute_hash_table_size(right_num_rows); - row_hash hash_build{cudf::nullate::YES{}, *right_rows_d}; - row_equality equality_build{cudf::nullate::YES{}, *right_rows_d, *right_rows_d, compare_nulls}; + auto const right_nulls = cudf::nullate::DYNAMIC{cudf::has_nulls(right_flattened_keys)}; + row_hash hash_build{right_nulls, *right_rows_d}; + row_equality equality_build{right_nulls, *right_rows_d, *right_rows_d, compare_nulls}; // Going to join it with left table - auto left_rows_d = table_device_view::create(left_flattened_keys, stream); - row_hash hash_probe{cudf::nullate::YES{}, *left_rows_d}; - row_equality equality_probe{cudf::nullate::YES{}, *left_rows_d, *right_rows_d, compare_nulls}; + auto left_rows_d = table_device_view::create(left_flattened_keys, stream); + auto const left_nulls = cudf::nullate::DYNAMIC{cudf::has_nulls(left_flattened_keys)}; + row_hash hash_probe{left_nulls, *left_rows_d}; + row_equality equality_probe{left_nulls, *left_rows_d, *right_rows_d, compare_nulls}; auto hash_table_ptr = hash_table_type::create(hash_table_size, stream,