Skip to content

Commit

Permalink
Update implementations to build with the latest cuco (rapidsai#15938)
Browse files Browse the repository at this point in the history
This PR updates existing libcudf to accommodate a cuco breaking change introduced in NVIDIA/cuCollections#479. It helps avoid breaking cudf when bumping the cuco version in `rapids-cmake`.

Redundant equal/hash overloads will be removed once the version bump is done on the `rapids-cmake` end.

Authors:
  - Yunsong Wang (https://github.com/PointKernel)

Approvers:
  - David Wendt (https://github.com/davidwendt)
  - Nghia Truong (https://github.com/ttnghia)

URL: rapidsai#15938
  • Loading branch information
PointKernel authored Jun 28, 2024
1 parent df88cf5 commit 3c3edfe
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 14 deletions.
22 changes: 21 additions & 1 deletion cpp/include/cudf/detail/distinct_hash_join.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ template <typename Equal>
struct comparator_adapter {
comparator_adapter(Equal const& d_equal) : _d_equal{d_equal} {}

// suppress "function was declared but never referenced warning"
#pragma nv_diagnostic push
#pragma nv_diag_suppress 177
__device__ constexpr auto operator()(
cuco::pair<hash_value_type, lhs_index_type> const&,
cuco::pair<hash_value_type, lhs_index_type> const&) const noexcept
Expand All @@ -50,6 +53,14 @@ struct comparator_adapter {
return false;
}

__device__ constexpr auto operator()(
cuco::pair<hash_value_type, rhs_index_type> const&,
cuco::pair<hash_value_type, rhs_index_type> const&) const noexcept
{
// All build table keys are distinct thus `false` no matter what
return false;
}

__device__ constexpr auto operator()(
cuco::pair<hash_value_type, lhs_index_type> const& lhs,
cuco::pair<hash_value_type, rhs_index_type> const& rhs) const noexcept
Expand All @@ -58,6 +69,15 @@ struct comparator_adapter {
return _d_equal(lhs.second, rhs.second);
}

__device__ constexpr auto operator()(
cuco::pair<hash_value_type, rhs_index_type> const& lhs,
cuco::pair<hash_value_type, lhs_index_type> const& rhs) const noexcept
{
if (lhs.first != rhs.first) { return false; }
return _d_equal(lhs.second, rhs.second);
}
#pragma nv_diagnostic pop

private:
Equal _d_equal;
};
Expand Down Expand Up @@ -94,7 +114,7 @@ struct distinct_hash_join {
using cuco_storage_type = cuco::storage<1>;

/// Hash table type
using hash_table_type = cuco::static_set<cuco::pair<hash_value_type, lhs_index_type>,
using hash_table_type = cuco::static_set<cuco::pair<hash_value_type, rhs_index_type>,
cuco::extent<size_type>,
cuda::thread_scope_device,
comparator_adapter<d_equal_type>,
Expand Down
10 changes: 5 additions & 5 deletions cpp/src/join/distinct_hash_join.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ auto prepare_device_equal(
cudf::null_equality compare_nulls)
{
auto const two_table_equal =
cudf::experimental::row::equality::two_table_comparator(build, probe);
cudf::experimental::row::equality::two_table_comparator(probe, build);
return comparator_adapter{two_table_equal.equal_to<HasNested == cudf::has_nested::YES>(
nullate::DYNAMIC{has_nulls}, compare_nulls)};
}
Expand Down Expand Up @@ -113,7 +113,7 @@ distinct_hash_join<HasNested>::distinct_hash_join(cudf::table_view const& build,
_hash_table{build.num_rows(),
CUCO_DESIRED_LOAD_FACTOR,
cuco::empty_key{cuco::pair{std::numeric_limits<hash_value_type>::max(),
lhs_index_type{JoinNoneValue}}},
rhs_index_type{JoinNoneValue}}},
prepare_device_equal<HasNested>(
_preprocessed_build, _preprocessed_probe, has_nulls, compare_nulls),
{},
Expand All @@ -131,7 +131,7 @@ distinct_hash_join<HasNested>::distinct_hash_join(cudf::table_view const& build,
auto const d_hasher = row_hasher.device_hasher(nullate::DYNAMIC{this->_has_nulls});

auto const iter = cudf::detail::make_counting_transform_iterator(
0, build_keys_fn<decltype(d_hasher), lhs_index_type>{d_hasher});
0, build_keys_fn<decltype(d_hasher), rhs_index_type>{d_hasher});

size_type const build_table_num_rows{build.num_rows()};
if (this->_nulls_equal == cudf::null_equality::EQUAL or (not cudf::nullable(this->_build))) {
Expand Down Expand Up @@ -174,7 +174,7 @@ distinct_hash_join<HasNested>::inner_join(rmm::cuda_stream_view stream,
cudf::experimental::row::hash::row_hasher{this->_preprocessed_probe};
auto const d_probe_hasher = probe_row_hasher.device_hasher(nullate::DYNAMIC{this->_has_nulls});
auto const iter = cudf::detail::make_counting_transform_iterator(
0, build_keys_fn<decltype(d_probe_hasher), rhs_index_type>{d_probe_hasher});
0, build_keys_fn<decltype(d_probe_hasher), lhs_index_type>{d_probe_hasher});

auto const build_indices_begin =
thrust::make_transform_output_iterator(build_indices->begin(), output_fn{});
Expand Down Expand Up @@ -216,7 +216,7 @@ std::unique_ptr<rmm::device_uvector<size_type>> distinct_hash_join<HasNested>::l
cudf::experimental::row::hash::row_hasher{this->_preprocessed_probe};
auto const d_probe_hasher = probe_row_hasher.device_hasher(nullate::DYNAMIC{this->_has_nulls});
auto const iter = cudf::detail::make_counting_transform_iterator(
0, build_keys_fn<decltype(d_probe_hasher), rhs_index_type>{d_probe_hasher});
0, build_keys_fn<decltype(d_probe_hasher), lhs_index_type>{d_probe_hasher});

auto const output_begin =
thrust::make_transform_output_iterator(build_indices->begin(), output_fn{});
Expand Down
35 changes: 27 additions & 8 deletions cpp/src/search/contains_table.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ struct hasher_adapter {

__device__ constexpr auto operator()(lhs_index_type idx) const noexcept
{
return _haystack_hasher(static_cast<size_type>(idx));
return _needle_hasher(static_cast<size_type>(idx));
}

__device__ constexpr auto operator()(rhs_index_type idx) const noexcept
{
return _needle_hasher(static_cast<size_type>(idx));
return _haystack_hasher(static_cast<size_type>(idx));
}

private:
Expand All @@ -76,6 +76,9 @@ struct comparator_adapter {
{
}

// suppress "function was declared but never referenced warning"
#pragma nv_diagnostic push
#pragma nv_diag_suppress 177
__device__ constexpr auto operator()(lhs_index_type lhs_index,
lhs_index_type rhs_index) const noexcept
{
Expand All @@ -85,12 +88,28 @@ struct comparator_adapter {
return _self_equal(lhs, rhs);
}

__device__ constexpr auto operator()(rhs_index_type lhs_index,
rhs_index_type rhs_index) const noexcept
{
auto const lhs = static_cast<size_type>(lhs_index);
auto const rhs = static_cast<size_type>(rhs_index);

return _self_equal(lhs, rhs);
}

__device__ constexpr auto operator()(lhs_index_type lhs_index,
rhs_index_type rhs_index) const noexcept
{
return _two_table_equal(lhs_index, rhs_index);
}

__device__ constexpr auto operator()(rhs_index_type lhs_index,
lhs_index_type rhs_index) const noexcept
{
return _two_table_equal(lhs_index, rhs_index);
}
#pragma nv_diagnostic pop

private:
SelfEqual const _self_equal;
TwoTableEqual const _two_table_equal;
Expand Down Expand Up @@ -210,26 +229,26 @@ rmm::device_uvector<bool> contains(table_view const& haystack,

auto const self_equal = cudf::experimental::row::equality::self_comparator(preprocessed_haystack);
auto const two_table_equal = cudf::experimental::row::equality::two_table_comparator(
preprocessed_haystack, preprocessed_needles);
preprocessed_needles, preprocessed_haystack);

// The output vector.
auto contained = rmm::device_uvector<bool>(needles.num_rows(), stream, mr);

auto const haystack_iter = cudf::detail::make_counting_transform_iterator(
size_type{0}, cuda::proclaim_return_type<lhs_index_type>([] __device__(auto idx) {
return lhs_index_type{idx};
}));
auto const needles_iter = cudf::detail::make_counting_transform_iterator(
size_type{0}, cuda::proclaim_return_type<rhs_index_type>([] __device__(auto idx) {
return rhs_index_type{idx};
}));
auto const needles_iter = cudf::detail::make_counting_transform_iterator(
size_type{0}, cuda::proclaim_return_type<lhs_index_type>([] __device__(auto idx) {
return lhs_index_type{idx};
}));

auto const helper_func =
[&](auto const& d_self_equal, auto const& d_two_table_equal, auto const& probing_scheme) {
auto const d_equal = comparator_adapter{d_self_equal, d_two_table_equal};

auto set = cuco::static_set{cuco::extent{compute_hash_table_size(haystack.num_rows())},
cuco::empty_key{lhs_index_type{-1}},
cuco::empty_key{rhs_index_type{-1}},
d_equal,
probing_scheme,
{},
Expand Down
13 changes: 13 additions & 0 deletions cpp/src/text/bpe/byte_pair_encoding.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,14 @@ struct bpe_equal {
auto const right = d_strings.element<cudf::string_view>(lhs + 1);
return (left == rhs.first) && (right == rhs.second);
}
// used by find
__device__ bool operator()(merge_pair_type const& lhs, cudf::size_type rhs) const noexcept
{
rhs *= 2;
auto const left = d_strings.element<cudf::string_view>(rhs);
auto const right = d_strings.element<cudf::string_view>(rhs + 1);
return (left == lhs.first) && (right == lhs.second);
}
};

using bpe_probe_scheme = cuco::linear_probing<1, bpe_hasher>;
Expand Down Expand Up @@ -154,6 +162,11 @@ struct mp_equal {
auto const left = d_strings.element<cudf::string_view>(lhs);
return left == rhs;
}
__device__ bool operator()(cudf::string_view const& lhs, cudf::size_type rhs) const noexcept
{
auto const right = d_strings.element<cudf::string_view>(rhs);
return lhs == right;
}
};

using mp_probe_scheme = cuco::linear_probing<1, mp_hasher>;
Expand Down
8 changes: 8 additions & 0 deletions cpp/src/text/vocabulary_tokenize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,18 @@ struct vocab_equal {
return lhs == rhs; // all rows are expected to be unique
}
// used by find
// suppress "function was declared but never referenced warning"
#pragma nv_diagnostic push
#pragma nv_diag_suppress 177
__device__ bool operator()(cudf::size_type lhs, cudf::string_view const& rhs) const noexcept
{
return d_strings.element<cudf::string_view>(lhs) == rhs;
}
__device__ bool operator()(cudf::string_view const& lhs, cudf::size_type rhs) const noexcept
{
return d_strings.element<cudf::string_view>(rhs) == lhs;
}
#pragma nv_diagnostic pop
};

using probe_scheme = cuco::linear_probing<1, vocab_hasher>;
Expand Down

0 comments on commit 3c3edfe

Please sign in to comment.