From b42b6c5ae3584ed82e7e7f5f0ca80d6d83b1bfb9 Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Wed, 24 Mar 2021 10:44:34 -0500 Subject: [PATCH] Fix Spark hash of decimal32/decimal64 --- .../cudf/detail/utilities/hash_functions.cuh | 15 +++++ cpp/tests/hashing/hash_test.cpp | 55 ++++++++++++------- .../java/ai/rapids/cudf/ColumnVectorTest.java | 20 +++++++ 3 files changed, 71 insertions(+), 19 deletions(-) diff --git a/cpp/include/cudf/detail/utilities/hash_functions.cuh b/cpp/include/cudf/detail/utilities/hash_functions.cuh index a6f24240c0c..8b04651e1e6 100644 --- a/cpp/include/cudf/detail/utilities/hash_functions.cuh +++ b/cpp/include/cudf/detail/utilities/hash_functions.cuh @@ -18,6 +18,7 @@ #include #include +#include #include #include @@ -656,6 +657,20 @@ SparkMurmurHash3_32::operator()(uint16_t const& key) const return this->compute(key); } +template <> +hash_value_type CUDA_DEVICE_CALLABLE +SparkMurmurHash3_32::operator()(numeric::decimal32 const& key) const +{ + return this->compute(key.value()); +} + +template <> +hash_value_type CUDA_DEVICE_CALLABLE +SparkMurmurHash3_32::operator()(numeric::decimal64 const& key) const +{ + return this->compute(key.value()); +} + /** * @brief Specialization of MurmurHash3_32 operator for strings. */ diff --git a/cpp/tests/hashing/hash_test.cpp b/cpp/tests/hashing/hash_test.cpp index 36049436732..5641d445ff3 100644 --- a/cpp/tests/hashing/hash_test.cpp +++ b/cpp/tests/hashing/hash_test.cpp @@ -15,6 +15,7 @@ */ #include +#include #include #include @@ -280,25 +281,27 @@ TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds) // The hash values were determined by running the following Scala code in Apache Spark: // import org.apache.spark.sql.catalyst.util.DateTimeUtils // val schema = new StructType().add("strings",StringType).add("doubles",DoubleType) - // .add("timestamps",TimestampType).add("longs",LongType).add("floats",FloatType) - // .add("dates",DateType).add("ints",IntegerType).add("shorts",ShortType) - // .add("bytes",ByteType).add("bools",BooleanType) + // .add("timestamps",TimestampType).add("decimal64", DecimalType(18,7)).add("longs",LongType) + // .add("floats",FloatType).add("dates",DateType).add("decimal32", DecimalType(9,3)) + // .add("ints",IntegerType).add("shorts",ShortType).add("bytes",ByteType) + // .add("bools",BooleanType) // val data = Seq( - // Row("", 0.toDouble, DateTimeUtils.toJavaTimestamp(0), 0.toLong, 0.toFloat, - // DateTimeUtils.toJavaDate(0), 0, 0.toShort, 0.toByte, false), - // Row("The quick brown fox", -(0.toDouble), DateTimeUtils.toJavaTimestamp(100), 100.toLong, - // -(0.toFloat), DateTimeUtils.toJavaDate(100), 100, 100.toShort, 100.toByte, true), - // Row("jumps over the lazy dog.", -Double.NaN, DateTimeUtils.toJavaTimestamp(-100), - // -100.toLong, -Float.NaN, DateTimeUtils.toJavaDate(-100), -100, -100.toShort, - // -100.toByte, true), - // Row("All work and no play makes Jack a dull boy", Double.MinValue, - // DateTimeUtils.toJavaTimestamp(Long.MinValue/1000000), Long.MinValue, Float.MinValue, - // DateTimeUtils.toJavaDate(Int.MinValue/100), Int.MinValue, Short.MinValue, Byte.MinValue, - // true), - // Row("!\"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~\ud720\ud721", Double.MaxValue, - // DateTimeUtils.toJavaTimestamp(Long.MaxValue/1000000), Long.MaxValue, Float.MaxValue, - // DateTimeUtils.toJavaDate(Int.MaxValue/100), Int.MaxValue, Short.MaxValue, Byte.MaxValue, - // false)) + // Row("", 0.toDouble, DateTimeUtils.toJavaTimestamp(0), BigDecimal(0), 0.toLong, 0.toFloat, + // DateTimeUtils.toJavaDate(0), BigDecimal(0), 0, 0.toShort, 0.toByte, false), + // Row("The quick brown fox", -(0.toDouble), DateTimeUtils.toJavaTimestamp(100), + // BigDecimal("0.00001"), 100.toLong, -(0.toFloat), DateTimeUtils.toJavaDate(100), + // BigDecimal("0.1"), 100, 100.toShort, 100.toByte, true), + // Row("jumps over the lazy dog.", -Double.NaN, DateTimeUtils.toJavaTimestamp(-100), + // BigDecimal("-0.00001"), -100.toLong, -Float.NaN, DateTimeUtils.toJavaDate(-100), + // BigDecimal("-0.1"), -100, -100.toShort, -100.toByte, true), + // Row("All work and no play makes Jack a dull boy", Double.MinValue, + // DateTimeUtils.toJavaTimestamp(Long.MinValue/1000000), BigDecimal("-99999999999.9999999"), + // Long.MinValue, Float.MinValue, DateTimeUtils.toJavaDate(Int.MinValue/100), + // BigDecimal("-999999.999"), Int.MinValue, Short.MinValue, Byte.MinValue, true), + // Row("!\"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~\ud720\ud721", Double.MaxValue, + // DateTimeUtils.toJavaTimestamp(Long.MaxValue/1000000), BigDecimal("99999999999.9999999"), + // Long.MaxValue, Float.MaxValue, DateTimeUtils.toJavaDate(Int.MaxValue/100), + // BigDecimal("999999.999"), Int.MaxValue, Short.MaxValue, Byte.MaxValue, false)) // val df = spark.createDataFrame(sc.parallelize(data), schema) // df.columns.foreach(c => println(s"$c => ${df.select(hash(col(c))).collect.mkString(",")}")) // df.select(hash(col("*"))).collect @@ -308,12 +311,16 @@ TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds) {-1670924195, -853646085, -1281358385, 1897734433, -508695674}); fixed_width_column_wrapper const hash_timestamps_expected( {-1670924195, 1114849490, 904948192, -1832979433, 1752430209}); + fixed_width_column_wrapper const hash_decimal64_expected( + {-1670924195, 1114849490, 904948192, 1962370902, -1795328666}); fixed_width_column_wrapper const hash_longs_expected( {-1670924195, 1114849490, 904948192, -853646085, -1604625029}); fixed_width_column_wrapper const hash_floats_expected( {933211791, 723455942, -349261430, -1225560532, -338752985}); fixed_width_column_wrapper const hash_dates_expected( {933211791, 751823303, -1080202046, -1906567553, -1503850410}); + fixed_width_column_wrapper const hash_decimal32_expected( + {-1670924195, 1114849490, 904948192, -1454351396, -193774131}); fixed_width_column_wrapper const hash_ints_expected( {933211791, 751823303, -1080202046, 723455942, 133916647}); fixed_width_column_wrapper const hash_shorts_expected( @@ -323,7 +330,7 @@ TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds) fixed_width_column_wrapper const hash_bools_expected( {933211791, -559580957, -559580957, -559580957, 933211791}); fixed_width_column_wrapper const hash_combined_expected( - {969935434, 595083937, 720326214, 971129823, -1060339603}); + {-1947042614, -1731440908, 807283935, 725489209, 822276819}); strings_column_wrapper const strings_col({"", "The quick brown fox", @@ -339,12 +346,16 @@ TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds) {0., -0., -double_limits::quiet_NaN(), double_limits::lowest(), double_limits::max()}); fixed_width_column_wrapper const timestamps_col( {0L, 100L, -100L, long_limits::min() / 1000000, long_limits::max() / 1000000}); + fixed_point_column_wrapper const decimal64_col( + {0L, 100L, -100L, -999999999999999999L, 999999999999999999L}, numeric::scale_type{-7}); fixed_width_column_wrapper const longs_col( {0L, 100L, -100L, long_limits::min(), long_limits::max()}); fixed_width_column_wrapper const floats_col( {0.f, -0.f, -float_limits::quiet_NaN(), float_limits::lowest(), float_limits::max()}); fixed_width_column_wrapper dates_col( {0, 100, -100, int_limits::min() / 100, int_limits::max() / 100}); + fixed_point_column_wrapper const decimal32_col({0, 100, -100, -999999999, 999999999}, + numeric::scale_type{-3}); fixed_width_column_wrapper const ints_col( {0, 100, -100, int_limits::min(), int_limits::max()}); fixed_width_column_wrapper const shorts_col({0, 100, -100, -32768, 32767}); @@ -356,9 +367,11 @@ TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds) auto const hash_strings = cudf::hash(cudf::table_view({strings_col}), hasher, {}, 314); auto const hash_doubles = cudf::hash(cudf::table_view({doubles_col}), hasher, {}, 42); auto const hash_timestamps = cudf::hash(cudf::table_view({timestamps_col}), hasher, {}, 42); + auto const hash_decimal64 = cudf::hash(cudf::table_view({decimal64_col}), hasher, {}, 42); auto const hash_longs = cudf::hash(cudf::table_view({longs_col}), hasher, {}, 42); auto const hash_floats = cudf::hash(cudf::table_view({floats_col}), hasher, {}, 42); auto const hash_dates = cudf::hash(cudf::table_view({dates_col}), hasher, {}, 42); + auto const hash_decimal32 = cudf::hash(cudf::table_view({decimal32_col}), hasher, {}, 42); auto const hash_ints = cudf::hash(cudf::table_view({ints_col}), hasher, {}, 42); auto const hash_shorts = cudf::hash(cudf::table_view({shorts_col}), hasher, {}, 42); auto const hash_bytes = cudf::hash(cudf::table_view({bytes_col}), hasher, {}, 42); @@ -368,9 +381,11 @@ TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds) CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_strings, hash_strings_expected, true); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_doubles, hash_doubles_expected, true); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_timestamps, hash_timestamps_expected, true); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_decimal64, hash_decimal64_expected, true); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_longs, hash_longs_expected, true); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_floats, hash_floats_expected, true); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_dates, hash_dates_expected, true); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_decimal32, hash_decimal32_expected, true); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_ints, hash_ints_expected, true); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_shorts, hash_shorts_expected, true); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_bytes, hash_bytes_expected, true); @@ -380,9 +395,11 @@ TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds) auto const combined_table = cudf::table_view({strings_col, doubles_col, timestamps_col, + decimal64_col, longs_col, floats_col, dates_col, + decimal32_col, ints_col, shorts_col, bytes_col, diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 673b8a67467..8b40f6e93d4 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -537,6 +537,26 @@ void testSpark32BitMurmur3HashTimestamps() { } } + @Test + void testSpark32BitMurmur3HashDecimal64() { + try (ColumnVector v = ColumnVector.decimalFromLongs(-7, + 0L, 100L, -100L, 0x123456789abcdefL, -0x123456789abcdefL); + ColumnVector result = ColumnVector.spark32BitMurmurHash3(42, new ColumnVector[]{v}); + ColumnVector expected = ColumnVector.fromBoxedInts(-1670924195, 1114849490, 904948192, 657182333, -57193045)) { + assertColumnsAreEqual(expected, result); + } + } + + @Test + void testSpark32BitMurmur3HashDecimal32() { + try (ColumnVector v = ColumnVector.decimalFromInts(-3, + 0, 100, -100, 0x12345678, -0x12345678); + ColumnVector result = ColumnVector.spark32BitMurmurHash3(42, new ColumnVector[]{v}); + ColumnVector expected = ColumnVector.fromBoxedInts(-1670924195, 1114849490, 904948192, -958054811, -1447702630)) { + assertColumnsAreEqual(expected, result); + } + } + @Test void testSpark32BitMurmur3HashDates() { try (ColumnVector v = ColumnVector.timestampDaysFromBoxedInts(