From 72694d22cff81e629f24dd64d7f9993dc5cc27fc Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Thu, 21 Oct 2021 16:25:19 -0700 Subject: [PATCH] Refactor MD5 implementation. (#9212) This PR refactors the MD5 hash implementation in libcudf. I used the MD5 code as a reference while working on SHA (extending #6020, PR #9215 to follow). List of high-level changes: - I moved the implementation of `MD5Hash` and related logic from `include/cudf/detail/utilities/hash_functions.cuh` to `src/hash/md5_hash.cu` because it is only used in that file and nowhere else. We don't need to include and build MD5 in `hash_functions.cuh` for all the collections/sorting/groupby tools that only use Murmur3 variants and `IdentityHash`. (This will be a bigger deal once we add the SHA hash functions, soon to follow this PR, because the size of `hash_functions.cuh` would be substantially larger without this separation.) - I removed an `MD5Hash` constructor that accepted and stored a seed whose value was unused. - Improved use of namespaces. - Use named constants instead of magic numbers. - Introduced a `hash_circular_buffer` and refactored dispatch logic. No changes were made to the feature scope or public APIs of the MD5 feature, so existing unit tests and bindings should remain the same. Authors: - Bradley Dice (https://github.com/bdice) Approvers: - David Wendt (https://github.com/davidwendt) - Mark Harris (https://github.com/harrism) - Jake Hemstad (https://github.com/jrhemstad) - Vyas Ramasubramani (https://github.com/vyasr) URL: https://github.com/rapidsai/cudf/pull/9212 --- .../cudf/detail/utilities/hash_functions.cuh | 312 +---------------- cpp/src/hash/hash_constants.hpp | 64 ---- cpp/src/hash/md5_hash.cu | 322 ++++++++++++++++-- 3 files changed, 294 insertions(+), 404 deletions(-) delete mode 100644 cpp/src/hash/hash_constants.hpp diff --git a/cpp/include/cudf/detail/utilities/hash_functions.cuh b/cpp/include/cudf/detail/utilities/hash_functions.cuh index 65deadd6cd0..ebb21492be9 100644 --- a/cpp/include/cudf/detail/utilities/hash_functions.cuh +++ b/cpp/include/cudf/detail/utilities/hash_functions.cuh @@ -21,110 +21,27 @@ #include #include #include -#include using hash_value_type = uint32_t; namespace cudf { namespace detail { -namespace { -/** - * @brief Core MD5 algorithm implementation. Processes a single 512-bit chunk, - * updating the hash value so far. Does not zero out the buffer contents. - */ -void CUDA_DEVICE_CALLABLE md5_hash_step(md5_intermediate_data* hash_state) -{ - uint32_t A = hash_state->hash_value[0]; - uint32_t B = hash_state->hash_value[1]; - uint32_t C = hash_state->hash_value[2]; - uint32_t D = hash_state->hash_value[3]; - - for (unsigned int j = 0; j < 64; j++) { - uint32_t F; - uint32_t g; - switch (j / 16) { - case 0: - F = (B & C) | ((~B) & D); - g = j; - break; - case 1: - F = (D & B) | ((~D) & C); - g = (5 * j + 1) % 16; - break; - case 2: - F = B ^ C ^ D; - g = (3 * j + 5) % 16; - break; - case 3: - F = C ^ (B | (~D)); - g = (7 * j) % 16; - break; - } - - uint32_t buffer_element_as_int; - std::memcpy(&buffer_element_as_int, hash_state->buffer + g * 4, 4); - F = F + A + md5_hash_constants[j] + buffer_element_as_int; - A = D; - D = C; - C = B; - B = B + __funnelshift_l(F, F, md5_shift_constants[((j / 16) * 4) + (j % 4)]); - } - - hash_state->hash_value[0] += A; - hash_state->hash_value[1] += B; - hash_state->hash_value[2] += C; - hash_state->hash_value[3] += D; - - hash_state->buffer_length = 0; -} /** - * @brief Core MD5 element processing function + * Normalization of floating point NaNs and zeros, passthrough for all other values. */ -template -void CUDA_DEVICE_CALLABLE md5_process(TKey const& key, md5_intermediate_data* hash_state) +template +T CUDA_DEVICE_CALLABLE normalize_nans_and_zeros(T const& key) { - uint32_t const len = sizeof(TKey); - uint8_t const* data = reinterpret_cast(&key); - hash_state->message_length += len; - - // 64 bytes for the number of byt es processed in a given step - constexpr int md5_chunk_size = 64; - if (hash_state->buffer_length + len < md5_chunk_size) { - std::memcpy(hash_state->buffer + hash_state->buffer_length, data, len); - hash_state->buffer_length += len; - } else { - uint32_t copylen = md5_chunk_size - hash_state->buffer_length; - - std::memcpy(hash_state->buffer + hash_state->buffer_length, data, copylen); - md5_hash_step(hash_state); - - while (len > md5_chunk_size + copylen) { - std::memcpy(hash_state->buffer, data + copylen, md5_chunk_size); - md5_hash_step(hash_state); - copylen += md5_chunk_size; + if constexpr (is_floating_point()) { + if (isnan(key)) { + return std::numeric_limits::quiet_NaN(); + } else if (key == T{0.0}) { + return T{0.0}; } - - std::memcpy(hash_state->buffer, data + copylen, len - copylen); - hash_state->buffer_length = len - copylen; - } -} - -/** - * Normalization of floating point NANs and zeros helper - */ -template ::value>* = nullptr> -T CUDA_DEVICE_CALLABLE normalize_nans_and_zeros_helper(T key) -{ - if (isnan(key)) { - return std::numeric_limits::quiet_NaN(); - } else if (key == T{0.0}) { - return T{0.0}; - } else { - return key; } + return key; } -} // namespace /** * Modified GPU implementation of @@ -149,217 +66,6 @@ void CUDA_DEVICE_CALLABLE uint32ToLowercaseHexString(uint32_t num, char* destina std::memcpy(destination, reinterpret_cast(&x), 8); } -struct MD5ListHasher { - template ()>* = nullptr> - void __device__ operator()(column_device_view data_col, - size_type offset_begin, - size_type offset_end, - md5_intermediate_data* hash_state) const - { - cudf_assert(false && "MD5 Unsupported chrono type column"); - } - - template ()>* = nullptr> - void __device__ operator()(column_device_view data_col, - size_type offset_begin, - size_type offset_end, - md5_intermediate_data* hash_state) const - { - cudf_assert(false && "MD5 Unsupported non-fixed-width type column"); - } - - template ()>* = nullptr> - void __device__ operator()(column_device_view data_col, - size_type offset_begin, - size_type offset_end, - md5_intermediate_data* hash_state) const - { - for (int i = offset_begin; i < offset_end; i++) { - if (!data_col.is_null(i)) { - md5_process(normalize_nans_and_zeros_helper(data_col.element(i)), hash_state); - } - } - } - - template < - typename T, - std::enable_if_t() && !is_floating_point() && !is_chrono()>* = nullptr> - void CUDA_DEVICE_CALLABLE operator()(column_device_view data_col, - size_type offset_begin, - size_type offset_end, - md5_intermediate_data* hash_state) const - { - for (int i = offset_begin; i < offset_end; i++) { - if (!data_col.is_null(i)) md5_process(data_col.element(i), hash_state); - } - } -}; - -template <> -void CUDA_DEVICE_CALLABLE -MD5ListHasher::operator()(column_device_view data_col, - size_type offset_begin, - size_type offset_end, - md5_intermediate_data* hash_state) const -{ - for (int i = offset_begin; i < offset_end; i++) { - if (!data_col.is_null(i)) { - string_view key = data_col.element(i); - uint32_t const len = static_cast(key.size_bytes()); - uint8_t const* data = reinterpret_cast(key.data()); - - hash_state->message_length += len; - - if (hash_state->buffer_length + len < 64) { - std::memcpy(hash_state->buffer + hash_state->buffer_length, data, len); - hash_state->buffer_length += len; - } else { - uint32_t copylen = 64 - hash_state->buffer_length; - std::memcpy(hash_state->buffer + hash_state->buffer_length, data, copylen); - md5_hash_step(hash_state); - - while (len > 64 + copylen) { - std::memcpy(hash_state->buffer, data + copylen, 64); - md5_hash_step(hash_state); - copylen += 64; - } - - std::memcpy(hash_state->buffer, data + copylen, len - copylen); - hash_state->buffer_length = len - copylen; - } - } - } -} - -struct MD5Hash { - MD5Hash() = default; - constexpr MD5Hash(uint32_t seed) : m_seed(seed) {} - - void __device__ finalize(md5_intermediate_data* hash_state, char* result_location) const - { - auto const full_length = (static_cast(hash_state->message_length)) << 3; - thrust::fill_n(thrust::seq, hash_state->buffer + hash_state->buffer_length, 1, 0x80); - - // 64 bytes for the number of bytes processed in a given step - constexpr int md5_chunk_size = 64; - // 8 bytes for the total message length, appended to the end of the last chunk processed - constexpr int message_length_size = 8; - // 1 byte for the end of the message flag - constexpr int end_of_message_size = 1; - if (hash_state->buffer_length + message_length_size + end_of_message_size <= md5_chunk_size) { - thrust::fill_n( - thrust::seq, - hash_state->buffer + hash_state->buffer_length + 1, - (md5_chunk_size - message_length_size - end_of_message_size - hash_state->buffer_length), - 0x00); - } else { - thrust::fill_n(thrust::seq, - hash_state->buffer + hash_state->buffer_length + 1, - (md5_chunk_size - hash_state->buffer_length), - 0x00); - md5_hash_step(hash_state); - - thrust::fill_n(thrust::seq, hash_state->buffer, md5_chunk_size - message_length_size, 0x00); - } - - std::memcpy(hash_state->buffer + md5_chunk_size - message_length_size, - reinterpret_cast(&full_length), - message_length_size); - md5_hash_step(hash_state); - -#pragma unroll - for (int i = 0; i < 4; ++i) - uint32ToLowercaseHexString(hash_state->hash_value[i], result_location + (8 * i)); - } - - template ()>* = nullptr> - void __device__ operator()(column_device_view col, - size_type row_index, - md5_intermediate_data* hash_state) const - { - cudf_assert(false && "MD5 Unsupported chrono type column"); - } - - template ()>* = nullptr> - void __device__ operator()(column_device_view col, - size_type row_index, - md5_intermediate_data* hash_state) const - { - cudf_assert(false && "MD5 Unsupported non-fixed-width type column"); - } - - template ()>* = nullptr> - void __device__ operator()(column_device_view col, - size_type row_index, - md5_intermediate_data* hash_state) const - { - md5_process(normalize_nans_and_zeros_helper(col.element(row_index)), hash_state); - } - - template < - typename T, - std::enable_if_t() && !is_floating_point() && !is_chrono()>* = nullptr> - void CUDA_DEVICE_CALLABLE operator()(column_device_view col, - size_type row_index, - md5_intermediate_data* hash_state) const - { - md5_process(col.element(row_index), hash_state); - } - - private: - uint32_t m_seed{cudf::DEFAULT_HASH_SEED}; -}; - -template <> -void CUDA_DEVICE_CALLABLE MD5Hash::operator()(column_device_view col, - size_type row_index, - md5_intermediate_data* hash_state) const -{ - string_view key = col.element(row_index); - uint32_t const len = static_cast(key.size_bytes()); - uint8_t const* data = reinterpret_cast(key.data()); - - hash_state->message_length += len; - - if (hash_state->buffer_length + len < 64) { - std::memcpy(hash_state->buffer + hash_state->buffer_length, data, len); - hash_state->buffer_length += len; - } else { - uint32_t copylen = 64 - hash_state->buffer_length; - std::memcpy(hash_state->buffer + hash_state->buffer_length, data, copylen); - md5_hash_step(hash_state); - - while (len > 64 + copylen) { - std::memcpy(hash_state->buffer, data + copylen, 64); - md5_hash_step(hash_state); - copylen += 64; - } - - std::memcpy(hash_state->buffer, data + copylen, len - copylen); - hash_state->buffer_length = len - copylen; - } -} - -template <> -void CUDA_DEVICE_CALLABLE MD5Hash::operator()(column_device_view col, - size_type row_index, - md5_intermediate_data* hash_state) const -{ - static constexpr size_type offsets_column_index{0}; - static constexpr size_type data_column_index{1}; - - column_device_view offsets = col.child(offsets_column_index); - column_device_view data = col.child(data_column_index); - - if (data.type().id() == type_id::LIST) cudf_assert(false && "Nested list unsupported"); - - cudf::type_dispatcher(data.type(), - MD5ListHasher{}, - data, - offsets.element(row_index), - offsets.element(row_index + 1), - hash_state); -} } // namespace detail } // namespace cudf diff --git a/cpp/src/hash/hash_constants.hpp b/cpp/src/hash/hash_constants.hpp deleted file mode 100644 index 0a5a9e0be93..00000000000 --- a/cpp/src/hash/hash_constants.hpp +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright (c) 2020, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -namespace cudf { -namespace detail { - -struct md5_intermediate_data { - uint64_t message_length = 0; - uint32_t buffer_length = 0; - uint32_t hash_value[4] = {0x67452301, 0xefcdab89, 0x98badcfe, 0x10325476}; - uint8_t buffer[64]; -}; - -// Type for the shift constants table. -using md5_shift_constants_type = uint32_t; - -__device__ __constant__ md5_shift_constants_type md5_shift_constants[16] = { - 7, - 12, - 17, - 22, - 5, - 9, - 14, - 20, - 4, - 11, - 16, - 23, - 6, - 10, - 15, - 21, -}; - -// Type for the hash constants table. -using md5_hash_constants_type = uint32_t; - -__device__ __constant__ md5_hash_constants_type md5_hash_constants[64] = { - 0xd76aa478, 0xe8c7b756, 0x242070db, 0xc1bdceee, 0xf57c0faf, 0x4787c62a, 0xa8304613, 0xfd469501, - 0x698098d8, 0x8b44f7af, 0xffff5bb1, 0x895cd7be, 0x6b901122, 0xfd987193, 0xa679438e, 0x49b40821, - 0xf61e2562, 0xc040b340, 0x265e5a51, 0xe9b6c7aa, 0xd62f105d, 0x02441453, 0xd8a1e681, 0xe7d3fbc8, - 0x21e1cde6, 0xc33707d6, 0xf4d50d87, 0x455a14ed, 0xa9e3e905, 0xfcefa3f8, 0x676f02d9, 0x8d2a4c8a, - 0xfffa3942, 0x8771f681, 0x6d9d6122, 0xfde5380c, 0xa4beea44, 0x4bdecfa9, 0xf6bb4b60, 0xbebfbc70, - 0x289b7ec6, 0xeaa127fa, 0xd4ef3085, 0x04881d05, 0xd9d4d039, 0xe6db99e5, 0x1fa27cf8, 0xc4ac5665, - 0xf4292244, 0x432aff97, 0xab9423a7, 0xfc93a039, 0x655b59c3, 0x8f0ccc92, 0xffeff47d, 0x85845dd1, - 0x6fa87e4f, 0xfe2ce6e0, 0xa3014314, 0x4e0811a1, 0xf7537e82, 0xbd3af235, 0x2ad7d2bb, 0xeb86d391, -}; -} // namespace detail -} // namespace cudf diff --git a/cpp/src/hash/md5_hash.cu b/cpp/src/hash/md5_hash.cu index 973f3204c37..d0e47d93bc6 100644 --- a/cpp/src/hash/md5_hash.cu +++ b/cpp/src/hash/md5_hash.cu @@ -15,10 +15,14 @@ */ #include #include +#include #include +#include #include #include +#include #include +#include #include #include @@ -26,71 +30,315 @@ #include #include +#include + namespace cudf { + +namespace detail { + namespace { +// The MD5 algorithm and its hash/shift constants are officially specified in +// RFC 1321. For convenience, these values can also be found on Wikipedia: +// https://en.wikipedia.org/wiki/MD5 +const __constant__ uint32_t md5_shift_constants[16] = { + 7, 12, 17, 22, 5, 9, 14, 20, 4, 11, 16, 23, 6, 10, 15, 21}; + +const __constant__ uint32_t md5_hash_constants[64] = { + 0xd76aa478, 0xe8c7b756, 0x242070db, 0xc1bdceee, 0xf57c0faf, 0x4787c62a, 0xa8304613, 0xfd469501, + 0x698098d8, 0x8b44f7af, 0xffff5bb1, 0x895cd7be, 0x6b901122, 0xfd987193, 0xa679438e, 0x49b40821, + 0xf61e2562, 0xc040b340, 0x265e5a51, 0xe9b6c7aa, 0xd62f105d, 0x02441453, 0xd8a1e681, 0xe7d3fbc8, + 0x21e1cde6, 0xc33707d6, 0xf4d50d87, 0x455a14ed, 0xa9e3e905, 0xfcefa3f8, 0x676f02d9, 0x8d2a4c8a, + 0xfffa3942, 0x8771f681, 0x6d9d6122, 0xfde5380c, 0xa4beea44, 0x4bdecfa9, 0xf6bb4b60, 0xbebfbc70, + 0x289b7ec6, 0xeaa127fa, 0xd4ef3085, 0x04881d05, 0xd9d4d039, 0xe6db99e5, 0x1fa27cf8, 0xc4ac5665, + 0xf4292244, 0x432aff97, 0xab9423a7, 0xfc93a039, 0x655b59c3, 0x8f0ccc92, 0xffeff47d, 0x85845dd1, + 0x6fa87e4f, 0xfe2ce6e0, 0xa3014314, 0x4e0811a1, 0xf7537e82, 0xbd3af235, 0x2ad7d2bb, 0xeb86d391, +}; + +template +struct hash_circular_buffer { + uint8_t storage[capacity]; + uint8_t* cur; + int available_space{capacity}; + hash_step_callable hash_step; + + CUDA_DEVICE_CALLABLE hash_circular_buffer(hash_step_callable hash_step) + : cur{storage}, hash_step{hash_step} + { + } + + CUDA_DEVICE_CALLABLE void put(uint8_t const* in, int size) + { + int copy_start = 0; + while (size >= available_space) { + // The buffer will be filled by this chunk of data. Copy a chunk of the + // data to fill the buffer and trigger a hash step. + memcpy(cur, in + copy_start, available_space); + hash_step(storage); + size -= available_space; + copy_start += available_space; + cur = storage; + available_space = capacity; + } + // The buffer will not be filled by the remaining data. That is, `size >= 0 + // && size < capacity`. We copy the remaining data into the buffer but do + // not trigger a hash step. + memcpy(cur, in + copy_start, size); + cur += size; + available_space -= size; + } + + CUDA_DEVICE_CALLABLE void pad(int const space_to_leave) + { + if (space_to_leave > available_space) { + memset(cur, 0x00, available_space); + hash_step(storage); + cur = storage; + available_space = capacity; + } + memset(cur, 0x00, available_space - space_to_leave); + cur += available_space - space_to_leave; + available_space = space_to_leave; + } + + CUDA_DEVICE_CALLABLE const uint8_t& operator[](int idx) const { return storage[idx]; } +}; + +// Get a uint8_t pointer to a column element and its size as a pair. +template +auto CUDA_DEVICE_CALLABLE get_element_pointer_and_size(Element const& element) +{ + if constexpr (is_fixed_width() && !is_chrono()) { + return thrust::make_pair(reinterpret_cast(&element), sizeof(Element)); + } else { + cudf_assert(false && "Unsupported type."); + } +} + +template <> +auto CUDA_DEVICE_CALLABLE get_element_pointer_and_size(string_view const& element) +{ + return thrust::make_pair(reinterpret_cast(element.data()), element.size_bytes()); +} + +struct MD5Hasher { + static constexpr int message_chunk_size = 64; + + CUDA_DEVICE_CALLABLE MD5Hasher(char* result_location) + : result_location(result_location), buffer(md5_hash_step{hash_values}) + { + } + + CUDA_DEVICE_CALLABLE ~MD5Hasher() + { + // On destruction, finalize the message buffer and write out the current + // hexadecimal hash value to the result location. + // Add a one byte flag 0b10000000 to signal the end of the message. + uint8_t constexpr end_of_message = 0x80; + // The message length is appended to the end of the last chunk processed. + uint64_t const message_length_in_bits = message_length * 8; + + buffer.put(&end_of_message, sizeof(end_of_message)); + buffer.pad(sizeof(message_length_in_bits)); + buffer.put(reinterpret_cast(&message_length_in_bits), + sizeof(message_length_in_bits)); + + for (int i = 0; i < 4; ++i) { + uint32ToLowercaseHexString(hash_values[i], result_location + (8 * i)); + } + } + + MD5Hasher(const MD5Hasher&) = delete; + MD5Hasher& operator=(const MD5Hasher&) = delete; + MD5Hasher(MD5Hasher&&) = delete; + MD5Hasher& operator=(MD5Hasher&&) = delete; + + template + void CUDA_DEVICE_CALLABLE process(Element const& element) + { + auto const normalized_element = normalize_nans_and_zeros(element); + auto const [element_ptr, size] = get_element_pointer_and_size(normalized_element); + buffer.put(element_ptr, size); + message_length += size; + } + + /** + * @brief Core MD5 algorithm implementation. Processes a single 64-byte chunk, + * updating the hash value so far. Does not zero out the buffer contents. + */ + struct md5_hash_step { + uint32_t (&hash_values)[4]; + + void CUDA_DEVICE_CALLABLE operator()(const uint8_t (&buffer)[message_chunk_size]) + { + uint32_t A = hash_values[0]; + uint32_t B = hash_values[1]; + uint32_t C = hash_values[2]; + uint32_t D = hash_values[3]; + + for (int j = 0; j < message_chunk_size; j++) { + uint32_t F; + uint32_t g; + // No default case is needed because j < 64. j / 16 is always 0, 1, 2, or 3. + switch (j / 16) { + case 0: + F = (B & C) | ((~B) & D); + g = j; + break; + case 1: + F = (D & B) | ((~D) & C); + g = (5 * j + 1) % 16; + break; + case 2: + F = B ^ C ^ D; + g = (3 * j + 5) % 16; + break; + case 3: + F = C ^ (B | (~D)); + g = (7 * j) % 16; + break; + } + + uint32_t buffer_element_as_int; + memcpy(&buffer_element_as_int, &buffer[g * 4], 4); + F = F + A + md5_hash_constants[j] + buffer_element_as_int; + A = D; + D = C; + C = B; + B = B + __funnelshift_l(F, F, md5_shift_constants[((j / 16) * 4) + (j % 4)]); + } + + hash_values[0] += A; + hash_values[1] += B; + hash_values[2] += C; + hash_values[3] += D; + } + }; + + char* result_location; + hash_circular_buffer buffer; + uint64_t message_length = 0; + uint32_t hash_values[4] = {0x67452301, 0xefcdab89, 0x98badcfe, 0x10325476}; +}; + +template +struct HasherDispatcher { + Hasher* hasher; + column_device_view const& input_col; + + CUDA_DEVICE_CALLABLE HasherDispatcher(Hasher* hasher, column_device_view const& input_col) + : hasher{hasher}, input_col{input_col} + { + } + + template + void CUDA_DEVICE_CALLABLE operator()(size_type const row_index) const + { + if constexpr ((is_fixed_width() && !is_chrono()) || + std::is_same_v) { + hasher->process(input_col.element(row_index)); + } else { + cudf_assert(false && "Unsupported type for hash function."); + } + } +}; + +template +struct ListHasherDispatcher { + Hasher* hasher; + column_device_view const& input_col; + + CUDA_DEVICE_CALLABLE ListHasherDispatcher(Hasher* hasher, column_device_view const& input_col) + : hasher{hasher}, input_col{input_col} + { + } + + template + void CUDA_DEVICE_CALLABLE operator()(size_type const offset_begin, + size_type const offset_end) const + { + if constexpr ((is_fixed_width() && !is_chrono()) || + std::is_same_v) { + for (size_type i = offset_begin; i < offset_end; i++) { + if (input_col.is_valid(i)) { hasher->process(input_col.element(i)); } + } + } else { + cudf_assert(false && "Unsupported type for hash function."); + } + } +}; + // MD5 supported leaf data type check -bool md5_type_check(data_type dt) +constexpr inline bool md5_leaf_type_check(data_type dt) { - return !is_chrono(dt) && (is_fixed_width(dt) || (dt.id() == type_id::STRING)); + return (is_fixed_width(dt) && !is_chrono(dt)) || (dt.id() == type_id::STRING); } } // namespace -namespace detail { - std::unique_ptr md5_hash(table_view const& input, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { if (input.num_columns() == 0 || input.num_rows() == 0) { - const string_scalar string_128bit("d41d8cd98f00b204e9orig98ecf8427e"); - auto output = make_column_from_scalar(string_128bit, input.num_rows(), stream, mr); - return output; + // Return the MD5 hash of a zero-length input. + string_scalar const string_128bit("d41d8cd98f00b204e9orig98ecf8427e"); + return make_column_from_scalar(string_128bit, input.num_rows(), stream, mr); } // Accepts string and fixed width columns, or single layer list columns holding those types - CUDF_EXPECTS( - std::all_of(input.begin(), - input.end(), - [](auto col) { - return md5_type_check(col.type()) || - (col.type().id() == type_id::LIST && md5_type_check(col.child(1).type())); - }), - "MD5 unsupported column type"); + CUDF_EXPECTS(std::all_of(input.begin(), + input.end(), + [](auto const& col) { + if (col.type().id() == type_id::LIST) { + return md5_leaf_type_check(lists_column_view(col).child().type()); + } + return md5_leaf_type_check(col.type()); + }), + "Unsupported column type for hash function."); + // Digest size in bytes + auto constexpr digest_size = 32; // Result column allocation and creation - auto begin = thrust::make_constant_iterator(32); + auto begin = thrust::make_constant_iterator(digest_size); auto offsets_column = cudf::strings::detail::make_offsets_child_column(begin, begin + input.num_rows(), stream, mr); - auto chars_column = strings::detail::create_chars_child_column(input.num_rows() * 32, stream, mr); - auto chars_view = chars_column->mutable_view(); - auto d_chars = chars_view.data(); + auto chars_column = + strings::detail::create_chars_child_column(input.num_rows() * digest_size, stream, mr); + auto chars_view = chars_column->mutable_view(); + auto d_chars = chars_view.data(); rmm::device_buffer null_mask{0, stream, mr}; auto const device_input = table_device_view::create(input, stream); // Hash each row, hashing each element sequentially left to right - thrust::for_each(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(input.num_rows()), - [d_chars, device_input = *device_input] __device__(auto row_index) { - md5_intermediate_data hash_state; - MD5Hash hasher = MD5Hash{}; - for (int col_index = 0; col_index < device_input.num_columns(); col_index++) { - if (device_input.column(col_index).is_valid(row_index)) { - cudf::type_dispatcher( - device_input.column(col_index).type(), - hasher, - device_input.column(col_index), - row_index, - &hash_state); - } - } - hasher.finalize(&hash_state, d_chars + (row_index * 32)); - }); + thrust::for_each( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(input.num_rows()), + [d_chars, device_input = *device_input] __device__(auto row_index) { + MD5Hasher hasher(d_chars + (row_index * digest_size)); + for (auto const& col : device_input) { + if (col.is_valid(row_index)) { + if (col.type().id() == type_id::LIST) { + auto const data_col = col.child(lists_column_view::child_column_index); + auto const offsets = col.child(lists_column_view::offsets_column_index); + if (data_col.type().id() == type_id::LIST) { + cudf_assert(false && "Nested list unsupported"); + } + auto const offset_begin = offsets.element(row_index); + auto const offset_end = offsets.element(row_index + 1); + cudf::type_dispatcher( + data_col.type(), ListHasherDispatcher(&hasher, data_col), offset_begin, offset_end); + } else { + cudf::type_dispatcher( + col.type(), HasherDispatcher(&hasher, col), row_index); + } + } + } + }); return make_strings_column( input.num_rows(), std::move(offsets_column), std::move(chars_column), 0, std::move(null_mask));