diff --git a/cpp/include/cudf/detail/utilities/hash_functions.cuh b/cpp/include/cudf/detail/utilities/hash_functions.cuh index c35d24ddeac..8a7f4276d05 100644 --- a/cpp/include/cudf/detail/utilities/hash_functions.cuh +++ b/cpp/include/cudf/detail/utilities/hash_functions.cuh @@ -16,12 +16,16 @@ #pragma once +#include + #include #include #include #include #include +#include + using hash_value_type = uint32_t; namespace cudf { @@ -337,17 +341,21 @@ struct SparkMurmurHash3_32 { template result_type __device__ inline compute(TKey const& key) const { - constexpr int len = sizeof(TKey); - int8_t const* const data = reinterpret_cast(&key); - constexpr int nblocks = len / 4; + return compute_bytes(reinterpret_cast(&key), sizeof(TKey)); + } + + result_type __device__ compute_bytes(std::byte const* const data, cudf::size_type const len) const + { + constexpr cudf::size_type block_size = sizeof(uint32_t) / sizeof(std::byte); + cudf::size_type const nblocks = len / block_size; + uint32_t h1 = m_seed; + constexpr uint32_t c1 = 0xcc9e2d51; + constexpr uint32_t c2 = 0x1b873593; - uint32_t h1 = m_seed; - constexpr uint32_t c1 = 0xcc9e2d51; - constexpr uint32_t c2 = 0x1b873593; //---------- - // body - uint32_t const* const blocks = reinterpret_cast(data + nblocks * 4); - for (int i = -nblocks; i; i++) { + // Process all four-byte chunks + uint32_t const* const blocks = reinterpret_cast(data); + for (cudf::size_type i = 0; i < nblocks; i++) { uint32_t k1 = blocks[i]; k1 *= c1; k1 = rotl32(k1, 15); @@ -357,9 +365,14 @@ struct SparkMurmurHash3_32 { h1 = h1 * 5 + 0xe6546b64; } //---------- - // byte by byte tail processing - for (int i = nblocks * 4; i < len; i++) { - int32_t k1 = data[i]; + // Process remaining bytes that do not fill a four-byte chunk using Spark's approach + // (does not conform to normal MurmurHash3) + for (cudf::size_type i = nblocks * 4; i < len; i++) { + // We require a two-step cast to get the k1 value from the byte. First, + // we must cast to a signed int8_t. Then, the sign bit is preserved when + // casting to uint32_t under 2's complement. Java preserves the + // signedness when casting byte-to-int, but C++ does not. + uint32_t k1 = static_cast(std::to_integer(data[i])); k1 *= c1; k1 = rotl32(k1, 15); k1 *= c2; @@ -427,7 +440,42 @@ template <> hash_value_type __device__ inline SparkMurmurHash3_32::operator()( numeric::decimal128 const& key) const { - return this->compute<__int128_t>(key.value()); + // Generates the Spark MurmurHash3 hash value, mimicking the conversion: + // java.math.BigDecimal.valueOf(unscaled_value, _scale).unscaledValue().toByteArray() + // https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala#L381 + __int128_t const val = key.value(); + constexpr cudf::size_type key_size = sizeof(__int128_t); + std::byte const* data = reinterpret_cast(&val); + + // Small negative values start with 0xff..., small positive values start with 0x00... + bool const is_negative = val < 0; + std::byte const zero_value = is_negative ? std::byte{0xff} : std::byte{0x00}; + + // If the value can be represented with a shorter than 16-byte integer, the + // leading bytes of the little-endian value are truncated and are not hashed. + auto const reverse_begin = thrust::reverse_iterator(data + key_size); + auto const reverse_end = thrust::reverse_iterator(data); + auto const first_nonzero_byte = + thrust::find_if_not(thrust::seq, reverse_begin, reverse_end, [zero_value](std::byte const& v) { + return v == zero_value; + }).base(); + // Max handles special case of 0 and -1 which would shorten to 0 length otherwise + cudf::size_type length = + std::max(1, static_cast(thrust::distance(data, first_nonzero_byte))); + + // Preserve the 2's complement sign bit by adding a byte back on if necessary. + // e.g. 0x0000ff would shorten to 0x00ff. The 0x00 byte is retained to + // preserve the sign bit, rather than leaving an "f" at the front which would + // change the sign bit. However, 0x00007f would shorten to 0x7f. No extra byte + // is needed because the leftmost bit matches the sign bit. Similarly for + // negative values: 0xffff00 --> 0xff00 and 0xffff80 --> 0x80. + if ((length < key_size) && (is_negative ^ bool(data[length - 1] & std::byte{0x80}))) { ++length; } + + // Convert to big endian by reversing the range of nonzero bytes. Only those bytes are hashed. + __int128_t big_endian_value = 0; + auto big_endian_data = reinterpret_cast(&big_endian_value); + thrust::reverse_copy(thrust::seq, data, data + length, big_endian_data); + return this->compute_bytes(big_endian_data, length); } template <> @@ -480,7 +528,7 @@ hash_value_type __device__ inline SparkMurmurHash3_32::operat //---------- // Spark's byte by byte tail processing for (int i = nblocks * 4; i < len; i++) { - int32_t k1 = data[i]; + uint32_t k1 = data[i]; k1 *= c1; k1 = rotl32(k1, 15); k1 *= c2; diff --git a/cpp/tests/hashing/hash_test.cpp b/cpp/tests/hashing/hash_test.cpp index bd6deae9dc4..1a73fb3abc9 100644 --- a/cpp/tests/hashing/hash_test.cpp +++ b/cpp/tests/hashing/hash_test.cpp @@ -298,32 +298,36 @@ TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds) // The hash values were determined by running the following Scala code in Apache Spark: // import org.apache.spark.sql.catalyst.util.DateTimeUtils // val schema = new StructType().add("structs", new StructType().add("a",IntegerType) - // .add("b",StringType).add("c",new StructType().add("x",FloatType).add("y",LongType))) + // .add("b",StringType).add("c",new StructType().add("x",FloatType).add("y",LongType))) // .add("strings",StringType).add("doubles",DoubleType).add("timestamps",TimestampType) // .add("decimal64", DecimalType(18,7)).add("longs",LongType).add("floats",FloatType) // .add("dates",DateType).add("decimal32", DecimalType(9,3)).add("ints",IntegerType) // .add("shorts",ShortType).add("bytes",ByteType).add("bools",BooleanType) + // .add("decimal128", DecimalType(38,11)) // val data = Seq( // Row(Row(0, "a", Row(0f, 0L)), "", 0.toDouble, DateTimeUtils.toJavaTimestamp(0), BigDecimal(0), // 0.toLong, 0.toFloat, DateTimeUtils.toJavaDate(0), BigDecimal(0), 0, 0.toShort, 0.toByte, - // false), + // false, BigDecimal(0)), // Row(Row(100, "bc", Row(100f, 100L)), "The quick brown fox", -(0.toDouble), // DateTimeUtils.toJavaTimestamp(100), BigDecimal("0.00001"), 100.toLong, -(0.toFloat), - // DateTimeUtils.toJavaDate(100), BigDecimal("0.1"), 100, 100.toShort, 100.toByte, true), + // DateTimeUtils.toJavaDate(100), BigDecimal("0.1"), 100, 100.toShort, 100.toByte, true, + // BigDecimal("0.000000001")), // Row(Row(-100, "def", Row(-100f, -100L)), "jumps over the lazy dog.", -Double.NaN, // DateTimeUtils.toJavaTimestamp(-100), BigDecimal("-0.00001"), -100.toLong, -Float.NaN, // DateTimeUtils.toJavaDate(-100), BigDecimal("-0.1"), -100, -100.toShort, -100.toByte, - // true), + // true, BigDecimal("-0.00000000001")), // Row(Row(0x12345678, "ghij", Row(Float.PositiveInfinity, 0x123456789abcdefL)), // "All work and no play makes Jack a dull boy", Double.MinValue, // DateTimeUtils.toJavaTimestamp(Long.MinValue/1000000), BigDecimal("-99999999999.9999999"), // Long.MinValue, Float.MinValue, DateTimeUtils.toJavaDate(Int.MinValue/100), - // BigDecimal("-999999.999"), Int.MinValue, Short.MinValue, Byte.MinValue, true), + // BigDecimal("-999999.999"), Int.MinValue, Short.MinValue, Byte.MinValue, true, + // BigDecimal("-9999999999999999.99999999999")), // Row(Row(-0x76543210, "klmno", Row(Float.NegativeInfinity, -0x123456789abcdefL)), // "!\"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~\ud720\ud721", Double.MaxValue, // DateTimeUtils.toJavaTimestamp(Long.MaxValue/1000000), BigDecimal("99999999999.9999999"), // Long.MaxValue, Float.MaxValue, DateTimeUtils.toJavaDate(Int.MaxValue/100), - // BigDecimal("999999.999"), Int.MaxValue, Short.MaxValue, Byte.MaxValue, false)) + // BigDecimal("999999.999"), Int.MaxValue, Short.MaxValue, Byte.MaxValue, false, + // BigDecimal("99999999999999999999999999.99999999999"))) // val df = spark.createDataFrame(sc.parallelize(data), schema) // df.columns.foreach(c => println(s"$c => ${df.select(hash(col(c))).collect.mkString(",")}")) // df.select(hash(col("*"))).collect @@ -353,8 +357,10 @@ TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds) {933211791, 751823303, -1080202046, 1110053733, 1135925485}); fixed_width_column_wrapper const hash_bools_expected( {933211791, -559580957, -559580957, -559580957, 933211791}); + fixed_width_column_wrapper const hash_decimal128_expected( + {-783713497, -295670906, 1398487324, -52622807, -1359749815}); fixed_width_column_wrapper const hash_combined_expected( - {-1172364561, -442972638, 1213234395, 796626751, 214075225}); + {401603227, 588162166, 552160517, 1132537411, -326043017}); using double_limits = std::numeric_limits; using long_limits = std::numeric_limits; @@ -394,6 +400,13 @@ TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds) fixed_width_column_wrapper const bytes_col({0, 100, -100, -128, 127}); fixed_width_column_wrapper const bools_col1({0, 1, 1, 1, 0}); fixed_width_column_wrapper const bools_col2({0, 1, 2, 255, 0}); + fixed_point_column_wrapper<__int128_t> const decimal128_col( + {static_cast<__int128>(0), + static_cast<__int128>(100), + static_cast<__int128>(-1), + (static_cast<__int128>(0xFFFFFFFFFCC4D1C3u) << 64 | 0x602F7FC318000001u), + (static_cast<__int128>(0x0785EE10D5DA46D9u) << 64 | 0x00F4369FFFFFFFFFu)}, + numeric::scale_type{-11}); constexpr auto hasher = cudf::hash_id::HASH_SPARK_MURMUR3; auto const hash_structs = cudf::hash(cudf::table_view({structs_col}), hasher, 42); @@ -410,6 +423,7 @@ TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds) auto const hash_bytes = cudf::hash(cudf::table_view({bytes_col}), hasher, 42); auto const hash_bools1 = cudf::hash(cudf::table_view({bools_col1}), hasher, 42); auto const hash_bools2 = cudf::hash(cudf::table_view({bools_col2}), hasher, 42); + auto const hash_decimal128 = cudf::hash(cudf::table_view({decimal128_col}), hasher, 42); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_structs, hash_structs_expected, verbosity); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_strings, hash_strings_expected, verbosity); @@ -425,6 +439,7 @@ TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds) CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_bytes, hash_bytes_expected, verbosity); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_bools1, hash_bools_expected, verbosity); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_bools2, hash_bools_expected, verbosity); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_decimal128, hash_decimal128_expected, verbosity); auto const combined_table = cudf::table_view({structs_col, strings_col, @@ -438,7 +453,8 @@ TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds) ints_col, shorts_col, bytes_col, - bools_col2}); + bools_col2, + decimal128_col}); auto const hash_combined = cudf::hash(combined_table, hasher, 42); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_combined, hash_combined_expected, verbosity); }