Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spark Decimal128 hashing #9919

Merged
merged 22 commits into from
Jan 20, 2022
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 60 additions & 10 deletions cpp/include/cudf/detail/utilities/hash_functions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include <cudf/strings/string_view.cuh>
#include <cudf/types.hpp>

#include <thrust/iterator/reverse_iterator.h>

rwlee marked this conversation as resolved.
Show resolved Hide resolved
using hash_value_type = uint32_t;

namespace cudf {
Expand Down Expand Up @@ -337,18 +339,21 @@ struct SparkMurmurHash3_32 {
template <typename TKey>
result_type __device__ inline compute(TKey const& key) const
{
constexpr int len = sizeof(TKey);
int8_t const* const data = reinterpret_cast<int8_t const*>(&key);
constexpr int nblocks = len / 4;
return compute_bytes(reinterpret_cast<std::byte const*>(&key), sizeof(TKey));
}

result_type __device__ compute_bytes(std::byte const* const data, cudf::size_type const len) const
{
int32_t const nblocks = len / 4;
rwlee marked this conversation as resolved.
Show resolved Hide resolved
uint32_t h1 = m_seed;
constexpr uint32_t c1 = 0xcc9e2d51;
constexpr uint32_t c2 = 0x1b873593;
rwlee marked this conversation as resolved.
Show resolved Hide resolved

//----------
// body
uint32_t const* const blocks = reinterpret_cast<uint32_t const*>(data + nblocks * 4);
for (int i = -nblocks; i; i++) {
uint32_t k1 = blocks[i];
// Process all four-byte chunks
uint32_t const* const blocks = reinterpret_cast<uint32_t const*>(data);
for (int i = 0; i < nblocks; i++) {
rwlee marked this conversation as resolved.
Show resolved Hide resolved
int32_t k1 = blocks[i];
rwlee marked this conversation as resolved.
Show resolved Hide resolved
k1 *= c1;
k1 = rotl32(k1, 15);
k1 *= c2;
Expand All @@ -357,9 +362,10 @@ struct SparkMurmurHash3_32 {
h1 = h1 * 5 + 0xe6546b64;
}
//----------
// byte by byte tail processing
// Process remaining bytes that do not fill a four-byte chunk using Spark's approach
// (does not conform to normal MurmurHash3)
for (int i = nblocks * 4; i < len; i++) {
rwlee marked this conversation as resolved.
Show resolved Hide resolved
int32_t k1 = data[i];
uint32_t k1 = std::to_integer<int8_t>(data[i]);
rwlee marked this conversation as resolved.
Show resolved Hide resolved
k1 *= c1;
k1 = rotl32(k1, 15);
k1 *= c2;
Expand Down Expand Up @@ -427,7 +433,51 @@ template <>
hash_value_type __device__ inline SparkMurmurHash3_32<numeric::decimal128>::operator()(
numeric::decimal128 const& key) const
{
return this->compute<__int128_t>(key.value());
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
// Generates the Spark MurmurHash3 hash value, mimicking the conversion:
// java.math.BigDecimal.valueOf(unscaled_value, _scale).unscaledValue().toByteArray()
// https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala#L381
__int128_t const val = key.value();
cudf::size_type key_size = sizeof(__int128_t);
rwlee marked this conversation as resolved.
Show resolved Hide resolved
std::byte const* data = reinterpret_cast<std::byte const*>(&val);

// Extract the first bit of the key, which holds the sign.
std::byte const sign_bit = data[key_size - 1] & static_cast<std::byte>(0x80);
rwlee marked this conversation as resolved.
Show resolved Hide resolved

// Small negative values start with 0xff..., small positive values start with 0x00...
std::byte const zero_value = std::to_integer<uint8_t>(sign_bit) ? static_cast<std::byte>(0xff)
: static_cast<std::byte>(0x00);
rwlee marked this conversation as resolved.
Show resolved Hide resolved
// Special cases for 0 and -1 which do not shorten correctly.
if (val == 0) { return this->compute(static_cast<uint8_t>(0)); }
if (val == static_cast<__int128>(-1)) { return this->compute(static_cast<uint8_t>(0xFF)); }
rwlee marked this conversation as resolved.
Show resolved Hide resolved

// Searching from the big-byte end, find the first non-zero byte in the unscaled little endian
// value. The preceding zero bytes that are bigger-bytes in the endian value are not hashed.
rwlee marked this conversation as resolved.
Show resolved Hide resolved
auto const reverse_begin = thrust::reverse_iterator(data + key_size);
auto const reverse_end = thrust::reverse_iterator(data);
auto const first_nonzero_byte =
thrust::find_if_not(thrust::device,
rwlee marked this conversation as resolved.
Show resolved Hide resolved
reverse_begin,
reverse_end,
[zero_value](std::byte const& v) { return v == zero_value; })
.base();
cudf::size_type length = thrust::distance(data, first_nonzero_byte);
rwlee marked this conversation as resolved.
Show resolved Hide resolved

// Preserve the 2's complement sign bit by adding a byte back on if necessary.
// e.g. 0x0000FF would shorten to 0x00FF. The 0x00 byte is retained to
// preserve the sign bit, rather than leaving an "F" at the front which would
// change the sign bit. However, 0x00007F would shorten to 0x7F. No extra byte
// is needed because the leftmost bit matches the sign bit. Similarly for
// negative values: 0xFFFF00 --> 0xFF00 and 0xFFFF80 --> 0x80.
if ((length < key_size) &&
std::to_integer<uint8_t>(sign_bit ^ (data[length - 1] & static_cast<std::byte>(0x80)))) {
rwlee marked this conversation as resolved.
Show resolved Hide resolved
++length;
}

// Convert to big endian by reversing the range of nonzero bytes. Only those bytes are hashed.
__int128_t big_endian_value = 0;
auto big_endian_data = reinterpret_cast<std::byte*>(&big_endian_value);
thrust::reverse_copy(thrust::device, data, data + length, big_endian_data);
rwlee marked this conversation as resolved.
Show resolved Hide resolved
return this->compute_bytes(big_endian_data, length);
}

template <>
Expand Down
32 changes: 24 additions & 8 deletions cpp/tests/hashing/hash_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,32 +298,36 @@ TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds)
// The hash values were determined by running the following Scala code in Apache Spark:
// import org.apache.spark.sql.catalyst.util.DateTimeUtils
// val schema = new StructType().add("structs", new StructType().add("a",IntegerType)
// .add("b",StringType).add("c",new StructType().add("x",FloatType).add("y",LongType)))
// .add("b",StringType).add("c",new StructType().add("x",FloatType).add("y",LongType)))
// .add("strings",StringType).add("doubles",DoubleType).add("timestamps",TimestampType)
// .add("decimal64", DecimalType(18,7)).add("longs",LongType).add("floats",FloatType)
// .add("dates",DateType).add("decimal32", DecimalType(9,3)).add("ints",IntegerType)
// .add("shorts",ShortType).add("bytes",ByteType).add("bools",BooleanType)
// .add("decimal128", DecimalType(38,11))
// val data = Seq(
// Row(Row(0, "a", Row(0f, 0L)), "", 0.toDouble, DateTimeUtils.toJavaTimestamp(0), BigDecimal(0),
// 0.toLong, 0.toFloat, DateTimeUtils.toJavaDate(0), BigDecimal(0), 0, 0.toShort, 0.toByte,
// false),
// false, BigDecimal(0)),
// Row(Row(100, "bc", Row(100f, 100L)), "The quick brown fox", -(0.toDouble),
// DateTimeUtils.toJavaTimestamp(100), BigDecimal("0.00001"), 100.toLong, -(0.toFloat),
// DateTimeUtils.toJavaDate(100), BigDecimal("0.1"), 100, 100.toShort, 100.toByte, true),
// DateTimeUtils.toJavaDate(100), BigDecimal("0.1"), 100, 100.toShort, 100.toByte, true,
// BigDecimal("0.000000001")),
// Row(Row(-100, "def", Row(-100f, -100L)), "jumps over the lazy dog.", -Double.NaN,
// DateTimeUtils.toJavaTimestamp(-100), BigDecimal("-0.00001"), -100.toLong, -Float.NaN,
// DateTimeUtils.toJavaDate(-100), BigDecimal("-0.1"), -100, -100.toShort, -100.toByte,
// true),
// true, BigDecimal("-0.00000000001")),
// Row(Row(0x12345678, "ghij", Row(Float.PositiveInfinity, 0x123456789abcdefL)),
// "All work and no play makes Jack a dull boy", Double.MinValue,
// DateTimeUtils.toJavaTimestamp(Long.MinValue/1000000), BigDecimal("-99999999999.9999999"),
// Long.MinValue, Float.MinValue, DateTimeUtils.toJavaDate(Int.MinValue/100),
// BigDecimal("-999999.999"), Int.MinValue, Short.MinValue, Byte.MinValue, true),
// BigDecimal("-999999.999"), Int.MinValue, Short.MinValue, Byte.MinValue, true,
// BigDecimal("-9999999999999999.99999999999")),
// Row(Row(-0x76543210, "klmno", Row(Float.NegativeInfinity, -0x123456789abcdefL)),
// "!\"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~\ud720\ud721", Double.MaxValue,
// DateTimeUtils.toJavaTimestamp(Long.MaxValue/1000000), BigDecimal("99999999999.9999999"),
// Long.MaxValue, Float.MaxValue, DateTimeUtils.toJavaDate(Int.MaxValue/100),
// BigDecimal("999999.999"), Int.MaxValue, Short.MaxValue, Byte.MaxValue, false))
// BigDecimal("999999.999"), Int.MaxValue, Short.MaxValue, Byte.MaxValue, false,
// BigDecimal("99999999999999999999999999.99999999999")))
// val df = spark.createDataFrame(sc.parallelize(data), schema)
// df.columns.foreach(c => println(s"$c => ${df.select(hash(col(c))).collect.mkString(",")}"))
// df.select(hash(col("*"))).collect
Expand Down Expand Up @@ -353,8 +357,10 @@ TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds)
{933211791, 751823303, -1080202046, 1110053733, 1135925485});
fixed_width_column_wrapper<int32_t> const hash_bools_expected(
{933211791, -559580957, -559580957, -559580957, 933211791});
fixed_width_column_wrapper<int32_t> const hash_decimal128_expected(
{-783713497, -295670906, 1398487324, -52622807, -1359749815});
fixed_width_column_wrapper<int32_t> const hash_combined_expected(
{-1172364561, -442972638, 1213234395, 796626751, 214075225});
{401603227, 588162166, 552160517, 1132537411, -326043017});

using double_limits = std::numeric_limits<double>;
using long_limits = std::numeric_limits<int64_t>;
Expand Down Expand Up @@ -394,6 +400,13 @@ TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds)
fixed_width_column_wrapper<int8_t> const bytes_col({0, 100, -100, -128, 127});
fixed_width_column_wrapper<bool> const bools_col1({0, 1, 1, 1, 0});
fixed_width_column_wrapper<bool> const bools_col2({0, 1, 2, 255, 0});
fixed_point_column_wrapper<__int128_t> const decimal128_col(
rwlee marked this conversation as resolved.
Show resolved Hide resolved
{static_cast<__int128>(0),
static_cast<__int128>(100),
static_cast<__int128>(-1),
(static_cast<__int128>(0xFFFFFFFFFCC4D1C3u) << 64 | 0x602F7FC318000001u),
(static_cast<__int128>(0x0785EE10D5DA46D9u) << 64 | 0x00F4369FFFFFFFFFu)},
numeric::scale_type{-11});

constexpr auto hasher = cudf::hash_id::HASH_SPARK_MURMUR3;
auto const hash_structs = cudf::hash(cudf::table_view({structs_col}), hasher, 42);
Expand All @@ -410,6 +423,7 @@ TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds)
auto const hash_bytes = cudf::hash(cudf::table_view({bytes_col}), hasher, 42);
auto const hash_bools1 = cudf::hash(cudf::table_view({bools_col1}), hasher, 42);
auto const hash_bools2 = cudf::hash(cudf::table_view({bools_col2}), hasher, 42);
auto const hash_decimal128 = cudf::hash(cudf::table_view({decimal128_col}), hasher, 42);

CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_structs, hash_structs_expected, verbosity);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_strings, hash_strings_expected, verbosity);
Expand All @@ -425,6 +439,7 @@ TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds)
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_bytes, hash_bytes_expected, verbosity);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_bools1, hash_bools_expected, verbosity);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_bools2, hash_bools_expected, verbosity);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_decimal128, hash_decimal128_expected, verbosity);

auto const combined_table = cudf::table_view({structs_col,
strings_col,
Expand All @@ -438,7 +453,8 @@ TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds)
ints_col,
shorts_col,
bytes_col,
bools_col2});
bools_col2,
decimal128_col});
auto const hash_combined = cudf::hash(combined_table, hasher, 42);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_combined, hash_combined_expected, verbosity);
}
Expand Down