Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HiveHash support for nested types #9

Merged
merged 16 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 245 additions & 6 deletions src/main/cpp/src/hive_hash.cu
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ hive_hash_value_t __device__ inline hive_hash_function<cudf::timestamp_us>::oper
* @tparam hash_function Hash functor to use for hashing elements. Must be hive_hash_function.
* @tparam Nullate A cudf::nullate type describing whether to check for nulls.
*/
template <template <typename> class hash_function, typename Nullate>
template <template <typename> class hash_function, typename Nullate, int MAX_NESTED_DEPTH>
class hive_device_row_hasher {
public:
CUDF_HOST_DEVICE hive_device_row_hasher(Nullate check_nulls, cudf::table_device_view t) noexcept
Expand All @@ -182,7 +182,7 @@ class hive_device_row_hasher {
HIVE_INIT_HASH,
cuda::proclaim_return_type<hive_hash_value_t>(
[row_index, nulls = this->_check_nulls] __device__(auto hash, auto const& column) {
auto cur_hash = cudf::type_dispatcher<cudf::experimental::dispatch_void_if_nested>(
auto cur_hash = cudf::type_dispatcher(
column.type(), element_hasher_adapter{nulls}, column, row_index);
return HIVE_HASH_FACTOR * hash + cur_hash;
}));
Expand All @@ -191,8 +191,6 @@ class hive_device_row_hasher {
private:
/**
* @brief Computes the hash value of an element in the given column.
*
* Only supported non nested types now
*/
class element_hasher_adapter {
public:
Expand All @@ -210,11 +208,219 @@ class hive_device_row_hasher {
return this->hash_functor.template operator()<T>(col, row_index);
}

/**
* @brief A structure representing an element in the stack used for processing columns.
*
* This structure is used to keep track of the current column being processed, the index of the
* next child column to process, and current hash value for the column.
*
* @param column The current column being processed.
* @param child_idx The index of the next child column to process.
* @param cur_hash The current hash value for the column.
*
* @note The default constructor is deleted to prevent uninitialized usage.
*
* @constructor
* @param col The column device view to be processed.
*/
class col_stack_element {
private:
cudf::column_device_view column; // current column
hive_hash_value_t cur_hash; // current hash value of the column
int child_idx; // index of the child column to process next, initialized as 0

public:
__device__ col_stack_element() = delete; // Because the default constructor of `cudf::column_device_view` is deleted

__device__ col_stack_element(cudf::column_device_view col) : column(col), child_idx(0), cur_hash(HIVE_INIT_HASH) {}

__device__ void update_cur_hash(hive_hash_value_t child_hash) {
this->cur_hash = this->cur_hash * HIVE_HASH_FACTOR + child_hash;
}

__device__ hive_hash_value_t get_hash() { return this->cur_hash; }

__device__ int child_idx_inc_one() { return this->child_idx++; }

__device__ int cur_child_idx() { return this->child_idx; }

__device__ cudf::column_device_view get_column() { return this->column; }
};

typedef col_stack_element* col_stack_element_ptr;

/**
* @brief Functor to compute the hive hash value for a nested column.
*
* This functor produces the same result as "HiveHash" in Spark for nested types.
* The pseudocode of Spark's HiveHash function is as follows:
*
* hive_hash_value_t hive_hash(NestedType element) {
* hive_hash_value_t hash = HIVE_INIT_HASH;
* for (int i = 0; i < element.num_child(); i++) {
* hash = hash * HIVE_HASH_FACTOR + hive_hash(element.get_child(i));
* }
* return hash;
* }
*
* This functor uses a stack to simulate the recursive process of the above pseudocode.
* When an element is popped from the stack, it means that the hash value of it has been computed.
* Therefore, we should update the parent's `cur_hash` upon popping the element.
*
* The algorithm is as follows:
*
* 1. Initialize the stack and push the root column into the stack.
* 2. While the stack is not empty:
* a. Get the top element of the stack. Don't pop it until it is processed.
* b. If the column is a nested column:
* i. If all child columns are processed, pop the element and update `cur_hash` of its parent column.
* ii. Otherwise, push the next child column into the stack.
* c. If the column is a primitive column, compute the hash value, pop the element, and update `cur_hash` of its parent column.
* 3. Return the hash value of the root column.
*
* For example, consider the following nested column: `Struct<Struct<int, float>, decimal>`
*
* S1
* / \
* S2 d
* / \
* i f
*
* The `pop` order of the stack is: i, f, S2, d, S1.
* - When i is popped, S2's cur_hash is updated to `hash(i)`.
* - When f is popped, S2's cur_hash is updated to `hash(i) * HIVE_HASH_FACTOR + hash(f)`, which is the hash value of S2.
* - When S2 is popped, S1's cur_hash is updated to `hash(S2)`.
* - When d is popped, S1's cur_hash is updated to `hash(S2) * HIVE_HASH_FACTOR + hash(d)`, which is the hash value of S1.
* - When S1 is popped, the hash value of the root column is returned.
*
* As lists columns have a different interface from structs columns, we need to handle them separately.
*
* For example, consider the following nested column: `List<List<int>>`
* lists_column = {{1, 0}, null, {2, null}}
*
* L1
* |
* L2
* |
* i
*
* List level L1:
* |Index| List<list<int>> |
* |-----|-------------------------|
* |0 |{{1, 0}, null, {2, null}}|
* length: 1
* Offsets: 0, 3
*
* List level L2:
* |Index|List<int>|
* |-----|---------|
* |0 |{1, 0} |
* |1 |null |
* |2 |{2, null}|
* length: 3
* Offsets: 0, 2, 2, 4
* null_mask: 101
*
* Int level i:
* |Index|Int |
* |-----|----|
* |0 |1 |
* |1 |0 |
* |2 |2 |
* |3 |null|
* length: 4
*
* Since the underlying data loses the null information of the top-level list column,
* It will produce different results than Spark to compute the hash value using the
* underlying data merely.
*
* For example, `List<List<int>>` column {{1, 0}, {2, null}} has the same underlying data as the above
* `List<List<int>>` column {{1, 0}, null, {2, null}}. However, they have different hive hash values.
*
* The computation process for lists columns in this solution is as follows:
* L1 List<list<int>>
* |
* L2 List<int>
* / | \
* L2[0] L2[1] L2[2] List<int>
* | |
* i1 i2 Int
* / \ / \
* i1[0] i1[1] i2[0] i2[1] Int
*
* Note: L2、i1、i2 are all temporary columns, which would not be pushed into the stack.
*
* Int level i1
* |Index|Int |
* |-----|----|
* |0 |1 |
* |1 |0 |
* length: 2
*
* Int level i2
* |Index|Int |
* |-----|----|
* |0 |2 |
* |1 |null|
* length: 2
*
* @tparam T The type of the column.
* @param col The column device view.
* @param row_index The index of the row to compute the hash for.
* @return The computed hive hash value.
*
* @note This function is only enabled for nested column types.
*/
template <typename T, CUDF_ENABLE_IF(cudf::is_nested<T>())>
__device__ hive_hash_value_t operator()(cudf::column_device_view const& col,
cudf::size_type row_index) const noexcept
{
CUDF_UNREACHABLE("Nested type is not supported");
cudf::column_device_view curr_col = col.slice(row_index, 1);
// col_stack_element default constructor is deleted, so it can not allocate a col_stack_element array directly.
// Instead leverage the byte array to create the col_stack_element array.
uint8_t stack_wrapper[MAX_NESTED_DEPTH * sizeof(col_stack_element)];
col_stack_element_ptr col_stack = reinterpret_cast<col_stack_element_ptr>(stack_wrapper);
int stack_size = 0;

col_stack[stack_size++] = col_stack_element(curr_col);

while (stack_size > 0) {
col_stack_element& element = col_stack[stack_size - 1];
curr_col = element.get_column();
// Do not pop it until it is processed. The definition of `processed` is:
// - For nested types, it is when the children are processed.
// - For primitive types, it is when the hash value is computed.
if (curr_col.type().id() == cudf::type_id::STRUCT) {
// All child columns processed, pop the element and update `cur_hash` of its parent column
if (element.cur_child_idx() == curr_col.num_child_columns()) {
if(--stack_size > 0) {
col_stack[stack_size - 1].update_cur_hash(element.get_hash());
}
} else {
// Push the next child column into the stack
col_stack[stack_size++] = col_stack_element(cudf::detail::structs_column_device_view(curr_col).get_sliced_child(element.child_idx_inc_one()));
}
} else if (curr_col.type().id() == cudf::type_id::LIST) {
//lists_column_device_view has a different interface from structs_column_device_view
curr_col = cudf::detail::lists_column_device_view(curr_col).get_sliced_child();
if (element.cur_child_idx() == curr_col.size()) {
if(--stack_size > 0) {
col_stack[stack_size - 1].update_cur_hash(element.get_hash());
}
} else {
col_stack[stack_size++] = col_stack_element(curr_col.slice(element.child_idx_inc_one(), 1));
}
} else {
// There is only one element in the column for primitive types
auto hash = cudf::type_dispatcher<cudf::experimental::dispatch_void_if_nested>(
curr_col.type(), this->hash_functor, curr_col, 0);
element.update_cur_hash(hash);
if(--stack_size > 0) {
col_stack[stack_size - 1].update_cur_hash(element.get_hash());
}
}
}
return col_stack[0].get_hash();
}

private:
Expand All @@ -224,6 +430,34 @@ class hive_device_row_hasher {
Nullate const _check_nulls;
cudf::table_device_view const _table;
};

void check_nested_depth(cudf::table_view const& input, int max_nested_depth)
{
using column_checker_fn_t = std::function<int(cudf::column_view const&)>;

column_checker_fn_t get_nested_depth = [&](cudf::column_view const& col) {
if (col.type().id() == cudf::type_id::LIST) {
return 1 + get_nested_depth(cudf::lists_column_view(col).child());
} else if (col.type().id() == cudf::type_id::STRUCT) {
int max_child_depth = 0;
for (auto child = col.child_begin(); child != col.child_end(); ++child) {
max_child_depth = std::max(max_child_depth, get_nested_depth(*child));
}
return 1 + max_child_depth;
} else { // Primitive type
return 1;
}
};

for (auto i = 0; i < input.num_columns(); i++) {
cudf::column_view const& col = input.column(i);
CUDF_EXPECTS(get_nested_depth(col) <= max_nested_depth,
"The " + std::to_string(i) + "-th column exceeds the maximum allowed nested depth. " +
"Current depth: " + std::to_string(get_nested_depth(col)) + ", " +
"Maximum allowed depth: " + std::to_string(max_nested_depth));
}
}

} // namespace

std::unique_ptr<cudf::column> hive_hash(cudf::table_view const& input,
Expand All @@ -239,6 +473,10 @@ std::unique_ptr<cudf::column> hive_hash(cudf::table_view const& input,
// Return early if there's nothing to hash
if (input.num_columns() == 0 || input.num_rows() == 0) { return output; }

//Nested depth cannot exceed 8
constexpr int max_nested_depth = 8;
check_nested_depth(input, max_nested_depth);

bool const nullable = has_nested_nulls(input);
auto const input_view = cudf::table_device_view::create(input, stream);
auto output_view = output->mutable_view();
Expand All @@ -247,7 +485,8 @@ std::unique_ptr<cudf::column> hive_hash(cudf::table_view const& input,
thrust::tabulate(rmm::exec_policy(stream),
output_view.begin<hive_hash_value_t>(),
output_view.end<hive_hash_value_t>(),
hive_device_row_hasher<hive_hash_function, bool>(nullable, *input_view));
hive_device_row_hasher<hive_hash_function, bool, max_nested_depth>(nullable,
*input_view));

return output;
}
Expand Down
1 change: 0 additions & 1 deletion src/main/java/com/nvidia/spark/rapids/jni/Hash.java
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ public static ColumnVector hiveHash(ColumnView columns[]) {
assert columns[i] != null : "Column vectors passed may not be null";
assert columns[i].getRowCount() == size : "Row count mismatch, all columns must be the same size";
assert !columns[i].getType().isDurationType() : "Unsupported column type Duration";
assert !columns[i].getType().isNestedType() : "Unsupported column type Nested";
columnViews[i] = columns[i].getNativeView();
}
return new ColumnVector(hiveHash(columnViews));
Expand Down
Loading