Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update implementations to build with the latest cuco #15938

Merged
merged 18 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion cpp/include/cudf/detail/distinct_hash_join.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ template <typename Equal>
struct comparator_adapter {
comparator_adapter(Equal const& d_equal) : _d_equal{d_equal} {}

#pragma nv_diagnostic push
#pragma nv_diag_suppress 177
__device__ constexpr auto operator()(
cuco::pair<hash_value_type, lhs_index_type> const&,
cuco::pair<hash_value_type, lhs_index_type> const&) const noexcept
Expand All @@ -50,6 +52,14 @@ struct comparator_adapter {
return false;
}

__device__ constexpr auto operator()(
cuco::pair<hash_value_type, rhs_index_type> const&,
cuco::pair<hash_value_type, rhs_index_type> const&) const noexcept
{
// All build table keys are distinct thus `false` no matter what
return false;
}

__device__ constexpr auto operator()(
cuco::pair<hash_value_type, lhs_index_type> const& lhs,
cuco::pair<hash_value_type, rhs_index_type> const& rhs) const noexcept
Expand All @@ -58,6 +68,15 @@ struct comparator_adapter {
return _d_equal(lhs.second, rhs.second);
}

__device__ constexpr auto operator()(
cuco::pair<hash_value_type, rhs_index_type> const& lhs,
cuco::pair<hash_value_type, lhs_index_type> const& rhs) const noexcept
{
if (lhs.first != rhs.first) { return false; }
return _d_equal(lhs.second, rhs.second);
}
#pragma nv_diagnostic pop

private:
Equal _d_equal;
};
Expand Down Expand Up @@ -94,7 +113,7 @@ struct distinct_hash_join {
using cuco_storage_type = cuco::storage<1>;

/// Hash table type
using hash_table_type = cuco::static_set<cuco::pair<hash_value_type, lhs_index_type>,
using hash_table_type = cuco::static_set<cuco::pair<hash_value_type, rhs_index_type>,
cuco::extent<size_type>,
cuda::thread_scope_device,
comparator_adapter<d_equal_type>,
Expand Down
10 changes: 5 additions & 5 deletions cpp/src/join/distinct_hash_join.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ auto prepare_device_equal(
cudf::null_equality compare_nulls)
{
auto const two_table_equal =
cudf::experimental::row::equality::two_table_comparator(build, probe);
cudf::experimental::row::equality::two_table_comparator(probe, build);
return comparator_adapter{two_table_equal.equal_to<HasNested == cudf::has_nested::YES>(
nullate::DYNAMIC{has_nulls}, compare_nulls)};
}
Expand Down Expand Up @@ -113,7 +113,7 @@ distinct_hash_join<HasNested>::distinct_hash_join(cudf::table_view const& build,
_hash_table{build.num_rows(),
CUCO_DESIRED_LOAD_FACTOR,
cuco::empty_key{cuco::pair{std::numeric_limits<hash_value_type>::max(),
lhs_index_type{JoinNoneValue}}},
rhs_index_type{JoinNoneValue}}},
prepare_device_equal<HasNested>(
_preprocessed_build, _preprocessed_probe, has_nulls, compare_nulls),
{},
Expand All @@ -131,7 +131,7 @@ distinct_hash_join<HasNested>::distinct_hash_join(cudf::table_view const& build,
auto const d_hasher = row_hasher.device_hasher(nullate::DYNAMIC{this->_has_nulls});

auto const iter = cudf::detail::make_counting_transform_iterator(
0, build_keys_fn<decltype(d_hasher), lhs_index_type>{d_hasher});
0, build_keys_fn<decltype(d_hasher), rhs_index_type>{d_hasher});

size_type const build_table_num_rows{build.num_rows()};
if (this->_nulls_equal == cudf::null_equality::EQUAL or (not cudf::nullable(this->_build))) {
Expand Down Expand Up @@ -174,7 +174,7 @@ distinct_hash_join<HasNested>::inner_join(rmm::cuda_stream_view stream,
cudf::experimental::row::hash::row_hasher{this->_preprocessed_probe};
auto const d_probe_hasher = probe_row_hasher.device_hasher(nullate::DYNAMIC{this->_has_nulls});
auto const iter = cudf::detail::make_counting_transform_iterator(
0, build_keys_fn<decltype(d_probe_hasher), rhs_index_type>{d_probe_hasher});
0, build_keys_fn<decltype(d_probe_hasher), lhs_index_type>{d_probe_hasher});

auto const build_indices_begin =
thrust::make_transform_output_iterator(build_indices->begin(), output_fn{});
Expand Down Expand Up @@ -216,7 +216,7 @@ std::unique_ptr<rmm::device_uvector<size_type>> distinct_hash_join<HasNested>::l
cudf::experimental::row::hash::row_hasher{this->_preprocessed_probe};
auto const d_probe_hasher = probe_row_hasher.device_hasher(nullate::DYNAMIC{this->_has_nulls});
auto const iter = cudf::detail::make_counting_transform_iterator(
0, build_keys_fn<decltype(d_probe_hasher), rhs_index_type>{d_probe_hasher});
0, build_keys_fn<decltype(d_probe_hasher), lhs_index_type>{d_probe_hasher});

auto const output_begin =
thrust::make_transform_output_iterator(build_indices->begin(), output_fn{});
Expand Down
34 changes: 26 additions & 8 deletions cpp/src/search/contains_table.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ struct hasher_adapter {

__device__ constexpr auto operator()(lhs_index_type idx) const noexcept
{
return _haystack_hasher(static_cast<size_type>(idx));
return _needle_hasher(static_cast<size_type>(idx));
}

__device__ constexpr auto operator()(rhs_index_type idx) const noexcept
{
return _needle_hasher(static_cast<size_type>(idx));
return _haystack_hasher(static_cast<size_type>(idx));
}

private:
Expand All @@ -76,6 +76,8 @@ struct comparator_adapter {
{
}

#pragma nv_diagnostic push
#pragma nv_diag_suppress 177
__device__ constexpr auto operator()(lhs_index_type lhs_index,
lhs_index_type rhs_index) const noexcept
{
Expand All @@ -85,12 +87,28 @@ struct comparator_adapter {
return _self_equal(lhs, rhs);
}

__device__ constexpr auto operator()(rhs_index_type lhs_index,
rhs_index_type rhs_index) const noexcept
{
auto const lhs = static_cast<size_type>(lhs_index);
auto const rhs = static_cast<size_type>(rhs_index);

return _self_equal(lhs, rhs);
}

__device__ constexpr auto operator()(lhs_index_type lhs_index,
rhs_index_type rhs_index) const noexcept
{
return _two_table_equal(lhs_index, rhs_index);
}

__device__ constexpr auto operator()(rhs_index_type lhs_index,
lhs_index_type rhs_index) const noexcept
{
return _two_table_equal(lhs_index, rhs_index);
}
#pragma nv_diagnostic pop

private:
SelfEqual const _self_equal;
TwoTableEqual const _two_table_equal;
Expand Down Expand Up @@ -210,26 +228,26 @@ rmm::device_uvector<bool> contains(table_view const& haystack,

auto const self_equal = cudf::experimental::row::equality::self_comparator(preprocessed_haystack);
auto const two_table_equal = cudf::experimental::row::equality::two_table_comparator(
preprocessed_haystack, preprocessed_needles);
preprocessed_needles, preprocessed_haystack);

// The output vector.
auto contained = rmm::device_uvector<bool>(needles.num_rows(), stream, mr);

auto const haystack_iter = cudf::detail::make_counting_transform_iterator(
size_type{0}, cuda::proclaim_return_type<lhs_index_type>([] __device__(auto idx) {
return lhs_index_type{idx};
}));
auto const needles_iter = cudf::detail::make_counting_transform_iterator(
size_type{0}, cuda::proclaim_return_type<rhs_index_type>([] __device__(auto idx) {
return rhs_index_type{idx};
}));
auto const needles_iter = cudf::detail::make_counting_transform_iterator(
size_type{0}, cuda::proclaim_return_type<lhs_index_type>([] __device__(auto idx) {
return lhs_index_type{idx};
}));

auto const helper_func =
[&](auto const& d_self_equal, auto const& d_two_table_equal, auto const& probing_scheme) {
auto const d_equal = comparator_adapter{d_self_equal, d_two_table_equal};

auto set = cuco::static_set{cuco::extent{compute_hash_table_size(haystack.num_rows())},
cuco::empty_key{lhs_index_type{-1}},
cuco::empty_key{rhs_index_type{-1}},
d_equal,
probing_scheme,
{},
Expand Down
13 changes: 13 additions & 0 deletions cpp/src/text/bpe/byte_pair_encoding.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,14 @@ struct bpe_equal {
auto const right = d_strings.element<cudf::string_view>(lhs + 1);
return (left == rhs.first) && (right == rhs.second);
}
// used by find
__device__ bool operator()(merge_pair_type const& lhs, cudf::size_type rhs) const noexcept
{
rhs *= 2;
auto const left = d_strings.element<cudf::string_view>(rhs);
auto const right = d_strings.element<cudf::string_view>(rhs + 1);
return (left == lhs.first) && (right == lhs.second);
}
};

using bpe_probe_scheme = cuco::linear_probing<1, bpe_hasher>;
Expand Down Expand Up @@ -154,6 +162,11 @@ struct mp_equal {
auto const left = d_strings.element<cudf::string_view>(lhs);
return left == rhs;
}
__device__ bool operator()(cudf::string_view const& lhs, cudf::size_type rhs) const noexcept
{
auto const right = d_strings.element<cudf::string_view>(rhs);
return lhs == right;
}
};

using mp_probe_scheme = cuco::linear_probing<1, mp_hasher>;
Expand Down
7 changes: 7 additions & 0 deletions cpp/src/text/vocabulary_tokenize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,17 @@ struct vocab_equal {
return lhs == rhs; // all rows are expected to be unique
}
// used by find
#pragma nv_diagnostic push
#pragma nv_diag_suppress 177
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you put a comment here on what this is suppressing?
Perhaps it can be removed with a future version of the compiler?

Copy link
Member Author

@PointKernel PointKernel Jun 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Comments added as requested.

Once the cuco bump PR merged into rapids-cmake, I will open another cudf PR removing all redundant overloads together with the pragmas added in this PR. I've also tested [[maybe_unused]] locally but it couldn't get rid of the warning.

__device__ bool operator()(cudf::size_type lhs, cudf::string_view const& rhs) const noexcept
{
return d_strings.element<cudf::string_view>(lhs) == rhs;
}
__device__ bool operator()(cudf::string_view const& lhs, cudf::size_type rhs) const noexcept
{
return d_strings.element<cudf::string_view>(rhs) == lhs;
}
#pragma nv_diagnostic pop
};

using probe_scheme = cuco::linear_probing<1, vocab_hasher>;
Expand Down
Loading