diff --git a/cpp/include/cudf/table/experimental/row_operators.cuh b/cpp/include/cudf/table/experimental/row_operators.cuh index 3ab7e7546d8..2397f9bc6fb 100644 --- a/cpp/include/cudf/table/experimental/row_operators.cuh +++ b/cpp/include/cudf/table/experimental/row_operators.cuh @@ -142,6 +142,67 @@ using rhs_iterator = strong_index_iterator; namespace lexicographic { +/** + * @brief Computes a weak ordering of two values with special sorting behavior. + * + * This relational comparator functor compares physical values rather than logical + * elements like lists, strings, or structs. It evaluates `NaN` as not less than, equal to, or + * greater than other values and is IEEE-754 compliant. + */ +struct physical_element_comparator { + /** + * @brief Operator for relational comparisons. + * + * @param lhs First element + * @param rhs Second element + * @return Relation between elements + */ + template + __device__ constexpr weak_ordering operator()(Element const lhs, Element const rhs) const noexcept + { + return detail::compare_elements(lhs, rhs); + } +}; + +/** + * @brief Relational comparator functor that compares physical values rather than logical + * elements like lists, strings, or structs. It evaluates `NaN` as equivalent to other `NaN`s and + * greater than all other values. + */ +struct sorting_physical_element_comparator { + /** + * @brief Operator for relational comparison of non-floating point values. + * + * @param lhs First element + * @param rhs Second element + * @return Relation between elements + */ + template )> + __device__ constexpr weak_ordering operator()(Element const lhs, Element const rhs) const noexcept + { + return detail::compare_elements(lhs, rhs); + } + + /** + * @brief Operator for relational comparison of floating point values. + * + * @param lhs First element + * @param rhs Second element + * @return Relation between elements + */ + template )> + __device__ constexpr weak_ordering operator()(Element const lhs, Element const rhs) const noexcept + { + if (isnan(lhs)) { + return isnan(rhs) ? weak_ordering::EQUIVALENT : weak_ordering::GREATER; + } else if (isnan(rhs)) { + return weak_ordering::LESS; + } + + return detail::compare_elements(lhs, rhs); + } +}; + /** * @brief Computes the lexicographic comparison between 2 rows. * @@ -158,8 +219,12 @@ namespace lexicographic { * `aac < abb`. * * @tparam Nullate A cudf::nullate type describing whether to check for nulls. + * @tparam PhysicalElementComparator A relational comparator functor that compares individual values + * rather than logical elements, defaults to `NaN` aware relational comparator that evaluates `NaN` + * as greater than all other values. */ -template +template class device_row_comparator { friend class self_comparator; ///< Allow self_comparator to access private members friend class two_table_comparator; ///< Allow two_table_comparator to access private members @@ -168,7 +233,7 @@ class device_row_comparator { * @brief Construct a function object for performing a lexicographic * comparison between the rows of two tables. * - * @param check_nulls Indicates if either input table contains columns with nulls. + * @param check_nulls Indicates if any input column contains nulls. * @param lhs The first table * @param rhs The second table (may be the same table as `lhs`) * @param depth Optional, device array the same length as a row that contains starting depths of @@ -179,20 +244,22 @@ class device_row_comparator { * @param null_precedence Optional, device array the same length as a row and indicates how null * values compare to all other for every column. If `nullopt`, then null precedence would be * `null_order::BEFORE` for all columns. + * @param comparator Physical element relational comparison functor. */ - device_row_comparator( - Nullate check_nulls, - table_device_view lhs, - table_device_view rhs, - std::optional> depth = std::nullopt, - std::optional> column_order = std::nullopt, - std::optional> null_precedence = std::nullopt) noexcept + device_row_comparator(Nullate check_nulls, + table_device_view lhs, + table_device_view rhs, + std::optional> depth = std::nullopt, + std::optional> column_order = std::nullopt, + std::optional> null_precedence = std::nullopt, + PhysicalElementComparator comparator = {}) noexcept : _lhs{lhs}, _rhs{rhs}, _check_nulls{check_nulls}, _depth{depth}, _column_order{column_order}, - _null_precedence{null_precedence} + _null_precedence{null_precedence}, + _comparator{comparator} { } @@ -213,17 +280,20 @@ class device_row_comparator { * @param null_precedence Indicates how null values are ordered with other values * @param depth The depth of the column if part of a nested column @see * preprocessed_table::depths + * @param comparator Physical element relational comparison functor. */ __device__ element_comparator(Nullate check_nulls, column_device_view lhs, column_device_view rhs, - null_order null_precedence = null_order::BEFORE, - int depth = 0) + null_order null_precedence = null_order::BEFORE, + int depth = 0, + PhysicalElementComparator comparator = {}) : _lhs{lhs}, _rhs{rhs}, _check_nulls{check_nulls}, _null_precedence{null_precedence}, - _depth{depth} + _depth{depth}, + _comparator{comparator} { } @@ -249,8 +319,8 @@ class device_row_comparator { } } - return cuda::std::pair(relational_compare(_lhs.element(lhs_element_index), - _rhs.element(rhs_element_index)), + return cuda::std::pair(_comparator(_lhs.element(lhs_element_index), + _rhs.element(rhs_element_index)), std::numeric_limits::max()); } @@ -289,9 +359,11 @@ class device_row_comparator { ++depth; } - auto const comparator = element_comparator{_check_nulls, lcol, rcol, _null_precedence, depth}; return cudf::type_dispatcher( - lcol.type(), comparator, lhs_element_index, rhs_element_index); + lcol.type(), + element_comparator{_check_nulls, lcol, rcol, _null_precedence, depth, _comparator}, + lhs_element_index, + rhs_element_index); } private: @@ -300,6 +372,7 @@ class device_row_comparator { Nullate const _check_nulls; null_order const _null_precedence; int const _depth; + PhysicalElementComparator const _comparator; }; public: @@ -312,8 +385,8 @@ class device_row_comparator { * @return weak ordering comparison of the row in the `lhs` table relative to the row in the `rhs` * table */ - __device__ weak_ordering operator()(size_type const lhs_index, - size_type const rhs_index) const noexcept + __device__ constexpr weak_ordering operator()(size_type const lhs_index, + size_type const rhs_index) const noexcept { int last_null_depth = std::numeric_limits::max(); for (size_type i = 0; i < _lhs.num_columns(); ++i) { @@ -326,11 +399,13 @@ class device_row_comparator { null_order const null_precedence = _null_precedence.has_value() ? (*_null_precedence)[i] : null_order::BEFORE; - auto const comparator = - element_comparator{_check_nulls, _lhs.column(i), _rhs.column(i), null_precedence, depth}; weak_ordering state; - cuda::std::tie(state, last_null_depth) = - cudf::type_dispatcher(_lhs.column(i).type(), comparator, lhs_index, rhs_index); + cuda::std::tie(state, last_null_depth) = cudf::type_dispatcher( + _lhs.column(i).type(), + element_comparator{ + _check_nulls, _lhs.column(i), _rhs.column(i), null_precedence, depth, _comparator}, + lhs_index, + rhs_index); if (state == weak_ordering::EQUIVALENT) { continue; } @@ -344,10 +419,11 @@ class device_row_comparator { private: table_device_view const _lhs; table_device_view const _rhs; - Nullate const _check_nulls{}; + Nullate const _check_nulls; std::optional> const _depth; std::optional> const _column_order; std::optional> const _null_precedence; + PhysicalElementComparator const _comparator; }; // class device_row_comparator /** @@ -362,6 +438,10 @@ class device_row_comparator { */ template struct weak_ordering_comparator_impl { + static_assert(not((weak_ordering::EQUIVALENT == values) && ...), + "weak_ordering_comparator should not be used for pure equality comparisons. The " + "`row_equality_comparator` should be used instead"); + template __device__ constexpr bool operator()(LhsType const lhs_index, RhsType const rhs_index) const noexcept @@ -379,11 +459,22 @@ struct weak_ordering_comparator_impl { * @tparam Nullate A cudf::nullate type describing whether to check for nulls. */ template -using less_comparator = weak_ordering_comparator_impl; +struct less_comparator : weak_ordering_comparator_impl { + less_comparator(Comparator const& comparator) + : weak_ordering_comparator_impl{comparator} + { + } +}; template -using less_equivalent_comparator = - weak_ordering_comparator_impl; +struct less_equivalent_comparator + : weak_ordering_comparator_impl { + less_equivalent_comparator(Comparator const& comparator) + : weak_ordering_comparator_impl{ + comparator} + { + } +}; /** * @brief Preprocessed table for use with lexicographical comparison @@ -538,14 +629,28 @@ class self_comparator { * `F(i,j)` returns true if and only if row `i` compares lexicographically less than row `j`. * * @tparam Nullate A cudf::nullate type describing whether to check for nulls. - * @param nullate Indicates if either input column contains nulls - * @return A binary callable object + * @tparam PhysicalElementComparator A relational comparator functor that compares individual + * values rather than logical elements, defaults to `NaN` aware relational comparator that + * evaluates `NaN` as greater than all other values. + * @param nullate Indicates if any input column contains nulls. + * @param comparator Physical element relational comparison functor. + * @return A binary callable object. */ - template - less_comparator> device_comparator(Nullate nullate = {}) const + template + auto less(Nullate nullate = {}, PhysicalElementComparator comparator = {}) const noexcept { - return less_comparator>{device_row_comparator( - nullate, *d_t, *d_t, d_t->depths(), d_t->column_order(), d_t->null_precedence())}; + return less_comparator{device_row_comparator{ + nullate, *d_t, *d_t, d_t->depths(), d_t->column_order(), d_t->null_precedence(), comparator}}; + } + + template + auto less_equivalent(Nullate nullate = {}, + PhysicalElementComparator comparator = {}) const noexcept + { + return less_equivalent_comparator{device_row_comparator{ + nullate, *d_t, *d_t, d_t->depths(), d_t->column_order(), d_t->null_precedence(), comparator}}; } private: @@ -555,6 +660,8 @@ class self_comparator { // @cond template struct strong_index_comparator_adapter { + strong_index_comparator_adapter(Comparator const& comparator) : comparator{comparator} {} + __device__ constexpr weak_ordering operator()(lhs_index_type const lhs_index, rhs_index_type const rhs_index) const noexcept { @@ -653,20 +760,40 @@ class two_table_comparator { * `j` of the left table. * * @tparam Nullate A cudf::nullate type describing whether to check for nulls. - * @param nullate Indicates if either input column contains nulls. - * @return A binary callable object + * @tparam PhysicalElementComparator A relational comparator functor that compares individual + * values rather than logical elements, defaults to `NaN` aware relational comparator that + * evaluates `NaN` as greater than all other values. + * @param nullate Indicates if any input column contains nulls. + * @param comparator Physical element relational comparison functor. + * @return A binary callable object. */ - template - less_comparator>> - device_comparator(Nullate nullate = {}) const + template + auto less(Nullate nullate = {}, PhysicalElementComparator comparator = {}) const noexcept + { + return less_comparator{ + strong_index_comparator_adapter{device_row_comparator{nullate, + *d_left_table, + *d_right_table, + d_left_table->depths(), + d_left_table->column_order(), + d_left_table->null_precedence(), + comparator}}}; + } + + template + auto less_equivalent(Nullate nullate = {}, + PhysicalElementComparator comparator = {}) const noexcept { - return less_comparator>>{ - device_row_comparator(nullate, - *d_left_table, - *d_right_table, - d_left_table->depths(), - d_left_table->column_order(), - d_left_table->null_precedence())}; + return less_equivalent_comparator{ + strong_index_comparator_adapter{device_row_comparator{nullate, + *d_left_table, + *d_right_table, + d_left_table->depths(), + d_left_table->column_order(), + d_left_table->null_precedence(), + comparator}}}; } private: @@ -681,12 +808,76 @@ class row_hasher; } namespace equality { + +/** + * @brief Equality comparator functor that compares physical values rather than logical + * elements like lists, strings, or structs. It evaluates `NaN` not equal to all other values for + * IEEE-754 compliance. + */ +struct physical_equality_comparator { + /** + * @brief Operator for equality comparisons. + * + * Note that `NaN != NaN`, following IEEE-754. + * + * @param lhs First element + * @param rhs Second element + * @return `true` if `lhs == rhs` else `false` + */ + template + __device__ constexpr bool operator()(Element const lhs, Element const rhs) const noexcept + { + return lhs == rhs; + } +}; + +/** + * @brief Equality comparator functor that compares physical values rather than logical + * elements like lists, strings, or structs. It evaluates `NaN` as equal to other `NaN`s. + */ +struct nan_equal_physical_equality_comparator { + /** + * @brief Operator for equality comparison of non-floating point values. + * + * @param lhs First element + * @param rhs Second element + * @return `true` if `lhs == rhs` else `false` + */ + template )> + __device__ constexpr bool operator()(Element const lhs, Element const rhs) const noexcept + { + return lhs == rhs; + } + + /** + * @brief Operator for equality comparison of floating point values. + * + * Note that `NaN == NaN`. + * + * @param lhs First element + * @param rhs Second element + * @return `true` if `lhs` == `rhs` else `false` + */ + template )> + __device__ constexpr bool operator()(Element const lhs, Element const rhs) const noexcept + { + return isnan(lhs) and isnan(rhs) ? true : lhs == rhs; + } +}; + /** - * @brief Comparator for performing equality comparison between the rows of two tables. + * @brief Computes the equality comparison between 2 rows. + * + * Equality is determined by comparing rows element by element. The first mismatching element + * returns false, representing unequal rows. If the rows are compared without mismatched elements, + * the rows are equal. * * @tparam Nullate A cudf::nullate type describing whether to check for nulls. + * @tparam PhysicalEqualityComparator A equality comparator functor that compares individual values + * rather than logical elements, defaults to a comparator for which `NaN == NaN`. */ -template +template class device_row_comparator { friend class self_comparator; ///< Allow self_comparator to access private members friend class two_table_comparator; ///< Allow two_table_comparator to access private members @@ -700,11 +891,15 @@ class device_row_comparator { * @param rhs_index The index of the row in the `rhs` table to examine * @return `true` if row from the `lhs` table is equal to the row in the `rhs` table */ - __device__ bool operator()(size_type const lhs_index, size_type const rhs_index) const noexcept + __device__ constexpr bool operator()(size_type const lhs_index, + size_type const rhs_index) const noexcept { auto equal_elements = [=](column_device_view l, column_device_view r) { return cudf::type_dispatcher( - l.type(), element_comparator{check_nulls, l, r, nulls_are_equal}, lhs_index, rhs_index); + l.type(), + element_comparator{check_nulls, l, r, nulls_are_equal, comparator}, + lhs_index, + rhs_index); }; return thrust::equal(thrust::seq, lhs.begin(), lhs.end(), rhs.begin(), equal_elements); @@ -715,16 +910,22 @@ class device_row_comparator { * @brief Construct a function object for performing equality comparison between the rows of two * tables. * - * @param check_nulls Indicates if either input table contains columns with nulls. + * @param check_nulls Indicates if any input column contains nulls. * @param lhs The first table * @param rhs The second table (may be the same table as `lhs`) * @param nulls_are_equal Indicates if two null elements are treated as equivalent + * @param comparator Physical element equality comparison functor. */ device_row_comparator(Nullate check_nulls, table_device_view lhs, table_device_view rhs, - null_equality nulls_are_equal = null_equality::EQUAL) noexcept - : lhs{lhs}, rhs{rhs}, check_nulls{check_nulls}, nulls_are_equal{nulls_are_equal} + null_equality nulls_are_equal = null_equality::EQUAL, + PhysicalEqualityComparator comparator = {}) noexcept + : lhs{lhs}, + rhs{rhs}, + check_nulls{check_nulls}, + nulls_are_equal{nulls_are_equal}, + comparator{comparator} { } @@ -743,12 +944,18 @@ class device_row_comparator { * @param lhs The column containing the first element * @param rhs The column containing the second element (may be the same as lhs) * @param nulls_are_equal Indicates if two null elements are treated as equivalent + * @param comparator Physical element equality comparison functor. */ __device__ element_comparator(Nullate check_nulls, column_device_view lhs, column_device_view rhs, - null_equality nulls_are_equal = null_equality::EQUAL) noexcept - : lhs{lhs}, rhs{rhs}, check_nulls{check_nulls}, nulls_are_equal{nulls_are_equal} + null_equality nulls_are_equal = null_equality::EQUAL, + PhysicalEqualityComparator comparator = {}) noexcept + : lhs{lhs}, + rhs{rhs}, + check_nulls{check_nulls}, + nulls_are_equal{nulls_are_equal}, + comparator{comparator} { } @@ -774,8 +981,8 @@ class device_row_comparator { } } - return equality_compare(lhs.element(lhs_element_index), - rhs.element(rhs_element_index)); + return comparator(lhs.element(lhs_element_index), + rhs.element(rhs_element_index)); } template (lcol.type(), comp); } @@ -875,12 +1082,14 @@ class device_row_comparator { column_device_view const rhs; Nullate const check_nulls; null_equality const nulls_are_equal; + PhysicalEqualityComparator const comparator; }; table_device_view const lhs; table_device_view const rhs; Nullate const check_nulls; null_equality const nulls_are_equal; + PhysicalEqualityComparator const comparator; }; /** @@ -965,16 +1174,21 @@ class self_comparator { * * `F(i,j)` returns true if and only if row `i` compares equal to row `j`. * - * @tparam Nullate A cudf::nullate type describing whether to check for nulls - * @param nullate Indicates if either input column contains nulls - * @param nulls_are_equal Indicates if nulls are equal + * @tparam Nullate A cudf::nullate type describing whether to check for nulls. + * @tparam PhysicalEqualityComparator A equality comparator functor that compares individual + * values rather than logical elements, defaults to a comparator for which `NaN == NaN`. + * @param nullate Indicates if any input column contains nulls. + * @param nulls_are_equal Indicates if nulls are equal. + * @param comparator Physical element equality comparison functor. * @return A binary callable object */ - template - device_row_comparator device_comparator( - Nullate nullate = {}, null_equality nulls_are_equal = null_equality::EQUAL) const + template + auto equal_to(Nullate nullate = {}, + null_equality nulls_are_equal = null_equality::EQUAL, + PhysicalEqualityComparator comparator = {}) const noexcept { - return device_row_comparator(nullate, *d_t, *d_t, nulls_are_equal); + return device_row_comparator{nullate, *d_t, *d_t, nulls_are_equal, comparator}; } private: @@ -984,6 +1198,8 @@ class self_comparator { // @cond template struct strong_index_comparator_adapter { + strong_index_comparator_adapter(Comparator const& comparator) : comparator{comparator} {} + __device__ constexpr bool operator()(lhs_index_type const lhs_index, rhs_index_type const rhs_index) const noexcept { @@ -1060,17 +1276,22 @@ class two_table_comparator { * Similarly, `F(rhs_index_type i, lhs_index_type j)` returns true if and only if row `i` of the * right table compares equal to row `j` of the left table. * - * @tparam Nullate A cudf::nullate type describing whether to check for nulls - * @param nullate Indicates if either input column contains nulls - * @param nulls_are_equal Indicates if nulls are equal + * @tparam Nullate A cudf::nullate type describing whether to check for nulls. + * @tparam PhysicalEqualityComparator A equality comparator functor that compares individual + * values rather than logical elements, defaults to a `NaN == NaN` equality comparator. + * @param nullate Indicates if any input column contains nulls. + * @param nulls_are_equal Indicates if nulls are equal. + * @param comparator Physical element equality comparison functor. * @return A binary callable object */ - template - auto device_comparator(Nullate nullate = {}, - null_equality nulls_are_equal = null_equality::EQUAL) const + template + auto equal_to(Nullate nullate = {}, + null_equality nulls_are_equal = null_equality::EQUAL, + PhysicalEqualityComparator comparator = {}) const noexcept { - return strong_index_comparator_adapter>{ - device_row_comparator(nullate, *d_left_table, *d_right_table, nulls_are_equal)}; + return strong_index_comparator_adapter{ + device_row_comparator(nullate, *d_left_table, *d_right_table, nulls_are_equal, comparator)}; } private: @@ -1294,7 +1515,7 @@ class row_hasher { * `F(i)` returns the hash of row i. * * @tparam Nullate A cudf::nullate type describing whether to check for nulls - * @param nullate Indicates if either input column contains nulls + * @param nullate Indicates if any input column contains nulls * @param seed The seed to use for the hash function * @return A hash operator to use on the device */ diff --git a/cpp/src/groupby/hash/groupby.cu b/cpp/src/groupby/hash/groupby.cu index ab8d0089347..c07833520ab 100644 --- a/cpp/src/groupby/hash/groupby.cu +++ b/cpp/src/groupby/hash/groupby.cu @@ -568,7 +568,7 @@ std::unique_ptr groupby(table_view const& keys, auto preprocessed_keys = cudf::experimental::row::hash::preprocessed_table::create(keys, stream); auto const comparator = cudf::experimental::row::equality::self_comparator{preprocessed_keys}; auto const row_hash = cudf::experimental::row::hash::row_hasher{std::move(preprocessed_keys)}; - auto const d_key_equal = comparator.device_comparator(has_null, null_keys_are_equal); + auto const d_key_equal = comparator.equal_to(has_null, null_keys_are_equal); auto const d_row_hash = row_hash.device_hasher(has_null); size_type constexpr unused_key{std::numeric_limits::max()}; diff --git a/cpp/src/reductions/scan/rank_scan.cu b/cpp/src/reductions/scan/rank_scan.cu index 0ababbf0a3d..c6909bfd601 100644 --- a/cpp/src/reductions/scan/rank_scan.cu +++ b/cpp/src/reductions/scan/rank_scan.cu @@ -53,7 +53,7 @@ std::unique_ptr rank_generator(column_view const& order_by, { auto comp = cudf::experimental::row::equality::self_comparator(table_view{{order_by}}, stream); auto const device_comparator = - comp.device_comparator(nullate::DYNAMIC{has_nested_nulls(table_view({order_by}))}); + comp.equal_to(nullate::DYNAMIC{has_nested_nulls(table_view({order_by}))}); auto ranks = make_fixed_width_column( data_type{type_to_id()}, order_by.size(), mask_state::UNALLOCATED, stream, mr); auto mutable_ranks = ranks->mutable_view(); diff --git a/cpp/src/search/contains_nested.cu b/cpp/src/search/contains_nested.cu index c3143b12e90..f4332efb23f 100644 --- a/cpp/src/search/contains_nested.cu +++ b/cpp/src/search/contains_nested.cu @@ -37,7 +37,7 @@ bool contains_nested_element(column_view const& haystack, auto const comparator = cudf::experimental::row::equality::two_table_comparator(haystack_tv, needle_tv, stream); - auto const d_comp = comparator.device_comparator(nullate::DYNAMIC{has_nulls}); + auto const d_comp = comparator.equal_to(nullate::DYNAMIC{has_nulls}); auto const begin = cudf::experimental::row::lhs_iterator(0); auto const end = begin + haystack.size(); diff --git a/cpp/src/search/search_ordered.cu b/cpp/src/search/search_ordered.cu index 2ae776420e2..01b990facdc 100644 --- a/cpp/src/search/search_ordered.cu +++ b/cpp/src/search/search_ordered.cu @@ -68,7 +68,7 @@ std::unique_ptr search_ordered(table_view const& haystack, auto const comparator = cudf::experimental::row::lexicographic::two_table_comparator( matched_haystack, matched_needles, column_order, null_precedence, stream); auto const has_nulls = has_nested_nulls(matched_haystack) or has_nested_nulls(matched_needles); - auto const d_comparator = comparator.device_comparator(nullate::DYNAMIC{has_nulls}); + auto const d_comparator = comparator.less(nullate::DYNAMIC{has_nulls}); auto const haystack_it = cudf::experimental::row::lhs_iterator(0); auto const needles_it = cudf::experimental::row::rhs_iterator(0); diff --git a/cpp/src/sort/sort_impl.cuh b/cpp/src/sort/sort_impl.cuh index 7f84c49a417..f98fda307b8 100644 --- a/cpp/src/sort/sort_impl.cuh +++ b/cpp/src/sort/sort_impl.cuh @@ -127,7 +127,7 @@ std::unique_ptr sorted_order(table_view input, auto comp = experimental::row::lexicographic::self_comparator(input, column_order, null_precedence, stream); - auto comparator = comp.device_comparator(nullate::DYNAMIC{has_nested_nulls(input)}); + auto comparator = comp.less(nullate::DYNAMIC{has_nested_nulls(input)}); if (stable) { thrust::stable_sort(rmm::exec_policy(stream), diff --git a/cpp/src/stream_compaction/distinct.cu b/cpp/src/stream_compaction/distinct.cu index d3b31dccf77..91d8c85120f 100644 --- a/cpp/src/stream_compaction/distinct.cu +++ b/cpp/src/stream_compaction/distinct.cu @@ -73,7 +73,7 @@ std::unique_ptr
distinct(table_view const& input, experimental::compaction_hash hash_key(row_hash.device_hasher(has_null)); cudf::experimental::row::equality::self_comparator row_equal(preprocessed_keys); - auto key_equal = row_equal.device_comparator(has_null, nulls_equal); + auto key_equal = row_equal.equal_to(has_null, nulls_equal); auto iter = cudf::detail::make_counting_transform_iterator( 0, [] __device__(size_type i) { return cuco::make_pair(i, i); }); diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index c85b10b4eb8..816c5a1c59c 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -310,6 +310,7 @@ ConfigureTest(TRANSPOSE_TEST transpose/transpose_test.cpp) # * table tests ----------------------------------------------------------------------------------- ConfigureTest( TABLE_TEST table/table_tests.cpp table/table_view_tests.cu table/row_operators_tests.cpp + table/experimental_row_operator_tests.cu ) # ################################################################################################## diff --git a/cpp/tests/groupby/lists_tests.cu b/cpp/tests/groupby/lists_tests.cu index 7c145271662..81322d87747 100644 --- a/cpp/tests/groupby/lists_tests.cu +++ b/cpp/tests/groupby/lists_tests.cu @@ -118,7 +118,7 @@ inline void test_hash_based_sum_agg(column_view const& keys, auto const null_keys_are_equal = include_null_keys == null_policy::INCLUDE ? null_equality::EQUAL : null_equality::UNEQUAL; - auto row_equal = comparator.device_comparator(nullate::DYNAMIC{true}, null_keys_are_equal); + auto row_equal = comparator.equal_to(nullate::DYNAMIC{true}, null_keys_are_equal); auto func = match_expected_fn{num_rows, row_equal}; // For each row in expected table `t[0, num_rows)`, there must be a match diff --git a/cpp/tests/table/experimental_row_operator_tests.cu b/cpp/tests/table/experimental_row_operator_tests.cu new file mode 100644 index 00000000000..6b392fe57fe --- /dev/null +++ b/cpp/tests/table/experimental_row_operator_tests.cu @@ -0,0 +1,231 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +#include +#include + +using namespace cudf::test; +using namespace cudf::experimental::row; + +template +struct TypedTableViewTest : public cudf::test::BaseFixture { +}; + +using NumericTypesNotBool = Concat; +TYPED_TEST_SUITE(TypedTableViewTest, NumericTypesNotBool); + +template +auto self_comparison(cudf::table_view input, + std::vector const& column_order, + PhysicalElementComparator comparator) +{ + rmm::cuda_stream_view stream{}; + + auto const table_comparator = lexicographic::self_comparator{input, column_order, {}, stream}; + auto const less_comparator = table_comparator.less(cudf::nullate::NO{}, comparator); + + auto output = cudf::make_numeric_column( + cudf::data_type(cudf::type_id::BOOL8), input.num_rows(), cudf::mask_state::UNALLOCATED); + + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(input.num_rows()), + thrust::make_counting_iterator(0), + output->mutable_view().data(), + less_comparator); + return output; +} + +template +auto two_table_comparison(cudf::table_view lhs, + cudf::table_view rhs, + std::vector const& column_order, + PhysicalElementComparator comparator) +{ + rmm::cuda_stream_view stream{}; + + auto const table_comparator = + lexicographic::two_table_comparator{lhs, rhs, column_order, {}, stream}; + auto const less_comparator = table_comparator.less(cudf::nullate::NO{}, comparator); + auto const lhs_it = cudf::experimental::row::lhs_iterator(0); + auto const rhs_it = cudf::experimental::row::rhs_iterator(0); + + auto output = cudf::make_numeric_column( + cudf::data_type(cudf::type_id::BOOL8), lhs.num_rows(), cudf::mask_state::UNALLOCATED); + + thrust::transform(rmm::exec_policy(stream), + lhs_it, + lhs_it + lhs.num_rows(), + rhs_it, + output->mutable_view().data(), + less_comparator); + return output; +} + +template +auto self_equality(cudf::table_view input, + std::vector const& column_order, + PhysicalElementComparator comparator) +{ + rmm::cuda_stream_view stream{}; + + auto const table_comparator = equality::self_comparator{input, stream}; + auto const equal_comparator = + table_comparator.equal_to(cudf::nullate::NO{}, cudf::null_equality::EQUAL, comparator); + + auto output = cudf::make_numeric_column( + cudf::data_type(cudf::type_id::BOOL8), input.num_rows(), cudf::mask_state::UNALLOCATED); + + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(input.num_rows()), + thrust::make_counting_iterator(0), + output->mutable_view().data(), + equal_comparator); + return output; +} + +template +auto two_table_equality(cudf::table_view lhs, + cudf::table_view rhs, + std::vector const& column_order, + PhysicalElementComparator comparator) +{ + rmm::cuda_stream_view stream{}; + + auto const table_comparator = equality::two_table_comparator{lhs, rhs, stream}; + auto const equal_comparator = + table_comparator.equal_to(cudf::nullate::NO{}, cudf::null_equality::EQUAL, comparator); + auto const lhs_it = cudf::experimental::row::lhs_iterator(0); + auto const rhs_it = cudf::experimental::row::rhs_iterator(0); + + auto output = cudf::make_numeric_column( + cudf::data_type(cudf::type_id::BOOL8), lhs.num_rows(), cudf::mask_state::UNALLOCATED); + + thrust::transform(rmm::exec_policy(stream), + lhs_it, + lhs_it + lhs.num_rows(), + rhs_it, + output->mutable_view().data(), + equal_comparator); + return output; +} + +TYPED_TEST(TypedTableViewTest, TestLexicographicalComparatorTwoTables) +{ + using T = TypeParam; + + auto const col1 = fixed_width_column_wrapper{{1, 2, 3, 4}}; + auto const col2 = fixed_width_column_wrapper{{0, 1, 4, 3}}; + auto const column_order = std::vector{cudf::order::DESCENDING}; + auto const lhs = cudf::table_view{{col1}}; + auto const rhs = cudf::table_view{{col2}}; + + auto const expected = fixed_width_column_wrapper{{1, 1, 0, 1}}; + auto const got = + two_table_comparison(lhs, rhs, column_order, lexicographic::physical_element_comparator{}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, got->view()); + + auto const sorting_got = two_table_comparison( + lhs, rhs, column_order, lexicographic::sorting_physical_element_comparator{}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, sorting_got->view()); +} + +TYPED_TEST(TypedTableViewTest, TestLexicographicalComparatorSameTable) +{ + using T = TypeParam; + + auto const col1 = fixed_width_column_wrapper{{1, 2, 3, 4}}; + auto const column_order = std::vector{cudf::order::DESCENDING}; + auto const input_table = cudf::table_view{{col1}}; + + auto const expected = fixed_width_column_wrapper{{0, 0, 0, 0}}; + auto const got = + self_comparison(input_table, column_order, lexicographic::physical_element_comparator{}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, got->view()); + + auto const sorting_got = self_comparison( + input_table, column_order, lexicographic::sorting_physical_element_comparator{}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, sorting_got->view()); +} + +template +struct NaNTableViewTest : public cudf::test::BaseFixture { +}; + +TYPED_TEST_SUITE(NaNTableViewTest, FloatingPointTypes); + +TYPED_TEST(NaNTableViewTest, TestLexicographicalComparatorTwoTableNaNCase) +{ + using T = TypeParam; + + auto const col1 = fixed_width_column_wrapper{{T(NAN), T(NAN), T(1), T(1)}}; + auto const col2 = fixed_width_column_wrapper{{T(NAN), T(1), T(NAN), T(1)}}; + auto const column_order = std::vector{cudf::order::DESCENDING}; + + auto const lhs = cudf::table_view{{col1}}; + auto const rhs = cudf::table_view{{col2}}; + + auto const expected = fixed_width_column_wrapper{{0, 0, 0, 0}}; + auto const got = + two_table_comparison(lhs, rhs, column_order, lexicographic::physical_element_comparator{}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, got->view()); + + auto const sorting_expected = fixed_width_column_wrapper{{0, 1, 0, 0}}; + auto const sorting_got = two_table_comparison( + lhs, rhs, column_order, lexicographic::sorting_physical_element_comparator{}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(sorting_expected, sorting_got->view()); +} + +TYPED_TEST(NaNTableViewTest, TestEqualityComparatorTwoTableNaNCase) +{ + using T = TypeParam; + + auto const col1 = fixed_width_column_wrapper{{T(NAN), T(NAN), T(1), T(1)}}; + auto const col2 = fixed_width_column_wrapper{{T(NAN), T(1), T(NAN), T(1)}}; + auto const column_order = std::vector{cudf::order::DESCENDING}; + + auto const lhs = cudf::table_view{{col1}}; + auto const rhs = cudf::table_view{{col2}}; + + auto const expected = fixed_width_column_wrapper{{0, 0, 0, 1}}; + auto const got = + two_table_equality(lhs, rhs, column_order, equality::physical_equality_comparator{}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, got->view()); + + auto const nan_equal_expected = fixed_width_column_wrapper{{1, 0, 0, 1}}; + auto const nan_equal_got = + two_table_equality(lhs, rhs, column_order, equality::nan_equal_physical_equality_comparator{}); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(nan_equal_expected, nan_equal_got->view()); +}