-
Notifications
You must be signed in to change notification settings - Fork 915
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Return weak orderings from device_row_comparator
.
#10793
Changes from 5 commits
bf1c6ee
5d87db2
fd716b9
2dd2045
7ba960e
84833e7
4d197ea
08092fe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,6 +45,7 @@ | |
#include <utility> | ||
|
||
namespace cudf { | ||
|
||
namespace experimental { | ||
|
||
/** | ||
|
@@ -68,16 +69,17 @@ struct dispatch_void_if_nested { | |
}; | ||
|
||
namespace row { | ||
|
||
namespace lexicographic { | ||
|
||
/** | ||
* @brief Computes whether one row is lexicographically *less* than another row. | ||
* @brief Computes the lexicographic comparison between 2 rows. | ||
* | ||
* Lexicographic ordering is determined by: | ||
* - Two rows are compared element by element. | ||
* - The first mismatching element defines which row is lexicographically less | ||
* or greater than the other. | ||
* - If the rows are compared without mismatched elements, the rows are equivalent | ||
* | ||
* | ||
* Lexicographic ordering is exactly equivalent to doing an alphabetical sort of | ||
* two words, for example, `aac` would be *less* than (or precede) `abb`. The | ||
|
@@ -88,8 +90,8 @@ namespace lexicographic { | |
*/ | ||
template <typename Nullate> | ||
class device_row_comparator { | ||
// friend class device_less_comparator<Nullate>; | ||
rwlee marked this conversation as resolved.
Show resolved
Hide resolved
|
||
friend class self_comparator; | ||
|
||
/** | ||
* @brief Construct a function object for performing a lexicographic | ||
* comparison between the rows of two tables. | ||
|
@@ -145,7 +147,11 @@ class device_row_comparator { | |
column_device_view rhs, | ||
null_order null_precedence = null_order::BEFORE, | ||
int depth = 0) | ||
: _lhs{lhs}, _rhs{rhs}, _nulls{check_nulls}, _null_precedence{null_precedence}, _depth{depth} | ||
: _lhs{lhs}, | ||
_rhs{rhs}, | ||
_check_nulls{check_nulls}, | ||
_null_precedence{null_precedence}, | ||
_depth{depth} | ||
{ | ||
} | ||
|
||
|
@@ -162,7 +168,7 @@ class device_row_comparator { | |
__device__ cuda::std::pair<weak_ordering, int> operator()( | ||
size_type const lhs_element_index, size_type const rhs_element_index) const noexcept | ||
{ | ||
if (_nulls) { | ||
if (_check_nulls) { | ||
bool const lhs_is_null{_lhs.is_null(lhs_element_index)}; | ||
bool const rhs_is_null{_rhs.is_null(rhs_element_index)}; | ||
|
||
|
@@ -211,29 +217,30 @@ class device_row_comparator { | |
++depth; | ||
} | ||
|
||
auto const comparator = element_comparator{_nulls, lcol, rcol, _null_precedence, depth}; | ||
auto const comparator = element_comparator{_check_nulls, lcol, rcol, _null_precedence, depth}; | ||
return cudf::type_dispatcher<dispatch_void_if_nested>( | ||
lcol.type(), comparator, lhs_element_index, rhs_element_index); | ||
} | ||
|
||
private: | ||
column_device_view const _lhs; | ||
column_device_view const _rhs; | ||
Nullate const _nulls; | ||
Nullate const _check_nulls; | ||
null_order const _null_precedence; | ||
int const _depth; | ||
}; | ||
|
||
public: | ||
/** | ||
* @brief Checks whether the row at `lhs_index` in the `lhs` table compares | ||
* lexicographically less than the row at `rhs_index` in the `rhs` table. | ||
* lexicographically less, greater, or equivalent to the row at `rhs_index` in the `rhs` table. | ||
* | ||
* @param lhs_index The index of row in the `lhs` table to examine | ||
* @param rhs_index The index of the row in the `rhs` table to examine | ||
* @return `true` if row from the `lhs` table compares less than row in the `rhs` table | ||
* @return weak ordering comparison of the row in the `lhs` table relative to the row in the `rhs` | ||
* table | ||
*/ | ||
__device__ bool operator()(size_type const lhs_index, size_type const rhs_index) const noexcept | ||
__device__ weak_ordering operator()(size_type lhs_index, size_type rhs_index) const noexcept | ||
{ | ||
int last_null_depth = std::numeric_limits<int>::max(); | ||
for (size_type i = 0; i < _lhs.num_columns(); ++i) { | ||
|
@@ -248,16 +255,17 @@ class device_row_comparator { | |
|
||
auto const comparator = | ||
element_comparator{_check_nulls, _lhs.column(i), _rhs.column(i), null_precedence, depth}; | ||
|
||
weak_ordering state; | ||
cuda::std::tie(state, last_null_depth) = | ||
cudf::type_dispatcher(_lhs.column(i).type(), comparator, lhs_index, rhs_index); | ||
|
||
if (state == weak_ordering::EQUIVALENT) { continue; } | ||
|
||
return state == (ascending ? weak_ordering::LESS : weak_ordering::GREATER); | ||
return ascending | ||
? state | ||
: (state == weak_ordering::GREATER ? weak_ordering::LESS : weak_ordering::GREATER); | ||
} | ||
return false; | ||
return weak_ordering::EQUIVALENT; | ||
} | ||
|
||
private: | ||
|
@@ -269,6 +277,37 @@ class device_row_comparator { | |
std::optional<device_span<null_order const>> const _null_precedence; | ||
}; // class device_row_comparator | ||
|
||
/** | ||
* @brief Wraps and interprets the result of templated Comparator that returns a weak_ordering. | ||
* Returns true if the weak_ordering matches any of the templated values. | ||
rwlee marked this conversation as resolved.
Show resolved
Hide resolved
|
||
* | ||
* @tparam Comparator generic comparator that returns a weak_ordering. | ||
* @tparam values weak_ordering parameter pack of orderings to interpret as true | ||
*/ | ||
template <typename Comparator, weak_ordering... values> | ||
struct weak_ordering_comparator_impl{ | ||
__device__ bool operator()(size_type const& lhs, size_type const& rhs){ | ||
weak_ordering const result = comparator(lhs, rhs); | ||
return ( (result == values) || ...); | ||
} | ||
Comparator comparator; | ||
}; | ||
|
||
/** | ||
* @brief Wraps and interprets the result of device_row_comparator, true if the result is | ||
* weak_ordering::LESS meaning one row is lexicographically *less* than another row. | ||
* | ||
* @tparam Nullate A cudf::nullate type describing whether to check for nulls. | ||
*/ | ||
template <typename Nullate> | ||
using less_comparator = | ||
weak_ordering_comparator_impl<device_row_comparator<Nullate>, weak_ordering::LESS>; | ||
|
||
template <typename Nullate> | ||
using less_equivalent_comparator = weak_ordering_comparator_impl<device_row_comparator<Nullate>, | ||
weak_ordering::LESS, | ||
weak_ordering::EQUIVALENT>; | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be useful to have aliases for the remaining comparator too (i.e., There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Then how about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. #9452 uses |
||
struct preprocessed_table { | ||
using table_device_view_owner = | ||
std::invoke_result_t<decltype(table_device_view::create), table_view, rmm::cuda_stream_view>; | ||
|
@@ -417,10 +456,10 @@ class self_comparator { | |
* @tparam Nullate A cudf::nullate type describing whether to check for nulls. | ||
*/ | ||
template <typename Nullate> | ||
device_row_comparator<Nullate> device_comparator(Nullate nullate = {}) const | ||
less_comparator<Nullate> device_comparator(Nullate nullate = {}) const | ||
{ | ||
return device_row_comparator( | ||
nullate, *d_t, *d_t, d_t->depths(), d_t->column_order(), d_t->null_precedence()); | ||
return less_comparator<Nullate>{device_row_comparator<Nullate>( | ||
nullate, *d_t, *d_t, d_t->depths(), d_t->column_order(), d_t->null_precedence())}; | ||
} | ||
|
||
private: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file has several layers of nested namespaces. For consistency, I would recommend leaving this unchanged.