Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HiveHash support for nested types #9

Merged
merged 16 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 234 additions & 6 deletions src/main/cpp/src/hive_hash.cu
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ hive_hash_value_t __device__ inline hive_hash_function<cudf::timestamp_us>::oper
* @tparam hash_function Hash functor to use for hashing elements. Must be hive_hash_function.
* @tparam Nullate A cudf::nullate type describing whether to check for nulls.
*/
template <template <typename> class hash_function, typename Nullate>
template <template <typename> class hash_function, typename Nullate, int MAX_NESTED_DEPTH>
class hive_device_row_hasher {
public:
CUDF_HOST_DEVICE hive_device_row_hasher(Nullate check_nulls, cudf::table_device_view t) noexcept
Expand All @@ -182,7 +182,7 @@ class hive_device_row_hasher {
HIVE_INIT_HASH,
cuda::proclaim_return_type<hive_hash_value_t>(
[row_index, nulls = this->_check_nulls] __device__(auto hash, auto const& column) {
auto cur_hash = cudf::type_dispatcher<cudf::experimental::dispatch_void_if_nested>(
auto cur_hash = cudf::type_dispatcher(
column.type(), element_hasher_adapter{nulls}, column, row_index);
return HIVE_HASH_FACTOR * hash + cur_hash;
}));
Expand All @@ -191,8 +191,6 @@ class hive_device_row_hasher {
private:
/**
* @brief Computes the hash value of an element in the given column.
*
* Only supported non nested types now
*/
class element_hasher_adapter {
public:
Expand All @@ -210,11 +208,208 @@ class hive_device_row_hasher {
return this->hash_functor.template operator()<T>(col, row_index);
}

/**
* @brief A structure representing an element in the stack used for processing columns.
*
* This structure is used to keep track of the current column being processed, the index of the
* next child column to process, and current hash value for the current column.
*
* @param col The current column being processed.
* @param child_idx The index of the next child column to process.
* @param cur_hash The current hash value for the current column.
*
* @note The default constructor is deleted to prevent uninitialized usage.
*
* @constructor
* @param col The column device view to be processed.
*/
struct StackElement{
cudf::column_device_view col; // current column
int child_idx; // index of the child column to process next, initialized as 0
hive_hash_value_t cur_hash; // current hash value of the column

__device__ StackElement() = delete;
__device__ StackElement(cudf::column_device_view col) : col(col), child_idx(0), cur_hash(HIVE_INIT_HASH) {}
};

typedef StackElement* StackElementPtr;

/**
* @brief Functor to compute the hive hash value for a nested column.
*
* This functor produces the same result as "HiveHash" in Spark for nested types.
* The pseudocode of Spark's HiveHash function is as follows:
*
* hive_hash_value_t hive_hash(NestedType element) {
* hive_hash_value_t hash = HIVE_INIT_HASH;
* for (int i = 0; i < element.num_child(); i++) {
* hash = hash * HIVE_HASH_FACTOR + hive_hash(element.get_child(i));
* }
* return hash;
* }
*
* This functor uses a stack to simulate the recursive process of the above pseudocode.
* When an element is popped from the stack, it means that the hash value of it has been computed.
* Therefore, we should update the parent's `cur_hash` upon popping the element.
*
* The algorithm is as follows:
*
* 1. Initialize the stack and push the root column into the stack.
* 2. While the stack is not empty:
* a. Get the top element of the stack. Don't pop it until it is processed.
* b. If the column is a nested column:
* i. If all child columns are processed, pop the element and update `cur_hash` of its parent column.
* ii. Otherwise, push the next child column into the stack.
* c. If the column is a primitive column, compute the hash value, pop the element, and update `cur_hash` of its parent column.
* 3. Return the hash value of the root column.
*
* For example, consider the following nested column: `Struct<Struct<int, float>, decimal>`
*
* S1
* / \
* S2 d
* / \
* i f
*
* The `pop` order of the stack is: i, f, S2, d, S1.
* - When i is popped, S2's cur_hash is updated to `hash(i)`.
* - When f is popped, S2's cur_hash is updated to `hash(i) * HIVE_HASH_FACTOR + hash(f)`, which is the hash value of S2.
* - When S2 is popped, S1's cur_hash is updated to `hash(S2)`.
* - When d is popped, S1's cur_hash is updated to `hash(S2) * HIVE_HASH_FACTOR + hash(d)`, which is the hash value of S1.
* - When S1 is popped, the hash value of the root column is returned.
*
* As lists columns have a different interface from structs columns, we need to handle them separately.
*
* For example, consider the following nested column: `List<List<int>>`
* lists_column = {{1, 0}, null, {2, null}}
*
* L1
* |
* L2
* |
* i
*
* List level L1:
* |Index| List<list<int>> |
* |-----|-------------------------|
* |0 |{{1, 0}, null, {2, null}}|
* length: 1
* Offsets: 0, 3
*
* List level L2:
* |Index|List<int>|
* |-----|---------|
* |0 |{1, 0} |
* |1 |null |
* |2 |{2, null}|
* length: 3
* Offsets: 0, 2, 2, 4
* null_mask: 101
*
* Int level i:
* |Index|Int |
* |-----|----|
* |0 |1 |
* |1 |0 |
* |2 |2 |
* |3 |null|
* length: 4
*
* Since the underlying data loses the null information of the top-level list column,
* I can not find a way to compute the hash value using the underlying data merely.
ustcfy marked this conversation as resolved.
Show resolved Hide resolved
*
* For example, `List<List<int>>` column {{1, 0}, {2, null}} has the same underlying data as the above
* `List<List<int>>` column {{1, 0}, null, {2, null}}. However, they have different hive hash values.
*
* The computation process for lists columns in my implementation is as follows:
ustcfy marked this conversation as resolved.
Show resolved Hide resolved
* L1 List<list<int>>
* |
* L2 List<int>
* / | \
* L2[0] L2[1] L2[2] List<int>
* | |
* i1 i2 Int
* / \ / \
* i1[0] i1[1] i2[0] i2[1] Int
*
* Note: L2、i1、i2 are all temporary columns, which would not be pushed into the stack.
*
* Int level i1
* |Index|Int |
* |-----|----|
* |0 |1 |
* |1 |0 |
* length: 2
*
* Int level i2
* |Index|Int |
* |-----|----|
* |0 |2 |
* |1 |null|
* length: 2
*
* @tparam T The type of the column.
* @param col The column device view.
* @param row_index The index of the row to compute the hash for.
* @return The computed hive hash value.
*
* @note This function is only enabled for nested column types.
*/
template <typename T, CUDF_ENABLE_IF(cudf::is_nested<T>())>
__device__ hive_hash_value_t operator()(cudf::column_device_view const& col,
cudf::size_type row_index) const noexcept
{
CUDF_UNREACHABLE("Nested type is not supported");
cudf::column_device_view curr_col = col.slice(row_index, 1);
// column_device_view default constructor is deleted, can not allocate StackElement array directly
ustcfy marked this conversation as resolved.
Show resolved Hide resolved
// use byte array to wrapper StackElement list
ustcfy marked this conversation as resolved.
Show resolved Hide resolved
constexpr int len_of_maxlen_stack_element = MAX_NESTED_DEPTH * sizeof(StackElement);
uint8_t stack_wrapper[len_of_maxlen_stack_element];
StackElementPtr stack = reinterpret_cast<StackElementPtr>(stack_wrapper);
int stack_size = 0;

stack[stack_size++] = StackElement(curr_col);

while (stack_size > 0) {
StackElementPtr element = &stack[stack_size - 1];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:
StackElement const& element = stack[stack_size - 1];

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: StackElement const& element = stack[stack_size - 1];

But I need to modify the child_idx member of the element.

curr_col = element->col;
// Do not pop it until it is processed. The definition of `processed` is:
// - For nested types, it is when the children are processed (i.e., cur.child_idx == cur.num_child()).
// - For primitive types, it is when the hash value is computed.
if (curr_col.type().id() == cudf::type_id::STRUCT) {
// All child columns processed, pop the element and update `cur_hash` of its parent column
if (element->child_idx == curr_col.num_child_columns()) {
stack_size--;
if(stack_size > 0) {
stack[stack_size - 1].cur_hash = stack[stack_size - 1].cur_hash * HIVE_HASH_FACTOR + element->cur_hash;
}
} else {
// Push the next child column into the stack
stack[stack_size++] = StackElement(cudf::detail::structs_column_device_view(curr_col).get_sliced_child(element->child_idx));
element->child_idx++;
}
} else if (curr_col.type().id() == cudf::type_id::LIST) {
//lists_column_device_view has a different interface from structs_column_device_view
curr_col = cudf::detail::lists_column_device_view(curr_col).get_sliced_child();
if (element->child_idx == curr_col.size()) {
stack_size--;
if(stack_size > 0) {
stack[stack_size - 1].cur_hash = stack[stack_size - 1].cur_hash * HIVE_HASH_FACTOR + element->cur_hash;
}
} else {
stack[stack_size++] = StackElement(curr_col.slice(element->child_idx, 1));
element->child_idx++;
}
} else {
// There is only one element in the column for primitive types
element->cur_hash = cudf::type_dispatcher(
ustcfy marked this conversation as resolved.
Show resolved Hide resolved
curr_col.type(), this->hash_functor, curr_col, 0);
stack_size--;
if(stack_size > 0) {
stack[stack_size - 1].cur_hash = stack[stack_size - 1].cur_hash * HIVE_HASH_FACTOR + element->cur_hash;
}
}
}
return stack[0].cur_hash;
}

private:
Expand All @@ -224,6 +419,34 @@ class hive_device_row_hasher {
Nullate const _check_nulls;
cudf::table_device_view const _table;
};

void check_nested_depth(cudf::table_view const& input, int max_nested_depth)
{
using column_checker_fn_t = std::function<int(cudf::column_view const&)>;

column_checker_fn_t get_nested_depth = [&](cudf::column_view const& col) {
if (col.type().id() == cudf::type_id::LIST) {
return 1 + get_nested_depth(cudf::lists_column_view(col).child());
} else if (col.type().id() == cudf::type_id::STRUCT) {
int max_child_depth = 0;
for (auto child = col.child_begin(); child != col.child_end(); ++child) {
max_child_depth = std::max(max_child_depth, get_nested_depth(*child));
}
return 1 + max_child_depth;
} else { // Primitive type
return 1;
}
};

for (auto i = 0; i < input.num_columns(); i++) {
cudf::column_view const& col = input.column(i);
CUDF_EXPECTS(get_nested_depth(col) <= max_nested_depth,
"The " + std::to_string(i) + "-th column exceeds the maximum nested depth. " +
"Current depth: " + std::to_string(get_nested_depth(col)) + ", " +
"Maximum allowed depth: " + std::to_string(max_nested_depth));
}
}

} // namespace

std::unique_ptr<cudf::column> hive_hash(cudf::table_view const& input,
Expand All @@ -239,6 +462,10 @@ std::unique_ptr<cudf::column> hive_hash(cudf::table_view const& input,
// Return early if there's nothing to hash
if (input.num_columns() == 0 || input.num_rows() == 0) { return output; }

//Nested depth cannot exceed 8
constexpr int max_nested_depth = 8;
check_nested_depth(input, max_nested_depth);

bool const nullable = has_nested_nulls(input);
auto const input_view = cudf::table_device_view::create(input, stream);
auto output_view = output->mutable_view();
Expand All @@ -247,7 +474,8 @@ std::unique_ptr<cudf::column> hive_hash(cudf::table_view const& input,
thrust::tabulate(rmm::exec_policy(stream),
output_view.begin<hive_hash_value_t>(),
output_view.end<hive_hash_value_t>(),
hive_device_row_hasher<hive_hash_function, bool>(nullable, *input_view));
hive_device_row_hasher<hive_hash_function, bool, max_nested_depth>(nullable,
*input_view));

return output;
}
Expand Down
1 change: 0 additions & 1 deletion src/main/java/com/nvidia/spark/rapids/jni/Hash.java
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ public static ColumnVector hiveHash(ColumnView columns[]) {
assert columns[i] != null : "Column vectors passed may not be null";
assert columns[i].getRowCount() == size : "Row count mismatch, all columns must be the same size";
assert !columns[i].getType().isDurationType() : "Unsupported column type Duration";
assert !columns[i].getType().isNestedType() : "Unsupported column type Nested";
columnViews[i] = columns[i].getNativeView();
}
return new ColumnVector(hiveHash(columnViews));
Expand Down
Loading