From 6c45cd4a66a7b8a46f6e50fe52c359f98cb383ae Mon Sep 17 00:00:00 2001 From: Devavret Makkar Date: Mon, 21 Mar 2022 14:15:06 +0530 Subject: [PATCH] Namespace changes and making element comparator private --- .../cudf/table/experimental/row_operators.cuh | 208 +++++++++--------- cpp/src/sort/sort_impl.cuh | 3 +- cpp/src/table/row_operators.cu | 7 +- 3 files changed, 112 insertions(+), 106 deletions(-) diff --git a/cpp/include/cudf/table/experimental/row_operators.cuh b/cpp/include/cudf/table/experimental/row_operators.cuh index 9d0c3ad23f7..09f5217b1ac 100644 --- a/cpp/include/cudf/table/experimental/row_operators.cuh +++ b/cpp/include/cudf/table/experimental/row_operators.cuh @@ -59,109 +59,9 @@ struct dispatch_void_if_nested { using type = std::conditional_t>; }; -namespace lex { +namespace row { -/** - * @brief Performs a relational comparison between two elements in two columns. - * - * @tparam Nullate A cudf::nullate type describing how to check for nulls. - */ -template -class element_comparator { - public: - /** - * @brief Construct type-dispatched function object for performing a - * relational comparison between two elements. - * - * @note `lhs` and `rhs` may be the same. - * - * @param has_nulls Indicates if either input column contains nulls. - * @param lhs The column containing the first element - * @param rhs The column containing the second element (may be the same as lhs) - * @param null_precedence Indicates how null values are ordered with other values - * @param depth The depth of the column if part of a nested column @see preprocessed_table::depths - */ - __device__ element_comparator(Nullate has_nulls, - column_device_view lhs, - column_device_view rhs, - null_order null_precedence = null_order::BEFORE, - int depth = 0) - : _lhs{lhs}, _rhs{rhs}, _nulls{has_nulls}, _null_precedence{null_precedence}, _depth{depth} - { - } - - /** - * @brief Performs a relational comparison between the specified elements - * - * @param lhs_element_index The index of the first element - * @param rhs_element_index The index of the second element - * @return Indicates the relationship between the elements in the `lhs` and `rhs` columns, along - * with the depth at which a null value was encountered. - */ - template ())> - __device__ cuda::std::pair operator()( - size_type const lhs_element_index, size_type const rhs_element_index) const noexcept - { - if (_nulls) { - bool const lhs_is_null{_lhs.is_null(lhs_element_index)}; - bool const rhs_is_null{_rhs.is_null(rhs_element_index)}; - - if (lhs_is_null or rhs_is_null) { // at least one is null - return cuda::std::make_pair(null_compare(lhs_is_null, rhs_is_null, _null_precedence), - _depth); - } - } - - return cuda::std::make_pair(relational_compare(_lhs.element(lhs_element_index), - _rhs.element(rhs_element_index)), - std::numeric_limits::max()); - } - - template () and - not std::is_same_v)> - __device__ cuda::std::pair operator()(size_type const lhs_element_index, - size_type const rhs_element_index) - { - // TODO: make this CUDF_UNREACHABLE - cudf_assert(false && "Attempted to compare elements of uncomparable types."); - return cuda::std::make_pair(weak_ordering::LESS, std::numeric_limits::max()); - } - - template )> - __device__ cuda::std::pair operator()(size_type const lhs_element_index, - size_type const rhs_element_index) - { - column_device_view lcol = _lhs; - column_device_view rcol = _rhs; - int depth = _depth; - while (lcol.type().id() == type_id::STRUCT) { - bool const lhs_is_null{lcol.is_null(lhs_element_index)}; - bool const rhs_is_null{rcol.is_null(rhs_element_index)}; - - if (lhs_is_null or rhs_is_null) { // at least one is null - weak_ordering state = null_compare(lhs_is_null, rhs_is_null, _null_precedence); - return cuda::std::make_pair(state, depth); - } - - // Structs have been modified to only have 1 child when using this. - lcol = lcol.children()[0]; - rcol = rcol.children()[0]; - ++depth; - } - - auto const comparator = element_comparator{_nulls, lcol, rcol, _null_precedence, depth}; - return cudf::type_dispatcher( - lcol.type(), comparator, lhs_element_index, rhs_element_index); - } - - private: - column_device_view const _lhs; - column_device_view const _rhs; - Nullate const _nulls; - null_order const _null_precedence; - int const _depth; -}; +namespace lexicographic { /** * @brief Computes whether one row is lexicographically *less* than another row. @@ -214,6 +114,107 @@ class device_row_comparator { { } + /** + * @brief Performs a relational comparison between two elements in two columns. + */ + class element_comparator { + public: + /** + * @brief Construct type-dispatched function object for performing a + * relational comparison between two elements. + * + * @note `lhs` and `rhs` may be the same. + * + * @param has_nulls Indicates if either input column contains nulls. + * @param lhs The column containing the first element + * @param rhs The column containing the second element (may be the same as lhs) + * @param null_precedence Indicates how null values are ordered with other values + * @param depth The depth of the column if part of a nested column @see + * preprocessed_table::depths + */ + __device__ element_comparator(Nullate has_nulls, + column_device_view lhs, + column_device_view rhs, + null_order null_precedence = null_order::BEFORE, + int depth = 0) + : _lhs{lhs}, _rhs{rhs}, _nulls{has_nulls}, _null_precedence{null_precedence}, _depth{depth} + { + } + + /** + * @brief Performs a relational comparison between the specified elements + * + * @param lhs_element_index The index of the first element + * @param rhs_element_index The index of the second element + * @return Indicates the relationship between the elements in the `lhs` and `rhs` columns, along + * with the depth at which a null value was encountered. + */ + template ())> + __device__ cuda::std::pair operator()( + size_type const lhs_element_index, size_type const rhs_element_index) const noexcept + { + if (_nulls) { + bool const lhs_is_null{_lhs.is_null(lhs_element_index)}; + bool const rhs_is_null{_rhs.is_null(rhs_element_index)}; + + if (lhs_is_null or rhs_is_null) { // at least one is null + return cuda::std::make_pair(null_compare(lhs_is_null, rhs_is_null, _null_precedence), + _depth); + } + } + + return cuda::std::make_pair(relational_compare(_lhs.element(lhs_element_index), + _rhs.element(rhs_element_index)), + std::numeric_limits::max()); + } + + template () and + not std::is_same_v)> + __device__ cuda::std::pair operator()(size_type const lhs_element_index, + size_type const rhs_element_index) + { + // TODO: make this CUDF_UNREACHABLE + cudf_assert(false && "Attempted to compare elements of uncomparable types."); + return cuda::std::make_pair(weak_ordering::LESS, std::numeric_limits::max()); + } + + template )> + __device__ cuda::std::pair operator()(size_type const lhs_element_index, + size_type const rhs_element_index) + { + column_device_view lcol = _lhs; + column_device_view rcol = _rhs; + int depth = _depth; + while (lcol.type().id() == type_id::STRUCT) { + bool const lhs_is_null{lcol.is_null(lhs_element_index)}; + bool const rhs_is_null{rcol.is_null(rhs_element_index)}; + + if (lhs_is_null or rhs_is_null) { // at least one is null + weak_ordering state = null_compare(lhs_is_null, rhs_is_null, _null_precedence); + return cuda::std::make_pair(state, depth); + } + + // Structs have been modified to only have 1 child when using this. + lcol = lcol.children()[0]; + rcol = rcol.children()[0]; + ++depth; + } + + auto const comparator = element_comparator{_nulls, lcol, rcol, _null_precedence, depth}; + return cudf::type_dispatcher( + lcol.type(), comparator, lhs_element_index, rhs_element_index); + } + + private: + column_device_view const _lhs; + column_device_view const _rhs; + Nullate const _nulls; + null_order const _null_precedence; + int const _depth; + }; + public: /** * @brief Checks whether the row at `lhs_index` in the `lhs` table compares @@ -417,6 +418,7 @@ class self_comparator { std::shared_ptr d_t; }; -} // namespace lex +} // namespace lexicographic +} // namespace row } // namespace experimental } // namespace cudf diff --git a/cpp/src/sort/sort_impl.cuh b/cpp/src/sort/sort_impl.cuh index 368bb3b03c5..2f093fd7d2d 100644 --- a/cpp/src/sort/sort_impl.cuh +++ b/cpp/src/sort/sort_impl.cuh @@ -124,7 +124,8 @@ std::unique_ptr sorted_order(table_view input, mutable_indices_view.end(), 0); - auto comp = experimental::lex::self_comparator(input, column_order, null_precedence, stream); + auto comp = + experimental::row::lexicographic::self_comparator(input, column_order, null_precedence, stream); auto comparator = comp.device_comparator(nullate::DYNAMIC{has_nested_nulls(input)}); if (stable) { diff --git a/cpp/src/table/row_operators.cu b/cpp/src/table/row_operators.cu index c6236f5419e..0a9396ccdf7 100644 --- a/cpp/src/table/row_operators.cu +++ b/cpp/src/table/row_operators.cu @@ -177,7 +177,9 @@ void check_lex_compatibility(table_view const& input) } // namespace -namespace lex { +namespace row { + +namespace lexicographic { std::shared_ptr preprocessed_table::create( table_view const& t, @@ -199,6 +201,7 @@ std::shared_ptr preprocessed_table::create( std::move(d_t), std::move(d_column_order), std::move(d_null_precedence), std::move(d_depths))); } -} // namespace lex +} // namespace lexicographic +} // namespace row } // namespace experimental } // namespace cudf