From 19944cf330cf247ced15490204ae1ea3ca486465 Mon Sep 17 00:00:00 2001 From: Tanmay Gujar Date: Thu, 1 Aug 2024 12:09:50 -0700 Subject: [PATCH] address review comments --- .../cudf/detail/distinct_hash_join.cuh | 40 ++++++------------- .../cudf/table/experimental/row_operators.cuh | 2 +- cpp/src/join/distinct_hash_join.cu | 17 ++++---- cpp/src/join/mixed_join_kernel_semi_impl.cuh | 2 +- cpp/src/join/mixed_join_kernels_semi.cu | 2 +- .../join/mixed_join_kernels_semi_compound.cu | 2 +- .../join/mixed_join_kernels_semi_nested.cu | 2 +- cpp/src/join/mixed_join_semi.cu | 6 +-- 8 files changed, 28 insertions(+), 45 deletions(-) diff --git a/cpp/include/cudf/detail/distinct_hash_join.cuh b/cpp/include/cudf/detail/distinct_hash_join.cuh index 6a5b3355c3d..2246556e035 100644 --- a/cpp/include/cudf/detail/distinct_hash_join.cuh +++ b/cpp/include/cudf/detail/distinct_hash_join.cuh @@ -109,33 +109,19 @@ struct distinct_hash_join { using cuco_storage_type = cuco::storage<1>; /// Hash table type - using hash_table_type = std::variant< - cuco::static_set, - cuco::extent, - cuda::thread_scope_device, - comparator_adapter>, - probing_scheme_type, - cudf::detail::cuco_allocator, - cuco_storage_type>, - cuco::static_set< - cuco::pair, - cuco::extent, - cuda::thread_scope_device, - comparator_adapter>, - probing_scheme_type, - cudf::detail::cuco_allocator, - cuco_storage_type>, - cuco::static_set< - cuco::pair, - cuco::extent, - cuda::thread_scope_device, - comparator_adapter>, - probing_scheme_type, - cudf::detail::cuco_allocator, - cuco_storage_type>>; + template + using static_set_with_comparator = cuco::static_set< + cuco::pair, + cuco::extent, + cuda::thread_scope_device, + comparator_adapter< + cudf::experimental::row::equality::strong_index_comparator_adapter>, + probing_scheme_type, + cudf::detail::cuco_allocator, + cuco_storage_type>; + using hash_table_type = std::variant, + static_set_with_comparator, + static_set_with_comparator>; bool _has_nulls; ///< true if nulls are present in either build table or probe table cudf::null_equality _nulls_equal; ///< whether to consider nulls as equal diff --git a/cpp/include/cudf/table/experimental/row_operators.cuh b/cpp/include/cudf/table/experimental/row_operators.cuh index feb0bfca2db..7dcac39f22e 100644 --- a/cpp/include/cudf/table/experimental/row_operators.cuh +++ b/cpp/include/cudf/table/experimental/row_operators.cuh @@ -1517,7 +1517,7 @@ class device_row_comparator { * @brief Compares the specified elements for equality. * * is_equality_comparable differs from implementation for std::equality_comparable and considers - * void as and equality comparable type. Thus we need to disable this for when type is void. + * void as an equality comparable type. Thus we need to disable this for when type is void. * * @param lhs_element_index The index of the first element * @param rhs_element_index The index of the second element diff --git a/cpp/src/join/distinct_hash_join.cu b/cpp/src/join/distinct_hash_join.cu index 949b386ddab..4d49e23dc01 100644 --- a/cpp/src/join/distinct_hash_join.cu +++ b/cpp/src/join/distinct_hash_join.cu @@ -83,10 +83,9 @@ auto prepare_device_equal( return std::visit( [&](auto& comparator) { - return ret_type{ - std::in_place_type< - comparator_adapter::type>>, - comparator}; + return ret_type{std::in_place_type< + comparator_adapter>>, + comparator}; }, d_comparator); } @@ -157,7 +156,7 @@ distinct_hash_join::distinct_hash_join(cudf::table_view const& build, cuco::static_set, cuco::extent, cuda::thread_scope_device, - typename std::remove_reference::type, + typename std::remove_reference_t, distinct_hash_join::probing_scheme_type, cudf::detail::cuco_allocator, distinct_hash_join::cuco_storage_type>; @@ -190,8 +189,7 @@ distinct_hash_join::distinct_hash_join(cudf::table_view const& build, [&](auto&& hasher, auto&& hash_table) { auto const iter = cudf::detail::make_counting_transform_iterator( 0, - build_keys_fn::type, rhs_index_type>{ - hasher}); + build_keys_fn, rhs_index_type>{hasher}); size_type const build_table_num_rows{build.num_rows()}; if (this->_nulls_equal == cudf::null_equality::EQUAL or (not cudf::nullable(this->_build))) { @@ -253,8 +251,7 @@ distinct_hash_join::inner_join(rmm::cuda_stream_view stream, [&](auto&& hasher, auto&& hash_table) { auto const iter = cudf::detail::make_counting_transform_iterator( 0, - build_keys_fn::type, lhs_index_type>{ - hasher}); + build_keys_fn, lhs_index_type>{hasher}); auto const [probe_indices_end, _] = hash_table.retrieve(iter, iter + probe_table_num_rows, @@ -307,7 +304,7 @@ std::unique_ptr> distinct_hash_join::l [&](auto&& hasher, auto&& hash_table) { auto const iter = cudf::detail::make_counting_transform_iterator( 0, - build_keys_fn::type, lhs_index_type>{ + build_keys_fn, lhs_index_type>{ hasher}); auto const output_begin = diff --git a/cpp/src/join/mixed_join_kernel_semi_impl.cuh b/cpp/src/join/mixed_join_kernel_semi_impl.cuh index 6737904064e..1f6abdc33cb 100644 --- a/cpp/src/join/mixed_join_kernel_semi_impl.cuh +++ b/cpp/src/join/mixed_join_kernel_semi_impl.cuh @@ -17,7 +17,7 @@ #include "join/join_common_utils.cuh" #include "join/join_common_utils.hpp" #include "join/mixed_join_common_utils.cuh" -#include "mixed_join_kernels_semi.cuh" +#include "join/mixed_join_kernels_semi.cuh" #include #include diff --git a/cpp/src/join/mixed_join_kernels_semi.cu b/cpp/src/join/mixed_join_kernels_semi.cu index 69dfab0b3ed..5369a99808b 100644 --- a/cpp/src/join/mixed_join_kernels_semi.cu +++ b/cpp/src/join/mixed_join_kernels_semi.cu @@ -15,7 +15,7 @@ */ #include "join/mixed_join_common_utils.cuh" -#include "mixed_join_kernel_semi_impl.cuh" +#include "join/mixed_join_kernel_semi_impl.cuh" namespace cudf { namespace detail { diff --git a/cpp/src/join/mixed_join_kernels_semi_compound.cu b/cpp/src/join/mixed_join_kernels_semi_compound.cu index 06e4753fcd6..68d45c19560 100644 --- a/cpp/src/join/mixed_join_kernels_semi_compound.cu +++ b/cpp/src/join/mixed_join_kernels_semi_compound.cu @@ -15,7 +15,7 @@ */ #include "join/mixed_join_common_utils.cuh" -#include "mixed_join_kernel_semi_impl.cuh" +#include "join/mixed_join_kernel_semi_impl.cuh " namespace cudf { namespace detail { diff --git a/cpp/src/join/mixed_join_kernels_semi_nested.cu b/cpp/src/join/mixed_join_kernels_semi_nested.cu index ef422c4f31f..44a612bba3a 100644 --- a/cpp/src/join/mixed_join_kernels_semi_nested.cu +++ b/cpp/src/join/mixed_join_kernels_semi_nested.cu @@ -15,7 +15,7 @@ */ #include "join/mixed_join_common_utils.cuh" -#include "mixed_join_kernel_semi_impl.cuh" +#include "join/mixed_join_kernel_semi_impl.cuh" namespace cudf { namespace detail { diff --git a/cpp/src/join/mixed_join_semi.cu b/cpp/src/join/mixed_join_semi.cu index 69e3b440e8d..4c78b892b1f 100644 --- a/cpp/src/join/mixed_join_semi.cu +++ b/cpp/src/join/mixed_join_semi.cu @@ -14,9 +14,9 @@ * limitations under the License. */ -#include "join_common_utils.cuh" -#include "join_common_utils.hpp" -#include "mixed_join_kernels_semi.cuh" +#include "join/join_common_utils.cuh" +#include "join/join_common_utils.hpp" +#include "join/mixed_join_kernels_semi.cuh" #include #include