Skip to content

Commit

Permalink
Refactor contains_table with cuco::static_set
Browse files Browse the repository at this point in the history
  • Loading branch information
PointKernel committed Sep 7, 2023
1 parent c9d8821 commit 8f2294f
Showing 1 changed file with 109 additions and 141 deletions.
250 changes: 109 additions & 141 deletions cpp/src/search/contains_table.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

#include <thrust/iterator/counting_iterator.h>

#include <cuco/static_map.cuh>
#include <cuco/static_set.cuh>

#include <type_traits>

Expand All @@ -37,11 +37,6 @@ namespace {
using cudf::experimental::row::lhs_index_type;
using cudf::experimental::row::rhs_index_type;

using static_map = cuco::static_map<lhs_index_type,
size_type,
cuda::thread_scope_device,
rmm::mr::stream_allocator_adaptor<default_allocator<char>>>;

/**
* @brief Check if the given type `T` is a strong index type (i.e., `lhs_index_type` or
* `rhs_index_type`).
Expand All @@ -58,48 +53,58 @@ constexpr auto is_strong_index_type()
* @brief An adapter functor to support strong index types for row hasher that must be operating on
* `cudf::size_type`.
*/
template <typename Hasher>
struct strong_index_hasher_adapter {
strong_index_hasher_adapter(Hasher const& hasher) : _hasher{hasher} {}
template <typename HaystackHasher, typename NeedleHasher>
struct hasher_adapter {
hasher_adapter(HaystackHasher const& haystack_hasher, NeedleHasher const& needle_hasher)
: _haystack_hasher{haystack_hasher}, _needle_hasher{needle_hasher}
{
}

__device__ constexpr auto operator()(lhs_index_type idx) const noexcept
{
return _haystack_hasher(static_cast<size_type>(idx));
}

template <typename T, CUDF_ENABLE_IF(is_strong_index_type<T>())>
__device__ constexpr auto operator()(T const idx) const noexcept
__device__ constexpr auto operator()(rhs_index_type idx) const noexcept
{
return _hasher(static_cast<size_type>(idx));
return _needle_hasher(static_cast<size_type>(idx));
}

private:
Hasher const _hasher;
HaystackHasher const _haystack_hasher;
NeedleHasher const _needle_hasher;
};

/**
* @brief An adapter functor to support strong index type for table row comparator that must be
* operating on `cudf::size_type`.
*/
template <typename Comparator>
struct strong_index_comparator_adapter {
strong_index_comparator_adapter(Comparator const& comparator) : _comparator{comparator} {}

template <typename T,
typename U,
CUDF_ENABLE_IF(is_strong_index_type<T>() && is_strong_index_type<U>())>
__device__ constexpr auto operator()(T const lhs_index, U const rhs_index) const noexcept
template <typename SelfComparator, typename TwoTableComparator>
struct comparator_adapter {
comparator_adapter(SelfComparator const& self_comparator,
TwoTableComparator const& two_table_comparator)
: _self_comparator{self_comparator}, _two_table_comparator{two_table_comparator}
{
}

__device__ constexpr auto operator()(lhs_index_type lhs_index,
lhs_index_type rhs_index) const noexcept
{
auto const lhs = static_cast<size_type>(lhs_index);
auto const rhs = static_cast<size_type>(rhs_index);

if constexpr (std::is_same_v<T, U> || std::is_same_v<T, lhs_index_type>) {
return _comparator(lhs, rhs);
} else {
// Here we have T == rhs_index_type.
// This is when the indices are provided in wrong order for two table comparator, so we need
// to switch them back to the right order before calling the underlying comparator.
return _comparator(rhs, lhs);
}
return _self_comparator(lhs, rhs);
}

__device__ constexpr auto operator()(lhs_index_type lhs_index,
rhs_index_type rhs_index) const noexcept
{
return _two_table_comparator(lhs_index, rhs_index);
}

private:
Comparator const _comparator;
SelfComparator const _self_comparator;
TwoTableComparator const _two_table_comparator;
};

/**
Expand Down Expand Up @@ -133,28 +138,44 @@ std::pair<rmm::device_buffer, bitmask_type const*> build_row_bitmask(table_view
return std::pair(rmm::device_buffer{0, stream}, nullable_columns.front().null_mask());
}

} // namespace
template <bool haystack_has_nulls, bool has_any_nulls, typename Func>
void dispatch(
null_equality compare_nulls, auto self_comp, auto two_table_comp, auto nan_comp, Func func)
{
auto const d_self_eq = self_comp.equal_to<haystack_has_nulls>(
nullate::DYNAMIC{haystack_has_nulls}, compare_nulls, nan_comp);
auto const d_two_table_eq = two_table_comp.equal_to<has_any_nulls>(
nullate::DYNAMIC{has_any_nulls}, compare_nulls, nan_comp);
func(d_self_eq, d_two_table_eq);
}

/**
* @brief Invoke an `operator()` template with a row equality comparator based on the specified
* `compare_nans` parameter.
*
* @param compare_nans The flag to specify whether NaNs should be compared equal or not
* @param func The input functor to invoke
*/
template <typename Func>
void dispatch_nan_comparator(nan_equality compare_nans, Func&& func)
template <bool haystack_has_nulls, bool has_any_nulls, typename Func>
void dispatch_nan_comparator(nan_equality compare_nans,
null_equality compare_nulls,
auto self_comp,
auto two_table_comp,
Func func)
{
if (compare_nans == nan_equality::ALL_EQUAL) {
using nan_equal_comparator =
cudf::experimental::row::equality::nan_equal_physical_equality_comparator;
func(nan_equal_comparator{});
dispatch<haystack_has_nulls, has_any_nulls>(
compare_nulls, self_comp, two_table_comp, nan_equal_comparator{}, func);
} else {
using nan_unequal_comparator = cudf::experimental::row::equality::physical_equality_comparator;
func(nan_unequal_comparator{});
dispatch<haystack_has_nulls, has_any_nulls>(
compare_nulls, self_comp, two_table_comp, nan_unequal_comparator{}, func);
}
}

} // namespace

/**
* @brief Check if rows in the given `needles` table exist in the `haystack` table.
*
Expand All @@ -173,124 +194,71 @@ rmm::device_uvector<bool> contains(table_view const& haystack,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto map = static_map(compute_hash_table_size(haystack.num_rows()),
cuco::empty_key{lhs_index_type{std::numeric_limits<size_type>::max()}},
cuco::empty_value{detail::JoinNoneValue},
detail::hash_table_allocator_type{default_allocator<char>{}, stream},
stream.value());

auto const haystack_has_nulls = has_nested_nulls(haystack);
auto const needles_has_nulls = has_nested_nulls(needles);
auto const has_any_nulls = haystack_has_nulls || needles_has_nulls;

auto const preprocessed_needles =
cudf::experimental::row::equality::preprocessed_table::create(needles, stream);
auto const preprocessed_haystack =
cudf::experimental::row::equality::preprocessed_table::create(haystack, stream);
// Insert row indices of the haystack table as map keys.
{
auto const haystack_it = cudf::detail::make_counting_transform_iterator(
size_type{0},
[] __device__(auto const idx) { return cuco::make_pair(lhs_index_type{idx}, 0); });

auto const hasher = cudf::experimental::row::hash::row_hasher(preprocessed_haystack);
auto const d_hasher =
strong_index_hasher_adapter{hasher.device_hasher(nullate::DYNAMIC{has_any_nulls})};
auto const haystack_hasher = cudf::experimental::row::hash::row_hasher(preprocessed_haystack);
auto const d_haystack_hasher = haystack_hasher.device_hasher(nullate::DYNAMIC{has_any_nulls});
auto const needle_hasher = cudf::experimental::row::hash::row_hasher(preprocessed_needles);
auto const d_needle_hasher = needle_hasher.device_hasher(nullate::DYNAMIC{has_any_nulls});
auto const d_hasher = hasher_adapter{d_haystack_hasher, d_needle_hasher};
using hasher_type = decltype(d_hasher);

auto const self_comparator =
cudf::experimental::row::equality::self_comparator(preprocessed_haystack);
auto const two_table_comparator = cudf::experimental::row::equality::two_table_comparator(
preprocessed_haystack, preprocessed_needles);

// The output vector.
auto contained = rmm::device_uvector<bool>(needles.num_rows(), stream, mr);

auto const haystack_iter = cudf::detail::make_counting_transform_iterator(
size_type{0}, [] __device__(auto idx) { return lhs_index_type{idx}; });
auto const needles_iter = cudf::detail::make_counting_transform_iterator(
size_type{0}, [] __device__(auto idx) { return rhs_index_type{idx}; });

auto const helper_func = [&](auto const& d_self_equal, auto const& d_two_table_equal) {
auto const d_equal = comparator_adapter{d_self_equal, d_two_table_equal};

auto const comparator =
cudf::experimental::row::equality::self_comparator(preprocessed_haystack);
auto set = cuco::experimental::static_set{
cuco::experimental::extent{compute_hash_table_size(haystack.num_rows())},
cuco::empty_key{lhs_index_type{-1}},
d_equal,
cuco::experimental::linear_probing<1, hasher_type>{d_hasher},
detail::hash_table_allocator_type{default_allocator<lhs_index_type>{}, stream},
stream.value()};

// If the haystack table has nulls but they are compared unequal, don't insert them.
// Otherwise, it was known to cause performance issue:
// - https://github.com/rapidsai/cudf/pull/6943
// - https://github.com/rapidsai/cudf/pull/8277
if (haystack_has_nulls && compare_nulls == null_equality::UNEQUAL) {
auto const bitmask_buffer_and_ptr = build_row_bitmask(haystack, stream);
auto const row_bitmask_ptr = bitmask_buffer_and_ptr.second;

auto const insert_map = [&](auto const value_comp) {
if (cudf::detail::has_nested_columns(haystack)) {
auto const d_eqcomp = strong_index_comparator_adapter{comparator.equal_to<true>(
nullate::DYNAMIC{haystack_has_nulls}, compare_nulls, value_comp)};
map.insert_if(haystack_it,
haystack_it + haystack.num_rows(),
thrust::counting_iterator<size_type>(0), // stencil
row_is_valid{row_bitmask_ptr},
d_hasher,
d_eqcomp,
stream.value());
} else {
auto const d_eqcomp = strong_index_comparator_adapter{comparator.equal_to<false>(
nullate::DYNAMIC{haystack_has_nulls}, compare_nulls, value_comp)};
map.insert_if(haystack_it,
haystack_it + haystack.num_rows(),
thrust::counting_iterator<size_type>(0), // stencil
row_is_valid{row_bitmask_ptr},
d_hasher,
d_eqcomp,
stream.value());
}
};

// Insert only rows that do not have any null at any level.
dispatch_nan_comparator(compare_nans, insert_map);
} else { // haystack_doesn't_have_nulls || compare_nulls == null_equality::EQUAL
auto const insert_map = [&](auto const value_comp) {
if (cudf::detail::has_nested_columns(haystack)) {
auto const d_eqcomp = strong_index_comparator_adapter{comparator.equal_to<true>(
nullate::DYNAMIC{haystack_has_nulls}, compare_nulls, value_comp)};
map.insert(
haystack_it, haystack_it + haystack.num_rows(), d_hasher, d_eqcomp, stream.value());
} else {
auto const d_eqcomp = strong_index_comparator_adapter{comparator.equal_to<false>(
nullate::DYNAMIC{haystack_has_nulls}, compare_nulls, value_comp)};
map.insert(
haystack_it, haystack_it + haystack.num_rows(), d_hasher, d_eqcomp, stream.value());
}
};

dispatch_nan_comparator(compare_nans, insert_map);
auto const row_bitmask = build_row_bitmask(haystack, stream).second;
set.insert_if_async(haystack_iter,
haystack_iter + haystack.num_rows(),
thrust::counting_iterator<size_type>(0), // stencil
row_is_valid{row_bitmask},
stream.value());
} else {
set.insert_async(haystack_iter, haystack_iter + haystack.num_rows(), stream.value());
}
}

// The output vector.
auto contained = rmm::device_uvector<bool>(needles.num_rows(), stream, mr);
if (needles_has_nulls && compare_nulls == null_equality::UNEQUAL) {
set.contains_if_async(
needles_iter, needles_iter + needles.num_rows(), contained.begin(), stream.value());
} else {
set.contains_async(
needles_iter, needles_iter + needles.num_rows(), contained.begin(), stream.value());
}
};

auto const preprocessed_needles =
cudf::experimental::row::equality::preprocessed_table::create(needles, stream);
// Check existence for each row of the needles table in the haystack table.
{
auto const needles_it = cudf::detail::make_counting_transform_iterator(
size_type{0}, [] __device__(auto const idx) { return rhs_index_type{idx}; });

auto const hasher = cudf::experimental::row::hash::row_hasher(preprocessed_needles);
auto const d_hasher =
strong_index_hasher_adapter{hasher.device_hasher(nullate::DYNAMIC{has_any_nulls})};

auto const comparator = cudf::experimental::row::equality::two_table_comparator(
preprocessed_haystack, preprocessed_needles);

auto const check_contains = [&](auto const value_comp) {
if (cudf::detail::has_nested_columns(haystack) or cudf::detail::has_nested_columns(needles)) {
auto const d_eqcomp =
comparator.equal_to<true>(nullate::DYNAMIC{has_any_nulls}, compare_nulls, value_comp);
map.contains(needles_it,
needles_it + needles.num_rows(),
contained.begin(),
d_hasher,
d_eqcomp,
stream.value());
} else {
auto const d_eqcomp =
comparator.equal_to<false>(nullate::DYNAMIC{has_any_nulls}, compare_nulls, value_comp);
map.contains(needles_it,
needles_it + needles.num_rows(),
contained.begin(),
d_hasher,
d_eqcomp,
stream.value());
}
};

dispatch_nan_comparator(compare_nans, check_contains);
if (haystack_has_nulls) {
if (has_any_nulls) {
dispatch_nan_comparator<true, true>(
compare_nans, compare_nulls, self_comparator, two_table_comparator, helper_func);
}
}

return contained;
Expand Down

0 comments on commit 8f2294f

Please sign in to comment.