From 8f2294f3e809304201564fa19ca50df84b4bdccf Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Thu, 7 Sep 2023 16:58:34 -0700 Subject: [PATCH] Refactor contains_table with cuco::static_set --- cpp/src/search/contains_table.cu | 250 ++++++++++++++----------------- 1 file changed, 109 insertions(+), 141 deletions(-) diff --git a/cpp/src/search/contains_table.cu b/cpp/src/search/contains_table.cu index e37f0686ac3..628ba2c4c3b 100644 --- a/cpp/src/search/contains_table.cu +++ b/cpp/src/search/contains_table.cu @@ -26,7 +26,7 @@ #include -#include +#include #include @@ -37,11 +37,6 @@ namespace { using cudf::experimental::row::lhs_index_type; using cudf::experimental::row::rhs_index_type; -using static_map = cuco::static_map>>; - /** * @brief Check if the given type `T` is a strong index type (i.e., `lhs_index_type` or * `rhs_index_type`). @@ -58,48 +53,58 @@ constexpr auto is_strong_index_type() * @brief An adapter functor to support strong index types for row hasher that must be operating on * `cudf::size_type`. */ -template -struct strong_index_hasher_adapter { - strong_index_hasher_adapter(Hasher const& hasher) : _hasher{hasher} {} +template +struct hasher_adapter { + hasher_adapter(HaystackHasher const& haystack_hasher, NeedleHasher const& needle_hasher) + : _haystack_hasher{haystack_hasher}, _needle_hasher{needle_hasher} + { + } + + __device__ constexpr auto operator()(lhs_index_type idx) const noexcept + { + return _haystack_hasher(static_cast(idx)); + } - template ())> - __device__ constexpr auto operator()(T const idx) const noexcept + __device__ constexpr auto operator()(rhs_index_type idx) const noexcept { - return _hasher(static_cast(idx)); + return _needle_hasher(static_cast(idx)); } private: - Hasher const _hasher; + HaystackHasher const _haystack_hasher; + NeedleHasher const _needle_hasher; }; /** * @brief An adapter functor to support strong index type for table row comparator that must be * operating on `cudf::size_type`. */ -template -struct strong_index_comparator_adapter { - strong_index_comparator_adapter(Comparator const& comparator) : _comparator{comparator} {} - - template () && is_strong_index_type())> - __device__ constexpr auto operator()(T const lhs_index, U const rhs_index) const noexcept +template +struct comparator_adapter { + comparator_adapter(SelfComparator const& self_comparator, + TwoTableComparator const& two_table_comparator) + : _self_comparator{self_comparator}, _two_table_comparator{two_table_comparator} + { + } + + __device__ constexpr auto operator()(lhs_index_type lhs_index, + lhs_index_type rhs_index) const noexcept { auto const lhs = static_cast(lhs_index); auto const rhs = static_cast(rhs_index); - if constexpr (std::is_same_v || std::is_same_v) { - return _comparator(lhs, rhs); - } else { - // Here we have T == rhs_index_type. - // This is when the indices are provided in wrong order for two table comparator, so we need - // to switch them back to the right order before calling the underlying comparator. - return _comparator(rhs, lhs); - } + return _self_comparator(lhs, rhs); + } + + __device__ constexpr auto operator()(lhs_index_type lhs_index, + rhs_index_type rhs_index) const noexcept + { + return _two_table_comparator(lhs_index, rhs_index); } private: - Comparator const _comparator; + SelfComparator const _self_comparator; + TwoTableComparator const _two_table_comparator; }; /** @@ -133,6 +138,18 @@ std::pair build_row_bitmask(table_view return std::pair(rmm::device_buffer{0, stream}, nullable_columns.front().null_mask()); } +} // namespace +template +void dispatch( + null_equality compare_nulls, auto self_comp, auto two_table_comp, auto nan_comp, Func func) +{ + auto const d_self_eq = self_comp.equal_to( + nullate::DYNAMIC{haystack_has_nulls}, compare_nulls, nan_comp); + auto const d_two_table_eq = two_table_comp.equal_to( + nullate::DYNAMIC{has_any_nulls}, compare_nulls, nan_comp); + func(d_self_eq, d_two_table_eq); +} + /** * @brief Invoke an `operator()` template with a row equality comparator based on the specified * `compare_nans` parameter. @@ -140,21 +157,25 @@ std::pair build_row_bitmask(table_view * @param compare_nans The flag to specify whether NaNs should be compared equal or not * @param func The input functor to invoke */ -template -void dispatch_nan_comparator(nan_equality compare_nans, Func&& func) +template +void dispatch_nan_comparator(nan_equality compare_nans, + null_equality compare_nulls, + auto self_comp, + auto two_table_comp, + Func func) { if (compare_nans == nan_equality::ALL_EQUAL) { using nan_equal_comparator = cudf::experimental::row::equality::nan_equal_physical_equality_comparator; - func(nan_equal_comparator{}); + dispatch( + compare_nulls, self_comp, two_table_comp, nan_equal_comparator{}, func); } else { using nan_unequal_comparator = cudf::experimental::row::equality::physical_equality_comparator; - func(nan_unequal_comparator{}); + dispatch( + compare_nulls, self_comp, two_table_comp, nan_unequal_comparator{}, func); } } -} // namespace - /** * @brief Check if rows in the given `needles` table exist in the `haystack` table. * @@ -173,124 +194,71 @@ rmm::device_uvector contains(table_view const& haystack, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { - auto map = static_map(compute_hash_table_size(haystack.num_rows()), - cuco::empty_key{lhs_index_type{std::numeric_limits::max()}}, - cuco::empty_value{detail::JoinNoneValue}, - detail::hash_table_allocator_type{default_allocator{}, stream}, - stream.value()); - auto const haystack_has_nulls = has_nested_nulls(haystack); auto const needles_has_nulls = has_nested_nulls(needles); auto const has_any_nulls = haystack_has_nulls || needles_has_nulls; + auto const preprocessed_needles = + cudf::experimental::row::equality::preprocessed_table::create(needles, stream); auto const preprocessed_haystack = cudf::experimental::row::equality::preprocessed_table::create(haystack, stream); - // Insert row indices of the haystack table as map keys. - { - auto const haystack_it = cudf::detail::make_counting_transform_iterator( - size_type{0}, - [] __device__(auto const idx) { return cuco::make_pair(lhs_index_type{idx}, 0); }); - auto const hasher = cudf::experimental::row::hash::row_hasher(preprocessed_haystack); - auto const d_hasher = - strong_index_hasher_adapter{hasher.device_hasher(nullate::DYNAMIC{has_any_nulls})}; + auto const haystack_hasher = cudf::experimental::row::hash::row_hasher(preprocessed_haystack); + auto const d_haystack_hasher = haystack_hasher.device_hasher(nullate::DYNAMIC{has_any_nulls}); + auto const needle_hasher = cudf::experimental::row::hash::row_hasher(preprocessed_needles); + auto const d_needle_hasher = needle_hasher.device_hasher(nullate::DYNAMIC{has_any_nulls}); + auto const d_hasher = hasher_adapter{d_haystack_hasher, d_needle_hasher}; + using hasher_type = decltype(d_hasher); + + auto const self_comparator = + cudf::experimental::row::equality::self_comparator(preprocessed_haystack); + auto const two_table_comparator = cudf::experimental::row::equality::two_table_comparator( + preprocessed_haystack, preprocessed_needles); + + // The output vector. + auto contained = rmm::device_uvector(needles.num_rows(), stream, mr); + + auto const haystack_iter = cudf::detail::make_counting_transform_iterator( + size_type{0}, [] __device__(auto idx) { return lhs_index_type{idx}; }); + auto const needles_iter = cudf::detail::make_counting_transform_iterator( + size_type{0}, [] __device__(auto idx) { return rhs_index_type{idx}; }); + + auto const helper_func = [&](auto const& d_self_equal, auto const& d_two_table_equal) { + auto const d_equal = comparator_adapter{d_self_equal, d_two_table_equal}; - auto const comparator = - cudf::experimental::row::equality::self_comparator(preprocessed_haystack); + auto set = cuco::experimental::static_set{ + cuco::experimental::extent{compute_hash_table_size(haystack.num_rows())}, + cuco::empty_key{lhs_index_type{-1}}, + d_equal, + cuco::experimental::linear_probing<1, hasher_type>{d_hasher}, + detail::hash_table_allocator_type{default_allocator{}, stream}, + stream.value()}; - // If the haystack table has nulls but they are compared unequal, don't insert them. - // Otherwise, it was known to cause performance issue: - // - https://github.com/rapidsai/cudf/pull/6943 - // - https://github.com/rapidsai/cudf/pull/8277 if (haystack_has_nulls && compare_nulls == null_equality::UNEQUAL) { - auto const bitmask_buffer_and_ptr = build_row_bitmask(haystack, stream); - auto const row_bitmask_ptr = bitmask_buffer_and_ptr.second; - - auto const insert_map = [&](auto const value_comp) { - if (cudf::detail::has_nested_columns(haystack)) { - auto const d_eqcomp = strong_index_comparator_adapter{comparator.equal_to( - nullate::DYNAMIC{haystack_has_nulls}, compare_nulls, value_comp)}; - map.insert_if(haystack_it, - haystack_it + haystack.num_rows(), - thrust::counting_iterator(0), // stencil - row_is_valid{row_bitmask_ptr}, - d_hasher, - d_eqcomp, - stream.value()); - } else { - auto const d_eqcomp = strong_index_comparator_adapter{comparator.equal_to( - nullate::DYNAMIC{haystack_has_nulls}, compare_nulls, value_comp)}; - map.insert_if(haystack_it, - haystack_it + haystack.num_rows(), - thrust::counting_iterator(0), // stencil - row_is_valid{row_bitmask_ptr}, - d_hasher, - d_eqcomp, - stream.value()); - } - }; - - // Insert only rows that do not have any null at any level. - dispatch_nan_comparator(compare_nans, insert_map); - } else { // haystack_doesn't_have_nulls || compare_nulls == null_equality::EQUAL - auto const insert_map = [&](auto const value_comp) { - if (cudf::detail::has_nested_columns(haystack)) { - auto const d_eqcomp = strong_index_comparator_adapter{comparator.equal_to( - nullate::DYNAMIC{haystack_has_nulls}, compare_nulls, value_comp)}; - map.insert( - haystack_it, haystack_it + haystack.num_rows(), d_hasher, d_eqcomp, stream.value()); - } else { - auto const d_eqcomp = strong_index_comparator_adapter{comparator.equal_to( - nullate::DYNAMIC{haystack_has_nulls}, compare_nulls, value_comp)}; - map.insert( - haystack_it, haystack_it + haystack.num_rows(), d_hasher, d_eqcomp, stream.value()); - } - }; - - dispatch_nan_comparator(compare_nans, insert_map); + auto const row_bitmask = build_row_bitmask(haystack, stream).second; + set.insert_if_async(haystack_iter, + haystack_iter + haystack.num_rows(), + thrust::counting_iterator(0), // stencil + row_is_valid{row_bitmask}, + stream.value()); + } else { + set.insert_async(haystack_iter, haystack_iter + haystack.num_rows(), stream.value()); } - } - // The output vector. - auto contained = rmm::device_uvector(needles.num_rows(), stream, mr); + if (needles_has_nulls && compare_nulls == null_equality::UNEQUAL) { + set.contains_if_async( + needles_iter, needles_iter + needles.num_rows(), contained.begin(), stream.value()); + } else { + set.contains_async( + needles_iter, needles_iter + needles.num_rows(), contained.begin(), stream.value()); + } + }; - auto const preprocessed_needles = - cudf::experimental::row::equality::preprocessed_table::create(needles, stream); - // Check existence for each row of the needles table in the haystack table. - { - auto const needles_it = cudf::detail::make_counting_transform_iterator( - size_type{0}, [] __device__(auto const idx) { return rhs_index_type{idx}; }); - - auto const hasher = cudf::experimental::row::hash::row_hasher(preprocessed_needles); - auto const d_hasher = - strong_index_hasher_adapter{hasher.device_hasher(nullate::DYNAMIC{has_any_nulls})}; - - auto const comparator = cudf::experimental::row::equality::two_table_comparator( - preprocessed_haystack, preprocessed_needles); - - auto const check_contains = [&](auto const value_comp) { - if (cudf::detail::has_nested_columns(haystack) or cudf::detail::has_nested_columns(needles)) { - auto const d_eqcomp = - comparator.equal_to(nullate::DYNAMIC{has_any_nulls}, compare_nulls, value_comp); - map.contains(needles_it, - needles_it + needles.num_rows(), - contained.begin(), - d_hasher, - d_eqcomp, - stream.value()); - } else { - auto const d_eqcomp = - comparator.equal_to(nullate::DYNAMIC{has_any_nulls}, compare_nulls, value_comp); - map.contains(needles_it, - needles_it + needles.num_rows(), - contained.begin(), - d_hasher, - d_eqcomp, - stream.value()); - } - }; - - dispatch_nan_comparator(compare_nans, check_contains); + if (haystack_has_nulls) { + if (has_any_nulls) { + dispatch_nan_comparator( + compare_nans, compare_nulls, self_comparator, two_table_comparator, helper_func); + } } return contained;