Skip to content

Commit

Permalink
Use dynamic nullate for join hasher and equality comparator (#9902)
Browse files Browse the repository at this point in the history
Follow on PR for this comment: #9623 (comment)

The join hasher and equality-comparator were previously hardcoded with `has_nulls=true` (and migrated to `nullate::YES`) to help minimize code size. The new `nullate::DYNAMIC` allows for runtime checking of nulls so this can now be used here instead.

Authors:
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - Jake Hemstad (https://github.com/jrhemstad)
  - Vukasin Milovanovic (https://github.com/vuule)

URL: #9902
  • Loading branch information
davidwendt authored Dec 17, 2021
1 parent 428a1b3 commit e6c6991
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 26 deletions.
58 changes: 42 additions & 16 deletions cpp/src/join/hash_join.cu
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ void build_join_hash_table(cudf::table_view const& build,
CUDF_EXPECTS(0 != build_table_ptr->num_columns(), "Selected build dataset is empty");
CUDF_EXPECTS(0 != build_table_ptr->num_rows(), "Build side table has no rows");

row_hash hash_build{nullate::YES{}, *build_table_ptr};
row_hash hash_build{nullate::DYNAMIC{cudf::has_nulls(build)}, *build_table_ptr};
auto const empty_key_sentinel = hash_table.get_empty_key_sentinel();
make_pair_function pair_func{hash_build, empty_key_sentinel};

Expand Down Expand Up @@ -123,6 +123,7 @@ std::pair<std::unique_ptr<rmm::device_uvector<size_type>>,
probe_join_hash_table(cudf::table_device_view build_table,
cudf::table_device_view probe_table,
multimap_type const& hash_table,
bool has_nulls,
null_equality compare_nulls,
std::optional<std::size_t> output_size,
rmm::cuda_stream_view stream,
Expand All @@ -133,10 +134,10 @@ probe_join_hash_table(cudf::table_device_view build_table,
? cudf::detail::join_kind::LEFT_JOIN
: JoinKind;

std::size_t const join_size = output_size
? *output_size
: compute_join_output_size<ProbeJoinKind>(
build_table, probe_table, hash_table, compare_nulls, stream);
std::size_t const join_size =
output_size ? *output_size
: compute_join_output_size<ProbeJoinKind>(
build_table, probe_table, hash_table, has_nulls, compare_nulls, stream);

// If output size is zero, return immediately
if (join_size == 0) {
Expand All @@ -147,9 +148,10 @@ probe_join_hash_table(cudf::table_device_view build_table,
auto left_indices = std::make_unique<rmm::device_uvector<size_type>>(join_size, stream, mr);
auto right_indices = std::make_unique<rmm::device_uvector<size_type>>(join_size, stream, mr);

pair_equality equality{probe_table, build_table, compare_nulls};
auto const probe_nulls = cudf::nullate::DYNAMIC{has_nulls};
pair_equality equality{probe_table, build_table, probe_nulls, compare_nulls};

row_hash hash_probe{nullate::YES{}, probe_table};
row_hash hash_probe{probe_nulls, probe_table};
auto const empty_key_sentinel = hash_table.get_empty_key_sentinel();
make_pair_function pair_func{hash_probe, empty_key_sentinel};

Expand Down Expand Up @@ -197,12 +199,13 @@ probe_join_hash_table(cudf::table_device_view build_table,
std::size_t get_full_join_size(cudf::table_device_view build_table,
cudf::table_device_view probe_table,
multimap_type const& hash_table,
bool has_nulls,
null_equality compare_nulls,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
std::size_t join_size = compute_join_output_size<cudf::detail::join_kind::LEFT_JOIN>(
build_table, probe_table, hash_table, compare_nulls, stream);
build_table, probe_table, hash_table, has_nulls, compare_nulls, stream);

// If output size is zero, return immediately
if (join_size == 0) { return join_size; }
Expand All @@ -212,9 +215,10 @@ std::size_t get_full_join_size(cudf::table_device_view build_table,
auto left_indices = std::make_unique<rmm::device_uvector<size_type>>(join_size, stream, mr);
auto right_indices = std::make_unique<rmm::device_uvector<size_type>>(join_size, stream, mr);

pair_equality equality{probe_table, build_table, compare_nulls};
auto const probe_nulls = cudf::nullate::DYNAMIC{has_nulls};
pair_equality equality{probe_table, build_table, probe_nulls, compare_nulls};

row_hash hash_probe{nullate::YES{}, probe_table};
row_hash hash_probe{probe_nulls, probe_table};
auto const empty_key_sentinel = hash_table.get_empty_key_sentinel();
make_pair_function pair_func{hash_probe, empty_key_sentinel};

Expand Down Expand Up @@ -367,7 +371,12 @@ std::size_t hash_join::hash_join_impl::inner_join_size(cudf::table_view const& p
auto flattened_probe_table_ptr = cudf::table_device_view::create(flattened_probe_table, stream);

return cudf::detail::compute_join_output_size<cudf::detail::join_kind::INNER_JOIN>(
*build_table_ptr, *flattened_probe_table_ptr, _hash_table, compare_nulls, stream);
*build_table_ptr,
*flattened_probe_table_ptr,
_hash_table,
cudf::has_nulls(flattened_probe_table) | cudf::has_nulls(_build),
compare_nulls,
stream);
}

std::size_t hash_join::hash_join_impl::left_join_size(cudf::table_view const& probe,
Expand All @@ -387,7 +396,12 @@ std::size_t hash_join::hash_join_impl::left_join_size(cudf::table_view const& pr
auto flattened_probe_table_ptr = cudf::table_device_view::create(flattened_probe_table, stream);

return cudf::detail::compute_join_output_size<cudf::detail::join_kind::LEFT_JOIN>(
*build_table_ptr, *flattened_probe_table_ptr, _hash_table, compare_nulls, stream);
*build_table_ptr,
*flattened_probe_table_ptr,
_hash_table,
cudf::has_nulls(flattened_probe_table) | cudf::has_nulls(_build),
compare_nulls,
stream);
}

std::size_t hash_join::hash_join_impl::full_join_size(cudf::table_view const& probe,
Expand All @@ -407,8 +421,13 @@ std::size_t hash_join::hash_join_impl::full_join_size(cudf::table_view const& pr
auto build_table_ptr = cudf::table_device_view::create(_build, stream);
auto flattened_probe_table_ptr = cudf::table_device_view::create(flattened_probe_table, stream);

return get_full_join_size(
*build_table_ptr, *flattened_probe_table_ptr, _hash_table, compare_nulls, stream, mr);
return get_full_join_size(*build_table_ptr,
*flattened_probe_table_ptr,
_hash_table,
cudf::has_nulls(flattened_probe_table) | cudf::has_nulls(_build),
compare_nulls,
stream,
mr);
}

template <cudf::detail::join_kind JoinKind>
Expand Down Expand Up @@ -466,8 +485,15 @@ hash_join::hash_join_impl::probe_join_indices(cudf::table_view const& probe,
auto build_table_ptr = cudf::table_device_view::create(_build, stream);
auto probe_table_ptr = cudf::table_device_view::create(probe, stream);

auto join_indices = cudf::detail::probe_join_hash_table<JoinKind>(
*build_table_ptr, *probe_table_ptr, _hash_table, compare_nulls, output_size, stream, mr);
auto join_indices =
cudf::detail::probe_join_hash_table<JoinKind>(*build_table_ptr,
*probe_table_ptr,
_hash_table,
cudf::has_nulls(probe) | cudf::has_nulls(_build),
compare_nulls,
output_size,
stream,
mr);

if constexpr (JoinKind == cudf::detail::join_kind::FULL_JOIN) {
auto complement_indices = detail::get_left_join_indices_complement(
Expand Down
6 changes: 4 additions & 2 deletions cpp/src/join/hash_join.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ template <join_kind JoinKind, typename multimap_type>
std::size_t compute_join_output_size(table_device_view build_table,
table_device_view probe_table,
multimap_type const& hash_table,
bool has_nulls,
null_equality compare_nulls,
rmm::cuda_stream_view stream)
{
Expand All @@ -117,9 +118,10 @@ std::size_t compute_join_output_size(table_device_view build_table,
}
}

pair_equality equality{probe_table, build_table, compare_nulls};
auto const probe_nulls = cudf::nullate::DYNAMIC{has_nulls};
pair_equality equality{probe_table, build_table, probe_nulls, compare_nulls};

row_hash hash_probe{nullate::YES{}, probe_table};
row_hash hash_probe{probe_nulls, probe_table};
auto const empty_key_sentinel = hash_table.get_empty_key_sentinel();
make_pair_function pair_func{hash_probe, empty_key_sentinel};

Expand Down
3 changes: 2 additions & 1 deletion cpp/src/join/join_common_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ class pair_equality {
public:
pair_equality(table_device_view lhs,
table_device_view rhs,
nullate::DYNAMIC has_nulls,
null_equality nulls_are_equal = null_equality::EQUAL)
: _check_row_equality{cudf::nullate::YES{}, lhs, rhs, nulls_are_equal}
: _check_row_equality{has_nulls, lhs, rhs, nulls_are_equal}
{
}

Expand Down
4 changes: 2 additions & 2 deletions cpp/src/join/join_common_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ using multimap_type =
hash_table_allocator_type,
cuco::double_hashing<DEFAULT_JOIN_CG_SIZE, hash_type, hash_type>>;

using row_hash = cudf::row_hasher<default_hash, cudf::nullate::YES>;
using row_hash = cudf::row_hasher<default_hash, cudf::nullate::DYNAMIC>;

using row_equality = cudf::row_equality_comparator<cudf::nullate::YES>;
using row_equality = cudf::row_equality_comparator<cudf::nullate::DYNAMIC>;

enum class join_kind { INNER_JOIN, LEFT_JOIN, FULL_JOIN, LEFT_SEMI_JOIN, LEFT_ANTI_JOIN };

Expand Down
12 changes: 7 additions & 5 deletions cpp/src/join/semi_join.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,15 @@ std::unique_ptr<rmm::device_uvector<cudf::size_type>> left_semi_anti_join(
// Create hash table containing all keys found in right table
auto right_rows_d = table_device_view::create(right_flattened_keys, stream);
size_t const hash_table_size = compute_hash_table_size(right_num_rows);
row_hash hash_build{cudf::nullate::YES{}, *right_rows_d};
row_equality equality_build{cudf::nullate::YES{}, *right_rows_d, *right_rows_d, compare_nulls};
auto const right_nulls = cudf::nullate::DYNAMIC{cudf::has_nulls(right_flattened_keys)};
row_hash hash_build{right_nulls, *right_rows_d};
row_equality equality_build{right_nulls, *right_rows_d, *right_rows_d, compare_nulls};

// Going to join it with left table
auto left_rows_d = table_device_view::create(left_flattened_keys, stream);
row_hash hash_probe{cudf::nullate::YES{}, *left_rows_d};
row_equality equality_probe{cudf::nullate::YES{}, *left_rows_d, *right_rows_d, compare_nulls};
auto left_rows_d = table_device_view::create(left_flattened_keys, stream);
auto const left_nulls = cudf::nullate::DYNAMIC{cudf::has_nulls(left_flattened_keys)};
row_hash hash_probe{left_nulls, *left_rows_d};
row_equality equality_probe{left_nulls, *left_rows_d, *right_rows_d, compare_nulls};

auto hash_table_ptr = hash_table_type::create(hash_table_size,
stream,
Expand Down

0 comments on commit e6c6991

Please sign in to comment.