Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor contains_table with cuco::static_set #14064

Merged
merged 28 commits into from
Sep 26, 2023
Merged
Changes from 3 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
8f2294f
Refactor contains_table with cuco::static_set
PointKernel Sep 7, 2023
f2fd994
Refactor contains_table with cuco::static_set
PointKernel Sep 7, 2023
6c85572
Merge remote-tracking branch 'upstream/branch-23.10' into cuco-contai…
PointKernel Sep 8, 2023
40dfead
Merge remote-tracking branch 'upstream/branch-23.10' into cuco-contai…
PointKernel Sep 12, 2023
bec3cd6
Fix logic issues with hashset
PointKernel Sep 13, 2023
5ada788
Get rid of build_row_bitmask function
PointKernel Sep 13, 2023
204bd45
Minor cleanups: renaming
PointKernel Sep 13, 2023
6ae6f2a
Merge branch 'branch-23.10' into cuco-contains-table
PointKernel Sep 13, 2023
9c7f4f6
Use build_row_bitmmask instead of bitmask_and
PointKernel Sep 14, 2023
237fd70
Code formatting
PointKernel Sep 14, 2023
7283dd3
Merge branch 'branch-23.10' into cuco-contains-table
PointKernel Sep 18, 2023
0a8b678
Add comments back
PointKernel Sep 18, 2023
93ba686
Merge remote-tracking branch 'origin/cuco-contains-table' into cuco-c…
PointKernel Sep 18, 2023
eca017f
Rename contains benchmark file as contains_scalar
PointKernel Sep 18, 2023
fc0980f
Add contains_table benchmark
PointKernel Sep 18, 2023
b2934e5
Add peak memory usage in contains_table benchmark
PointKernel Sep 19, 2023
218ab4f
Minor doc cleanups
PointKernel Sep 19, 2023
6fcaa46
Merge remote-tracking branch 'upstream/branch-23.10' into cuco-contai…
PointKernel Sep 20, 2023
9553104
Distinguish probing scheme CG sizes between nested and flat types for…
PointKernel Sep 20, 2023
0b67f9b
Merge branch 'branch-23.10' into cuco-contains-table
PointKernel Sep 20, 2023
71a8793
Merge branch 'branch-23.10' into cuco-contains-table
PointKernel Sep 21, 2023
cd75057
Merge remote-tracking branch 'upstream/branch-23.10' into cuco-contai…
PointKernel Sep 22, 2023
e1125c3
Remove redundant docs
PointKernel Sep 22, 2023
37f7048
Throw if needles and haystack column types mismatch
PointKernel Sep 22, 2023
4f6af5d
Simplify nested column handling
PointKernel Sep 22, 2023
de99c48
Merge branch 'cuco-contains-table' of github.com:PointKernel/cudf int…
PointKernel Sep 22, 2023
cb9614d
Merge branch 'branch-23.10' into cuco-contains-table
PointKernel Sep 22, 2023
b848ada
Merge branch 'branch-23.10' into cuco-contains-table
PointKernel Sep 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 108 additions & 139 deletions cpp/src/search/contains_table.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

#include <thrust/iterator/counting_iterator.h>

#include <cuco/static_map.cuh>
#include <cuco/static_set.cuh>

#include <type_traits>

Expand All @@ -37,11 +37,6 @@ namespace {
using cudf::experimental::row::lhs_index_type;
using cudf::experimental::row::rhs_index_type;

using static_map = cuco::static_map<lhs_index_type,
size_type,
cuda::thread_scope_device,
rmm::mr::stream_allocator_adaptor<default_allocator<char>>>;

/**
* @brief Check if the given type `T` is a strong index type (i.e., `lhs_index_type` or
* `rhs_index_type`).
Expand All @@ -58,48 +53,58 @@ constexpr auto is_strong_index_type()
* @brief An adapter functor to support strong index types for row hasher that must be operating on
* `cudf::size_type`.
*/
template <typename Hasher>
struct strong_index_hasher_adapter {
strong_index_hasher_adapter(Hasher const& hasher) : _hasher{hasher} {}
template <typename HaystackHasher, typename NeedleHasher>
struct hasher_adapter {
hasher_adapter(HaystackHasher const& haystack_hasher, NeedleHasher const& needle_hasher)
: _haystack_hasher{haystack_hasher}, _needle_hasher{needle_hasher}
{
}

__device__ constexpr auto operator()(lhs_index_type idx) const noexcept
{
return _haystack_hasher(static_cast<size_type>(idx));
}

template <typename T, CUDF_ENABLE_IF(is_strong_index_type<T>())>
__device__ constexpr auto operator()(T const idx) const noexcept
__device__ constexpr auto operator()(rhs_index_type idx) const noexcept
{
return _hasher(static_cast<size_type>(idx));
return _needle_hasher(static_cast<size_type>(idx));
}

private:
Hasher const _hasher;
HaystackHasher const _haystack_hasher;
NeedleHasher const _needle_hasher;
};

/**
* @brief An adapter functor to support strong index type for table row comparator that must be
* operating on `cudf::size_type`.
*/
template <typename Comparator>
struct strong_index_comparator_adapter {
strong_index_comparator_adapter(Comparator const& comparator) : _comparator{comparator} {}

template <typename T,
typename U,
CUDF_ENABLE_IF(is_strong_index_type<T>() && is_strong_index_type<U>())>
__device__ constexpr auto operator()(T const lhs_index, U const rhs_index) const noexcept
template <typename SelfComparator, typename TwoTableComparator>
struct comparator_adapter {
comparator_adapter(SelfComparator const& self_comparator,
TwoTableComparator const& two_table_comparator)
: _self_comparator{self_comparator}, _two_table_comparator{two_table_comparator}
{
}

__device__ constexpr auto operator()(lhs_index_type lhs_index,
lhs_index_type rhs_index) const noexcept
{
auto const lhs = static_cast<size_type>(lhs_index);
auto const rhs = static_cast<size_type>(rhs_index);

if constexpr (std::is_same_v<T, U> || std::is_same_v<T, lhs_index_type>) {
return _comparator(lhs, rhs);
} else {
// Here we have T == rhs_index_type.
// This is when the indices are provided in wrong order for two table comparator, so we need
// to switch them back to the right order before calling the underlying comparator.
return _comparator(rhs, lhs);
}
return _self_comparator(lhs, rhs);
}

__device__ constexpr auto operator()(lhs_index_type lhs_index,
rhs_index_type rhs_index) const noexcept
{
return _two_table_comparator(lhs_index, rhs_index);
}

private:
Comparator const _comparator;
SelfComparator const _self_comparator;
TwoTableComparator const _two_table_comparator;
};

/**
Expand Down Expand Up @@ -133,23 +138,40 @@ std::pair<rmm::device_buffer, bitmask_type const*> build_row_bitmask(table_view
return std::pair(rmm::device_buffer{0, stream}, nullable_columns.front().null_mask());
}

template <bool haystack_has_nulls, bool has_any_nulls, typename Func>
void dispatch(
null_equality compare_nulls, auto self_comp, auto two_table_comp, auto nan_comp, Func func)
{
auto const d_self_eq = self_comp.equal_to<haystack_has_nulls>(
PointKernel marked this conversation as resolved.
Show resolved Hide resolved
nullate::DYNAMIC{haystack_has_nulls}, compare_nulls, nan_comp);
auto const d_two_table_eq = two_table_comp.equal_to<has_any_nulls>(
nullate::DYNAMIC{has_any_nulls}, compare_nulls, nan_comp);
func(d_self_eq, d_two_table_eq);
}

/**
* @brief Invoke an `operator()` template with a row equality comparator based on the specified
* `compare_nans` parameter.
*
* @param compare_nans The flag to specify whether NaNs should be compared equal or not
* @param func The input functor to invoke
*/
template <typename Func>
void dispatch_nan_comparator(nan_equality compare_nans, Func&& func)
template <bool haystack_has_nulls, bool has_any_nulls, typename Func>
void dispatch_nan_comparator(nan_equality compare_nans,
null_equality compare_nulls,
auto self_comp,
auto two_table_comp,
Func func)
{
if (compare_nans == nan_equality::ALL_EQUAL) {
using nan_equal_comparator =
cudf::experimental::row::equality::nan_equal_physical_equality_comparator;
func(nan_equal_comparator{});
dispatch<haystack_has_nulls, has_any_nulls>(
compare_nulls, self_comp, two_table_comp, nan_equal_comparator{}, func);
} else {
using nan_unequal_comparator = cudf::experimental::row::equality::physical_equality_comparator;
func(nan_unequal_comparator{});
dispatch<haystack_has_nulls, has_any_nulls>(
compare_nulls, self_comp, two_table_comp, nan_unequal_comparator{}, func);
}
}

Expand All @@ -173,124 +195,71 @@ rmm::device_uvector<bool> contains(table_view const& haystack,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto map = static_map(compute_hash_table_size(haystack.num_rows()),
cuco::empty_key{lhs_index_type{std::numeric_limits<size_type>::max()}},
cuco::empty_value{detail::JoinNoneValue},
detail::hash_table_allocator_type{default_allocator<char>{}, stream},
stream.value());

auto const haystack_has_nulls = has_nested_nulls(haystack);
auto const needles_has_nulls = has_nested_nulls(needles);
auto const has_any_nulls = haystack_has_nulls || needles_has_nulls;

auto const preprocessed_needles =
cudf::experimental::row::equality::preprocessed_table::create(needles, stream);
auto const preprocessed_haystack =
cudf::experimental::row::equality::preprocessed_table::create(haystack, stream);
// Insert row indices of the haystack table as map keys.
{
auto const haystack_it = cudf::detail::make_counting_transform_iterator(
size_type{0},
[] __device__(auto const idx) { return cuco::make_pair(lhs_index_type{idx}, 0); });

auto const hasher = cudf::experimental::row::hash::row_hasher(preprocessed_haystack);
auto const d_hasher =
strong_index_hasher_adapter{hasher.device_hasher(nullate::DYNAMIC{has_any_nulls})};
auto const haystack_hasher = cudf::experimental::row::hash::row_hasher(preprocessed_haystack);
auto const d_haystack_hasher = haystack_hasher.device_hasher(nullate::DYNAMIC{has_any_nulls});
auto const needle_hasher = cudf::experimental::row::hash::row_hasher(preprocessed_needles);
auto const d_needle_hasher = needle_hasher.device_hasher(nullate::DYNAMIC{has_any_nulls});
auto const d_hasher = hasher_adapter{d_haystack_hasher, d_needle_hasher};
using hasher_type = decltype(d_hasher);

auto const self_comparator =
cudf::experimental::row::equality::self_comparator(preprocessed_haystack);
auto const two_table_comparator = cudf::experimental::row::equality::two_table_comparator(
preprocessed_haystack, preprocessed_needles);

auto const comparator =
cudf::experimental::row::equality::self_comparator(preprocessed_haystack);
// The output vector.
auto contained = rmm::device_uvector<bool>(needles.num_rows(), stream, mr);

auto const haystack_iter = cudf::detail::make_counting_transform_iterator(
size_type{0}, [] __device__(auto idx) { return lhs_index_type{idx}; });
auto const needles_iter = cudf::detail::make_counting_transform_iterator(
size_type{0}, [] __device__(auto idx) { return rhs_index_type{idx}; });

auto helper_func = [&](auto const& d_self_equal, auto const& d_two_table_equal) {
auto const d_equal = comparator_adapter{d_self_equal, d_two_table_equal};

auto set = cuco::experimental::static_set{
cuco::experimental::extent{compute_hash_table_size(haystack.num_rows())},
PointKernel marked this conversation as resolved.
Show resolved Hide resolved
cuco::empty_key{lhs_index_type{-1}},
d_equal,
cuco::experimental::linear_probing<1, hasher_type>{d_hasher},
detail::hash_table_allocator_type{default_allocator<lhs_index_type>{}, stream},
stream.value()};

// If the haystack table has nulls but they are compared unequal, don't insert them.
// Otherwise, it was known to cause performance issue:
// - https://github.com/rapidsai/cudf/pull/6943
// - https://github.com/rapidsai/cudf/pull/8277
if (haystack_has_nulls && compare_nulls == null_equality::UNEQUAL) {
auto const bitmask_buffer_and_ptr = build_row_bitmask(haystack, stream);
auto const row_bitmask_ptr = bitmask_buffer_and_ptr.second;

auto const insert_map = [&](auto const value_comp) {
if (cudf::detail::has_nested_columns(haystack)) {
auto const d_eqcomp = strong_index_comparator_adapter{comparator.equal_to<true>(
nullate::DYNAMIC{haystack_has_nulls}, compare_nulls, value_comp)};
map.insert_if(haystack_it,
haystack_it + haystack.num_rows(),
thrust::counting_iterator<size_type>(0), // stencil
row_is_valid{row_bitmask_ptr},
d_hasher,
d_eqcomp,
stream.value());
} else {
auto const d_eqcomp = strong_index_comparator_adapter{comparator.equal_to<false>(
nullate::DYNAMIC{haystack_has_nulls}, compare_nulls, value_comp)};
map.insert_if(haystack_it,
haystack_it + haystack.num_rows(),
thrust::counting_iterator<size_type>(0), // stencil
row_is_valid{row_bitmask_ptr},
d_hasher,
d_eqcomp,
stream.value());
}
};

// Insert only rows that do not have any null at any level.
dispatch_nan_comparator(compare_nans, insert_map);
} else { // haystack_doesn't_have_nulls || compare_nulls == null_equality::EQUAL
auto const insert_map = [&](auto const value_comp) {
if (cudf::detail::has_nested_columns(haystack)) {
auto const d_eqcomp = strong_index_comparator_adapter{comparator.equal_to<true>(
nullate::DYNAMIC{haystack_has_nulls}, compare_nulls, value_comp)};
map.insert(
haystack_it, haystack_it + haystack.num_rows(), d_hasher, d_eqcomp, stream.value());
} else {
auto const d_eqcomp = strong_index_comparator_adapter{comparator.equal_to<false>(
nullate::DYNAMIC{haystack_has_nulls}, compare_nulls, value_comp)};
map.insert(
haystack_it, haystack_it + haystack.num_rows(), d_hasher, d_eqcomp, stream.value());
}
};

dispatch_nan_comparator(compare_nans, insert_map);
auto const row_bitmask = build_row_bitmask(haystack, stream).second;
set.insert_if_async(haystack_iter,
PointKernel marked this conversation as resolved.
Show resolved Hide resolved
haystack_iter + haystack.num_rows(),
thrust::counting_iterator<size_type>(0), // stencil
row_is_valid{row_bitmask},
stream.value());
} else {
set.insert_async(haystack_iter, haystack_iter + haystack.num_rows(), stream.value());
}
}

// The output vector.
auto contained = rmm::device_uvector<bool>(needles.num_rows(), stream, mr);
if (needles_has_nulls && compare_nulls == null_equality::UNEQUAL) {
set.contains_if_async(
needles_iter, needles_iter + needles.num_rows(), contained.begin(), stream.value());
} else {
set.contains_async(
needles_iter, needles_iter + needles.num_rows(), contained.begin(), stream.value());
}
};

auto const preprocessed_needles =
cudf::experimental::row::equality::preprocessed_table::create(needles, stream);
// Check existence for each row of the needles table in the haystack table.
{
auto const needles_it = cudf::detail::make_counting_transform_iterator(
size_type{0}, [] __device__(auto const idx) { return rhs_index_type{idx}; });

auto const hasher = cudf::experimental::row::hash::row_hasher(preprocessed_needles);
auto const d_hasher =
strong_index_hasher_adapter{hasher.device_hasher(nullate::DYNAMIC{has_any_nulls})};

auto const comparator = cudf::experimental::row::equality::two_table_comparator(
preprocessed_haystack, preprocessed_needles);

auto const check_contains = [&](auto const value_comp) {
if (cudf::detail::has_nested_columns(haystack) or cudf::detail::has_nested_columns(needles)) {
auto const d_eqcomp =
comparator.equal_to<true>(nullate::DYNAMIC{has_any_nulls}, compare_nulls, value_comp);
map.contains(needles_it,
needles_it + needles.num_rows(),
contained.begin(),
d_hasher,
d_eqcomp,
stream.value());
} else {
auto const d_eqcomp =
comparator.equal_to<false>(nullate::DYNAMIC{has_any_nulls}, compare_nulls, value_comp);
map.contains(needles_it,
needles_it + needles.num_rows(),
contained.begin(),
d_hasher,
d_eqcomp,
stream.value());
}
};

dispatch_nan_comparator(compare_nans, check_contains);
if (haystack_has_nulls) {
if (has_any_nulls) {
dispatch_nan_comparator<true, true>(
compare_nans, compare_nulls, self_comparator, two_table_comparator, helper_func);
}
}

return contained;
Expand Down
Loading