Skip to content

Commit

Permalink
Refactor MD5 implementation. (#9212)
Browse files Browse the repository at this point in the history
This PR refactors the MD5 hash implementation in libcudf. I used the MD5 code as a reference while working on SHA (extending #6020, PR #9215 to follow).

List of high-level changes:
- I moved the implementation of `MD5Hash` and related logic from `include/cudf/detail/utilities/hash_functions.cuh` to `src/hash/md5_hash.cu` because it is only used in that file and nowhere else. We don't need to include and build MD5 in `hash_functions.cuh` for all the collections/sorting/groupby tools that only use Murmur3 variants and `IdentityHash`. (This will be a bigger deal once we add the SHA hash functions, soon to follow this PR, because the size of `hash_functions.cuh` would be substantially larger without this separation.)
- I removed an `MD5Hash` constructor that accepted and stored a seed whose value was unused.
- Improved use of namespaces.
- Use named constants instead of magic numbers.
- Introduced a `hash_circular_buffer` and refactored dispatch logic.

No changes were made to the feature scope or public APIs of the MD5 feature, so existing unit tests and bindings should remain the same.

Authors:
  - Bradley Dice (https://github.com/bdice)

Approvers:
  - David Wendt (https://github.com/davidwendt)
  - Mark Harris (https://github.com/harrism)
  - Jake Hemstad (https://github.com/jrhemstad)
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #9212
  • Loading branch information
bdice authored Oct 21, 2021
1 parent 5c76bc2 commit 72694d2
Show file tree
Hide file tree
Showing 3 changed files with 294 additions and 404 deletions.
312 changes: 9 additions & 303 deletions cpp/include/cudf/detail/utilities/hash_functions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,110 +21,27 @@
#include <cudf/fixed_point/fixed_point.hpp>
#include <cudf/strings/string_view.cuh>
#include <cudf/types.hpp>
#include <hash/hash_constants.hpp>

using hash_value_type = uint32_t;

namespace cudf {
namespace detail {
namespace {
/**
* @brief Core MD5 algorithm implementation. Processes a single 512-bit chunk,
* updating the hash value so far. Does not zero out the buffer contents.
*/
void CUDA_DEVICE_CALLABLE md5_hash_step(md5_intermediate_data* hash_state)
{
uint32_t A = hash_state->hash_value[0];
uint32_t B = hash_state->hash_value[1];
uint32_t C = hash_state->hash_value[2];
uint32_t D = hash_state->hash_value[3];

for (unsigned int j = 0; j < 64; j++) {
uint32_t F;
uint32_t g;
switch (j / 16) {
case 0:
F = (B & C) | ((~B) & D);
g = j;
break;
case 1:
F = (D & B) | ((~D) & C);
g = (5 * j + 1) % 16;
break;
case 2:
F = B ^ C ^ D;
g = (3 * j + 5) % 16;
break;
case 3:
F = C ^ (B | (~D));
g = (7 * j) % 16;
break;
}

uint32_t buffer_element_as_int;
std::memcpy(&buffer_element_as_int, hash_state->buffer + g * 4, 4);
F = F + A + md5_hash_constants[j] + buffer_element_as_int;
A = D;
D = C;
C = B;
B = B + __funnelshift_l(F, F, md5_shift_constants[((j / 16) * 4) + (j % 4)]);
}

hash_state->hash_value[0] += A;
hash_state->hash_value[1] += B;
hash_state->hash_value[2] += C;
hash_state->hash_value[3] += D;

hash_state->buffer_length = 0;
}

/**
* @brief Core MD5 element processing function
* Normalization of floating point NaNs and zeros, passthrough for all other values.
*/
template <typename TKey>
void CUDA_DEVICE_CALLABLE md5_process(TKey const& key, md5_intermediate_data* hash_state)
template <typename T>
T CUDA_DEVICE_CALLABLE normalize_nans_and_zeros(T const& key)
{
uint32_t const len = sizeof(TKey);
uint8_t const* data = reinterpret_cast<uint8_t const*>(&key);
hash_state->message_length += len;

// 64 bytes for the number of byt es processed in a given step
constexpr int md5_chunk_size = 64;
if (hash_state->buffer_length + len < md5_chunk_size) {
std::memcpy(hash_state->buffer + hash_state->buffer_length, data, len);
hash_state->buffer_length += len;
} else {
uint32_t copylen = md5_chunk_size - hash_state->buffer_length;

std::memcpy(hash_state->buffer + hash_state->buffer_length, data, copylen);
md5_hash_step(hash_state);

while (len > md5_chunk_size + copylen) {
std::memcpy(hash_state->buffer, data + copylen, md5_chunk_size);
md5_hash_step(hash_state);
copylen += md5_chunk_size;
if constexpr (is_floating_point<T>()) {
if (isnan(key)) {
return std::numeric_limits<T>::quiet_NaN();
} else if (key == T{0.0}) {
return T{0.0};
}

std::memcpy(hash_state->buffer, data + copylen, len - copylen);
hash_state->buffer_length = len - copylen;
}
}

/**
* Normalization of floating point NANs and zeros helper
*/
template <typename T, std::enable_if_t<std::is_floating_point<T>::value>* = nullptr>
T CUDA_DEVICE_CALLABLE normalize_nans_and_zeros_helper(T key)
{
if (isnan(key)) {
return std::numeric_limits<T>::quiet_NaN();
} else if (key == T{0.0}) {
return T{0.0};
} else {
return key;
}
return key;
}
} // namespace

/**
* Modified GPU implementation of
Expand All @@ -149,217 +66,6 @@ void CUDA_DEVICE_CALLABLE uint32ToLowercaseHexString(uint32_t num, char* destina
std::memcpy(destination, reinterpret_cast<uint8_t*>(&x), 8);
}

struct MD5ListHasher {
template <typename T, std::enable_if_t<is_chrono<T>()>* = nullptr>
void __device__ operator()(column_device_view data_col,
size_type offset_begin,
size_type offset_end,
md5_intermediate_data* hash_state) const
{
cudf_assert(false && "MD5 Unsupported chrono type column");
}

template <typename T, std::enable_if_t<!is_fixed_width<T>()>* = nullptr>
void __device__ operator()(column_device_view data_col,
size_type offset_begin,
size_type offset_end,
md5_intermediate_data* hash_state) const
{
cudf_assert(false && "MD5 Unsupported non-fixed-width type column");
}

template <typename T, std::enable_if_t<is_floating_point<T>()>* = nullptr>
void __device__ operator()(column_device_view data_col,
size_type offset_begin,
size_type offset_end,
md5_intermediate_data* hash_state) const
{
for (int i = offset_begin; i < offset_end; i++) {
if (!data_col.is_null(i)) {
md5_process(normalize_nans_and_zeros_helper<T>(data_col.element<T>(i)), hash_state);
}
}
}

template <
typename T,
std::enable_if_t<is_fixed_width<T>() && !is_floating_point<T>() && !is_chrono<T>()>* = nullptr>
void CUDA_DEVICE_CALLABLE operator()(column_device_view data_col,
size_type offset_begin,
size_type offset_end,
md5_intermediate_data* hash_state) const
{
for (int i = offset_begin; i < offset_end; i++) {
if (!data_col.is_null(i)) md5_process(data_col.element<T>(i), hash_state);
}
}
};

template <>
void CUDA_DEVICE_CALLABLE
MD5ListHasher::operator()<string_view>(column_device_view data_col,
size_type offset_begin,
size_type offset_end,
md5_intermediate_data* hash_state) const
{
for (int i = offset_begin; i < offset_end; i++) {
if (!data_col.is_null(i)) {
string_view key = data_col.element<string_view>(i);
uint32_t const len = static_cast<uint32_t>(key.size_bytes());
uint8_t const* data = reinterpret_cast<uint8_t const*>(key.data());

hash_state->message_length += len;

if (hash_state->buffer_length + len < 64) {
std::memcpy(hash_state->buffer + hash_state->buffer_length, data, len);
hash_state->buffer_length += len;
} else {
uint32_t copylen = 64 - hash_state->buffer_length;
std::memcpy(hash_state->buffer + hash_state->buffer_length, data, copylen);
md5_hash_step(hash_state);

while (len > 64 + copylen) {
std::memcpy(hash_state->buffer, data + copylen, 64);
md5_hash_step(hash_state);
copylen += 64;
}

std::memcpy(hash_state->buffer, data + copylen, len - copylen);
hash_state->buffer_length = len - copylen;
}
}
}
}

struct MD5Hash {
MD5Hash() = default;
constexpr MD5Hash(uint32_t seed) : m_seed(seed) {}

void __device__ finalize(md5_intermediate_data* hash_state, char* result_location) const
{
auto const full_length = (static_cast<uint64_t>(hash_state->message_length)) << 3;
thrust::fill_n(thrust::seq, hash_state->buffer + hash_state->buffer_length, 1, 0x80);

// 64 bytes for the number of bytes processed in a given step
constexpr int md5_chunk_size = 64;
// 8 bytes for the total message length, appended to the end of the last chunk processed
constexpr int message_length_size = 8;
// 1 byte for the end of the message flag
constexpr int end_of_message_size = 1;
if (hash_state->buffer_length + message_length_size + end_of_message_size <= md5_chunk_size) {
thrust::fill_n(
thrust::seq,
hash_state->buffer + hash_state->buffer_length + 1,
(md5_chunk_size - message_length_size - end_of_message_size - hash_state->buffer_length),
0x00);
} else {
thrust::fill_n(thrust::seq,
hash_state->buffer + hash_state->buffer_length + 1,
(md5_chunk_size - hash_state->buffer_length),
0x00);
md5_hash_step(hash_state);

thrust::fill_n(thrust::seq, hash_state->buffer, md5_chunk_size - message_length_size, 0x00);
}

std::memcpy(hash_state->buffer + md5_chunk_size - message_length_size,
reinterpret_cast<uint8_t const*>(&full_length),
message_length_size);
md5_hash_step(hash_state);

#pragma unroll
for (int i = 0; i < 4; ++i)
uint32ToLowercaseHexString(hash_state->hash_value[i], result_location + (8 * i));
}

template <typename T, std::enable_if_t<is_chrono<T>()>* = nullptr>
void __device__ operator()(column_device_view col,
size_type row_index,
md5_intermediate_data* hash_state) const
{
cudf_assert(false && "MD5 Unsupported chrono type column");
}

template <typename T, std::enable_if_t<!is_fixed_width<T>()>* = nullptr>
void __device__ operator()(column_device_view col,
size_type row_index,
md5_intermediate_data* hash_state) const
{
cudf_assert(false && "MD5 Unsupported non-fixed-width type column");
}

template <typename T, std::enable_if_t<is_floating_point<T>()>* = nullptr>
void __device__ operator()(column_device_view col,
size_type row_index,
md5_intermediate_data* hash_state) const
{
md5_process(normalize_nans_and_zeros_helper<T>(col.element<T>(row_index)), hash_state);
}

template <
typename T,
std::enable_if_t<is_fixed_width<T>() && !is_floating_point<T>() && !is_chrono<T>()>* = nullptr>
void CUDA_DEVICE_CALLABLE operator()(column_device_view col,
size_type row_index,
md5_intermediate_data* hash_state) const
{
md5_process(col.element<T>(row_index), hash_state);
}

private:
uint32_t m_seed{cudf::DEFAULT_HASH_SEED};
};

template <>
void CUDA_DEVICE_CALLABLE MD5Hash::operator()<string_view>(column_device_view col,
size_type row_index,
md5_intermediate_data* hash_state) const
{
string_view key = col.element<string_view>(row_index);
uint32_t const len = static_cast<uint32_t>(key.size_bytes());
uint8_t const* data = reinterpret_cast<uint8_t const*>(key.data());

hash_state->message_length += len;

if (hash_state->buffer_length + len < 64) {
std::memcpy(hash_state->buffer + hash_state->buffer_length, data, len);
hash_state->buffer_length += len;
} else {
uint32_t copylen = 64 - hash_state->buffer_length;
std::memcpy(hash_state->buffer + hash_state->buffer_length, data, copylen);
md5_hash_step(hash_state);

while (len > 64 + copylen) {
std::memcpy(hash_state->buffer, data + copylen, 64);
md5_hash_step(hash_state);
copylen += 64;
}

std::memcpy(hash_state->buffer, data + copylen, len - copylen);
hash_state->buffer_length = len - copylen;
}
}

template <>
void CUDA_DEVICE_CALLABLE MD5Hash::operator()<list_view>(column_device_view col,
size_type row_index,
md5_intermediate_data* hash_state) const
{
static constexpr size_type offsets_column_index{0};
static constexpr size_type data_column_index{1};

column_device_view offsets = col.child(offsets_column_index);
column_device_view data = col.child(data_column_index);

if (data.type().id() == type_id::LIST) cudf_assert(false && "Nested list unsupported");

cudf::type_dispatcher(data.type(),
MD5ListHasher{},
data,
offsets.element<size_type>(row_index),
offsets.element<size_type>(row_index + 1),
hash_state);
}
} // namespace detail
} // namespace cudf

Expand Down
Loading

0 comments on commit 72694d2

Please sign in to comment.