Skip to content

Commit

Permalink
Fix SparkMurmurHash3_32 hash inconsistencies with Apache Spark
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowe committed Mar 22, 2021
1 parent 8632ca0 commit d01cc1c
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 38 deletions.
34 changes: 30 additions & 4 deletions cpp/include/cudf/detail/utilities/hash_functions.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2017-2020, NVIDIA CORPORATION.
* Copyright (c) 2017-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -570,9 +570,7 @@ struct SparkMurmurHash3_32 {
template <typename T, std::enable_if_t<std::is_floating_point<T>::value>* = nullptr>
hash_value_type CUDA_DEVICE_CALLABLE compute_floating_point(T const& key) const
{
if (key == T{0.0}) {
return compute(T{0.0});
} else if (isnan(key)) {
if (isnan(key)) {
T nan = std::numeric_limits<T>::quiet_NaN();
return compute(nan);
} else {
Expand Down Expand Up @@ -630,6 +628,34 @@ hash_value_type CUDA_DEVICE_CALLABLE SparkMurmurHash3_32<bool>::operator()(bool
return this->compute<uint32_t>(key);
}

template <>
hash_value_type CUDA_DEVICE_CALLABLE
SparkMurmurHash3_32<int8_t>::operator()(int8_t const& key) const
{
return this->compute<uint32_t>(key);
}

template <>
hash_value_type CUDA_DEVICE_CALLABLE
SparkMurmurHash3_32<uint8_t>::operator()(uint8_t const& key) const
{
return this->compute<uint32_t>(key);
}

template <>
hash_value_type CUDA_DEVICE_CALLABLE
SparkMurmurHash3_32<int16_t>::operator()(int16_t const& key) const
{
return this->compute<uint32_t>(key);
}

template <>
hash_value_type CUDA_DEVICE_CALLABLE
SparkMurmurHash3_32<uint16_t>::operator()(uint16_t const& key) const
{
return this->compute<uint32_t>(key);
}

/**
* @brief Specialization of MurmurHash3_32 operator for strings.
*/
Expand Down
157 changes: 124 additions & 33 deletions cpp/tests/hashing/hash_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,27 +201,37 @@ TYPED_TEST(HashTestFloatTyped, TestExtremes)
T nan = std::numeric_limits<T>::quiet_NaN();
T inf = std::numeric_limits<T>::infinity();

fixed_width_column_wrapper<T> const col1({T(0.0), T(100.0), T(-100.0), min, max, nan, inf, -inf});
fixed_width_column_wrapper<T> const col2(
{T(-0.0), T(100.0), T(-100.0), min, max, -nan, inf, -inf});
fixed_width_column_wrapper<T> const col({T(0.0), T(100.0), T(-100.0), min, max, nan, inf, -inf});
fixed_width_column_wrapper<T> const col_neg_zero(
{T(-0.0), T(100.0), T(-100.0), min, max, nan, inf, -inf});
fixed_width_column_wrapper<T> const col_neg_nan(
{T(0.0), T(100.0), T(-100.0), min, max, -nan, inf, -inf});

auto const input1 = cudf::table_view({col1});
auto const input2 = cudf::table_view({col2});
auto const table_col = cudf::table_view({col});
auto const table_col_neg_zero = cudf::table_view({col_neg_zero});
auto const table_col_neg_nan = cudf::table_view({col_neg_nan});

auto const output1 = cudf::hash(input1);
auto const output2 = cudf::hash(input2);
auto const hash_col = cudf::hash(table_col);
auto const hash_col_neg_zero = cudf::hash(table_col_neg_zero);
auto const hash_col_neg_nan = cudf::hash(table_col_neg_nan);

CUDF_TEST_EXPECT_COLUMNS_EQUAL(output1->view(), output2->view(), true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_col, *hash_col_neg_zero, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_col, *hash_col_neg_nan, true);

auto const serial_output1 = cudf::hash(input1, cudf::hash_id::HASH_SERIAL_MURMUR3, {}, 0);
auto const serial_output2 = cudf::hash(input2, cudf::hash_id::HASH_SERIAL_MURMUR3);
constexpr auto serial_hasher = cudf::hash_id::HASH_SERIAL_MURMUR3;
auto const serial_col = cudf::hash(table_col, serial_hasher, {}, 0);
auto const serial_col_neg_zero = cudf::hash(table_col_neg_zero, serial_hasher);
auto const serial_col_neg_nan = cudf::hash(table_col_neg_nan, serial_hasher);

CUDF_TEST_EXPECT_COLUMNS_EQUAL(serial_output1->view(), serial_output2->view());
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*serial_col, *serial_col_neg_zero, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*serial_col, *serial_col_neg_nan, true);

auto const spark_output1 = cudf::hash(input1, cudf::hash_id::HASH_SPARK_MURMUR3, {}, 0);
auto const spark_output2 = cudf::hash(input2, cudf::hash_id::HASH_SPARK_MURMUR3);
// Spark hash is sensitive to 0 and -0
constexpr auto spark_hasher = cudf::hash_id::HASH_SPARK_MURMUR3;
auto const spark_col = cudf::hash(table_col, spark_hasher, {}, 0);
auto const spark_col_neg_nan = cudf::hash(table_col_neg_nan, spark_hasher);

CUDF_TEST_EXPECT_COLUMNS_EQUAL(spark_output1->view(), spark_output2->view());
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*spark_col, *spark_col_neg_nan);
}

class SerialMurmurHash3Test : public cudf::test::BaseFixture {
Expand Down Expand Up @@ -267,37 +277,118 @@ class SparkMurmurHash3Test : public cudf::test::BaseFixture {

TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds)
{
fixed_width_column_wrapper<int32_t> const strings_col_result(
// The hash values were determined by running the following Scala code in Apache Spark:
// import org.apache.spark.sql.catalyst.util.DateTimeUtils
// val schema = new StructType().add("strings",StringType).add("doubles",DoubleType)
// .add("timestamps",TimestampType).add("longs",LongType).add("floats",FloatType)
// .add("dates",DateType).add("ints",IntegerType).add("shorts",ShortType)
// .add("bytes",ByteType).add("bools",BooleanType)
// val data = Seq(
// Row("", 0.toDouble, DateTimeUtils.toJavaTimestamp(0), 0.toLong, 0.toFloat,
// DateTimeUtils.toJavaDate(0), 0, 0.toShort, 0.toByte, false),
// Row("The quick brown fox", -(0.toDouble), DateTimeUtils.toJavaTimestamp(100), 100.toLong,
// -(0.toFloat), DateTimeUtils.toJavaDate(100), 100, 100.toShort, 100.toByte, true),
// Row("jumps over the lazy dog.", -Double.NaN, DateTimeUtils.toJavaTimestamp(-100),
// -100.toLong, -Float.NaN, DateTimeUtils.toJavaDate(-100), -100, -100.toShort,
// -100.toByte, true),
// Row("All work and no play makes Jack a dull boy", Double.MinValue,
// DateTimeUtils.toJavaTimestamp(Long.MinValue/1000000), Long.MinValue, Float.MinValue,
// DateTimeUtils.toJavaDate(Int.MinValue/100), Int.MinValue, Short.MinValue, Byte.MinValue,
// true),
// Row("!\"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~\ud720\ud721", Double.MaxValue,
// DateTimeUtils.toJavaTimestamp(Long.MaxValue/1000000), Long.MaxValue, Float.MaxValue,
// DateTimeUtils.toJavaDate(Int.MaxValue/100), Int.MaxValue, Short.MaxValue, Byte.MaxValue,
// false))
// val df = spark.createDataFrame(sc.parallelize(data), schema)
// df.columns.foreach(c => println(s"$c => ${df.select(hash(col(c))).collect.mkString(",")}"))
// df.select(hash(col("*"))).collect
fixed_width_column_wrapper<int32_t> const hash_strings_expected(
{1467149710, 723257560, -1620282500, -2001858707, 1588473657});
fixed_width_column_wrapper<int32_t> const ints_col_result(
fixed_width_column_wrapper<int32_t> const hash_doubles_expected(
{-1670924195, -853646085, -1281358385, 1897734433, -508695674});
fixed_width_column_wrapper<int32_t> const hash_timestamps_expected(
{-1670924195, 1114849490, 904948192, -1832979433, 1752430209});
fixed_width_column_wrapper<int32_t> const hash_longs_expected(
{-1670924195, 1114849490, 904948192, -853646085, -1604625029});
fixed_width_column_wrapper<int32_t> const hash_floats_expected(
{933211791, 723455942, -349261430, -1225560532, -338752985});
fixed_width_column_wrapper<int32_t> const hash_dates_expected(
{933211791, 751823303, -1080202046, -1906567553, -1503850410});
fixed_width_column_wrapper<int32_t> const hash_ints_expected(
{933211791, 751823303, -1080202046, 723455942, 133916647});
fixed_width_column_wrapper<int32_t> const hash_shorts_expected(
{933211791, 751823303, -1080202046, -1871935946, 1249274084});
fixed_width_column_wrapper<int32_t> const hash_bytes_expected(
{933211791, 751823303, -1080202046, 1110053733, 1135925485});
fixed_width_column_wrapper<int32_t> const hash_bools_expected(
{933211791, -559580957, -559580957, -559580957, 933211791});
fixed_width_column_wrapper<int32_t> const hash_combined_expected(
{969935434, 595083937, 720326214, 971129823, -1060339603});

strings_column_wrapper const strings_col({"",
"The quick brown fox",
"jumps over the lazy dog.",
"All work and no play makes Jack a dull boy",
"!\"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~\ud720\ud721"});

using limits = std::numeric_limits<int32_t>;
fixed_width_column_wrapper<int32_t> const ints_col({0, 100, -100, limits::min(), limits::max()});

using double_limits = std::numeric_limits<double>;
using long_limits = std::numeric_limits<int64_t>;
using float_limits = std::numeric_limits<float>;
using int_limits = std::numeric_limits<int32_t>;
fixed_width_column_wrapper<double> const doubles_col(
{0., -0., -double_limits::quiet_NaN(), double_limits::lowest(), double_limits::max()});
fixed_width_column_wrapper<cudf::timestamp_ms, cudf::timestamp_ms::rep> const timestamps_col(
{0L, 100L, -100L, long_limits::min() / 1000000, long_limits::max() / 1000000});
fixed_width_column_wrapper<int64_t> const longs_col(
{0L, 100L, -100L, long_limits::min(), long_limits::max()});
fixed_width_column_wrapper<float> const floats_col(
{0.f, -0.f, -float_limits::quiet_NaN(), float_limits::lowest(), float_limits::max()});
fixed_width_column_wrapper<cudf::timestamp_D, cudf::timestamp_D::rep> dates_col(
{0, 100, -100, int_limits::min() / 100, int_limits::max() / 100});
fixed_width_column_wrapper<int32_t> const ints_col(
{0, 100, -100, int_limits::min(), int_limits::max()});
fixed_width_column_wrapper<int16_t> const shorts_col({0, 100, -100, -32768, 32767});
fixed_width_column_wrapper<int8_t> const bytes_col({0, 100, -100, -128, 127});
fixed_width_column_wrapper<bool> const bools_col1({0, 1, 1, 1, 0});
fixed_width_column_wrapper<bool> const bools_col2({0, 1, 2, 255, 0});

auto const input1 = cudf::table_view({strings_col});
auto const input2 = cudf::table_view({ints_col});
auto const input3 = cudf::table_view({strings_col, ints_col, bools_col1});
auto const input4 = cudf::table_view({strings_col, ints_col, bools_col2});

auto const hashed_output1 = cudf::hash(input1, cudf::hash_id::HASH_SPARK_MURMUR3, {}, 314);
auto const hashed_output2 = cudf::hash(input2, cudf::hash_id::HASH_SPARK_MURMUR3, {}, 42);
auto const hashed_output3 = cudf::hash(input3, cudf::hash_id::HASH_SPARK_MURMUR3, {});
auto const hashed_output4 = cudf::hash(input4, cudf::hash_id::HASH_SPARK_MURMUR3, {});

CUDF_TEST_EXPECT_COLUMNS_EQUAL(hashed_output1->view(), strings_col_result, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(hashed_output2->view(), ints_col_result, true);
EXPECT_EQ(input3.num_rows(), hashed_output3->size());
CUDF_TEST_EXPECT_COLUMNS_EQUAL(hashed_output3->view(), hashed_output4->view(), true);
constexpr auto hasher = cudf::hash_id::HASH_SPARK_MURMUR3;
auto const hash_strings = cudf::hash(cudf::table_view({strings_col}), hasher, {}, 314);
auto const hash_doubles = cudf::hash(cudf::table_view({doubles_col}), hasher, {}, 42);
auto const hash_timestamps = cudf::hash(cudf::table_view({timestamps_col}), hasher, {}, 42);
auto const hash_longs = cudf::hash(cudf::table_view({longs_col}), hasher, {}, 42);
auto const hash_floats = cudf::hash(cudf::table_view({floats_col}), hasher, {}, 42);
auto const hash_dates = cudf::hash(cudf::table_view({dates_col}), hasher, {}, 42);
auto const hash_ints = cudf::hash(cudf::table_view({ints_col}), hasher, {}, 42);
auto const hash_shorts = cudf::hash(cudf::table_view({shorts_col}), hasher, {}, 42);
auto const hash_bytes = cudf::hash(cudf::table_view({bytes_col}), hasher, {}, 42);
auto const hash_bools1 = cudf::hash(cudf::table_view({bools_col1}), hasher, {}, 42);
auto const hash_bools2 = cudf::hash(cudf::table_view({bools_col2}), hasher, {}, 42);

CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_strings, hash_strings_expected, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_doubles, hash_doubles_expected, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_timestamps, hash_timestamps_expected, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_longs, hash_longs_expected, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_floats, hash_floats_expected, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_dates, hash_dates_expected, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_ints, hash_ints_expected, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_shorts, hash_shorts_expected, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_bytes, hash_bytes_expected, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_bools1, hash_bools_expected, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_bools2, hash_bools_expected, true);

auto const combined_table = cudf::table_view({strings_col,
doubles_col,
timestamps_col,
longs_col,
floats_col,
dates_col,
ints_col,
shorts_col,
bytes_col,
bools_col2});
auto const hash_combined = cudf::hash(combined_table, hasher, {}, 42);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_combined, hash_combined_expected, true);
}

class MD5HashTest : public cudf::test::BaseFixture {
Expand Down
1 change: 0 additions & 1 deletion java/src/main/java/ai/rapids/cudf/ColumnVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,6 @@ public static ColumnVector spark32BitMurmurHash3(int seed, ColumnView columns[])
assert columns[i] != null : "Column vectors passed may not be null";
assert columns[i].getRowCount() == size : "Row count mismatch, all columns must be the same size";
assert !columns[i].getType().isDurationType() : "Unsupported column type Duration";
assert !columns[i].getType().isTimestampType() : "Unsupported column type Timestamp";
assert !columns[i].getType().isNestedType() : "Unsupported column of nested type";
columnViews[i] = columns[i].getNativeView();
}
Expand Down
20 changes: 20 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,26 @@ void testSpark32BitMurmur3HashDoubles() {
}
}

@Test
void testSpark32BitMurmur3HashTimestamps() {
try (ColumnVector v = ColumnVector.timestampMicroSecondsFromBoxedLongs(
0L, null, 100L, -100L, 0x123456789abcdefL, null, -0x123456789abcdefL);
ColumnVector result = ColumnVector.spark32BitMurmurHash3(42, new ColumnVector[]{v});
ColumnVector expected = ColumnVector.fromBoxedInts(-1670924195, 42, 1114849490, 904948192, 657182333, 42, -57193045)) {
assertColumnsAreEqual(expected, result);
}
}

@Test
void testSpark32BitMurmur3HashDates() {
try (ColumnVector v = ColumnVector.timestampDaysFromBoxedInts(
0, null, 100, -100, 0x12345678, null, -0x12345678);
ColumnVector result = ColumnVector.spark32BitMurmurHash3(42, new ColumnVector[]{v});
ColumnVector expected = ColumnVector.fromBoxedInts(933211791, 42, 751823303, -1080202046, -1721170160, 42, 1852996993)) {
assertColumnsAreEqual(expected, result);
}
}

@Test
void testSpark32BitMurmur3HashFloats() {
try (ColumnVector v = ColumnVector.fromBoxedFloats(
Expand Down

0 comments on commit d01cc1c

Please sign in to comment.