Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spark Decimal128 hashing #9919

Merged
merged 22 commits into from
Jan 20, 2022
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 60 additions & 13 deletions cpp/include/cudf/detail/utilities/hash_functions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include <cudf/strings/string_view.cuh>
#include <cudf/types.hpp>

#include <thrust/iterator/reverse_iterator.h>

rwlee marked this conversation as resolved.
Show resolved Hide resolved
using hash_value_type = uint32_t;

namespace cudf {
Expand Down Expand Up @@ -337,17 +339,20 @@ struct SparkMurmurHash3_32 {
template <typename TKey>
result_type __device__ inline compute(TKey const& key) const
{
constexpr int len = sizeof(TKey);
int8_t const* const data = reinterpret_cast<int8_t const*>(&key);
constexpr int nblocks = len / 4;
return compute_bytes(reinterpret_cast<std::byte const*>(&key), sizeof(TKey));
}

result_type __device__ compute_bytes(std::byte const* const data, cudf::size_type const len) const
{
cudf::size_type const nblocks = len / 4;
rwlee marked this conversation as resolved.
Show resolved Hide resolved
uint32_t h1 = m_seed;
constexpr uint32_t c1 = 0xcc9e2d51;
constexpr uint32_t c2 = 0x1b873593;

uint32_t h1 = m_seed;
constexpr uint32_t c1 = 0xcc9e2d51;
constexpr uint32_t c2 = 0x1b873593;
//----------
// body
uint32_t const* const blocks = reinterpret_cast<uint32_t const*>(data + nblocks * 4);
for (int i = -nblocks; i; i++) {
// Process all four-byte chunks
uint32_t const* const blocks = reinterpret_cast<uint32_t const*>(data);
for (int i = 0; i < nblocks; i++) {
rwlee marked this conversation as resolved.
Show resolved Hide resolved
uint32_t k1 = blocks[i];
k1 *= c1;
k1 = rotl32(k1, 15);
Expand All @@ -357,9 +362,14 @@ struct SparkMurmurHash3_32 {
h1 = h1 * 5 + 0xe6546b64;
}
//----------
// byte by byte tail processing
// Process remaining bytes that do not fill a four-byte chunk using Spark's approach
// (does not conform to normal MurmurHash3)
for (int i = nblocks * 4; i < len; i++) {
rwlee marked this conversation as resolved.
Show resolved Hide resolved
int32_t k1 = data[i];
// We require a two-step cast to get the k1 value from the byte. First,
// we must cast to a signed int8_t. Then, the sign bit is preserved when
// casting to uint32_t under 2's complement. Java preserves the
// signedness when casting byte-to-int, but C++ does not.
uint32_t k1 = static_cast<uint32_t>(std::to_integer<int8_t>(data[i]));
k1 *= c1;
k1 = rotl32(k1, 15);
k1 *= c2;
Expand Down Expand Up @@ -427,7 +437,44 @@ template <>
hash_value_type __device__ inline SparkMurmurHash3_32<numeric::decimal128>::operator()(
numeric::decimal128 const& key) const
{
return this->compute<__int128_t>(key.value());
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
// Generates the Spark MurmurHash3 hash value, mimicking the conversion:
// java.math.BigDecimal.valueOf(unscaled_value, _scale).unscaledValue().toByteArray()
// https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala#L381
__int128_t const val = key.value();
constexpr cudf::size_type key_size = sizeof(__int128_t);
std::byte const* data = reinterpret_cast<std::byte const*>(&val);

// Small negative values start with 0xff..., small positive values start with 0x00...
bool const is_negative = val < 0;
std::byte const zero_value = is_negative ? std::byte{0xff} : std::byte{0x00};

// Special cases for 0 and -1 which do not shorten correctly.
if (val == 0) { return this->compute(static_cast<uint8_t>(0)); }
if (val == static_cast<__int128>(-1)) { return this->compute(static_cast<uint8_t>(0xff)); }

// If the value can be represented with a shorter than 16-byte integer, the
// leading bytes of the little-endian value are truncated and are not hashed.
auto const reverse_begin = thrust::reverse_iterator(data + key_size);
auto const reverse_end = thrust::reverse_iterator(data);
auto const first_nonzero_byte =
thrust::find_if_not(thrust::seq, reverse_begin, reverse_end, [zero_value](std::byte const& v) {
return v == zero_value;
}).base();
cudf::size_type length = thrust::distance(data, first_nonzero_byte);
rwlee marked this conversation as resolved.
Show resolved Hide resolved

// Preserve the 2's complement sign bit by adding a byte back on if necessary.
// e.g. 0x0000ff would shorten to 0x00ff. The 0x00 byte is retained to
// preserve the sign bit, rather than leaving an "f" at the front which would
// change the sign bit. However, 0x00007f would shorten to 0x7f. No extra byte
// is needed because the leftmost bit matches the sign bit. Similarly for
// negative values: 0xffff00 --> 0xff00 and 0xffff80 --> 0x80.
if ((length < key_size) && (is_negative ^ bool(data[length - 1] & std::byte{0x80}))) { ++length; }

// Convert to big endian by reversing the range of nonzero bytes. Only those bytes are hashed.
__int128_t big_endian_value = 0;
auto big_endian_data = reinterpret_cast<std::byte*>(&big_endian_value);
thrust::reverse_copy(thrust::seq, data, data + length, big_endian_data);
return this->compute_bytes(big_endian_data, length);
}

template <>
Expand Down Expand Up @@ -480,7 +527,7 @@ hash_value_type __device__ inline SparkMurmurHash3_32<cudf::string_view>::operat
//----------
// Spark's byte by byte tail processing
for (int i = nblocks * 4; i < len; i++) {
int32_t k1 = data[i];
uint32_t k1 = data[i];
k1 *= c1;
k1 = rotl32(k1, 15);
k1 *= c2;
Expand Down
32 changes: 24 additions & 8 deletions cpp/tests/hashing/hash_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,32 +298,36 @@ TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds)
// The hash values were determined by running the following Scala code in Apache Spark:
// import org.apache.spark.sql.catalyst.util.DateTimeUtils
// val schema = new StructType().add("structs", new StructType().add("a",IntegerType)
// .add("b",StringType).add("c",new StructType().add("x",FloatType).add("y",LongType)))
// .add("b",StringType).add("c",new StructType().add("x",FloatType).add("y",LongType)))
// .add("strings",StringType).add("doubles",DoubleType).add("timestamps",TimestampType)
// .add("decimal64", DecimalType(18,7)).add("longs",LongType).add("floats",FloatType)
// .add("dates",DateType).add("decimal32", DecimalType(9,3)).add("ints",IntegerType)
// .add("shorts",ShortType).add("bytes",ByteType).add("bools",BooleanType)
// .add("decimal128", DecimalType(38,11))
// val data = Seq(
// Row(Row(0, "a", Row(0f, 0L)), "", 0.toDouble, DateTimeUtils.toJavaTimestamp(0), BigDecimal(0),
// 0.toLong, 0.toFloat, DateTimeUtils.toJavaDate(0), BigDecimal(0), 0, 0.toShort, 0.toByte,
// false),
// false, BigDecimal(0)),
// Row(Row(100, "bc", Row(100f, 100L)), "The quick brown fox", -(0.toDouble),
// DateTimeUtils.toJavaTimestamp(100), BigDecimal("0.00001"), 100.toLong, -(0.toFloat),
// DateTimeUtils.toJavaDate(100), BigDecimal("0.1"), 100, 100.toShort, 100.toByte, true),
// DateTimeUtils.toJavaDate(100), BigDecimal("0.1"), 100, 100.toShort, 100.toByte, true,
// BigDecimal("0.000000001")),
// Row(Row(-100, "def", Row(-100f, -100L)), "jumps over the lazy dog.", -Double.NaN,
// DateTimeUtils.toJavaTimestamp(-100), BigDecimal("-0.00001"), -100.toLong, -Float.NaN,
// DateTimeUtils.toJavaDate(-100), BigDecimal("-0.1"), -100, -100.toShort, -100.toByte,
// true),
// true, BigDecimal("-0.00000000001")),
// Row(Row(0x12345678, "ghij", Row(Float.PositiveInfinity, 0x123456789abcdefL)),
// "All work and no play makes Jack a dull boy", Double.MinValue,
// DateTimeUtils.toJavaTimestamp(Long.MinValue/1000000), BigDecimal("-99999999999.9999999"),
// Long.MinValue, Float.MinValue, DateTimeUtils.toJavaDate(Int.MinValue/100),
// BigDecimal("-999999.999"), Int.MinValue, Short.MinValue, Byte.MinValue, true),
// BigDecimal("-999999.999"), Int.MinValue, Short.MinValue, Byte.MinValue, true,
// BigDecimal("-9999999999999999.99999999999")),
// Row(Row(-0x76543210, "klmno", Row(Float.NegativeInfinity, -0x123456789abcdefL)),
// "!\"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~\ud720\ud721", Double.MaxValue,
// DateTimeUtils.toJavaTimestamp(Long.MaxValue/1000000), BigDecimal("99999999999.9999999"),
// Long.MaxValue, Float.MaxValue, DateTimeUtils.toJavaDate(Int.MaxValue/100),
// BigDecimal("999999.999"), Int.MaxValue, Short.MaxValue, Byte.MaxValue, false))
// BigDecimal("999999.999"), Int.MaxValue, Short.MaxValue, Byte.MaxValue, false,
// BigDecimal("99999999999999999999999999.99999999999")))
// val df = spark.createDataFrame(sc.parallelize(data), schema)
// df.columns.foreach(c => println(s"$c => ${df.select(hash(col(c))).collect.mkString(",")}"))
// df.select(hash(col("*"))).collect
Expand Down Expand Up @@ -353,8 +357,10 @@ TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds)
{933211791, 751823303, -1080202046, 1110053733, 1135925485});
fixed_width_column_wrapper<int32_t> const hash_bools_expected(
{933211791, -559580957, -559580957, -559580957, 933211791});
fixed_width_column_wrapper<int32_t> const hash_decimal128_expected(
{-783713497, -295670906, 1398487324, -52622807, -1359749815});
fixed_width_column_wrapper<int32_t> const hash_combined_expected(
{-1172364561, -442972638, 1213234395, 796626751, 214075225});
{401603227, 588162166, 552160517, 1132537411, -326043017});

using double_limits = std::numeric_limits<double>;
using long_limits = std::numeric_limits<int64_t>;
Expand Down Expand Up @@ -394,6 +400,13 @@ TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds)
fixed_width_column_wrapper<int8_t> const bytes_col({0, 100, -100, -128, 127});
fixed_width_column_wrapper<bool> const bools_col1({0, 1, 1, 1, 0});
fixed_width_column_wrapper<bool> const bools_col2({0, 1, 2, 255, 0});
fixed_point_column_wrapper<__int128_t> const decimal128_col(
rwlee marked this conversation as resolved.
Show resolved Hide resolved
{static_cast<__int128>(0),
static_cast<__int128>(100),
static_cast<__int128>(-1),
(static_cast<__int128>(0xFFFFFFFFFCC4D1C3u) << 64 | 0x602F7FC318000001u),
(static_cast<__int128>(0x0785EE10D5DA46D9u) << 64 | 0x00F4369FFFFFFFFFu)},
numeric::scale_type{-11});

constexpr auto hasher = cudf::hash_id::HASH_SPARK_MURMUR3;
auto const hash_structs = cudf::hash(cudf::table_view({structs_col}), hasher, 42);
Expand All @@ -410,6 +423,7 @@ TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds)
auto const hash_bytes = cudf::hash(cudf::table_view({bytes_col}), hasher, 42);
auto const hash_bools1 = cudf::hash(cudf::table_view({bools_col1}), hasher, 42);
auto const hash_bools2 = cudf::hash(cudf::table_view({bools_col2}), hasher, 42);
auto const hash_decimal128 = cudf::hash(cudf::table_view({decimal128_col}), hasher, 42);

CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_structs, hash_structs_expected, verbosity);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_strings, hash_strings_expected, verbosity);
Expand All @@ -425,6 +439,7 @@ TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds)
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_bytes, hash_bytes_expected, verbosity);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_bools1, hash_bools_expected, verbosity);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_bools2, hash_bools_expected, verbosity);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_decimal128, hash_decimal128_expected, verbosity);

auto const combined_table = cudf::table_view({structs_col,
strings_col,
Expand All @@ -438,7 +453,8 @@ TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds)
ints_col,
shorts_col,
bytes_col,
bools_col2});
bools_col2,
decimal128_col});
auto const hash_combined = cudf::hash(combined_table, hasher, 42);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_combined, hash_combined_expected, verbosity);
}
Expand Down