From f2015f1defccdb9685bea4aaa62b586195fbd86d Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Mon, 25 Nov 2024 16:32:26 -0800 Subject: [PATCH 1/6] Add primitive_row_operator header --- .../cudf/table/primitive_row_operators.cuh | 448 ++++++++++++++++++ 1 file changed, 448 insertions(+) create mode 100644 cpp/include/cudf/table/primitive_row_operators.cuh diff --git a/cpp/include/cudf/table/primitive_row_operators.cuh b/cpp/include/cudf/table/primitive_row_operators.cuh new file mode 100644 index 00000000000..d5683f9e6bb --- /dev/null +++ b/cpp/include/cudf/table/primitive_row_operators.cuh @@ -0,0 +1,448 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace CUDF_EXPORT cudf { +namespace row::primitive { +namespace detail { +/** + * @brief Compare the elements ordering with respect to `lhs`. + * + * @param lhs first element + * @param rhs second element + * @return Indicates the relationship between the elements in + * the `lhs` and `rhs` columns. + */ +template +__device__ weak_ordering compare_elements(Element lhs, Element rhs) +{ + if (lhs < rhs) { + return weak_ordering::LESS; + } else if (rhs < lhs) { + return weak_ordering::GREATER; + } + return weak_ordering::EQUIVALENT; +} +} // namespace detail + +/** + * @brief A specialization for floating-point `Element` type relational comparison + * to derive the order of the elements with respect to `lhs`. + */ +template )> +__device__ weak_ordering relational_compare(Element lhs, Element rhs) +{ + if (isnan(lhs) and isnan(rhs)) { + return weak_ordering::EQUIVALENT; + } else if (isnan(rhs)) { + return weak_ordering::LESS; + } else if (isnan(lhs)) { + return weak_ordering::GREATER; + } + + return detail::compare_elements(lhs, rhs); +} + +/** + * @brief A specialization for non-floating-point `Element` type relational + * comparison to derive the order of the elements with respect to `lhs`. + * + * @param lhs The first element + * @param rhs The second element + * @return Indicates the relationship between the elements in the `lhs` and `rhs` columns + */ +template )> +__device__ weak_ordering relational_compare(Element lhs, Element rhs) +{ + return detail::compare_elements(lhs, rhs); +} + +/** + * @brief A specialization for floating-point `Element` type to check if + * `lhs` is equivalent to `rhs`. `nan == nan`. + * + * @param lhs first element + * @param rhs second element + * @return `true` if `lhs` == `rhs` else `false`. + */ +template )> +__device__ bool equality_compare(Element lhs, Element rhs) +{ + if (isnan(lhs) and isnan(rhs)) { return true; } + return lhs == rhs; +} + +/** + * @brief A specialization for non-floating-point `Element` type to check if + * `lhs` is equivalent to `rhs`. + * + * @param lhs first element + * @param rhs second element + * @return `true` if `lhs` == `rhs` else `false`. + */ +template )> +__device__ bool equality_compare(Element const lhs, Element const rhs) +{ + return lhs == rhs; +} + +/** + * @brief Performs an equality comparison between two elements in two columns. + */ +class element_equality_comparator { + public: + /** + * @brief Construct type-dispatched function object for comparing equality + * between two elements. + * + * @note `lhs` and `rhs` may be the same. + * + * @param lhs The column containing the first element + * @param rhs The column containing the second element (may be the same as lhs) + */ + __host__ __device__ element_equality_comparator(column_device_view lhs, column_device_view rhs) + : lhs{lhs}, rhs{rhs} + { + } + + /** + * @brief Compares the specified elements for equality. + * + * @param lhs_element_index The index of the first element + * @param rhs_element_index The index of the second element + * @return True if lhs and rhs element are equal + */ + template ())> + __device__ bool operator()(size_type lhs_element_index, + size_type rhs_element_index) const noexcept + { + return equality_compare(lhs.element(lhs_element_index), + rhs.element(rhs_element_index)); + } + + // @cond + template ())> + __device__ bool operator()(size_type lhs_element_index, size_type rhs_element_index) + { + CUDF_UNREACHABLE("Attempted to compare elements of uncomparable types."); + } + // @endcond + + private: + column_device_view lhs; + column_device_view rhs; +}; + +/** + * @brief Performs a relational comparison between two elements in two tables. + */ +class row_equality_comparator { + public: + /** + * @brief Construct a new row equality comparator object + * + * @param lhs The column containing the first element + * @param rhs The column containing the second element (may be the same as lhs) + */ + row_equality_comparator(table_device_view lhs, table_device_view rhs) : lhs{lhs}, rhs{rhs} + { + CUDF_EXPECTS(lhs.num_columns() == rhs.num_columns(), "Mismatched number of columns."); + } + + /** + * @brief Compares the specified rows for equality. + * + * @param lhs_row_index The index of the first row to compare (in the lhs table) + * @param rhs_row_index The index of the second row to compare (in the rhs table) + * @return true if both rows are equal, otherwise false + */ + __device__ bool operator()(size_type lhs_row_index, size_type rhs_row_index) const noexcept + { + return cudf::type_dispatcher(lhs.begin().type(), + element_equality_comparator{lhs.begin(), rhs.begin()}, + lhs_row_index, + rhs_row_index); + } + + private: + table_device_view lhs; + table_device_view rhs; +}; + +/** + * @brief Performs a relational comparison between two elements in two columns. + */ +class element_relational_comparator { + public: + /** + * @brief Construct type-dispatched function object for performing a + * relational comparison between two elements. + * + * @note `lhs` and `rhs` may be the same. + * + * @param lhs The column containing the first element + * @param rhs The column containing the second element (may be the same as lhs) + */ + __host__ __device__ element_relational_comparator(column_device_view lhs, column_device_view rhs) + : lhs{lhs}, rhs{rhs} + { + } + + /** + * @brief Construct type-dispatched function object for performing a relational comparison between + * two elements in two columns. + * + * @param lhs The column containing the first element + * @param rhs The column containing the second element (may be the same as lhs) + */ + __host__ __device__ element_relational_comparator(column_device_view lhs, column_device_view rhs) + : lhs{lhs}, rhs{rhs} + { + } + + /** + * @brief Performs a relational comparison between the specified elements + * + * @param lhs_element_index The index of the first element + * @param rhs_element_index The index of the second element + * @return Indicates the relationship between the elements in + * the `lhs` and `rhs` columns. + */ + template () >)> + __device__ weak_ordering operator()(size_type lhs_element_index, + size_type rhs_element_index) const noexcept + { + return relational_compare(lhs.element(lhs_element_index), + rhs.element(rhs_element_index)); + } + + // @cond + template ())> + __device__ weak_ordering operator()(size_type lhs_element_index, size_type rhs_element_index) + { + CUDF_UNREACHABLE("Attempted to compare elements of uncomparable types."); + } + // @endcond + + private: + column_device_view lhs; + column_device_view rhs; +}; + +/** + * @brief Computes whether one row is lexicographically *less* than another row. + * + * Lexicographic ordering is determined by: + * - Two rows are compared element by element. + * - The first mismatching element defines which row is lexicographically less + * or greater than the other. + * + * Lexicographic ordering is exactly equivalent to doing an alphabetical sort of + * two words, for example, `aac` would be *less* than (or precede) `abb`. The + * second letter in both words is the first non-equal letter, and `a < b`, thus + * `aac < abb`. + */ +class row_lexicographic_comparator { + public: + /** + * @brief Construct a function object for performing a lexicographic + * comparison between the rows of two tables. + * + * Behavior is undefined if called with incomparable column types. + * + * @throws cudf::logic_error if `lhs.num_columns() != rhs.num_columns()` + * + * @param lhs The first table + * @param rhs The second table (may be the same table as `lhs`) + * @param column_order Optional, device array the same length as a row that + * indicates the desired ascending/descending order of each column in a row. + * If `nullptr`, it is assumed all columns are sorted in ascending order. + */ + row_lexicographic_comparator(table_device_view lhs, + table_device_view rhs, + order const* column_order = nullptr) + : _lhs{lhs}, _rhs{rhs}, _column_order{column_order}, + { + CUDF_EXPECTS(_lhs.num_columns() == _rhs.num_columns(), "Mismatched number of columns."); + } + + /** + * @brief Checks whether the row at `lhs_index` in the `lhs` table compares + * lexicographically less than the row at `rhs_index` in the `rhs` table. + * + * @param lhs_index The index of the row in the `lhs` table to examine + * @param rhs_index The index of the row in the `rhs` table to examine + * @return `true` if row from the `lhs` table compares less than row in the + * `rhs` table + */ + __device__ bool operator()(size_type lhs_index, size_type rhs_index) const noexcept + { + for (size_type i = 0; i < _lhs.num_columns(); ++i) { + bool ascending = (_column_order == nullptr) or (_column_order[i] == order::ASCENDING); + + auto comparator = element_relational_comparator{_lhs.column(i), _rhs.column(i)}; + + weak_ordering state = + cudf::type_dispatcher(_lhs.column(i).type(), comparator, lhs_index, rhs_index); + + if (state == weak_ordering::EQUIVALENT) { continue; } + + return state == (ascending ? weak_ordering::LESS : weak_ordering::GREATER); + } + return false; + } + + private: + table_device_view _lhs; + table_device_view _rhs; + order const* _column_order{}; +}; // class row_lexicographic_comparator + +/** + * @brief Computes the hash value of an element in the given column. + * + * @tparam Hash Hash functor to use for hashing elements. + */ +template