From a20460c160560d38e51bcf35a5ffb5848992c9c4 Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Tue, 27 Aug 2024 09:35:34 -0700 Subject: [PATCH] Expose hash_function member function (#587) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Close #582 This PR exposes `hash_function` member function for cuco hash tables. --------- Co-authored-by: Daniel Jünger --- .../open_addressing/open_addressing_impl.cuh | 11 +++++++++ .../open_addressing_ref_impl.cuh | 14 ++++++++++- .../probing_scheme/probing_scheme_impl.inl | 15 ++++++++++++ include/cuco/detail/static_map/static_map.inl | 15 ++++++++++++ .../cuco/detail/static_map/static_map_ref.inl | 20 ++++++++++++++++ .../static_multimap/static_multimap.inl | 16 +++++++++++++ .../static_multimap/static_multimap_ref.inl | 20 ++++++++++++++++ .../static_multiset/static_multiset.inl | 14 +++++++++++ .../static_multiset/static_multiset_ref.inl | 18 +++++++++++++++ include/cuco/detail/static_set/static_set.inl | 14 +++++++++++ .../cuco/detail/static_set/static_set_ref.inl | 18 +++++++++++++++ include/cuco/probing_scheme.cuh | 23 +++++++++++++++++-- include/cuco/static_map.cuh | 8 +++++++ include/cuco/static_map_ref.cuh | 8 +++++++ include/cuco/static_multimap.cuh | 8 +++++++ include/cuco/static_multimap_ref.cuh | 8 +++++++ include/cuco/static_multiset.cuh | 8 +++++++ include/cuco/static_multiset_ref.cuh | 8 +++++++ include/cuco/static_set.cuh | 8 +++++++ include/cuco/static_set_ref.cuh | 8 +++++++ 20 files changed, 259 insertions(+), 3 deletions(-) diff --git a/include/cuco/detail/open_addressing/open_addressing_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_impl.cuh index 176239be2..a8eff9036 100644 --- a/include/cuco/detail/open_addressing/open_addressing_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_impl.cuh @@ -100,6 +100,7 @@ class open_addressing_impl { using storage_ref_type = typename storage_type::ref_type; ///< Non-owning window storage ref type using probing_scheme_type = ProbingScheme; ///< Probe scheme type + using hasher = typename probing_scheme_type::hasher; ///< Hash function type /** * @brief Constructs a statically-sized open addressing data structure with the specified initial @@ -933,6 +934,16 @@ class open_addressing_impl { return probing_scheme_; } + /** + * @brief Gets the function(s) used to hash keys + * + * @return The function(s) used to hash keys + */ + [[nodiscard]] constexpr hasher hash_function() const noexcept + { + return this->probing_scheme().hash_function(); + } + /** * @brief Gets the container allocator. * diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index 12a306a71..f4c20f829 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -109,6 +109,7 @@ class open_addressing_ref_impl { public: using key_type = Key; ///< Key type using probing_scheme_type = ProbingScheme; ///< Type of probing scheme + using hasher = typename probing_scheme_type::hasher; ///< Hash function type using storage_ref_type = StorageRef; ///< Type of storage ref using window_type = typename storage_ref_type::window_type; ///< Window type using value_type = typename storage_ref_type::value_type; ///< Storage element type @@ -233,11 +234,22 @@ class open_addressing_ref_impl { * * @return The probing scheme used for the container */ - [[nodiscard]] __device__ constexpr probing_scheme_type const& probing_scheme() const noexcept + [[nodiscard]] __host__ __device__ constexpr probing_scheme_type const& probing_scheme() + const noexcept { return probing_scheme_; } + /** + * @brief Gets the function(s) used to hash keys + * + * @return The function(s) used to hash keys + */ + [[nodiscard]] __host__ __device__ constexpr hasher hash_function() const noexcept + { + return this->probing_scheme().hash_function(); + } + /** * @brief Gets the non-owning storage ref. * diff --git a/include/cuco/detail/probing_scheme/probing_scheme_impl.inl b/include/cuco/detail/probing_scheme/probing_scheme_impl.inl index 61670d7be..047ec7987 100644 --- a/include/cuco/detail/probing_scheme/probing_scheme_impl.inl +++ b/include/cuco/detail/probing_scheme/probing_scheme_impl.inl @@ -127,6 +127,13 @@ __host__ __device__ constexpr auto linear_probing::operator()( upper_bound}; } +template +__host__ __device__ constexpr linear_probing::hasher +linear_probing::hash_function() const noexcept +{ + return hash_; +} + template __host__ __device__ constexpr double_hashing::double_hashing( Hash1 const& hash1, Hash2 const& hash2) @@ -192,4 +199,12 @@ __host__ __device__ constexpr auto double_hashing::operato cg_size), upper_bound}; // TODO use fast_int operator } + +template +__host__ __device__ constexpr double_hashing::hasher +double_hashing::hash_function() const noexcept +{ + return {hash1_, hash2_}; +} + } // namespace cuco diff --git a/include/cuco/detail/static_map/static_map.inl b/include/cuco/detail/static_map/static_map.inl index a90315187..e2915e1fd 100644 --- a/include/cuco/detail/static_map/static_map.inl +++ b/include/cuco/detail/static_map/static_map.inl @@ -724,6 +724,21 @@ static_map:: return impl_->key_eq(); } +template +constexpr static_map::hasher +static_map::hash_function() + const noexcept +{ + return impl_->hash_function(); +} + template return this->impl_.key_eq(); } +template +__host__ __device__ constexpr static_map_ref::hasher +static_map_ref::hash_function() + const noexcept +{ + return impl_.hash_function(); +} + template key_eq(); } +template +constexpr static_multimap:: + hasher + static_multimap:: + hash_function() const noexcept +{ + return impl_->hash_function(); +} + template +__host__ __device__ constexpr static_multimap_ref::hasher +static_multimap_ref:: + hash_function() const noexcept +{ + return impl_.hash_function(); +} + template key_eq(); } +template +constexpr static_multiset::hasher +static_multiset::hash_function() + const noexcept +{ + return impl_->hash_function(); +} + template impl_.key_eq(); } +template +__host__ __device__ constexpr static_multiset_ref::hasher +static_multiset_ref::hash_function() + const noexcept +{ + return impl_.hash_function(); +} + template ::key return impl_->key_eq(); } +template +constexpr static_set::hasher +static_set::hash_function() + const noexcept +{ + return impl_->hash_function(); +} + template ::k return this->impl_.key_eq(); } +template +__host__ __device__ constexpr static_set_ref::hasher +static_set_ref::hash_function() + const noexcept +{ + return impl_.hash_function(); +} + template #include +#include #include #include @@ -37,10 +38,12 @@ namespace cuco { */ template class linear_probing : private detail::probing_scheme_base { - public: using probing_scheme_base_type = detail::probing_scheme_base; ///< The base probe scheme type + + public: using probing_scheme_base_type::cg_size; + using hasher = Hash; ///< Hash function type /** *@brief Constructs linear probing scheme with the hasher callable. @@ -93,6 +96,13 @@ class linear_probing : private detail::probing_scheme_base { ProbeKey const& probe_key, Extent upper_bound) const noexcept; + /** + * @brief Gets the function used to hash keys + * + * @return The function used to hash keys + */ + __host__ __device__ constexpr hasher hash_function() const noexcept; + private: Hash hash_; }; @@ -113,10 +123,12 @@ class linear_probing : private detail::probing_scheme_base { */ template class double_hashing : private detail::probing_scheme_base { - public: using probing_scheme_base_type = detail::probing_scheme_base; ///< The base probe scheme type + + public: using probing_scheme_base_type::cg_size; + using hasher = cuda::std::tuple; ///< Hash function type /** *@brief Constructs double hashing probing scheme with the two hasher callables. @@ -195,6 +207,13 @@ class double_hashing : private detail::probing_scheme_base { ProbeKey const& probe_key, Extent upper_bound) const noexcept; + /** + * @brief Gets the functions used to hash keys + * + * @return The functions used to hash keys + */ + __host__ __device__ constexpr hasher hash_function() const noexcept; + private: Hash1 hash1_; Hash2 hash2_; diff --git a/include/cuco/static_map.cuh b/include/cuco/static_map.cuh index efc202b4f..fc7dc088d 100644 --- a/include/cuco/static_map.cuh +++ b/include/cuco/static_map.cuh @@ -127,6 +127,7 @@ class static_map { /// Non-owning window storage ref type using storage_ref_type = typename impl_type::storage_ref_type; using probing_scheme_type = typename impl_type::probing_scheme_type; ///< Probing scheme type + using hasher = typename probing_scheme_type::hasher; ///< Hash function type using mapped_type = T; ///< Payload type template @@ -959,6 +960,13 @@ class static_map { */ [[nodiscard]] constexpr key_equal key_eq() const noexcept; + /** + * @brief Gets the function(s) used to hash keys + * + * @return The function(s) used to hash keys + */ + [[nodiscard]] constexpr hasher hash_function() const noexcept; + /** * @brief Get device ref with operators. * diff --git a/include/cuco/static_map_ref.cuh b/include/cuco/static_map_ref.cuh index e12bdb6f2..1da1e501a 100644 --- a/include/cuco/static_map_ref.cuh +++ b/include/cuco/static_map_ref.cuh @@ -86,6 +86,7 @@ class static_map_ref using key_type = Key; ///< Key type using mapped_type = T; ///< Mapped type using probing_scheme_type = ProbingScheme; ///< Type of probing scheme + using hasher = typename probing_scheme_type::hasher; ///< Hash function type using storage_ref_type = StorageRef; ///< Type of storage ref using window_type = typename storage_ref_type::window_type; ///< Window type using value_type = typename storage_ref_type::value_type; ///< Storage element type @@ -190,6 +191,13 @@ class static_map_ref */ [[nodiscard]] __host__ __device__ constexpr key_equal key_eq() const noexcept; + /** + * @brief Gets the function(s) used to hash keys + * + * @return The function(s) used to hash keys + */ + [[nodiscard]] __host__ __device__ constexpr hasher hash_function() const noexcept; + /** * @brief Returns a const_iterator to one past the last slot. * diff --git a/include/cuco/static_multimap.cuh b/include/cuco/static_multimap.cuh index 6eb2a960c..ebf17edba 100644 --- a/include/cuco/static_multimap.cuh +++ b/include/cuco/static_multimap.cuh @@ -130,6 +130,7 @@ class static_multimap { /// Non-owning window storage ref type using storage_ref_type = typename impl_type::storage_ref_type; using probing_scheme_type = typename impl_type::probing_scheme_type; ///< Probing scheme type + using hasher = typename probing_scheme_type::hasher; ///< Hash function type using mapped_type = T; ///< Payload type template @@ -522,6 +523,13 @@ class static_multimap { */ [[nodiscard]] constexpr key_equal key_eq() const noexcept; + /** + * @brief Gets the function(s) used to hash keys + * + * @return The function(s) used to hash keys + */ + [[nodiscard]] constexpr hasher hash_function() const noexcept; + /** * @brief Get device ref with operators. * diff --git a/include/cuco/static_multimap_ref.cuh b/include/cuco/static_multimap_ref.cuh index ac2526285..b23925b86 100644 --- a/include/cuco/static_multimap_ref.cuh +++ b/include/cuco/static_multimap_ref.cuh @@ -83,6 +83,7 @@ class static_multimap_ref using key_type = Key; ///< Key type using mapped_type = T; ///< Mapped type using probing_scheme_type = ProbingScheme; ///< Type of probing scheme + using hasher = typename probing_scheme_type::hasher; ///< Hash function type using storage_ref_type = StorageRef; ///< Type of storage ref using window_type = typename storage_ref_type::window_type; ///< Window type using value_type = typename storage_ref_type::value_type; ///< Storage element type @@ -189,6 +190,13 @@ class static_multimap_ref */ [[nodiscard]] __host__ __device__ constexpr key_equal key_eq() const noexcept; + /** + * @brief Gets the function(s) used to hash keys + * + * @return The function(s) used to hash keys + */ + [[nodiscard]] __host__ __device__ constexpr hasher hash_function() const noexcept; + /** * @brief Returns a const_iterator to one past the last slot. * diff --git a/include/cuco/static_multiset.cuh b/include/cuco/static_multiset.cuh index 90f57f2f4..22cda307f 100644 --- a/include/cuco/static_multiset.cuh +++ b/include/cuco/static_multiset.cuh @@ -100,6 +100,7 @@ class static_multiset { /// Non-owning window storage ref type using storage_ref_type = typename impl_type::storage_ref_type; using probing_scheme_type = typename impl_type::probing_scheme_type; ///< Probing scheme type + using hasher = typename probing_scheme_type::hasher; ///< Hash function type template using ref_type = cuco::static_multiset_ref using ref_type = cuco::static_set_ref