Skip to content

Commit

Permalink
Namespace changes and making element comparator private
Browse files Browse the repository at this point in the history
  • Loading branch information
devavret committed Mar 21, 2022
1 parent de95530 commit 6c45cd4
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 106 deletions.
208 changes: 105 additions & 103 deletions cpp/include/cudf/table/experimental/row_operators.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -59,109 +59,9 @@ struct dispatch_void_if_nested {
using type = std::conditional_t<cudf::is_nested(data_type(t)), void, id_to_type<t>>;
};

namespace lex {
namespace row {

/**
* @brief Performs a relational comparison between two elements in two columns.
*
* @tparam Nullate A cudf::nullate type describing how to check for nulls.
*/
template <typename Nullate>
class element_comparator {
public:
/**
* @brief Construct type-dispatched function object for performing a
* relational comparison between two elements.
*
* @note `lhs` and `rhs` may be the same.
*
* @param has_nulls Indicates if either input column contains nulls.
* @param lhs The column containing the first element
* @param rhs The column containing the second element (may be the same as lhs)
* @param null_precedence Indicates how null values are ordered with other values
* @param depth The depth of the column if part of a nested column @see preprocessed_table::depths
*/
__device__ element_comparator(Nullate has_nulls,
column_device_view lhs,
column_device_view rhs,
null_order null_precedence = null_order::BEFORE,
int depth = 0)
: _lhs{lhs}, _rhs{rhs}, _nulls{has_nulls}, _null_precedence{null_precedence}, _depth{depth}
{
}

/**
* @brief Performs a relational comparison between the specified elements
*
* @param lhs_element_index The index of the first element
* @param rhs_element_index The index of the second element
* @return Indicates the relationship between the elements in the `lhs` and `rhs` columns, along
* with the depth at which a null value was encountered.
*/
template <typename Element, CUDF_ENABLE_IF(cudf::is_relationally_comparable<Element, Element>())>
__device__ cuda::std::pair<weak_ordering, int> operator()(
size_type const lhs_element_index, size_type const rhs_element_index) const noexcept
{
if (_nulls) {
bool const lhs_is_null{_lhs.is_null(lhs_element_index)};
bool const rhs_is_null{_rhs.is_null(rhs_element_index)};

if (lhs_is_null or rhs_is_null) { // at least one is null
return cuda::std::make_pair(null_compare(lhs_is_null, rhs_is_null, _null_precedence),
_depth);
}
}

return cuda::std::make_pair(relational_compare(_lhs.element<Element>(lhs_element_index),
_rhs.element<Element>(rhs_element_index)),
std::numeric_limits<int>::max());
}

template <typename Element,
CUDF_ENABLE_IF(not cudf::is_relationally_comparable<Element, Element>() and
not std::is_same_v<Element, cudf::struct_view>)>
__device__ cuda::std::pair<weak_ordering, int> operator()(size_type const lhs_element_index,
size_type const rhs_element_index)
{
// TODO: make this CUDF_UNREACHABLE
cudf_assert(false && "Attempted to compare elements of uncomparable types.");
return cuda::std::make_pair(weak_ordering::LESS, std::numeric_limits<int>::max());
}

template <typename Element, CUDF_ENABLE_IF(std::is_same_v<Element, cudf::struct_view>)>
__device__ cuda::std::pair<weak_ordering, int> operator()(size_type const lhs_element_index,
size_type const rhs_element_index)
{
column_device_view lcol = _lhs;
column_device_view rcol = _rhs;
int depth = _depth;
while (lcol.type().id() == type_id::STRUCT) {
bool const lhs_is_null{lcol.is_null(lhs_element_index)};
bool const rhs_is_null{rcol.is_null(rhs_element_index)};

if (lhs_is_null or rhs_is_null) { // at least one is null
weak_ordering state = null_compare(lhs_is_null, rhs_is_null, _null_precedence);
return cuda::std::make_pair(state, depth);
}

// Structs have been modified to only have 1 child when using this.
lcol = lcol.children()[0];
rcol = rcol.children()[0];
++depth;
}

auto const comparator = element_comparator{_nulls, lcol, rcol, _null_precedence, depth};
return cudf::type_dispatcher<dispatch_void_if_nested>(
lcol.type(), comparator, lhs_element_index, rhs_element_index);
}

private:
column_device_view const _lhs;
column_device_view const _rhs;
Nullate const _nulls;
null_order const _null_precedence;
int const _depth;
};
namespace lexicographic {

/**
* @brief Computes whether one row is lexicographically *less* than another row.
Expand Down Expand Up @@ -214,6 +114,107 @@ class device_row_comparator {
{
}

/**
* @brief Performs a relational comparison between two elements in two columns.
*/
class element_comparator {
public:
/**
* @brief Construct type-dispatched function object for performing a
* relational comparison between two elements.
*
* @note `lhs` and `rhs` may be the same.
*
* @param has_nulls Indicates if either input column contains nulls.
* @param lhs The column containing the first element
* @param rhs The column containing the second element (may be the same as lhs)
* @param null_precedence Indicates how null values are ordered with other values
* @param depth The depth of the column if part of a nested column @see
* preprocessed_table::depths
*/
__device__ element_comparator(Nullate has_nulls,
column_device_view lhs,
column_device_view rhs,
null_order null_precedence = null_order::BEFORE,
int depth = 0)
: _lhs{lhs}, _rhs{rhs}, _nulls{has_nulls}, _null_precedence{null_precedence}, _depth{depth}
{
}

/**
* @brief Performs a relational comparison between the specified elements
*
* @param lhs_element_index The index of the first element
* @param rhs_element_index The index of the second element
* @return Indicates the relationship between the elements in the `lhs` and `rhs` columns, along
* with the depth at which a null value was encountered.
*/
template <typename Element,
CUDF_ENABLE_IF(cudf::is_relationally_comparable<Element, Element>())>
__device__ cuda::std::pair<weak_ordering, int> operator()(
size_type const lhs_element_index, size_type const rhs_element_index) const noexcept
{
if (_nulls) {
bool const lhs_is_null{_lhs.is_null(lhs_element_index)};
bool const rhs_is_null{_rhs.is_null(rhs_element_index)};

if (lhs_is_null or rhs_is_null) { // at least one is null
return cuda::std::make_pair(null_compare(lhs_is_null, rhs_is_null, _null_precedence),
_depth);
}
}

return cuda::std::make_pair(relational_compare(_lhs.element<Element>(lhs_element_index),
_rhs.element<Element>(rhs_element_index)),
std::numeric_limits<int>::max());
}

template <typename Element,
CUDF_ENABLE_IF(not cudf::is_relationally_comparable<Element, Element>() and
not std::is_same_v<Element, cudf::struct_view>)>
__device__ cuda::std::pair<weak_ordering, int> operator()(size_type const lhs_element_index,
size_type const rhs_element_index)
{
// TODO: make this CUDF_UNREACHABLE
cudf_assert(false && "Attempted to compare elements of uncomparable types.");
return cuda::std::make_pair(weak_ordering::LESS, std::numeric_limits<int>::max());
}

template <typename Element, CUDF_ENABLE_IF(std::is_same_v<Element, cudf::struct_view>)>
__device__ cuda::std::pair<weak_ordering, int> operator()(size_type const lhs_element_index,
size_type const rhs_element_index)
{
column_device_view lcol = _lhs;
column_device_view rcol = _rhs;
int depth = _depth;
while (lcol.type().id() == type_id::STRUCT) {
bool const lhs_is_null{lcol.is_null(lhs_element_index)};
bool const rhs_is_null{rcol.is_null(rhs_element_index)};

if (lhs_is_null or rhs_is_null) { // at least one is null
weak_ordering state = null_compare(lhs_is_null, rhs_is_null, _null_precedence);
return cuda::std::make_pair(state, depth);
}

// Structs have been modified to only have 1 child when using this.
lcol = lcol.children()[0];
rcol = rcol.children()[0];
++depth;
}

auto const comparator = element_comparator{_nulls, lcol, rcol, _null_precedence, depth};
return cudf::type_dispatcher<dispatch_void_if_nested>(
lcol.type(), comparator, lhs_element_index, rhs_element_index);
}

private:
column_device_view const _lhs;
column_device_view const _rhs;
Nullate const _nulls;
null_order const _null_precedence;
int const _depth;
};

public:
/**
* @brief Checks whether the row at `lhs_index` in the `lhs` table compares
Expand Down Expand Up @@ -417,6 +418,7 @@ class self_comparator {
std::shared_ptr<preprocessed_table> d_t;
};

} // namespace lex
} // namespace lexicographic
} // namespace row
} // namespace experimental
} // namespace cudf
3 changes: 2 additions & 1 deletion cpp/src/sort/sort_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ std::unique_ptr<column> sorted_order(table_view input,
mutable_indices_view.end<size_type>(),
0);

auto comp = experimental::lex::self_comparator(input, column_order, null_precedence, stream);
auto comp =
experimental::row::lexicographic::self_comparator(input, column_order, null_precedence, stream);
auto comparator = comp.device_comparator(nullate::DYNAMIC{has_nested_nulls(input)});

if (stable) {
Expand Down
7 changes: 5 additions & 2 deletions cpp/src/table/row_operators.cu
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,9 @@ void check_lex_compatibility(table_view const& input)

} // namespace

namespace lex {
namespace row {

namespace lexicographic {

std::shared_ptr<preprocessed_table> preprocessed_table::create(
table_view const& t,
Expand All @@ -199,6 +201,7 @@ std::shared_ptr<preprocessed_table> preprocessed_table::create(
std::move(d_t), std::move(d_column_order), std::move(d_null_precedence), std::move(d_depths)));
}

} // namespace lex
} // namespace lexicographic
} // namespace row
} // namespace experimental
} // namespace cudf

0 comments on commit 6c45cd4

Please sign in to comment.