From a61c1b35c75702a7eaa27289076186a1f0629260 Mon Sep 17 00:00:00 2001 From: Ryan Lee Date: Wed, 19 Jan 2022 17:10:14 -0800 Subject: [PATCH] Remove need for special case handling --- .../cudf/detail/utilities/hash_functions.cuh | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/cpp/include/cudf/detail/utilities/hash_functions.cuh b/cpp/include/cudf/detail/utilities/hash_functions.cuh index b4c2cc4ef2d..8a7f4276d05 100644 --- a/cpp/include/cudf/detail/utilities/hash_functions.cuh +++ b/cpp/include/cudf/detail/utilities/hash_functions.cuh @@ -16,6 +16,8 @@ #pragma once +#include + #include #include #include @@ -344,15 +346,16 @@ struct SparkMurmurHash3_32 { result_type __device__ compute_bytes(std::byte const* const data, cudf::size_type const len) const { - cudf::size_type const nblocks = len / 4; - uint32_t h1 = m_seed; - constexpr uint32_t c1 = 0xcc9e2d51; - constexpr uint32_t c2 = 0x1b873593; + constexpr cudf::size_type block_size = sizeof(uint32_t) / sizeof(std::byte); + cudf::size_type const nblocks = len / block_size; + uint32_t h1 = m_seed; + constexpr uint32_t c1 = 0xcc9e2d51; + constexpr uint32_t c2 = 0x1b873593; //---------- // Process all four-byte chunks uint32_t const* const blocks = reinterpret_cast(data); - for (int i = 0; i < nblocks; i++) { + for (cudf::size_type i = 0; i < nblocks; i++) { uint32_t k1 = blocks[i]; k1 *= c1; k1 = rotl32(k1, 15); @@ -364,7 +367,7 @@ struct SparkMurmurHash3_32 { //---------- // Process remaining bytes that do not fill a four-byte chunk using Spark's approach // (does not conform to normal MurmurHash3) - for (int i = nblocks * 4; i < len; i++) { + for (cudf::size_type i = nblocks * 4; i < len; i++) { // We require a two-step cast to get the k1 value from the byte. First, // we must cast to a signed int8_t. Then, the sign bit is preserved when // casting to uint32_t under 2's complement. Java preserves the @@ -448,10 +451,6 @@ hash_value_type __device__ inline SparkMurmurHash3_32::oper bool const is_negative = val < 0; std::byte const zero_value = is_negative ? std::byte{0xff} : std::byte{0x00}; - // Special cases for 0 and -1 which do not shorten correctly. - if (val == 0) { return this->compute(static_cast(0)); } - if (val == static_cast<__int128>(-1)) { return this->compute(static_cast(0xff)); } - // If the value can be represented with a shorter than 16-byte integer, the // leading bytes of the little-endian value are truncated and are not hashed. auto const reverse_begin = thrust::reverse_iterator(data + key_size); @@ -460,7 +459,9 @@ hash_value_type __device__ inline SparkMurmurHash3_32::oper thrust::find_if_not(thrust::seq, reverse_begin, reverse_end, [zero_value](std::byte const& v) { return v == zero_value; }).base(); - cudf::size_type length = thrust::distance(data, first_nonzero_byte); + // Max handles special case of 0 and -1 which would shorten to 0 length otherwise + cudf::size_type length = + std::max(1, static_cast(thrust::distance(data, first_nonzero_byte))); // Preserve the 2's complement sign bit by adding a byte back on if necessary. // e.g. 0x0000ff would shorten to 0x00ff. The 0x00 byte is retained to