Skip to content

Commit

Permalink
Fix SparkMurmurHash3_32 hash inconsistencies with Apache Spark (#7672)
Browse files Browse the repository at this point in the history
#7024 added a Spark variant of Murmur3 hashing, but it is inconsistent with Apache Spark's hash calculations in a few areas:
- `-0.0` and `0.0` are not treated the same by Apache Spark for floats and doubles
- byte and short integral values are upcast to a 32-bit unsigned int (i.e.: zero-filled) before calculating the hash

In addition libcudf allows hashing of timestamp columns but the JNI bindings asserted if timestamp columns were passed in, disabling the ability to hash on timestamps directly.

Authors:
  - Jason Lowe (@jlowe)

Approvers:
  - Nghia Truong (@ttnghia)
  - Jake Hemstad (@jrhemstad)
  - Alessandro Bellina (@abellina)
  - MithunR (@mythrocks)
  - Robert (Bobby) Evans (@revans2)

URL: #7672
  • Loading branch information
jlowe authored Mar 24, 2021
1 parent e73fff0 commit aa7ca46
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 38 deletions.
49 changes: 45 additions & 4 deletions cpp/include/cudf/detail/utilities/hash_functions.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2017-2020, NVIDIA CORPORATION.
* Copyright (c) 2017-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,6 +18,7 @@

#include <cudf/column/column_device_view.cuh>
#include <cudf/detail/utilities/assert.cuh>
#include <cudf/fixed_point/fixed_point.hpp>
#include <cudf/strings/string_view.cuh>
#include <hash/hash_constants.hpp>

Expand Down Expand Up @@ -570,9 +571,7 @@ struct SparkMurmurHash3_32 {
template <typename T, std::enable_if_t<std::is_floating_point<T>::value>* = nullptr>
hash_value_type CUDA_DEVICE_CALLABLE compute_floating_point(T const& key) const
{
if (key == T{0.0}) {
return compute(T{0.0});
} else if (isnan(key)) {
if (isnan(key)) {
T nan = std::numeric_limits<T>::quiet_NaN();
return compute(nan);
} else {
Expand Down Expand Up @@ -630,6 +629,48 @@ hash_value_type CUDA_DEVICE_CALLABLE SparkMurmurHash3_32<bool>::operator()(bool
return this->compute<uint32_t>(key);
}

template <>
hash_value_type CUDA_DEVICE_CALLABLE
SparkMurmurHash3_32<int8_t>::operator()(int8_t const& key) const
{
return this->compute<uint32_t>(key);
}

template <>
hash_value_type CUDA_DEVICE_CALLABLE
SparkMurmurHash3_32<uint8_t>::operator()(uint8_t const& key) const
{
return this->compute<uint32_t>(key);
}

template <>
hash_value_type CUDA_DEVICE_CALLABLE
SparkMurmurHash3_32<int16_t>::operator()(int16_t const& key) const
{
return this->compute<uint32_t>(key);
}

template <>
hash_value_type CUDA_DEVICE_CALLABLE
SparkMurmurHash3_32<uint16_t>::operator()(uint16_t const& key) const
{
return this->compute<uint32_t>(key);
}

template <>
hash_value_type CUDA_DEVICE_CALLABLE
SparkMurmurHash3_32<numeric::decimal32>::operator()(numeric::decimal32 const& key) const
{
return this->compute<uint64_t>(key.value());
}

template <>
hash_value_type CUDA_DEVICE_CALLABLE
SparkMurmurHash3_32<numeric::decimal64>::operator()(numeric::decimal64 const& key) const
{
return this->compute<uint64_t>(key.value());
}

/**
* @brief Specialization of MurmurHash3_32 operator for strings.
*/
Expand Down
174 changes: 141 additions & 33 deletions cpp/tests/hashing/hash_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include <cudf/detail/iterator.cuh>
#include <cudf/fixed_point/fixed_point.hpp>
#include <cudf/hashing.hpp>

#include <cudf_test/base_fixture.hpp>
Expand Down Expand Up @@ -201,27 +202,37 @@ TYPED_TEST(HashTestFloatTyped, TestExtremes)
T nan = std::numeric_limits<T>::quiet_NaN();
T inf = std::numeric_limits<T>::infinity();

fixed_width_column_wrapper<T> const col1({T(0.0), T(100.0), T(-100.0), min, max, nan, inf, -inf});
fixed_width_column_wrapper<T> const col2(
{T(-0.0), T(100.0), T(-100.0), min, max, -nan, inf, -inf});
fixed_width_column_wrapper<T> const col({T(0.0), T(100.0), T(-100.0), min, max, nan, inf, -inf});
fixed_width_column_wrapper<T> const col_neg_zero(
{T(-0.0), T(100.0), T(-100.0), min, max, nan, inf, -inf});
fixed_width_column_wrapper<T> const col_neg_nan(
{T(0.0), T(100.0), T(-100.0), min, max, -nan, inf, -inf});

auto const input1 = cudf::table_view({col1});
auto const input2 = cudf::table_view({col2});
auto const table_col = cudf::table_view({col});
auto const table_col_neg_zero = cudf::table_view({col_neg_zero});
auto const table_col_neg_nan = cudf::table_view({col_neg_nan});

auto const output1 = cudf::hash(input1);
auto const output2 = cudf::hash(input2);
auto const hash_col = cudf::hash(table_col);
auto const hash_col_neg_zero = cudf::hash(table_col_neg_zero);
auto const hash_col_neg_nan = cudf::hash(table_col_neg_nan);

CUDF_TEST_EXPECT_COLUMNS_EQUAL(output1->view(), output2->view(), true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_col, *hash_col_neg_zero, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_col, *hash_col_neg_nan, true);

auto const serial_output1 = cudf::hash(input1, cudf::hash_id::HASH_SERIAL_MURMUR3, {}, 0);
auto const serial_output2 = cudf::hash(input2, cudf::hash_id::HASH_SERIAL_MURMUR3);
constexpr auto serial_hasher = cudf::hash_id::HASH_SERIAL_MURMUR3;
auto const serial_col = cudf::hash(table_col, serial_hasher, {}, 0);
auto const serial_col_neg_zero = cudf::hash(table_col_neg_zero, serial_hasher);
auto const serial_col_neg_nan = cudf::hash(table_col_neg_nan, serial_hasher);

CUDF_TEST_EXPECT_COLUMNS_EQUAL(serial_output1->view(), serial_output2->view());
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*serial_col, *serial_col_neg_zero, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*serial_col, *serial_col_neg_nan, true);

auto const spark_output1 = cudf::hash(input1, cudf::hash_id::HASH_SPARK_MURMUR3, {}, 0);
auto const spark_output2 = cudf::hash(input2, cudf::hash_id::HASH_SPARK_MURMUR3);
// Spark hash is sensitive to 0 and -0
constexpr auto spark_hasher = cudf::hash_id::HASH_SPARK_MURMUR3;
auto const spark_col = cudf::hash(table_col, spark_hasher, {}, 0);
auto const spark_col_neg_nan = cudf::hash(table_col_neg_nan, spark_hasher);

CUDF_TEST_EXPECT_COLUMNS_EQUAL(spark_output1->view(), spark_output2->view());
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*spark_col, *spark_col_neg_nan);
}

class SerialMurmurHash3Test : public cudf::test::BaseFixture {
Expand Down Expand Up @@ -267,37 +278,134 @@ class SparkMurmurHash3Test : public cudf::test::BaseFixture {

TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds)
{
fixed_width_column_wrapper<int32_t> const strings_col_result(
// The hash values were determined by running the following Scala code in Apache Spark:
// import org.apache.spark.sql.catalyst.util.DateTimeUtils
// val schema = new StructType().add("strings",StringType).add("doubles",DoubleType)
// .add("timestamps",TimestampType).add("decimal64", DecimalType(18,7)).add("longs",LongType)
// .add("floats",FloatType).add("dates",DateType).add("decimal32", DecimalType(9,3))
// .add("ints",IntegerType).add("shorts",ShortType).add("bytes",ByteType)
// .add("bools",BooleanType)
// val data = Seq(
// Row("", 0.toDouble, DateTimeUtils.toJavaTimestamp(0), BigDecimal(0), 0.toLong, 0.toFloat,
// DateTimeUtils.toJavaDate(0), BigDecimal(0), 0, 0.toShort, 0.toByte, false),
// Row("The quick brown fox", -(0.toDouble), DateTimeUtils.toJavaTimestamp(100),
// BigDecimal("0.00001"), 100.toLong, -(0.toFloat), DateTimeUtils.toJavaDate(100),
// BigDecimal("0.1"), 100, 100.toShort, 100.toByte, true),
// Row("jumps over the lazy dog.", -Double.NaN, DateTimeUtils.toJavaTimestamp(-100),
// BigDecimal("-0.00001"), -100.toLong, -Float.NaN, DateTimeUtils.toJavaDate(-100),
// BigDecimal("-0.1"), -100, -100.toShort, -100.toByte, true),
// Row("All work and no play makes Jack a dull boy", Double.MinValue,
// DateTimeUtils.toJavaTimestamp(Long.MinValue/1000000), BigDecimal("-99999999999.9999999"),
// Long.MinValue, Float.MinValue, DateTimeUtils.toJavaDate(Int.MinValue/100),
// BigDecimal("-999999.999"), Int.MinValue, Short.MinValue, Byte.MinValue, true),
// Row("!\"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~\ud720\ud721", Double.MaxValue,
// DateTimeUtils.toJavaTimestamp(Long.MaxValue/1000000), BigDecimal("99999999999.9999999"),
// Long.MaxValue, Float.MaxValue, DateTimeUtils.toJavaDate(Int.MaxValue/100),
// BigDecimal("999999.999"), Int.MaxValue, Short.MaxValue, Byte.MaxValue, false))
// val df = spark.createDataFrame(sc.parallelize(data), schema)
// df.columns.foreach(c => println(s"$c => ${df.select(hash(col(c))).collect.mkString(",")}"))
// df.select(hash(col("*"))).collect
fixed_width_column_wrapper<int32_t> const hash_strings_expected(
{1467149710, 723257560, -1620282500, -2001858707, 1588473657});
fixed_width_column_wrapper<int32_t> const ints_col_result(
fixed_width_column_wrapper<int32_t> const hash_doubles_expected(
{-1670924195, -853646085, -1281358385, 1897734433, -508695674});
fixed_width_column_wrapper<int32_t> const hash_timestamps_expected(
{-1670924195, 1114849490, 904948192, -1832979433, 1752430209});
fixed_width_column_wrapper<int32_t> const hash_decimal64_expected(
{-1670924195, 1114849490, 904948192, 1962370902, -1795328666});
fixed_width_column_wrapper<int32_t> const hash_longs_expected(
{-1670924195, 1114849490, 904948192, -853646085, -1604625029});
fixed_width_column_wrapper<int32_t> const hash_floats_expected(
{933211791, 723455942, -349261430, -1225560532, -338752985});
fixed_width_column_wrapper<int32_t> const hash_dates_expected(
{933211791, 751823303, -1080202046, -1906567553, -1503850410});
fixed_width_column_wrapper<int32_t> const hash_decimal32_expected(
{-1670924195, 1114849490, 904948192, -1454351396, -193774131});
fixed_width_column_wrapper<int32_t> const hash_ints_expected(
{933211791, 751823303, -1080202046, 723455942, 133916647});
fixed_width_column_wrapper<int32_t> const hash_shorts_expected(
{933211791, 751823303, -1080202046, -1871935946, 1249274084});
fixed_width_column_wrapper<int32_t> const hash_bytes_expected(
{933211791, 751823303, -1080202046, 1110053733, 1135925485});
fixed_width_column_wrapper<int32_t> const hash_bools_expected(
{933211791, -559580957, -559580957, -559580957, 933211791});
fixed_width_column_wrapper<int32_t> const hash_combined_expected(
{-1947042614, -1731440908, 807283935, 725489209, 822276819});

strings_column_wrapper const strings_col({"",
"The quick brown fox",
"jumps over the lazy dog.",
"All work and no play makes Jack a dull boy",
"!\"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~\ud720\ud721"});

using limits = std::numeric_limits<int32_t>;
fixed_width_column_wrapper<int32_t> const ints_col({0, 100, -100, limits::min(), limits::max()});

using double_limits = std::numeric_limits<double>;
using long_limits = std::numeric_limits<int64_t>;
using float_limits = std::numeric_limits<float>;
using int_limits = std::numeric_limits<int32_t>;
fixed_width_column_wrapper<double> const doubles_col(
{0., -0., -double_limits::quiet_NaN(), double_limits::lowest(), double_limits::max()});
fixed_width_column_wrapper<cudf::timestamp_ms, cudf::timestamp_ms::rep> const timestamps_col(
{0L, 100L, -100L, long_limits::min() / 1000000, long_limits::max() / 1000000});
fixed_point_column_wrapper<int64_t> const decimal64_col(
{0L, 100L, -100L, -999999999999999999L, 999999999999999999L}, numeric::scale_type{-7});
fixed_width_column_wrapper<int64_t> const longs_col(
{0L, 100L, -100L, long_limits::min(), long_limits::max()});
fixed_width_column_wrapper<float> const floats_col(
{0.f, -0.f, -float_limits::quiet_NaN(), float_limits::lowest(), float_limits::max()});
fixed_width_column_wrapper<cudf::timestamp_D, cudf::timestamp_D::rep> dates_col(
{0, 100, -100, int_limits::min() / 100, int_limits::max() / 100});
fixed_point_column_wrapper<int32_t> const decimal32_col({0, 100, -100, -999999999, 999999999},
numeric::scale_type{-3});
fixed_width_column_wrapper<int32_t> const ints_col(
{0, 100, -100, int_limits::min(), int_limits::max()});
fixed_width_column_wrapper<int16_t> const shorts_col({0, 100, -100, -32768, 32767});
fixed_width_column_wrapper<int8_t> const bytes_col({0, 100, -100, -128, 127});
fixed_width_column_wrapper<bool> const bools_col1({0, 1, 1, 1, 0});
fixed_width_column_wrapper<bool> const bools_col2({0, 1, 2, 255, 0});

auto const input1 = cudf::table_view({strings_col});
auto const input2 = cudf::table_view({ints_col});
auto const input3 = cudf::table_view({strings_col, ints_col, bools_col1});
auto const input4 = cudf::table_view({strings_col, ints_col, bools_col2});

auto const hashed_output1 = cudf::hash(input1, cudf::hash_id::HASH_SPARK_MURMUR3, {}, 314);
auto const hashed_output2 = cudf::hash(input2, cudf::hash_id::HASH_SPARK_MURMUR3, {}, 42);
auto const hashed_output3 = cudf::hash(input3, cudf::hash_id::HASH_SPARK_MURMUR3, {});
auto const hashed_output4 = cudf::hash(input4, cudf::hash_id::HASH_SPARK_MURMUR3, {});

CUDF_TEST_EXPECT_COLUMNS_EQUAL(hashed_output1->view(), strings_col_result, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(hashed_output2->view(), ints_col_result, true);
EXPECT_EQ(input3.num_rows(), hashed_output3->size());
CUDF_TEST_EXPECT_COLUMNS_EQUAL(hashed_output3->view(), hashed_output4->view(), true);
constexpr auto hasher = cudf::hash_id::HASH_SPARK_MURMUR3;
auto const hash_strings = cudf::hash(cudf::table_view({strings_col}), hasher, {}, 314);
auto const hash_doubles = cudf::hash(cudf::table_view({doubles_col}), hasher, {}, 42);
auto const hash_timestamps = cudf::hash(cudf::table_view({timestamps_col}), hasher, {}, 42);
auto const hash_decimal64 = cudf::hash(cudf::table_view({decimal64_col}), hasher, {}, 42);
auto const hash_longs = cudf::hash(cudf::table_view({longs_col}), hasher, {}, 42);
auto const hash_floats = cudf::hash(cudf::table_view({floats_col}), hasher, {}, 42);
auto const hash_dates = cudf::hash(cudf::table_view({dates_col}), hasher, {}, 42);
auto const hash_decimal32 = cudf::hash(cudf::table_view({decimal32_col}), hasher, {}, 42);
auto const hash_ints = cudf::hash(cudf::table_view({ints_col}), hasher, {}, 42);
auto const hash_shorts = cudf::hash(cudf::table_view({shorts_col}), hasher, {}, 42);
auto const hash_bytes = cudf::hash(cudf::table_view({bytes_col}), hasher, {}, 42);
auto const hash_bools1 = cudf::hash(cudf::table_view({bools_col1}), hasher, {}, 42);
auto const hash_bools2 = cudf::hash(cudf::table_view({bools_col2}), hasher, {}, 42);

CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_strings, hash_strings_expected, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_doubles, hash_doubles_expected, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_timestamps, hash_timestamps_expected, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_decimal64, hash_decimal64_expected, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_longs, hash_longs_expected, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_floats, hash_floats_expected, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_dates, hash_dates_expected, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_decimal32, hash_decimal32_expected, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_ints, hash_ints_expected, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_shorts, hash_shorts_expected, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_bytes, hash_bytes_expected, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_bools1, hash_bools_expected, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_bools2, hash_bools_expected, true);

auto const combined_table = cudf::table_view({strings_col,
doubles_col,
timestamps_col,
decimal64_col,
longs_col,
floats_col,
dates_col,
decimal32_col,
ints_col,
shorts_col,
bytes_col,
bools_col2});
auto const hash_combined = cudf::hash(combined_table, hasher, {}, 42);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_combined, hash_combined_expected, true);
}

class MD5HashTest : public cudf::test::BaseFixture {
Expand Down
1 change: 0 additions & 1 deletion java/src/main/java/ai/rapids/cudf/ColumnVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,6 @@ public static ColumnVector spark32BitMurmurHash3(int seed, ColumnView columns[])
assert columns[i] != null : "Column vectors passed may not be null";
assert columns[i].getRowCount() == size : "Row count mismatch, all columns must be the same size";
assert !columns[i].getType().isDurationType() : "Unsupported column type Duration";
assert !columns[i].getType().isTimestampType() : "Unsupported column type Timestamp";
assert !columns[i].getType().isNestedType() : "Unsupported column of nested type";
columnViews[i] = columns[i].getNativeView();
}
Expand Down
40 changes: 40 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,46 @@ void testSpark32BitMurmur3HashDoubles() {
}
}

@Test
void testSpark32BitMurmur3HashTimestamps() {
try (ColumnVector v = ColumnVector.timestampMicroSecondsFromBoxedLongs(
0L, null, 100L, -100L, 0x123456789abcdefL, null, -0x123456789abcdefL);
ColumnVector result = ColumnVector.spark32BitMurmurHash3(42, new ColumnVector[]{v});
ColumnVector expected = ColumnVector.fromBoxedInts(-1670924195, 42, 1114849490, 904948192, 657182333, 42, -57193045)) {
assertColumnsAreEqual(expected, result);
}
}

@Test
void testSpark32BitMurmur3HashDecimal64() {
try (ColumnVector v = ColumnVector.decimalFromLongs(-7,
0L, 100L, -100L, 0x123456789abcdefL, -0x123456789abcdefL);
ColumnVector result = ColumnVector.spark32BitMurmurHash3(42, new ColumnVector[]{v});
ColumnVector expected = ColumnVector.fromBoxedInts(-1670924195, 1114849490, 904948192, 657182333, -57193045)) {
assertColumnsAreEqual(expected, result);
}
}

@Test
void testSpark32BitMurmur3HashDecimal32() {
try (ColumnVector v = ColumnVector.decimalFromInts(-3,
0, 100, -100, 0x12345678, -0x12345678);
ColumnVector result = ColumnVector.spark32BitMurmurHash3(42, new ColumnVector[]{v});
ColumnVector expected = ColumnVector.fromBoxedInts(-1670924195, 1114849490, 904948192, -958054811, -1447702630)) {
assertColumnsAreEqual(expected, result);
}
}

@Test
void testSpark32BitMurmur3HashDates() {
try (ColumnVector v = ColumnVector.timestampDaysFromBoxedInts(
0, null, 100, -100, 0x12345678, null, -0x12345678);
ColumnVector result = ColumnVector.spark32BitMurmurHash3(42, new ColumnVector[]{v});
ColumnVector expected = ColumnVector.fromBoxedInts(933211791, 42, 751823303, -1080202046, -1721170160, 42, 1852996993)) {
assertColumnsAreEqual(expected, result);
}
}

@Test
void testSpark32BitMurmur3HashFloats() {
try (ColumnVector v = ColumnVector.fromBoxedFloats(
Expand Down

0 comments on commit aa7ca46

Please sign in to comment.