Skip to content

Commit

Permalink
Struct hashing support for SerialMurmur3 and SparkMurmur3 (#7714)
Browse files Browse the repository at this point in the history
Adding struct column support for serial Murmur3 and Spark-compatible Murmur3 hashing.  This explodes the struct column into the leaf columns before passing it to the existing hash support.  The validity of the parent struct columns can be ignored because hashing a null ends up as a no-op that returns the hash seed, so only the leaf columns within the struct column need to be considered for the hash computation.

Authors:
  - Jason Lowe (https://github.com/jlowe)

Approvers:
  - Robert (Bobby) Evans (https://github.com/revans2)
  - Conor Hoekstra (https://github.com/codereport)
  - Mark Harris (https://github.com/harrism)

URL: #7714
  • Loading branch information
jlowe authored Mar 31, 2021
1 parent 4d6ea76 commit 9970f1d
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 49 deletions.
32 changes: 32 additions & 0 deletions cpp/include/cudf/detail/utilities/hash_functions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,22 @@ hash_value_type CUDA_DEVICE_CALLABLE MurmurHash3_32<double>::operator()(double c
return this->compute_floating_point(key);
}

template <>
hash_value_type CUDA_DEVICE_CALLABLE
MurmurHash3_32<cudf::list_view>::operator()(cudf::list_view const& key) const
{
cudf_assert(false && "List column hashing is not supported");
return 0;
}

template <>
hash_value_type CUDA_DEVICE_CALLABLE
MurmurHash3_32<cudf::struct_view>::operator()(cudf::struct_view const& key) const
{
cudf_assert(false && "Direct hashing of struct_view is not supported");
return 0;
}

template <typename Key>
struct SparkMurmurHash3_32 {
using argument_type = Key;
Expand Down Expand Up @@ -671,6 +687,22 @@ SparkMurmurHash3_32<numeric::decimal64>::operator()(numeric::decimal64 const& ke
return this->compute<uint64_t>(key.value());
}

template <>
hash_value_type CUDA_DEVICE_CALLABLE
SparkMurmurHash3_32<cudf::list_view>::operator()(cudf::list_view const& key) const
{
cudf_assert(false && "List column hashing is not supported");
return 0;
}

template <>
hash_value_type CUDA_DEVICE_CALLABLE
SparkMurmurHash3_32<cudf::struct_view>::operator()(cudf::struct_view const& key) const
{
cudf_assert(false && "Direct hashing of struct_view is not supported");
return 0;
}

/**
* @brief Specialization of MurmurHash3_32 operator for strings.
*/
Expand Down
25 changes: 22 additions & 3 deletions cpp/src/hash/hashing.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -29,6 +29,8 @@

#include <rmm/cuda_stream_view.hpp>

#include <algorithm>

namespace cudf {
namespace {

Expand All @@ -38,6 +40,22 @@ bool md5_type_check(data_type dt)
return !is_chrono(dt) && (is_fixed_width(dt) || (dt.id() == type_id::STRING));
}

template <typename IterType>
std::vector<column_view> to_leaf_columns(IterType iter_begin, IterType iter_end)
{
std::vector<column_view> leaf_columns;
std::for_each(iter_begin, iter_end, [&leaf_columns](column_view const& col) {
if (is_nested(col.type())) {
CUDF_EXPECTS(col.type().id() == type_id::STRUCT, "unsupported nested type");
auto child_columns = to_leaf_columns(col.child_begin(), col.child_end());
leaf_columns.insert(leaf_columns.end(), child_columns.begin(), child_columns.end());
} else {
leaf_columns.emplace_back(col);
}
});
return leaf_columns;
}

} // namespace

namespace detail {
Expand Down Expand Up @@ -133,10 +151,11 @@ std::unique_ptr<column> serial_murmur_hash3_32(table_view const& input,

if (input.num_columns() == 0 || input.num_rows() == 0) { return output; }

auto const device_input = table_device_view::create(input, stream);
table_view const leaf_table(to_leaf_columns(input.begin(), input.end()));
auto const device_input = table_device_view::create(leaf_table, stream);
auto output_view = output->mutable_view();

if (has_nulls(input)) {
if (has_nulls(leaf_table)) {
thrust::tabulate(rmm::exec_policy(stream),
output_view.begin<int32_t>(),
output_view.end<int32_t>(),
Expand Down
125 changes: 83 additions & 42 deletions cpp/tests/hashing/hash_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,20 +257,35 @@ TEST_F(SerialMurmurHash3Test, MultiValueWithSeeds)
fixed_width_column_wrapper<bool> const bools_col1({0, 1, 1, 1, 0});
fixed_width_column_wrapper<bool> const bools_col2({0, 1, 2, 255, 0});

auto const input1 = cudf::table_view({strings_col});
auto const input2 = cudf::table_view({ints_col});
auto const input3 = cudf::table_view({strings_col, ints_col, bools_col1});
auto const input4 = cudf::table_view({strings_col, ints_col, bools_col2});

auto const hashed_output1 = cudf::hash(input1, cudf::hash_id::HASH_SERIAL_MURMUR3, {}, 314);
auto const hashed_output2 = cudf::hash(input2, cudf::hash_id::HASH_SERIAL_MURMUR3, {}, 42);
auto const hashed_output3 = cudf::hash(input3, cudf::hash_id::HASH_SERIAL_MURMUR3, {});
auto const hashed_output4 = cudf::hash(input4, cudf::hash_id::HASH_SERIAL_MURMUR3, {});

CUDF_TEST_EXPECT_COLUMNS_EQUAL(hashed_output1->view(), strings_col_result, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(hashed_output2->view(), ints_col_result, true);
EXPECT_EQ(input3.num_rows(), hashed_output3->size());
CUDF_TEST_EXPECT_COLUMNS_EQUAL(hashed_output3->view(), hashed_output4->view(), true);
std::vector<std::unique_ptr<cudf::column>> struct_field_cols;
struct_field_cols.emplace_back(std::make_unique<cudf::column>(strings_col));
struct_field_cols.emplace_back(std::make_unique<cudf::column>(ints_col));
struct_field_cols.emplace_back(std::make_unique<cudf::column>(bools_col1));
structs_column_wrapper structs_col(std::move(struct_field_cols));

auto const combo1 = cudf::table_view({strings_col, ints_col, bools_col1});
auto const combo2 = cudf::table_view({strings_col, ints_col, bools_col2});

constexpr auto hasher = cudf::hash_id::HASH_SERIAL_MURMUR3;
auto const strings_hash = cudf::hash(cudf::table_view({strings_col}), hasher, {}, 314);
auto const ints_hash = cudf::hash(cudf::table_view({ints_col}), hasher, {}, 42);
auto const combo1_hash = cudf::hash(combo1, hasher, {});
auto const combo2_hash = cudf::hash(combo2, hasher, {});
auto const structs_hash = cudf::hash(cudf::table_view({structs_col}), hasher, {});

CUDF_TEST_EXPECT_COLUMNS_EQUAL(*strings_hash, strings_col_result, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*ints_hash, ints_col_result, true);
EXPECT_EQ(combo1.num_rows(), combo1_hash->size());
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*combo1_hash, *combo2_hash, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*structs_hash, *combo1_hash, true);
}

TEST_F(SerialMurmurHash3Test, ListThrows)
{
lists_column_wrapper<cudf::string_view> strings_list_col({{""}, {"abc"}, {"123"}});
EXPECT_THROW(
cudf::hash(cudf::table_view({strings_list_col}), cudf::hash_id::HASH_SERIAL_MURMUR3, {}),
cudf::logic_error);
}

class SparkMurmurHash3Test : public cudf::test::BaseFixture {
Expand All @@ -280,31 +295,38 @@ TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds)
{
// The hash values were determined by running the following Scala code in Apache Spark:
// import org.apache.spark.sql.catalyst.util.DateTimeUtils
// val schema = new StructType().add("strings",StringType).add("doubles",DoubleType)
// .add("timestamps",TimestampType).add("decimal64", DecimalType(18,7)).add("longs",LongType)
// .add("floats",FloatType).add("dates",DateType).add("decimal32", DecimalType(9,3))
// .add("ints",IntegerType).add("shorts",ShortType).add("bytes",ByteType)
// .add("bools",BooleanType)
// val schema = new StructType().add("structs", new StructType().add("a",IntegerType)
// .add("b",StringType).add("c",new StructType().add("x",FloatType).add("y",LongType)))
// .add("strings",StringType).add("doubles",DoubleType).add("timestamps",TimestampType)
// .add("decimal64", DecimalType(18,7)).add("longs",LongType).add("floats",FloatType)
// .add("dates",DateType).add("decimal32", DecimalType(9,3)).add("ints",IntegerType)
// .add("shorts",ShortType).add("bytes",ByteType).add("bools",BooleanType)
// val data = Seq(
// Row("", 0.toDouble, DateTimeUtils.toJavaTimestamp(0), BigDecimal(0), 0.toLong, 0.toFloat,
// DateTimeUtils.toJavaDate(0), BigDecimal(0), 0, 0.toShort, 0.toByte, false),
// Row("The quick brown fox", -(0.toDouble), DateTimeUtils.toJavaTimestamp(100),
// BigDecimal("0.00001"), 100.toLong, -(0.toFloat), DateTimeUtils.toJavaDate(100),
// BigDecimal("0.1"), 100, 100.toShort, 100.toByte, true),
// Row("jumps over the lazy dog.", -Double.NaN, DateTimeUtils.toJavaTimestamp(-100),
// BigDecimal("-0.00001"), -100.toLong, -Float.NaN, DateTimeUtils.toJavaDate(-100),
// BigDecimal("-0.1"), -100, -100.toShort, -100.toByte, true),
// Row("All work and no play makes Jack a dull boy", Double.MinValue,
// DateTimeUtils.toJavaTimestamp(Long.MinValue/1000000), BigDecimal("-99999999999.9999999"),
// Long.MinValue, Float.MinValue, DateTimeUtils.toJavaDate(Int.MinValue/100),
// BigDecimal("-999999.999"), Int.MinValue, Short.MinValue, Byte.MinValue, true),
// Row("!\"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~\ud720\ud721", Double.MaxValue,
// DateTimeUtils.toJavaTimestamp(Long.MaxValue/1000000), BigDecimal("99999999999.9999999"),
// Long.MaxValue, Float.MaxValue, DateTimeUtils.toJavaDate(Int.MaxValue/100),
// BigDecimal("999999.999"), Int.MaxValue, Short.MaxValue, Byte.MaxValue, false))
// Row(Row(0, "a", Row(0f, 0L)), "", 0.toDouble, DateTimeUtils.toJavaTimestamp(0), BigDecimal(0),
// 0.toLong, 0.toFloat, DateTimeUtils.toJavaDate(0), BigDecimal(0), 0, 0.toShort, 0.toByte,
// false),
// Row(Row(100, "bc", Row(100f, 100L)), "The quick brown fox", -(0.toDouble),
// DateTimeUtils.toJavaTimestamp(100), BigDecimal("0.00001"), 100.toLong, -(0.toFloat),
// DateTimeUtils.toJavaDate(100), BigDecimal("0.1"), 100, 100.toShort, 100.toByte, true),
// Row(Row(-100, "def", Row(-100f, -100L)), "jumps over the lazy dog.", -Double.NaN,
// DateTimeUtils.toJavaTimestamp(-100), BigDecimal("-0.00001"), -100.toLong, -Float.NaN,
// DateTimeUtils.toJavaDate(-100), BigDecimal("-0.1"), -100, -100.toShort, -100.toByte,
// true),
// Row(Row(0x12345678, "ghij", Row(Float.PositiveInfinity, 0x123456789abcdefL)),
// "All work and no play makes Jack a dull boy", Double.MinValue,
// DateTimeUtils.toJavaTimestamp(Long.MinValue/1000000), BigDecimal("-99999999999.9999999"),
// Long.MinValue, Float.MinValue, DateTimeUtils.toJavaDate(Int.MinValue/100),
// BigDecimal("-999999.999"), Int.MinValue, Short.MinValue, Byte.MinValue, true),
// Row(Row(-0x76543210, "klmno", Row(Float.NegativeInfinity, -0x123456789abcdefL)),
// "!\"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~\ud720\ud721", Double.MaxValue,
// DateTimeUtils.toJavaTimestamp(Long.MaxValue/1000000), BigDecimal("99999999999.9999999"),
// Long.MaxValue, Float.MaxValue, DateTimeUtils.toJavaDate(Int.MaxValue/100),
// BigDecimal("999999.999"), Int.MaxValue, Short.MaxValue, Byte.MaxValue, false))
// val df = spark.createDataFrame(sc.parallelize(data), schema)
// df.columns.foreach(c => println(s"$c => ${df.select(hash(col(c))).collect.mkString(",")}"))
// df.select(hash(col("*"))).collect
fixed_width_column_wrapper<int32_t> const hash_structs_expected(
{-105406170, 90479889, -678041645, 1667387937, 301478567});
fixed_width_column_wrapper<int32_t> const hash_strings_expected(
{1467149710, 723257560, -1620282500, -2001858707, 1588473657});
fixed_width_column_wrapper<int32_t> const hash_doubles_expected(
Expand All @@ -330,18 +352,26 @@ TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds)
fixed_width_column_wrapper<int32_t> const hash_bools_expected(
{933211791, -559580957, -559580957, -559580957, 933211791});
fixed_width_column_wrapper<int32_t> const hash_combined_expected(
{-1947042614, -1731440908, 807283935, 725489209, 822276819});
{-1172364561, -442972638, 1213234395, 796626751, 214075225});

using double_limits = std::numeric_limits<double>;
using long_limits = std::numeric_limits<int64_t>;
using float_limits = std::numeric_limits<float>;
using int_limits = std::numeric_limits<int32_t>;
fixed_width_column_wrapper<int32_t> a_col{0, 100, -100, 0x12345678, -0x76543210};
strings_column_wrapper b_col{"a", "bc", "def", "ghij", "klmno"};
fixed_width_column_wrapper<float> x_col{
0.f, 100.f, -100.f, float_limits::infinity(), -float_limits::infinity()};
fixed_width_column_wrapper<int64_t> y_col{
0L, 100L, -100L, 0x123456789abcdefL, -0x123456789abcdefL};
structs_column_wrapper c_col{{x_col, y_col}};
structs_column_wrapper const structs_col{{a_col, b_col, c_col}};

strings_column_wrapper const strings_col({"",
"The quick brown fox",
"jumps over the lazy dog.",
"All work and no play makes Jack a dull boy",
"!\"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~\ud720\ud721"});

using double_limits = std::numeric_limits<double>;
using long_limits = std::numeric_limits<int64_t>;
using float_limits = std::numeric_limits<float>;
using int_limits = std::numeric_limits<int32_t>;
fixed_width_column_wrapper<double> const doubles_col(
{0., -0., -double_limits::quiet_NaN(), double_limits::lowest(), double_limits::max()});
fixed_width_column_wrapper<cudf::timestamp_ms, cudf::timestamp_ms::rep> const timestamps_col(
Expand All @@ -364,6 +394,7 @@ TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds)
fixed_width_column_wrapper<bool> const bools_col2({0, 1, 2, 255, 0});

constexpr auto hasher = cudf::hash_id::HASH_SPARK_MURMUR3;
auto const hash_structs = cudf::hash(cudf::table_view({structs_col}), hasher, {}, 42);
auto const hash_strings = cudf::hash(cudf::table_view({strings_col}), hasher, {}, 314);
auto const hash_doubles = cudf::hash(cudf::table_view({doubles_col}), hasher, {}, 42);
auto const hash_timestamps = cudf::hash(cudf::table_view({timestamps_col}), hasher, {}, 42);
Expand All @@ -378,6 +409,7 @@ TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds)
auto const hash_bools1 = cudf::hash(cudf::table_view({bools_col1}), hasher, {}, 42);
auto const hash_bools2 = cudf::hash(cudf::table_view({bools_col2}), hasher, {}, 42);

CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_structs, hash_structs_expected, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_strings, hash_strings_expected, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_doubles, hash_doubles_expected, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_timestamps, hash_timestamps_expected, true);
Expand All @@ -392,7 +424,8 @@ TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds)
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_bools1, hash_bools_expected, true);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_bools2, hash_bools_expected, true);

auto const combined_table = cudf::table_view({strings_col,
auto const combined_table = cudf::table_view({structs_col,
strings_col,
doubles_col,
timestamps_col,
decimal64_col,
Expand All @@ -408,6 +441,14 @@ TEST_F(SparkMurmurHash3Test, MultiValueWithSeeds)
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*hash_combined, hash_combined_expected, true);
}

TEST_F(SparkMurmurHash3Test, ListThrows)
{
lists_column_wrapper<cudf::string_view> strings_list_col({{""}, {"abc"}, {"123"}});
EXPECT_THROW(
cudf::hash(cudf::table_view({strings_list_col}), cudf::hash_id::HASH_SPARK_MURMUR3, {}),
cudf::logic_error);
}

class MD5HashTest : public cudf::test::BaseFixture {
};

Expand Down
5 changes: 2 additions & 3 deletions java/src/main/java/ai/rapids/cudf/ColumnVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -570,8 +570,7 @@ public static ColumnVector serial32BitMurmurHash3(int seed, ColumnView columns[]
assert columns[i] != null : "Column vectors passed may not be null";
assert columns[i].getRowCount() == size : "Row count mismatch, all columns must be the same size";
assert !columns[i].getType().isDurationType() : "Unsupported column type Duration";
assert !columns[i].getType().isTimestampType() : "Unsupported column type Timestamp";
assert !columns[i].getType().isNestedType() : "Unsupported column of nested type";
assert !columns[i].getType().equals(DType.LIST) : "List columns are not supported";
columnViews[i] = columns[i].getNativeView();
}
return new ColumnVector(hash(columnViews, HashType.HASH_SERIAL_MURMUR3.getNativeId(), new int[0], seed));
Expand Down Expand Up @@ -606,7 +605,7 @@ public static ColumnVector spark32BitMurmurHash3(int seed, ColumnView columns[])
assert columns[i] != null : "Column vectors passed may not be null";
assert columns[i].getRowCount() == size : "Row count mismatch, all columns must be the same size";
assert !columns[i].getType().isDurationType() : "Unsupported column type Duration";
assert !columns[i].getType().isNestedType() : "Unsupported column of nested type";
assert !columns[i].getType().equals(DType.LIST) : "List columns are not supported";
columnViews[i] = columns[i].getNativeView();
}
return new ColumnVector(hash(columnViews, HashType.HASH_SPARK_MURMUR3.getNativeId(), new int[0], seed));
Expand Down
Loading

0 comments on commit 9970f1d

Please sign in to comment.