Skip to content

Commit

Permalink
Move struct to implementation header
Browse files Browse the repository at this point in the history
Unify template parameter for stable sort as a sort_method enum member
rather than bool.
  • Loading branch information
wence- committed Feb 26, 2024
1 parent 033158d commit 809cca1
Show file tree
Hide file tree
Showing 9 changed files with 141 additions and 79 deletions.
33 changes: 0 additions & 33 deletions cpp/include/cudf/detail/sorting.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,48 +19,15 @@
#include <cudf/sorting.hpp>
#include <cudf/types.hpp>
#include <cudf/utilities/default_stream.hpp>
#include <cudf/utilities/error.hpp>
#include <cudf/utilities/traits.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/exec_policy.hpp>

#include <thrust/functional.h>
#include <thrust/sort.h>

#include <memory>
#include <vector>

namespace cudf {
namespace detail {

template <bool stable>
struct inplace_column_sort_fn {
template <typename T, std::enable_if_t<cudf::is_fixed_width<T>()>* = nullptr>
void operator()(mutable_column_view& col, bool ascending, rmm::cuda_stream_view stream) const
{
CUDF_EXPECTS(!col.has_nulls(), "Nulls not supported for in-place sort");
auto const do_sort = [&](auto const cmp) {
if constexpr (stable) {
thrust::stable_sort(rmm::exec_policy(stream), col.begin<T>(), col.end<T>(), cmp);
} else {
thrust::sort(rmm::exec_policy(stream), col.begin<T>(), col.end<T>(), cmp);
}
};
if (ascending) {
do_sort(thrust::less<T>());
} else {
do_sort(thrust::greater<T>());
}
}

template <typename T, std::enable_if_t<!cudf::is_fixed_width<T>()>* = nullptr>
void operator()(mutable_column_view&, bool, rmm::cuda_stream_view) const
{
CUDF_FAIL("Column type must be relationally comparable and fixed-width");
}
};

/**
* @copydoc cudf::sorted_order
*
Expand Down
96 changes: 96 additions & 0 deletions cpp/src/sort/common_sort_impl.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <cudf/column/column_device_view.cuh>
#include <cudf/table/table_view.hpp>
#include <cudf/types.hpp>
#include <cudf/utilities/traits.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/exec_policy.hpp>

#include <thrust/sort.h>

namespace cudf {
namespace detail {

/**
* @brief The enum specifying which sorting method to use (stable or unstable).
*/
enum class sort_method { STABLE, UNSTABLE };

/**
* @brief Fast-path sort a column in place
*
* Precondition, is_usable(column) returned true
*
* @tparam method Whether to use a stable sort or not.
* @param col Column to sort, modified in place.
* @param order Ascending or descending sort order.
* @param stream CUDA stream used for device memory operations and kernel launches
*
*/
template <sort_method method>
struct inplace_column_sort_fn {
/**
* @brief Can one use a fast-path in-place sort for this column?
*
* @param column to check
* @return true if fast-path sort is available, false otherwise.
*/
static bool is_usable(column_view const& column)
{
return !column.has_nulls() && cudf::is_fixed_width(column.type()) &&
!cudf::is_floating_point(column.type());
}
/**
* @brief Can one use a fast-path in-place sort for this table?
*
* @param table to check
* @return true if fast-path sort is available, false otherwise.
*/
static bool is_usable(table_view const& table)
{
return table.num_columns() == 1 && is_usable(table.column(0));
}

template <typename T, std::enable_if_t<cudf::is_fixed_width<T>()>* = nullptr>
void operator()(mutable_column_view& col, order order, rmm::cuda_stream_view stream) const
{
auto const do_sort = [&](auto const cmp) {
if constexpr (method == sort_method::STABLE) {
thrust::stable_sort(rmm::exec_policy(stream), col.begin<T>(), col.end<T>(), cmp);
} else {
thrust::sort(rmm::exec_policy(stream), col.begin<T>(), col.end<T>(), cmp);
}
};
if (order == order::ASCENDING) {
do_sort(thrust::less<T>());
} else {
do_sort(thrust::greater<T>());
}
}

template <typename T, std::enable_if_t<!cudf::is_fixed_width<T>()>* = nullptr>
void operator()(mutable_column_view&, order, rmm::cuda_stream_view) const
{
CUDF_FAIL("Column type must be relationally comparable and fixed-width");
}
};

} // namespace detail
} // namespace cudf
9 changes: 3 additions & 6 deletions cpp/src/sort/segmented_sort_impl.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -14,6 +14,8 @@
* limitations under the License.
*/

#include "common_sort_impl.cuh"

#include <cudf/column/column_factories.hpp>
#include <cudf/detail/copy.hpp>
#include <cudf/detail/gather.hpp>
Expand All @@ -29,11 +31,6 @@
namespace cudf {
namespace detail {

/**
* @brief The enum specifying which sorting method to use (stable or unstable).
*/
enum class sort_method { STABLE, UNSTABLE };

/**
* @brief Functor performs faster segmented sort on eligible columns
*/
Expand Down
15 changes: 7 additions & 8 deletions cpp/src/sort/sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/

#include "common_sort_impl.cuh"
#include "sort_impl.cuh"

#include <cudf/column/column.hpp>
Expand All @@ -37,7 +38,7 @@ std::unique_ptr<column> sorted_order(table_view const& input,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
return sorted_order<false>(input, column_order, null_precedence, stream, mr);
return sorted_order<sort_method::UNSTABLE>(input, column_order, null_precedence, stream, mr);
}

std::unique_ptr<table> sort_by_key(table_view const& values,
Expand Down Expand Up @@ -68,14 +69,12 @@ std::unique_ptr<table> sort(table_view const& input,
rmm::mr::device_memory_resource* mr)
{
// fast-path sort conditions: single, non-floating-point, fixed-width column with no nulls
if (input.num_columns() == 1 && !input.column(0).has_nulls() &&
cudf::is_fixed_width(input.column(0).type()) &&
!cudf::is_floating_point(input.column(0).type())) {
auto output = std::make_unique<column>(input.column(0), stream, mr);
auto view = output->mutable_view();
bool ascending = (column_order.empty() ? true : column_order.front() == order::ASCENDING);
if (inplace_column_sort_fn<sort_method::UNSTABLE>::is_usable(input)) {
auto output = std::make_unique<column>(input.column(0), stream, mr);
auto view = output->mutable_view();
auto order = (column_order.empty() ? order::ASCENDING : column_order.front());
cudf::type_dispatcher<dispatch_storage_type>(
output->type(), inplace_column_sort_fn<false>{}, view, ascending, stream);
output->type(), inplace_column_sort_fn<sort_method::UNSTABLE>{}, view, order, stream);
std::vector<std::unique_ptr<column>> columns;
columns.emplace_back(std::move(output));
return std::make_unique<table>(std::move(columns));
Expand Down
15 changes: 8 additions & 7 deletions cpp/src/sort/sort_column.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -14,6 +14,7 @@
* limitations under the License.
*/

#include "common_sort_impl.cuh"
#include "sort_column_impl.cuh"

#include <cudf/column/column_factories.hpp>
Expand All @@ -30,19 +31,19 @@ namespace detail {
* sorted_order(column_view&,order,null_order,rmm::cuda_stream_view,rmm::mr::device_memory_resource*)
*/
template <>
std::unique_ptr<column> sorted_order<false>(column_view const& input,
order column_order,
null_order null_precedence,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
std::unique_ptr<column> sorted_order<sort_method::UNSTABLE>(column_view const& input,
order column_order,
null_order null_precedence,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto sorted_indices = cudf::make_numeric_column(
data_type(type_to_id<size_type>()), input.size(), mask_state::UNALLOCATED, stream, mr);
mutable_column_view indices_view = sorted_indices->mutable_view();
thrust::sequence(
rmm::exec_policy(stream), indices_view.begin<size_type>(), indices_view.end<size_type>(), 0);
cudf::type_dispatcher<dispatch_storage_type>(input.type(),
column_sorted_order_fn<false>{},
column_sorted_order_fn<sort_method::UNSTABLE>{},
input,
indices_view,
column_order == order::ASCENDING,
Expand Down
14 changes: 8 additions & 6 deletions cpp/src/sort/sort_column_impl.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,8 @@

#pragma once

#include "common_sort_impl.cuh"

#include <cudf/column/column_device_view.cuh>
#include <cudf/table/experimental/row_operators.cuh>
#include <cudf/utilities/error.hpp>
Expand All @@ -36,7 +38,7 @@ namespace detail {
* This API offers fast sorting for primitive types. It cannot handle nested types and will not
* consider `NaN` as equivalent to other `NaN`.
*
* @tparam stable Whether to use stable sort
* @tparam method Whether to use stable sort
* @param input Column to sort. The column data is not modified.
* @param column_order Ascending or descending sort order
* @param null_precedence How null rows are to be ordered
Expand All @@ -45,7 +47,7 @@ namespace detail {
* @param mr Device memory resource used to allocate the returned column's device memory
* @return Sorted indices for the input column.
*/
template <bool stable>
template <sort_method method>
std::unique_ptr<column> sorted_order(column_view const& input,
order column_order,
null_order null_precedence,
Expand Down Expand Up @@ -78,7 +80,7 @@ struct simple_comparator {
null_order null_precedence{};
};

template <bool stable>
template <sort_method method>
struct column_sorted_order_fn {
/**
* @brief Compile time check for allowing faster sort.
Expand Down Expand Up @@ -121,7 +123,7 @@ struct column_sorted_order_fn {
auto const do_sort = [&](auto const comp) {
// Compiling `thrust::*sort*` APIs is expensive.
// Thus, we should optimize that by using constexpr condition to only compile what we need.
if constexpr (stable) {
if constexpr (method == sort_method::STABLE) {
thrust::stable_sort_by_key(rmm::exec_policy(stream),
d_col.begin<T>(),
d_col.end<T>(),
Expand Down Expand Up @@ -165,7 +167,7 @@ struct column_sorted_order_fn {
auto comp = simple_comparator<T>{*keys, input.has_nulls(), ascending, null_precedence};
// Compiling `thrust::*sort*` APIs is expensive.
// Thus, we should optimize that by using constexpr condition to only compile what we need.
if constexpr (stable) {
if constexpr (method == sort_method::STABLE) {
thrust::stable_sort(
rmm::exec_policy(stream), indices.begin<size_type>(), indices.end<size_type>(), comp);
} else {
Expand Down
7 changes: 4 additions & 3 deletions cpp/src/sort/sort_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include "common_sort_impl.cuh"
#include "sort_column_impl.cuh"

#include <cudf/column/column_factories.hpp>
Expand All @@ -30,7 +31,7 @@ namespace detail {
* @tparam stable Whether to use stable sort
* @param stream CUDA stream used for device memory operations and kernel launches
*/
template <bool stable>
template <sort_method method>
std::unique_ptr<column> sorted_order(table_view input,
std::vector<order> const& column_order,
std::vector<null_order> const& null_precedence,
Expand All @@ -57,7 +58,7 @@ std::unique_ptr<column> sorted_order(table_view input,
auto const single_col = input.column(0);
auto const col_order = column_order.empty() ? order::ASCENDING : column_order.front();
auto const null_prec = null_precedence.empty() ? null_order::BEFORE : null_precedence.front();
return sorted_order<stable>(single_col, col_order, null_prec, stream, mr);
return sorted_order<method>(single_col, col_order, null_prec, stream, mr);
}

std::unique_ptr<column> sorted_indices = cudf::make_numeric_column(
Expand All @@ -71,7 +72,7 @@ std::unique_ptr<column> sorted_order(table_view input,
auto const do_sort = [&](auto const comparator) {
// Compiling `thrust::*sort*` APIs is expensive.
// Thus, we should optimize that by using constexpr condition to only compile what we need.
if constexpr (stable) {
if constexpr (method == sort_method::STABLE) {
thrust::stable_sort(rmm::exec_policy(stream),
mutable_indices_view.begin<size_type>(),
mutable_indices_view.end<size_type>(),
Expand Down
16 changes: 7 additions & 9 deletions cpp/src/sort/stable_sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/

#include "common_sort_impl.cuh"
#include "sort_impl.cuh"

#include <cudf/column/column.hpp>
Expand All @@ -34,7 +35,7 @@ std::unique_ptr<column> stable_sorted_order(table_view const& input,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
return sorted_order<true>(input, column_order, null_precedence, stream, mr);
return sorted_order<sort_method::STABLE>(input, column_order, null_precedence, stream, mr);
}

std::unique_ptr<table> stable_sort(table_view const& input,
Expand All @@ -43,15 +44,12 @@ std::unique_ptr<table> stable_sort(table_view const& input,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
// fast-path sort conditions: single, non-floating-point, fixed-width column with no nulls
if (input.num_columns() == 1 && !input.column(0).has_nulls() &&
cudf::is_fixed_width(input.column(0).type()) &&
!cudf::is_floating_point(input.column(0).type())) {
auto output = std::make_unique<column>(input.column(0), stream, mr);
auto view = output->mutable_view();
bool ascending = (column_order.empty() ? true : column_order.front() == order::ASCENDING);
if (inplace_column_sort_fn<sort_method::STABLE>::is_usable(input)) {
auto output = std::make_unique<column>(input.column(0), stream, mr);
auto view = output->mutable_view();
auto order = (column_order.empty() ? order::ASCENDING : column_order.front());
cudf::type_dispatcher<dispatch_storage_type>(
output->type(), inplace_column_sort_fn<true>{}, view, ascending, stream);
output->type(), inplace_column_sort_fn<sort_method::STABLE>{}, view, order, stream);
std::vector<std::unique_ptr<column>> columns;
columns.emplace_back(std::move(output));
return std::make_unique<table>(std::move(columns));
Expand Down
Loading

0 comments on commit 809cca1

Please sign in to comment.