Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HiveHash support for nested types #9

Merged
merged 16 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 215 additions & 6 deletions src/main/cpp/src/hive_hash.cu
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ hive_hash_value_t __device__ inline hive_hash_function<cudf::timestamp_us>::oper
* @tparam hash_function Hash functor to use for hashing elements. Must be hive_hash_function.
* @tparam Nullate A cudf::nullate type describing whether to check for nulls.
*/
template <template <typename> class hash_function, typename Nullate>
template <template <typename> class hash_function, typename Nullate, int MAX_NESTED_LEN>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:
Rename MAX_NESTED_LEN to MAX_NESTED_DEPTH

class hive_device_row_hasher {
public:
CUDF_HOST_DEVICE hive_device_row_hasher(Nullate check_nulls, cudf::table_device_view t) noexcept
Expand All @@ -182,7 +182,7 @@ class hive_device_row_hasher {
HIVE_INIT_HASH,
cuda::proclaim_return_type<hive_hash_value_t>(
[row_index, nulls = this->_check_nulls] __device__(auto hash, auto const& column) {
auto cur_hash = cudf::type_dispatcher<cudf::experimental::dispatch_void_if_nested>(
auto cur_hash = cudf::type_dispatcher(
column.type(), element_hasher_adapter{nulls}, column, row_index);
return HIVE_HASH_FACTOR * hash + cur_hash;
}));
Expand All @@ -191,8 +191,6 @@ class hive_device_row_hasher {
private:
/**
* @brief Computes the hash value of an element in the given column.
*
* Only supported non nested types now
*/
class element_hasher_adapter {
public:
Expand All @@ -210,11 +208,195 @@ class hive_device_row_hasher {
return this->hash_functor.template operator()<T>(col, row_index);
}

/**
* @brief A structure representing an element in the stack used for processing columns.
*
* This structure is used to keep track of the current column being processed, the index of the
* next child column to process, and the factor for the current column.
*
* @param col The current column being processed.
* @param child_idx The index of the next child column to process.
* @param factor The factor for the current column.
*
* @note The default constructor is deleted to prevent uninitialized usage.
*
* @constructor
* @param col The column device view to be processed.
* @param factor The factor for the column.
*/
struct StackElement{
cudf::column_device_view col; // current column
int child_idx; // index of the child column to process next, initialized as 0
hive_hash_value_t factor; // factor for the current column

__device__ StackElement() = delete;
__device__ StackElement(cudf::column_device_view col, hive_hash_value_t factor) : col(col), child_idx(0), factor(factor) {}
};

typedef StackElement* StackElementPtr;

/**
* @brief Functor to compute the hive hash value for a nested column.
*
* This functor computes the hive hash value for a given row in a nested column. It uses a depth-first search
* approach to traverse the nested structure of the column. The hash value is computed by accumulating the
* hash values of the primitive elements in the column.
*
* For example, consider the following nested column: `Struct<Struct<int, float>, decimal>`
*
* S1
* / \
* S2 d
* / \
* i f
*
* The hash value for the column is computed as:
* hash(S1) = hash(S2) * HIVE_HASH_FACTOR + hash(d)
* = (hash(i) * HIVE_HASH_FACTOR + hash(f)) * HIVE_HASH_FACTOR + hash(d)
* = hash(i) * HIVE_HASH_FACTOR^2 + hash(f) * HIVE_HASH_FACTOR + hash(d)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This transformation is not equivalent when considering overflow.

*
* From the above example, we can see that the factor of a column is calculated as the product of
* its parent column's factor and its factor relative to its parent column.
*
* The relative factor is calculated as:
* relative_factor = HIVE_HASH_FACTOR ^ (parent.num_child_columns() - 1 - child_idx)
* (parent.num_child_columns() - 1 - child_idx) can be interpreted as a reverse index of the child column.
*
* Thus, we can compute the factor of the current column during the dfs traversal process,
* performing the actual computation only for the primitive types and accumulating the hash values into the result.
*
* As lists columns have a different interface from structs columns, we need to handle them separately.
* For example, consider the following nested column: `List<List<int>>`
* lists_column = {{1, 0}, null, {2, null}}
*
* L1
* |
* L2
* |
* i
*
* List level L1:
* |Index| List<list<int>> |
* |-----|-------------------------|
* |0 |{{1, 0}, null, {2, null}}|
* length: 1
* Offsets: 0, 3
*
* List level L2:
* |Index|List<int>|
* |-----|---------|
* |0 |{1, 0} |
* |1 |null |
* |2 |{2, null}|
* length: 3
* Offsets: 0, 2, 2, 4
* null_mask: 101
*
* Int level i:
* |Index|Int |
* |-----|----|
* |0 |1 |
* |1 |0 |
* |2 |2 |
* |3 |null|
* length: 4
*
* Since the underlying data loses the null information of the top-level list column,
* I cannot directly use the underlying data to calculate the hash value.
*
* The computation process of my algorithm is as follows:
* L1 List<list<int>>
* |
* L2 List<int>
* / | \
* L2[0] L2[1] L2[2] List<int>
* | |
* i1 i2 Int
* / \ / \
* i1[0] i1[1] i2[0] i2[1] Int
*
* Int level i1
* |Index|Int |
* |-----|----|
* |0 |1 |
* |1 |0 |
* length: 2
*
* Int level i2
* |Index|Int |
* |-----|----|
* |0 |2 |
* |1 |null|
* length: 2
*
* L2、i1、i2 are all temporary columns, which would not be pushed into the stack.
*
* @tparam T The type of the column.
* @param col The column device view.
* @param row_index The index of the row to compute the hash for.
* @return The computed hive hash value.
*
* @note This function is only enabled for nested column types.
*/
template <typename T, CUDF_ENABLE_IF(cudf::is_nested<T>())>
__device__ hive_hash_value_t operator()(cudf::column_device_view const& col,
cudf::size_type row_index) const noexcept
{
CUDF_UNREACHABLE("Nested type is not supported");
hive_hash_value_t ret = HIVE_INIT_HASH;

This comment was marked as resolved.

cudf::column_device_view curr_col = col.slice(row_index, 1);
// column_device_view default constructor is deleted, can not allocate StackElement array directly
ustcfy marked this conversation as resolved.
Show resolved Hide resolved
// use byte array to wrapper StackElement list
ustcfy marked this conversation as resolved.
Show resolved Hide resolved
constexpr int len_of_maxlen_stack_element = MAX_NESTED_LEN * sizeof(StackElement);
uint8_t stack_wrapper[len_of_maxlen_stack_element];
StackElementPtr stack = reinterpret_cast<StackElementPtr>(stack_wrapper);
int stack_size = 0;

stack[stack_size++] = StackElement(curr_col, 1);

while (stack_size > 0) {
StackElementPtr element = &stack[stack_size - 1];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:
StackElement const& element = stack[stack_size - 1];

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: StackElement const& element = stack[stack_size - 1];

But I need to modify the child_idx member of the element.

curr_col = element->col;
// Do not pop here.
// The current node will only be popped after it is processed. The definition of `processed` is:
// - For nested types, it is when the children are processed (i.e., cur.child_idx == cur.num_child()).
// - For primitive types, it is when the hash value is computed and accumulated into the result.

if (curr_col.type().id() == cudf::type_id::STRUCT) {
// All child columns processed, pop the stack
if (element->child_idx == curr_col.num_child_columns()) {
stack_size--;
} else {
// Calculate the factor for the child column, and push the child column into the stack
hive_hash_value_t cur_factor = element->factor;
for(int i = 0; i < curr_col.num_child_columns() - 1 - element->child_idx; i++) {
cur_factor *= HIVE_HASH_FACTOR;
}
stack[stack_size++] = StackElement(cudf::detail::structs_column_device_view(curr_col).get_sliced_child(element->child_idx), cur_factor);
element->child_idx++;
}
} else if (curr_col.type().id() == cudf::type_id::LIST) {
//lists_column_device_view has a different interface from structs_column_device_view
curr_col = cudf::detail::lists_column_device_view(curr_col).get_sliced_child();
if (element->child_idx == curr_col.size()) {
stack_size--;
} else {
hive_hash_value_t cur_factor = element->factor;
for(int i = 0; i < curr_col.size() - element->child_idx - 1; i++) {
cur_factor *= HIVE_HASH_FACTOR;
}
stack[stack_size++] = StackElement(curr_col.slice(element->child_idx, 1), cur_factor);
element->child_idx++;
}
} else {
// There is only one element in the column for primitive types
hive_hash_value_t cur_hash = cudf::type_dispatcher(
curr_col.type(), this->hash_functor, curr_col, 0);
// Accumulate the hash value into the result
ret += cur_hash * element->factor;
stack_size--;
}
}
return ret;
}

private:
Expand All @@ -224,6 +406,28 @@ class hive_device_row_hasher {
Nullate const _check_nulls;
cudf::table_device_view const _table;
};

void check_nested_depth(cudf::table_view const& input, int max_nested_len)
{
using column_checker_fn_t = std::function<void(cudf::column_view const&, int)>;

column_checker_fn_t check_nested_depth_impl = [&](cudf::column_view const& col, int max_nested_len) {
if (col.type().id() == cudf::type_id::LIST) {
check_nested_depth_impl(cudf::lists_column_view(col).child(), max_nested_len - 1);
} else if (col.type().id() == cudf::type_id::STRUCT) {
for (auto child = col.child_begin(); child != col.child_end(); ++child) {
check_nested_depth_impl(*child, max_nested_len - 1);
}
} else { // Primitive type
CUDF_EXPECTS(max_nested_len > 0, "The nested depth of the input table exceeds the maximum supported depth.");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:
Add max_nested_len into error msg to tell report more details.

}
};

for (cudf::column_view const& col : input) {
check_nested_depth_impl(col, max_nested_len);
}
}

} // namespace

std::unique_ptr<cudf::column> hive_hash(cudf::table_view const& input,
Expand All @@ -239,6 +443,10 @@ std::unique_ptr<cudf::column> hive_hash(cudf::table_view const& input,
// Return early if there's nothing to hash
if (input.num_columns() == 0 || input.num_rows() == 0) { return output; }

//Nested depth cannot exceed 8
int const max_nested_len = 8;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

constexpr is better than int const

check_nested_depth(input, max_nested_len);

bool const nullable = has_nested_nulls(input);
auto const input_view = cudf::table_device_view::create(input, stream);
auto output_view = output->mutable_view();
Expand All @@ -247,7 +455,8 @@ std::unique_ptr<cudf::column> hive_hash(cudf::table_view const& input,
thrust::tabulate(rmm::exec_policy(stream),
output_view.begin<hive_hash_value_t>(),
output_view.end<hive_hash_value_t>(),
hive_device_row_hasher<hive_hash_function, bool>(nullable, *input_view));
hive_device_row_hasher<hive_hash_function, bool, max_nested_len>(nullable,
*input_view));

return output;
}
Expand Down
1 change: 0 additions & 1 deletion src/main/java/com/nvidia/spark/rapids/jni/Hash.java
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ public static ColumnVector hiveHash(ColumnView columns[]) {
assert columns[i] != null : "Column vectors passed may not be null";
assert columns[i].getRowCount() == size : "Row count mismatch, all columns must be the same size";
assert !columns[i].getType().isDurationType() : "Unsupported column type Duration";
assert !columns[i].getType().isNestedType() : "Unsupported column type Nested";
columnViews[i] = columns[i].getNativeView();
}
return new ColumnVector(hiveHash(columnViews));
Expand Down
Loading