Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reapply mixed_semi_join refactoring and bug fixes #16859

Merged
6 changes: 0 additions & 6 deletions cpp/src/join/join_common_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include <cudf/table/row_operators.cuh>
#include <cudf/table/table_view.hpp>

#include <cuco/static_map.cuh>
#include <cuco/static_multimap.cuh>
#include <cuda/atomic>

Expand Down Expand Up @@ -51,11 +50,6 @@ using mixed_multimap_type =
cudf::detail::cuco_allocator<char>,
cuco::legacy::double_hashing<1, hash_type, hash_type>>;

using semi_map_type = cuco::legacy::static_map<hash_value_type,
size_type,
cuda::thread_scope_device,
cudf::detail::cuco_allocator<char>>;

using row_hash_legacy =
cudf::row_hasher<cudf::hashing::detail::default_hash, cudf::nullate::DYNAMIC>;

Expand Down
34 changes: 34 additions & 0 deletions cpp/src/join/mixed_join_common_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <rmm/device_uvector.hpp>

#include <cub/cub.cuh>
#include <cuco/static_set.cuh>

namespace cudf {
namespace detail {
Expand Down Expand Up @@ -160,6 +161,39 @@ struct pair_expression_equality : public expression_equality<has_nulls> {
}
};

/**
* @brief Equality comparator that composes two row_equality comparators.
*/
struct double_row_equality_comparator {
row_equality const equality_comparator;
row_equality const conditional_comparator;

__device__ bool operator()(size_type lhs_row_index, size_type rhs_row_index) const noexcept
{
using experimental::row::lhs_index_type;
using experimental::row::rhs_index_type;

return equality_comparator(lhs_index_type{lhs_row_index}, rhs_index_type{rhs_row_index}) &&
conditional_comparator(lhs_index_type{lhs_row_index}, rhs_index_type{rhs_row_index});
}
};

// A CUDA Cooperative Group of 1 thread for the hash set for mixed semi.
auto constexpr DEFAULT_MIXED_SEMI_JOIN_CG_SIZE = 1;

// The hash set type used by mixed_semi_join with the build_table.
using hash_set_type =
cuco::static_set<size_type,
cuco::extent<size_t>,
cuda::thread_scope_device,
double_row_equality_comparator,
cuco::linear_probing<DEFAULT_MIXED_SEMI_JOIN_CG_SIZE, row_hash>,
cudf::detail::cuco_allocator<char>,
cuco::storage<1>>;

// The hash_set_ref_type used by mixed_semi_join kerenels for probing.
using hash_set_ref_type = hash_set_type::ref_type<cuco::contains_tag>;

} // namespace detail

} // namespace cudf
51 changes: 29 additions & 22 deletions cpp/src/join/mixed_join_kernels_semi.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,38 +38,48 @@ CUDF_KERNEL void __launch_bounds__(block_size)
table_device_view right_table,
table_device_view probe,
table_device_view build,
row_hash const hash_probe,
row_equality const equality_probe,
cudf::detail::semi_map_type::device_view hash_table_view,
hash_set_ref_type set_ref,
cudf::device_span<bool> left_table_keep_mask,
cudf::ast::detail::expression_device_view device_expression_data)
{
auto constexpr cg_size = hash_set_ref_type::cg_size;

auto const tile = cg::tiled_partition<cg_size>(cg::this_thread_block());

// Normally the casting of a shared memory array is used to create multiple
// arrays of different types from the shared memory buffer, but here it is
// used to circumvent conflicts between arrays of different types between
// different template instantiations due to the extern specifier.
extern __shared__ char raw_intermediate_storage[];
cudf::ast::detail::IntermediateDataType<has_nulls>* intermediate_storage =
auto intermediate_storage =
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
reinterpret_cast<cudf::ast::detail::IntermediateDataType<has_nulls>*>(raw_intermediate_storage);
auto thread_intermediate_storage =
&intermediate_storage[threadIdx.x * device_expression_data.num_intermediates];
intermediate_storage + (tile.meta_group_rank() * device_expression_data.num_intermediates);

cudf::size_type const left_num_rows = left_table.num_rows();
cudf::size_type const right_num_rows = right_table.num_rows();
auto const outer_num_rows = left_num_rows;
// Equality evaluator to use
auto const evaluator = cudf::ast::detail::expression_evaluator<has_nulls>(
left_table, right_table, device_expression_data);

cudf::size_type outer_row_index = threadIdx.x + blockIdx.x * block_size;
// Make sure to swap_tables here as hash_set will use probe table as the left one
auto constexpr swap_tables = true;
auto const equality = single_expression_equality<has_nulls>{
evaluator, thread_intermediate_storage, swap_tables, equality_probe};

auto evaluator = cudf::ast::detail::expression_evaluator<has_nulls>(
left_table, right_table, device_expression_data);
// Create set ref with the new equality comparator
auto const set_ref_equality = set_ref.with_key_eq(equality);

if (outer_row_index < outer_num_rows) {
// Figure out the number of elements for this key.
auto equality = single_expression_equality<has_nulls>{
evaluator, thread_intermediate_storage, false, equality_probe};
// Total number of rows to query the set
auto const outer_num_rows = left_table.num_rows();
// Grid stride for the tile
auto const cg_grid_stride = cudf::detail::grid_1d::grid_stride<block_size>() / cg_size;

left_table_keep_mask[outer_row_index] =
hash_table_view.contains(outer_row_index, hash_probe, equality);
// Find all the rows in the left table that are in the hash table
for (auto outer_row_index = cudf::detail::grid_1d::global_thread_id<block_size>() / cg_size;
Copy link
Member Author

@mhaseeb123 mhaseeb123 Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bug fix (part 1): This was a simple if before so only querying nelems/cg_size items instead of all nelems

outer_row_index < outer_num_rows;
outer_row_index += cg_grid_stride) {
auto const result = set_ref_equality.contains(tile, outer_row_index);
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
if (tile.thread_rank() == 0) { left_table_keep_mask[outer_row_index] = result; }
}
}

Expand All @@ -78,9 +88,8 @@ void launch_mixed_join_semi(bool has_nulls,
table_device_view right_table,
table_device_view probe,
table_device_view build,
row_hash const hash_probe,
row_equality const equality_probe,
cudf::detail::semi_map_type::device_view hash_table_view,
hash_set_ref_type set_ref,
cudf::device_span<bool> left_table_keep_mask,
cudf::ast::detail::expression_device_view device_expression_data,
detail::grid_1d const config,
Expand All @@ -94,9 +103,8 @@ void launch_mixed_join_semi(bool has_nulls,
right_table,
probe,
build,
hash_probe,
equality_probe,
hash_table_view,
set_ref,
left_table_keep_mask,
device_expression_data);
} else {
Expand All @@ -106,9 +114,8 @@ void launch_mixed_join_semi(bool has_nulls,
right_table,
probe,
build,
hash_probe,
equality_probe,
hash_table_view,
set_ref,
left_table_keep_mask,
device_expression_data);
}
Expand Down
6 changes: 2 additions & 4 deletions cpp/src/join/mixed_join_kernels_semi.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,8 @@ namespace detail {
* @param[in] right_table The right table
* @param[in] probe The table with which to probe the hash table for matches.
* @param[in] build The table with which the hash table was built.
* @param[in] hash_probe The hasher used for the probe table.
* @param[in] equality_probe The equality comparator used when probing the hash table.
* @param[in] hash_table_view The hash table built from `build`.
* @param[in] set_ref The hash table device view built from `build`.
* @param[out] left_table_keep_mask The result of the join operation with "true" element indicating
* the corresponding index from left table is present in output
* @param[in] device_expression_data Container of device data required to evaluate the desired
Expand All @@ -58,9 +57,8 @@ void launch_mixed_join_semi(bool has_nulls,
table_device_view right_table,
table_device_view probe,
table_device_view build,
row_hash const hash_probe,
row_equality const equality_probe,
cudf::detail::semi_map_type::device_view hash_table_view,
hash_set_ref_type set_ref,
cudf::device_span<bool> left_table_keep_mask,
cudf::ast::detail::expression_device_view device_expression_data,
detail::grid_1d const config,
Expand Down
92 changes: 27 additions & 65 deletions cpp/src/join/mixed_join_semi.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,45 +45,6 @@
namespace cudf {
namespace detail {

namespace {
/**
* @brief Device functor to create a pair of hash value and index for a given row.
*/
struct make_pair_function_semi {
__device__ __forceinline__ cudf::detail::pair_type operator()(size_type i) const noexcept
{
// The value is irrelevant since we only ever use the hash map to check for
// membership of a particular row index.
return cuco::make_pair(static_cast<hash_value_type>(i), 0);
}
};

/**
* @brief Equality comparator that composes two row_equality comparators.
*/
class double_row_equality {
public:
double_row_equality(row_equality equality_comparator, row_equality conditional_comparator)
: _equality_comparator{equality_comparator}, _conditional_comparator{conditional_comparator}
{
}

__device__ bool operator()(size_type lhs_row_index, size_type rhs_row_index) const noexcept
{
using experimental::row::lhs_index_type;
using experimental::row::rhs_index_type;

return _equality_comparator(lhs_index_type{lhs_row_index}, rhs_index_type{rhs_row_index}) &&
_conditional_comparator(lhs_index_type{lhs_row_index}, rhs_index_type{rhs_row_index});
}

private:
row_equality _equality_comparator;
row_equality _conditional_comparator;
};

} // namespace

std::unique_ptr<rmm::device_uvector<size_type>> mixed_join_semi(
table_view const& left_equality,
table_view const& right_equality,
Expand All @@ -95,7 +56,7 @@ std::unique_ptr<rmm::device_uvector<size_type>> mixed_join_semi(
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
CUDF_EXPECTS((join_type != join_kind::INNER_JOIN) && (join_type != join_kind::LEFT_JOIN) &&
CUDF_EXPECTS((join_type != join_kind::INNER_JOIN) and (join_type != join_kind::LEFT_JOIN) and
(join_type != join_kind::FULL_JOIN),
"Inner, left, and full joins should use mixed_join.");

Expand Down Expand Up @@ -136,7 +97,7 @@ std::unique_ptr<rmm::device_uvector<size_type>> mixed_join_semi(
// output column and follow the null-supporting expression evaluation code
// path.
auto const has_nulls = cudf::nullate::DYNAMIC{
cudf::has_nulls(left_equality) || cudf::has_nulls(right_equality) ||
cudf::has_nulls(left_equality) or cudf::has_nulls(right_equality) or
binary_predicate.may_evaluate_null(left_conditional, right_conditional, stream)};

auto const parser = ast::detail::expression_parser{
Expand All @@ -155,27 +116,20 @@ std::unique_ptr<rmm::device_uvector<size_type>> mixed_join_semi(
auto right_conditional_view = table_device_view::create(right_conditional, stream);

auto const preprocessed_build =
experimental::row::equality::preprocessed_table::create(build, stream);
cudf::experimental::row::equality::preprocessed_table::create(build, stream);
auto const preprocessed_probe =
experimental::row::equality::preprocessed_table::create(probe, stream);
cudf::experimental::row::equality::preprocessed_table::create(probe, stream);
auto const row_comparator =
cudf::experimental::row::equality::two_table_comparator{preprocessed_probe, preprocessed_build};
cudf::experimental::row::equality::two_table_comparator{preprocessed_build, preprocessed_probe};
auto const equality_probe = row_comparator.equal_to<false>(has_nulls, compare_nulls);

semi_map_type hash_table{
compute_hash_table_size(build.num_rows()),
cuco::empty_key{std::numeric_limits<hash_value_type>::max()},
cuco::empty_value{cudf::detail::JoinNoneValue},
cudf::detail::cuco_allocator<char>{rmm::mr::polymorphic_allocator<char>{}, stream},
stream.value()};

// Create hash table containing all keys found in right table
// TODO: To add support for nested columns we will need to flatten in many
// places. However, this probably isn't worth adding any time soon since we
// won't be able to support AST conditions for those types anyway.
auto const build_nulls = cudf::nullate::DYNAMIC{cudf::has_nulls(build)};
auto const row_hash_build = cudf::experimental::row::hash::row_hasher{preprocessed_build};
auto const hash_build = row_hash_build.device_hasher(build_nulls);

// Since we may see multiple rows that are identical in the equality tables
// but differ in the conditional tables, the equality comparator used for
// insertion must account for both sets of tables. An alternative solution
Expand All @@ -190,39 +144,48 @@ std::unique_ptr<rmm::device_uvector<size_type>> mixed_join_semi(
auto const equality_build_equality =
row_comparator_build.equal_to<false>(build_nulls, compare_nulls);
auto const preprocessed_build_condtional =
experimental::row::equality::preprocessed_table::create(right_conditional, stream);
cudf::experimental::row::equality::preprocessed_table::create(right_conditional, stream);
auto const row_comparator_conditional_build =
cudf::experimental::row::equality::two_table_comparator{preprocessed_build_condtional,
preprocessed_build_condtional};
auto const equality_build_conditional =
row_comparator_conditional_build.equal_to<false>(build_nulls, compare_nulls);
double_row_equality equality_build{equality_build_equality, equality_build_conditional};
make_pair_function_semi pair_func_build{};

auto iter = cudf::detail::make_counting_transform_iterator(0, pair_func_build);
hash_set_type row_set{
{compute_hash_table_size(build.num_rows())},
cuco::empty_key{JoinNoneValue},
{equality_build_equality, equality_build_conditional},
{row_hash_build.device_hasher(build_nulls)},
{},
{},
cudf::detail::cuco_allocator<char>{rmm::mr::polymorphic_allocator<char>{}, stream},
{stream.value()}};

auto iter = thrust::make_counting_iterator(0);

// skip rows that are null here.
if ((compare_nulls == null_equality::EQUAL) or (not nullable(build))) {
hash_table.insert(iter, iter + right_num_rows, hash_build, equality_build, stream.value());
row_set.insert_async(iter, iter + right_num_rows, stream.value());
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some perf improvement through use of _async APIs.

} else {
thrust::counting_iterator<cudf::size_type> stencil(0);
auto const [row_bitmask, _] =
cudf::detail::bitmask_and(build, stream, cudf::get_current_device_resource_ref());
row_is_valid pred{static_cast<bitmask_type const*>(row_bitmask.data())};

// insert valid rows
hash_table.insert_if(
iter, iter + right_num_rows, stencil, pred, hash_build, equality_build, stream.value());
row_set.insert_if_async(iter, iter + right_num_rows, stencil, pred, stream.value());
}

auto hash_table_view = hash_table.get_device_view();

detail::grid_1d const config(outer_num_rows, DEFAULT_JOIN_BLOCK_SIZE);
auto const shmem_size_per_block = parser.shmem_per_thread * config.num_threads_per_block;
detail::grid_1d const config(outer_num_rows * hash_set_type::cg_size, DEFAULT_JOIN_BLOCK_SIZE);
Copy link
Member Author

@mhaseeb123 mhaseeb123 Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug fix (part 2): We need to launch outer_num_rows * hash_set_type::cg_size threads now. Even without this the loop in part 1 will take care of everything.

auto const shmem_size_per_block =
parser.shmem_per_thread *
cuco::detail::int_div_ceil(config.num_threads_per_block, hash_set_type::cg_size);

auto const row_hash = cudf::experimental::row::hash::row_hasher{preprocessed_probe};
auto const hash_probe = row_hash.device_hasher(has_nulls);

hash_set_ref_type const row_set_ref = row_set.ref(cuco::contains).with_hash_function(hash_probe);

// Vector used to indicate indices from left/probe table which are present in output
auto left_table_keep_mask = rmm::device_uvector<bool>(probe.num_rows(), stream);

Expand All @@ -231,9 +194,8 @@ std::unique_ptr<rmm::device_uvector<size_type>> mixed_join_semi(
*right_conditional_view,
*probe_view,
*build_view,
hash_probe,
equality_probe,
hash_table_view,
row_set_ref,
cudf::device_span<bool>(left_table_keep_mask),
parser.device_expression_data,
config,
Expand Down
Loading
Loading