From 3c3edfef406288e164cc80ab82f9c64c0b88d0bd Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Fri, 28 Jun 2024 13:58:22 -0700 Subject: [PATCH] Update implementations to build with the latest cuco (#15938) This PR updates existing libcudf to accommodate a cuco breaking change introduced in https://github.com/NVIDIA/cuCollections/pull/479. It helps avoid breaking cudf when bumping the cuco version in `rapids-cmake`. Redundant equal/hash overloads will be removed once the version bump is done on the `rapids-cmake` end. Authors: - Yunsong Wang (https://github.com/PointKernel) Approvers: - David Wendt (https://github.com/davidwendt) - Nghia Truong (https://github.com/ttnghia) URL: https://github.com/rapidsai/cudf/pull/15938 --- .../cudf/detail/distinct_hash_join.cuh | 22 +++++++++++- cpp/src/join/distinct_hash_join.cu | 10 +++--- cpp/src/search/contains_table.cu | 35 ++++++++++++++----- cpp/src/text/bpe/byte_pair_encoding.cuh | 13 +++++++ cpp/src/text/vocabulary_tokenize.cu | 8 +++++ 5 files changed, 74 insertions(+), 14 deletions(-) diff --git a/cpp/include/cudf/detail/distinct_hash_join.cuh b/cpp/include/cudf/detail/distinct_hash_join.cuh index de3d23e9470..1ef8b3b120a 100644 --- a/cpp/include/cudf/detail/distinct_hash_join.cuh +++ b/cpp/include/cudf/detail/distinct_hash_join.cuh @@ -42,6 +42,9 @@ template struct comparator_adapter { comparator_adapter(Equal const& d_equal) : _d_equal{d_equal} {} + // suppress "function was declared but never referenced warning" +#pragma nv_diagnostic push +#pragma nv_diag_suppress 177 __device__ constexpr auto operator()( cuco::pair const&, cuco::pair const&) const noexcept @@ -50,6 +53,14 @@ struct comparator_adapter { return false; } + __device__ constexpr auto operator()( + cuco::pair const&, + cuco::pair const&) const noexcept + { + // All build table keys are distinct thus `false` no matter what + return false; + } + __device__ constexpr auto operator()( cuco::pair const& lhs, cuco::pair const& rhs) const noexcept @@ -58,6 +69,15 @@ struct comparator_adapter { return _d_equal(lhs.second, rhs.second); } + __device__ constexpr auto operator()( + cuco::pair const& lhs, + cuco::pair const& rhs) const noexcept + { + if (lhs.first != rhs.first) { return false; } + return _d_equal(lhs.second, rhs.second); + } +#pragma nv_diagnostic pop + private: Equal _d_equal; }; @@ -94,7 +114,7 @@ struct distinct_hash_join { using cuco_storage_type = cuco::storage<1>; /// Hash table type - using hash_table_type = cuco::static_set, + using hash_table_type = cuco::static_set, cuco::extent, cuda::thread_scope_device, comparator_adapter, diff --git a/cpp/src/join/distinct_hash_join.cu b/cpp/src/join/distinct_hash_join.cu index 5048da25e86..daa1bf17c0d 100644 --- a/cpp/src/join/distinct_hash_join.cu +++ b/cpp/src/join/distinct_hash_join.cu @@ -54,7 +54,7 @@ auto prepare_device_equal( cudf::null_equality compare_nulls) { auto const two_table_equal = - cudf::experimental::row::equality::two_table_comparator(build, probe); + cudf::experimental::row::equality::two_table_comparator(probe, build); return comparator_adapter{two_table_equal.equal_to( nullate::DYNAMIC{has_nulls}, compare_nulls)}; } @@ -113,7 +113,7 @@ distinct_hash_join::distinct_hash_join(cudf::table_view const& build, _hash_table{build.num_rows(), CUCO_DESIRED_LOAD_FACTOR, cuco::empty_key{cuco::pair{std::numeric_limits::max(), - lhs_index_type{JoinNoneValue}}}, + rhs_index_type{JoinNoneValue}}}, prepare_device_equal( _preprocessed_build, _preprocessed_probe, has_nulls, compare_nulls), {}, @@ -131,7 +131,7 @@ distinct_hash_join::distinct_hash_join(cudf::table_view const& build, auto const d_hasher = row_hasher.device_hasher(nullate::DYNAMIC{this->_has_nulls}); auto const iter = cudf::detail::make_counting_transform_iterator( - 0, build_keys_fn{d_hasher}); + 0, build_keys_fn{d_hasher}); size_type const build_table_num_rows{build.num_rows()}; if (this->_nulls_equal == cudf::null_equality::EQUAL or (not cudf::nullable(this->_build))) { @@ -174,7 +174,7 @@ distinct_hash_join::inner_join(rmm::cuda_stream_view stream, cudf::experimental::row::hash::row_hasher{this->_preprocessed_probe}; auto const d_probe_hasher = probe_row_hasher.device_hasher(nullate::DYNAMIC{this->_has_nulls}); auto const iter = cudf::detail::make_counting_transform_iterator( - 0, build_keys_fn{d_probe_hasher}); + 0, build_keys_fn{d_probe_hasher}); auto const build_indices_begin = thrust::make_transform_output_iterator(build_indices->begin(), output_fn{}); @@ -216,7 +216,7 @@ std::unique_ptr> distinct_hash_join::l cudf::experimental::row::hash::row_hasher{this->_preprocessed_probe}; auto const d_probe_hasher = probe_row_hasher.device_hasher(nullate::DYNAMIC{this->_has_nulls}); auto const iter = cudf::detail::make_counting_transform_iterator( - 0, build_keys_fn{d_probe_hasher}); + 0, build_keys_fn{d_probe_hasher}); auto const output_begin = thrust::make_transform_output_iterator(build_indices->begin(), output_fn{}); diff --git a/cpp/src/search/contains_table.cu b/cpp/src/search/contains_table.cu index 466f9093194..fbb0f6cb0f5 100644 --- a/cpp/src/search/contains_table.cu +++ b/cpp/src/search/contains_table.cu @@ -53,12 +53,12 @@ struct hasher_adapter { __device__ constexpr auto operator()(lhs_index_type idx) const noexcept { - return _haystack_hasher(static_cast(idx)); + return _needle_hasher(static_cast(idx)); } __device__ constexpr auto operator()(rhs_index_type idx) const noexcept { - return _needle_hasher(static_cast(idx)); + return _haystack_hasher(static_cast(idx)); } private: @@ -76,6 +76,9 @@ struct comparator_adapter { { } + // suppress "function was declared but never referenced warning" +#pragma nv_diagnostic push +#pragma nv_diag_suppress 177 __device__ constexpr auto operator()(lhs_index_type lhs_index, lhs_index_type rhs_index) const noexcept { @@ -85,12 +88,28 @@ struct comparator_adapter { return _self_equal(lhs, rhs); } + __device__ constexpr auto operator()(rhs_index_type lhs_index, + rhs_index_type rhs_index) const noexcept + { + auto const lhs = static_cast(lhs_index); + auto const rhs = static_cast(rhs_index); + + return _self_equal(lhs, rhs); + } + __device__ constexpr auto operator()(lhs_index_type lhs_index, rhs_index_type rhs_index) const noexcept { return _two_table_equal(lhs_index, rhs_index); } + __device__ constexpr auto operator()(rhs_index_type lhs_index, + lhs_index_type rhs_index) const noexcept + { + return _two_table_equal(lhs_index, rhs_index); + } +#pragma nv_diagnostic pop + private: SelfEqual const _self_equal; TwoTableEqual const _two_table_equal; @@ -210,26 +229,26 @@ rmm::device_uvector contains(table_view const& haystack, auto const self_equal = cudf::experimental::row::equality::self_comparator(preprocessed_haystack); auto const two_table_equal = cudf::experimental::row::equality::two_table_comparator( - preprocessed_haystack, preprocessed_needles); + preprocessed_needles, preprocessed_haystack); // The output vector. auto contained = rmm::device_uvector(needles.num_rows(), stream, mr); auto const haystack_iter = cudf::detail::make_counting_transform_iterator( - size_type{0}, cuda::proclaim_return_type([] __device__(auto idx) { - return lhs_index_type{idx}; - })); - auto const needles_iter = cudf::detail::make_counting_transform_iterator( size_type{0}, cuda::proclaim_return_type([] __device__(auto idx) { return rhs_index_type{idx}; })); + auto const needles_iter = cudf::detail::make_counting_transform_iterator( + size_type{0}, cuda::proclaim_return_type([] __device__(auto idx) { + return lhs_index_type{idx}; + })); auto const helper_func = [&](auto const& d_self_equal, auto const& d_two_table_equal, auto const& probing_scheme) { auto const d_equal = comparator_adapter{d_self_equal, d_two_table_equal}; auto set = cuco::static_set{cuco::extent{compute_hash_table_size(haystack.num_rows())}, - cuco::empty_key{lhs_index_type{-1}}, + cuco::empty_key{rhs_index_type{-1}}, d_equal, probing_scheme, {}, diff --git a/cpp/src/text/bpe/byte_pair_encoding.cuh b/cpp/src/text/bpe/byte_pair_encoding.cuh index 2ad22fd4e46..3bb574748b6 100644 --- a/cpp/src/text/bpe/byte_pair_encoding.cuh +++ b/cpp/src/text/bpe/byte_pair_encoding.cuh @@ -96,6 +96,14 @@ struct bpe_equal { auto const right = d_strings.element(lhs + 1); return (left == rhs.first) && (right == rhs.second); } + // used by find + __device__ bool operator()(merge_pair_type const& lhs, cudf::size_type rhs) const noexcept + { + rhs *= 2; + auto const left = d_strings.element(rhs); + auto const right = d_strings.element(rhs + 1); + return (left == lhs.first) && (right == lhs.second); + } }; using bpe_probe_scheme = cuco::linear_probing<1, bpe_hasher>; @@ -154,6 +162,11 @@ struct mp_equal { auto const left = d_strings.element(lhs); return left == rhs; } + __device__ bool operator()(cudf::string_view const& lhs, cudf::size_type rhs) const noexcept + { + auto const right = d_strings.element(rhs); + return lhs == right; + } }; using mp_probe_scheme = cuco::linear_probing<1, mp_hasher>; diff --git a/cpp/src/text/vocabulary_tokenize.cu b/cpp/src/text/vocabulary_tokenize.cu index f012f7ce09a..ea09f5d17af 100644 --- a/cpp/src/text/vocabulary_tokenize.cu +++ b/cpp/src/text/vocabulary_tokenize.cu @@ -86,10 +86,18 @@ struct vocab_equal { return lhs == rhs; // all rows are expected to be unique } // used by find + // suppress "function was declared but never referenced warning" +#pragma nv_diagnostic push +#pragma nv_diag_suppress 177 __device__ bool operator()(cudf::size_type lhs, cudf::string_view const& rhs) const noexcept { return d_strings.element(lhs) == rhs; } + __device__ bool operator()(cudf::string_view const& lhs, cudf::size_type rhs) const noexcept + { + return d_strings.element(rhs) == lhs; + } +#pragma nv_diagnostic pop }; using probe_scheme = cuco::linear_probing<1, vocab_hasher>;