diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 95c509efc5b..9ab8d1b16d4 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -462,8 +462,10 @@ add_library( src/join/mixed_join.cu src/join/mixed_join_kernel.cu src/join/mixed_join_kernel_nulls.cu - src/join/mixed_join_kernels_semi.cu src/join/mixed_join_semi.cu + src/join/mixed_join_kernels_semi.cu + src/join/mixed_join_kernels_semi_nested.cu + src/join/mixed_join_kernels_semi_compound.cu src/join/mixed_join_size_kernel.cu src/join/mixed_join_size_kernel_nulls.cu src/join/semi_join.cu diff --git a/cpp/include/cudf/detail/distinct_hash_join.cuh b/cpp/include/cudf/detail/distinct_hash_join.cuh index c3bc3ad89fa..2246556e035 100644 --- a/cpp/include/cudf/detail/distinct_hash_join.cuh +++ b/cpp/include/cudf/detail/distinct_hash_join.cuh @@ -29,6 +29,7 @@ #include #include #include +#include namespace cudf::detail { @@ -85,22 +86,42 @@ struct hasher_adapter { template struct distinct_hash_join { private: - /// Device row equal type - using d_equal_type = cudf::experimental::row::equality::strong_index_comparator_adapter< - cudf::experimental::row::equality::device_row_comparator>; + using row_comparator = cudf::experimental::row::equality::device_row_comparator< + true, + cudf::nullate::DYNAMIC, + cudf::experimental::row::equality::nan_equal_physical_equality_comparator, + cudf::experimental::type_identity_t>; + + using row_comparator_no_nested = cudf::experimental::row::equality::device_row_comparator< + false, + cudf::nullate::DYNAMIC, + cudf::experimental::row::equality::nan_equal_physical_equality_comparator, + cudf::experimental::dispatch_void_if_nested_t>; + + using row_comparator_no_compound = cudf::experimental::row::equality::device_row_comparator< + false, + cudf::nullate::DYNAMIC, + cudf::experimental::row::equality::nan_equal_physical_equality_comparator, + cudf::experimental::dispatch_void_if_compound_t>; + using hasher = hasher_adapter>; using probing_scheme_type = cuco::linear_probing<1, hasher>; using cuco_storage_type = cuco::storage<1>; /// Hash table type - using hash_table_type = cuco::static_set, - cuco::extent, - cuda::thread_scope_device, - comparator_adapter, - probing_scheme_type, - cudf::detail::cuco_allocator, - cuco_storage_type>; + template + using static_set_with_comparator = cuco::static_set< + cuco::pair, + cuco::extent, + cuda::thread_scope_device, + comparator_adapter< + cudf::experimental::row::equality::strong_index_comparator_adapter>, + probing_scheme_type, + cudf::detail::cuco_allocator, + cuco_storage_type>; + using hash_table_type = std::variant, + static_set_with_comparator, + static_set_with_comparator>; bool _has_nulls; ///< true if nulls are present in either build table or probe table cudf::null_equality _nulls_equal; ///< whether to consider nulls as equal @@ -109,8 +130,8 @@ struct distinct_hash_join { std::shared_ptr _preprocessed_build; ///< input table preprocssed for row operators std::shared_ptr - _preprocessed_probe; ///< input table preprocssed for row operators - hash_table_type _hash_table; ///< hash table built on `_build` + _preprocessed_probe; ///< input table preprocssed for row operators + std::unique_ptr _hash_table; ///< hash table built on `_build` public: distinct_hash_join() = delete; diff --git a/cpp/include/cudf/table/experimental/row_operators.cuh b/cpp/include/cudf/table/experimental/row_operators.cuh index f05e5f4ca5c..7dcac39f22e 100644 --- a/cpp/include/cudf/table/experimental/row_operators.cuh +++ b/cpp/include/cudf/table/experimental/row_operators.cuh @@ -29,6 +29,8 @@ #include #include #include +#include +#include #include #include #include @@ -52,12 +54,84 @@ #include #include #include +#include #include +#include namespace CUDF_EXPORT cudf { namespace experimental { +/// Type identity type from C++20 +template +using type_identity_t = T; + +/** + * @brief Recursive template to apply type transformations sequentially. Transformations are + * applied first-to-last in the order specified. + * + * @tparam T Base type or initial type on which transformations are applied + * @tparam Rest Type transformations to apply + */ +template typename... Rest> +struct transform_sequence; + +/// @copydoc transform_sequence +template typename First, template typename... Rest> +struct transform_sequence { + using type = + typename transform_sequence, Rest...>::type; ///< Resolved type after transformations +}; + +/// @copydoc transform_sequence +template +struct transform_sequence { + using type = T; ///< The underlying type +}; + +/** + * @brief Helper alias for transform_sequence + */ +template typename... Rest> +using transform_sequence_t = typename transform_sequence::type; + +/** + * @brief Void dispatcher helper + */ +template +using dispatch_void_conditional_t = std::conditional_t; + +/** + * @brief Void dispatcher generator + */ +template +struct dispatch_void_conditional_generator { + /// The underlying type + template + using type = dispatch_void_conditional_t...>::value, T>; +}; + +/** + * @brief Returns `void` if it's a nested type + */ +template +using dispatch_void_if_nested_t = + dispatch_void_conditional_generator, + id_to_type>::type; + +/** + * @brief Returns `void` if it's a compound type + * + */ +template +using dispatch_void_if_compound_t = + dispatch_void_conditional_generator, + id_to_type, + id_to_type, + id_to_type, + id_to_type, + id_to_type, + id_to_type>::type; /** * @brief A map from cudf::type_id to cudf type that excludes LIST and STRUCT types. * @@ -75,12 +149,11 @@ namespace experimental { */ template struct dispatch_void_if_nested { - /// The type to dispatch to if the type is nested - using type = std::conditional_t>; + /// The underlying type + using type = dispatch_void_if_nested_t>; }; namespace row { - enum class lhs_index_type : size_type {}; enum class rhs_index_type : size_type {}; @@ -1335,11 +1408,19 @@ struct nan_equal_physical_equality_comparator { */ template + typename PhysicalEqualityComparator = nan_equal_physical_equality_comparator, + template typename dispatch_conditional_t = cudf::experimental::type_identity_t> class device_row_comparator { friend class self_comparator; ///< Allow self_comparator to access private members friend class two_table_comparator; ///< Allow two_table_comparator to access private members + template + using dispatch_void_if_nested = + transform_sequence, dispatch_conditional_t, dispatch_void_if_nested_t>; + + template + using dispatch_conditional = transform_sequence, dispatch_conditional_t>; + public: /** * @brief Checks whether the row at `lhs_index` in the `lhs` table is equal to the row at @@ -1353,7 +1434,7 @@ class device_row_comparator { size_type const rhs_index) const noexcept { auto equal_elements = [=](column_device_view l, column_device_view r) { - return cudf::type_dispatcher( + return cudf::type_dispatcher( l.type(), element_comparator{check_nulls, l, r, nulls_are_equal, comparator}, lhs_index, @@ -1417,15 +1498,35 @@ class device_row_comparator { { } + /** + * @brief Dummy operator for dispatch to void type. Ideally, we want this to be unreachable, but + * using CUDF_UNREACHABLE leads to an increase in register usage and is avoided. + * + * @note A correct implementation should never call this function. + * + * @return False + */ + template )> + __device__ bool operator()(size_type const lhs_element_index, + size_type const rhs_element_index) const noexcept + { + return false; + } + /** * @brief Compares the specified elements for equality. * + * is_equality_comparable differs from implementation for std::equality_comparable and considers + * void as an equality comparable type. Thus we need to disable this for when type is void. + * * @param lhs_element_index The index of the first element * @param rhs_element_index The index of the second element * @return True if lhs and rhs are equal or if both lhs and rhs are null and nulls are * considered equal (`nulls_are_equal` == `null_equality::EQUAL`) */ - template ())> + template () and + (not std::is_void_v))> __device__ bool operator()(size_type const lhs_element_index, size_type const rhs_element_index) const noexcept { @@ -1445,14 +1546,17 @@ class device_row_comparator { template () and - (not has_nested_columns or not cudf::is_nested())), + (not has_nested_columns or not cudf::is_nested()) and + (not std::is_void_v)), typename... Args> __device__ bool operator()(Args...) { CUDF_UNREACHABLE("Attempted to compare elements of uncomparable types."); } - template ())> + template () and + (not std::is_void_v))> __device__ bool operator()(size_type const lhs_element_index, size_type const rhs_element_index) const noexcept { @@ -1655,16 +1759,83 @@ class self_comparator { * @return A binary callable object */ template typename dispatch_conditional_t = type_identity_t, typename Nullate, typename PhysicalEqualityComparator = nan_equal_physical_equality_comparator> auto equal_to(Nullate nullate = {}, null_equality nulls_are_equal = null_equality::EQUAL, PhysicalEqualityComparator comparator = {}) const noexcept { - return device_row_comparator{ + return device_row_comparator{ nullate, *d_t, *d_t, nulls_are_equal, comparator}; } + /** + * @brief Get the variant of comparison operator to use on the device + * + * Returns a binary callable variant, `F`, with signature `bool F(size_type, size_type)`. + * + * `F(i,j)` returns true if and only if row `i` compares equal to row `j`. + * + * @tparam Nullate A cudf::nullate type describing whether to check for nulls. + * @tparam PhysicalEqualityComparator A equality comparator functor that compares individual + * values rather than logical elements, defaults to a comparator for which `NaN == NaN`. + * @param column_types Column types in the row to be compared + * @param nullate Indicates if any input column contains nulls. + * @param nulls_are_equal Indicates if nulls are equal. + * @param comparator Physical element equality comparison functor. + * @return A binary callable object variant + */ + template + auto equal_to(std::unordered_set column_types, + Nullate nullate = {}, + null_equality nulls_are_equal = null_equality::EQUAL, + PhysicalEqualityComparator comparator = {}) const noexcept + { + using row_comparator_t = std::variant< + device_row_comparator, + device_row_comparator, + device_row_comparator>; + + auto find_any = [](std::initializer_list ids, + std::unordered_set const& search_set) { + for (auto id : ids) { + if (search_set.find(id) != search_set.end()) return true; + } + return false; + }; + + if (find_any({type_id::STRUCT, type_id::LIST}, column_types)) { + return row_comparator_t{ + device_row_comparator{ + nullate, *d_t, *d_t, nulls_are_equal, comparator}}; + } else if (find_any({type_id::DECIMAL32, + type_id::DECIMAL64, + type_id::DECIMAL128, + type_id::STRING, + type_id::DICTIONARY32}, + column_types)) { + return row_comparator_t{device_row_comparator{ + nullate, *d_t, *d_t, nulls_are_equal, comparator}}; + } else { + return row_comparator_t{device_row_comparator{ + nullate, *d_t, *d_t, nulls_are_equal, comparator}}; + } + } + private: std::shared_ptr d_t; }; @@ -1769,15 +1940,92 @@ class two_table_comparator { * @return A binary callable object */ template typename dispatch_conditional_t = type_identity_t, typename Nullate, typename PhysicalEqualityComparator = nan_equal_physical_equality_comparator> auto equal_to(Nullate nullate = {}, null_equality nulls_are_equal = null_equality::EQUAL, PhysicalEqualityComparator comparator = {}) const noexcept { - return strong_index_comparator_adapter{ - device_row_comparator( - nullate, *d_left_table, *d_right_table, nulls_are_equal, comparator)}; + return strong_index_comparator_adapter{device_row_comparator( + nullate, *d_left_table, *d_right_table, nulls_are_equal, comparator)}; + } + + /** + * @brief Return the binary operator for comparing rows in the table. + * + * Returns a binary callable variant, `F`, with signatures `bool F(lhs_index_type, + * rhs_index_type)` and `bool F(rhs_index_type, lhs_index_type)`. + * + * `F(lhs_index_type i, rhs_index_type j)` returns true if and only if row `i` of the left table + * compares equal to row `j` of the right table. + * + * Similarly, `F(rhs_index_type i, lhs_index_type j)` returns true if and only if row `i` of the + * right table compares equal to row `j` of the left table. + * + * @tparam Nullate A cudf::nullate type describing whether to check for nulls. + * @tparam PhysicalEqualityComparator A equality comparator functor that compares individual + * values rather than logical elements, defaults to a comparator for which `NaN == NaN`. + * @param column_types Column types in the row to be compared + * @param nullate Indicates if any input column contains nulls. + * @param nulls_are_equal Indicates if nulls are equal. + * @param comparator Physical element equality comparison functor. + * @return A binary callable object variant + */ + template + auto equal_to(std::unordered_set column_types, + Nullate nullate = {}, + null_equality nulls_are_equal = null_equality::EQUAL, + PhysicalEqualityComparator comparator = {}) const noexcept + { + using row_comparator_t = std::variant< + strong_index_comparator_adapter< + device_row_comparator>, + strong_index_comparator_adapter>, + strong_index_comparator_adapter>>; + + auto find_any = [](std::initializer_list ids, + std::unordered_set const& search_set) { + for (auto id : ids) { + if (search_set.find(id) != search_set.end()) return true; + } + return false; + }; + + if (find_any({type_id::STRUCT, type_id::LIST}, column_types)) { + return row_comparator_t{strong_index_comparator_adapter{ + device_row_comparator( + nullate, *d_left_table, *d_right_table, nulls_are_equal, comparator)}}; + } else if (find_any({type_id::DECIMAL32, + type_id::DECIMAL64, + type_id::DECIMAL128, + type_id::STRING, + type_id::DICTIONARY32}, + column_types)) { + return row_comparator_t{ + strong_index_comparator_adapter{device_row_comparator( + nullate, *d_left_table, *d_right_table, nulls_are_equal, comparator)}}; + } else { + return row_comparator_t{ + strong_index_comparator_adapter{device_row_comparator( + nullate, *d_left_table, *d_right_table, nulls_are_equal, comparator)}}; + } } private: @@ -1855,9 +2103,18 @@ class element_hasher { * @tparam hash_function Hash functor to use for hashing elements. * @tparam Nullate A cudf::nullate type describing whether to check for nulls. */ -template