Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor cudf::detail::contains(table_view, table_view) #11325

Closed
wants to merge 7 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 110 additions & 49 deletions cpp/src/search/contains_table.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

#include <thrust/iterator/counting_iterator.h>

#include <cuco/static_multimap.cuh>
#include <cuco/static_map.cuh>

namespace cudf::detail {

Expand All @@ -36,6 +36,60 @@ namespace {
using cudf::experimental::row::lhs_index_type;
using cudf::experimental::row::rhs_index_type;

/**
* @brief An adapter functor to support strong index types for row hasher.
*/
template <typename Hasher>
struct strong_index_hasher_adapter {
strong_index_hasher_adapter(Hasher const& hasher) : _hasher{hasher} {}

template <typename T>
__device__ inline auto operator()(T const idx) const noexcept
{
return _hasher(static_cast<size_type>(idx));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will there be narrow conversion issues?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This functor is only used locally in this file, and the input idx is only strong index types (either lhs_index_type or rhs_index_type).

}

private:
Hasher _hasher;
};

/**
* @brief An adapter functor to support strong index types for table self comparator.
*/
template <typename Comparator>
struct strong_index_self_comparator_adapter {
strong_index_self_comparator_adapter(Comparator const& comparator) : _comparator{comparator} {}

template <typename T>
__device__ inline auto operator()(T const lhs_index, T const rhs_index) const noexcept
{
return _comparator(static_cast<size_type>(lhs_index), static_cast<size_type>(rhs_index));
}

private:
Comparator const _comparator;
};

/**
* @brief Invoke an `operator()` template with a row equality comparator based on the specified
* `compare_nans` parameter.
*
* @param compare_nans The flag to specify whether NaNs should be compared equal or not
* @param func The input functor to invoke
*/
template <typename Func>
void dispatch_nan_comparator(nan_equality compare_nans, Func&& func)
{
if (compare_nans == nan_equality::ALL_EQUAL) {
using nan_equal_comparator =
cudf::experimental::row::equality::nan_equal_physical_equality_comparator;
func(nan_equal_comparator{});
} else {
using nan_unequal_comparator = cudf::experimental::row::equality::physical_equality_comparator;
func(nan_unequal_comparator{});
}
}

} // namespace

rmm::device_uvector<bool> contains(table_view const& haystack,
Expand All @@ -45,34 +99,33 @@ rmm::device_uvector<bool> contains(table_view const& haystack,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
// Use a hash map with key type is row hash values and map value type is `lhs_index_type` to store
// all row indices of the haystack table.
using static_multimap =
cuco::static_multimap<hash_value_type,
lhs_index_type,
cuda::thread_scope_device,
rmm::mr::stream_allocator_adaptor<default_allocator<char>>,
cuco::double_hashing<detail::DEFAULT_JOIN_CG_SIZE, hash_type, hash_type>>;

auto map = static_multimap(compute_hash_table_size(haystack.num_rows()),
cuco::sentinel::empty_key{std::numeric_limits<hash_value_type>::max()},
cuco::sentinel::empty_value{lhs_index_type{detail::JoinNoneValue}},
stream.value(),
detail::hash_table_allocator_type{default_allocator<char>{}, stream});
using static_map = cuco::static_map<lhs_index_type,
size_type,
cuda::thread_scope_device,
rmm::mr::stream_allocator_adaptor<default_allocator<char>>>;

auto map =
static_map(compute_hash_table_size(haystack.num_rows()),
cuco::sentinel::empty_key{lhs_index_type{std::numeric_limits<size_type>::max()}},
cuco::sentinel::empty_value{detail::JoinNoneValue},
detail::hash_table_allocator_type{default_allocator<char>{}, stream},
stream.value());

auto const haystack_has_nulls = has_nested_nulls(haystack);
auto const needles_has_nulls = has_nested_nulls(needles);
auto const has_any_nulls = haystack_has_nulls || needles_has_nulls;

// Insert all row hash values and indices of the haystack table.
// Insert row indices of the haystack table as map keys.
{
auto const hasher = cudf::experimental::row::hash::row_hasher(haystack, stream);
auto const d_hasher = hasher.device_hasher(nullate::DYNAMIC{has_any_nulls});
auto const haystack_it = cudf::detail::make_counting_transform_iterator(
size_type{0},
[] __device__(auto const idx) { return cuco::make_pair(lhs_index_type{idx}, 0); });

using make_pair_fn = make_pair_function<decltype(d_hasher), lhs_index_type>;
auto const hasher = cudf::experimental::row::hash::row_hasher(haystack, stream);
auto const d_hasher =
strong_index_hasher_adapter{hasher.device_hasher(nullate::DYNAMIC{has_any_nulls})};

auto const haystack_it = cudf::detail::make_counting_transform_iterator(
size_type{0}, make_pair_fn{d_hasher, map.get_empty_key_sentinel()});
auto const comparator = cudf::experimental::row::equality::self_comparator(haystack, stream);

// If the haystack table has nulls but they are compared unequal, don't insert them.
// Otherwise, it was known to cause performance issue:
Expand All @@ -95,13 +148,29 @@ rmm::device_uvector<bool> contains(table_view const& haystack,
: haystack_nullable_columns.front().null_mask();

// Insert only rows that do not have any null at any level.
map.insert_if(haystack_it,
haystack_it + haystack.num_rows(),
thrust::counting_iterator<size_type>(0), // stencil
row_is_valid{row_bitmask_ptr},
stream.value());
} else {
map.insert(haystack_it, haystack_it + haystack.num_rows(), stream.value());
auto const insert_map = [&](auto const value_comp) {
auto const d_eqcomp = strong_index_self_comparator_adapter{
comparator.equal_to(nullate::DYNAMIC{haystack_has_nulls}, compare_nulls, value_comp)};
map.insert_if(haystack_it,
haystack_it + haystack.num_rows(),
thrust::counting_iterator<size_type>(0), // stencil
row_is_valid{row_bitmask_ptr},
d_hasher,
d_eqcomp,
stream.value());
};

dispatch_nan_comparator(compare_nans, insert_map);

} else { // haystack_doesn't_have_nulls || compare_nulls == null_equality::EQUAL
auto const insert_map = [&](auto const value_comp) {
auto const d_eqcomp = strong_index_self_comparator_adapter{
comparator.equal_to(nullate::DYNAMIC{haystack_has_nulls}, compare_nulls, value_comp)};
map.insert(
haystack_it, haystack_it + haystack.num_rows(), d_hasher, d_eqcomp, stream.value());
};

dispatch_nan_comparator(compare_nans, insert_map);
}
}

Expand All @@ -110,36 +179,28 @@ rmm::device_uvector<bool> contains(table_view const& haystack,

// Check existence for each row of the needles table in the haystack table.
{
auto const hasher = cudf::experimental::row::hash::row_hasher(needles, stream);
auto const d_hasher = hasher.device_hasher(nullate::DYNAMIC{has_any_nulls});
auto const needles_it = cudf::detail::make_counting_transform_iterator(
size_type{0}, [] __device__(auto const idx) { return rhs_index_type{idx}; });

auto const hasher = cudf::experimental::row::hash::row_hasher(needles, stream);
auto const d_hasher =
strong_index_hasher_adapter{hasher.device_hasher(nullate::DYNAMIC{has_any_nulls})};

auto const comparator =
cudf::experimental::row::equality::two_table_comparator(haystack, needles, stream);

using make_pair_fn = make_pair_function<decltype(d_hasher), rhs_index_type>;

auto const needles_it = cudf::detail::make_counting_transform_iterator(
size_type{0}, make_pair_fn{d_hasher, map.get_empty_key_sentinel()});

auto const check_contains = [&](auto const value_comp) {
auto const d_eqcomp =
comparator.equal_to(nullate::DYNAMIC{has_any_nulls}, compare_nulls, value_comp);
map.pair_contains(needles_it,
needles_it + needles.num_rows(),
contained.begin(),
pair_equality{d_eqcomp},
stream.value());
map.contains(needles_it,
needles_it + needles.num_rows(),
contained.begin(),
d_hasher,
d_eqcomp,
stream.value());
};

if (compare_nans == nan_equality::ALL_EQUAL) {
using nan_equal_comparator =
cudf::experimental::row::equality::nan_equal_physical_equality_comparator;
check_contains(nan_equal_comparator{});
} else {
using nan_unequal_comparator =
cudf::experimental::row::equality::physical_equality_comparator;
check_contains(nan_unequal_comparator{});
}
dispatch_nan_comparator(compare_nans, check_contains);
}

return contained;
Expand Down