Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cudf::stable_sorted_order for NaN and -NaN in FLOAT64 columns #11874

Merged
merged 9 commits into from
Oct 14, 2022
104 changes: 6 additions & 98 deletions cpp/src/sort/sort_column.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,108 +14,16 @@
* limitations under the License.
*/

#include <sort/sort_impl.cuh>
#include <sort/sort_column_impl.cuh>

#include <cudf/column/column_factories.hpp>
#include <cudf/column/column_view.hpp>
#include <cudf/utilities/type_dispatcher.hpp>

#include <thrust/functional.h>
#include <thrust/sequence.h>
#include <thrust/sort.h>

namespace cudf {
namespace detail {
namespace {

/**
* @brief Type-dispatched functor for sorting a single column.
*/
struct column_sorted_order_fn {
/**
* @brief Compile time check for allowing radix sort for column type.
*
* Floating point is removed here for special handling of NaNs.
*/
template <typename T>
static constexpr bool is_radix_sort_supported()
{
return cudf::is_fixed_width<T>() && !cudf::is_floating_point<T>();
}

/**
* @brief Sorts fixed-width columns using faster thrust sort.
*
* @param input Column to sort
* @param indices Output sorted indices
* @param ascending True if sort order is ascending
* @param stream CUDA stream used for device memory operations and kernel launches
*/
template <typename T, std::enable_if_t<is_radix_sort_supported<T>()>* = nullptr>
void radix_sort(column_view const& input,
mutable_column_view& indices,
bool ascending,
rmm::cuda_stream_view stream)
{
// A non-stable sort on a column of arithmetic type with no nulls will use a radix sort
// if specifying only the `thrust::less` or `thrust::greater` comparators.
// But this also requires making a copy of the input data.
auto temp_col = column(input, stream);
auto d_col = temp_col.mutable_view();
if (ascending) {
thrust::sort_by_key(rmm::exec_policy(stream),
d_col.begin<T>(),
d_col.end<T>(),
indices.begin<size_type>(),
thrust::less<T>());
} else {
thrust::sort_by_key(rmm::exec_policy(stream),
d_col.begin<T>(),
d_col.end<T>(),
indices.begin<size_type>(),
thrust::greater<T>());
}
}
template <typename T, std::enable_if_t<!is_radix_sort_supported<T>()>* = nullptr>
void radix_sort(column_view const&, mutable_column_view&, bool, rmm::cuda_stream_view)
{
CUDF_FAIL("Only fixed-width types are suitable for faster sorting");
}

/**
* @brief Sorts a single column with a relationally comparable type.
*
* This includes numeric, timestamp, duration, and string types.
*
* @param input Column to sort
* @param indices Output sorted indices
* @param ascending True if sort order is ascending
* @param null_precedence How null rows are to be ordered
* @param stream CUDA stream used for device memory operations and kernel launches
*/
template <typename T, std::enable_if_t<cudf::is_relationally_comparable<T, T>()>* = nullptr>
void operator()(column_view const& input,
mutable_column_view& indices,
bool ascending,
null_order null_precedence,
rmm::cuda_stream_view stream)
{
// column with nulls or non-supported types will also use a comparator
if (input.has_nulls() || !is_radix_sort_supported<T>()) {
auto keys = column_device_view::create(input, stream);
thrust::sort(rmm::exec_policy(stream),
indices.begin<size_type>(),
indices.end<size_type>(),
simple_comparator<T>{*keys, input.has_nulls(), ascending, null_precedence});
} else {
radix_sort<T>(input, indices, ascending, stream);
}
}

template <typename T, std::enable_if_t<!cudf::is_relationally_comparable<T, T>()>* = nullptr>
void operator()(column_view const&, mutable_column_view&, bool, null_order, rmm::cuda_stream_view)
{
CUDF_FAIL("Column type must be relationally comparable");
}
};

} // namespace

/**
* @copydoc
Expand All @@ -134,7 +42,7 @@ std::unique_ptr<column> sorted_order<false>(column_view const& input,
thrust::sequence(
rmm::exec_policy(stream), indices_view.begin<size_type>(), indices_view.end<size_type>(), 0);
cudf::type_dispatcher<dispatch_storage_type>(input.type(),
column_sorted_order_fn{},
column_sorted_order_fn<false>{},
input,
indices_view,
column_order == order::ASCENDING,
Expand Down
162 changes: 162 additions & 0 deletions cpp/src/sort/sort_column_impl.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
/*
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <sort/sort_impl.cuh>

#include <thrust/sequence.h>
#include <thrust/sort.h>

namespace cudf {
namespace detail {

template <bool stable>
struct column_sorted_order_fn {
/**
* @brief Compile time check for allowing faster sort.
*
* Faster sort is defined for fixed-width types where only
* the primitive comparators thrust::greater or thrust::less
* are needed.
*
* Floating point is removed here for special handling of NaNs
* which require the row-comparator.
*/
template <typename T>
static constexpr bool is_faster_sort_supported()
{
return cudf::is_fixed_width<T>() && !cudf::is_floating_point<T>();
}

/**
* @brief Sorts fixed-width columns using faster thrust sort.
*
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
* Should not be called if `input.has_nulls()==true`
*
* @param input Column to sort
* @param indices Output sorted indices
* @param ascending True if sort order is ascending
* @param stream CUDA stream used for device memory operations and kernel launches
*/
template <typename T>
void faster_sort(column_view const& input,
mutable_column_view& indices,
bool ascending,
rmm::cuda_stream_view stream)
{
// A thrust sort on a column of primitive types will use a radix sort.
// For other fixed-width types, thrust will use merge-sort.
// But this also requires making a copy of the input data.
auto temp_col = column(input, stream);
auto d_col = temp_col.mutable_view();
if (ascending) {
if constexpr (stable) {
thrust::stable_sort_by_key(rmm::exec_policy(stream),
d_col.begin<T>(),
d_col.end<T>(),
indices.begin<size_type>(),
thrust::less<T>());
} else {
thrust::sort_by_key(rmm::exec_policy(stream),
d_col.begin<T>(),
d_col.end<T>(),
indices.begin<size_type>(),
thrust::less<T>());
}
} else {
if constexpr (stable) {
thrust::stable_sort_by_key(rmm::exec_policy(stream),
d_col.begin<T>(),
d_col.end<T>(),
indices.begin<size_type>(),
thrust::greater<T>());
} else {
thrust::sort_by_key(rmm::exec_policy(stream),
d_col.begin<T>(),
d_col.end<T>(),
indices.begin<size_type>(),
thrust::greater<T>());
}
}
}

/**
* @brief Sorts a single column with a relationally comparable type.
*
* This is used when a comparator is required.
*
* @param input Column to sort
* @param indices Output sorted indices
* @param ascending True if sort order is ascending
* @param null_precedence How null rows are to be ordered
* @param stream CUDA stream used for device memory operations and kernel launches
*/
template <typename T>
void sorted_order(column_view const& input,
mutable_column_view& indices,
bool ascending,
null_order null_precedence,
rmm::cuda_stream_view stream)
{
auto keys = column_device_view::create(input, stream);
auto comp = simple_comparator<T>{*keys, input.has_nulls(), ascending, null_precedence};
if constexpr (stable) {
thrust::stable_sort(
rmm::exec_policy(stream), indices.begin<size_type>(), indices.end<size_type>(), comp);
} else {
thrust::sort(
rmm::exec_policy(stream), indices.begin<size_type>(), indices.end<size_type>(), comp);
}
}

template <typename T,
CUDF_ENABLE_IF(cudf::is_relationally_comparable<T, T>() and
is_faster_sort_supported<T>())>
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
void operator()(column_view const& input,
mutable_column_view& indices,
bool ascending,
null_order null_precedence,
rmm::cuda_stream_view stream)
{
if (input.has_nulls()) {
sorted_order<T>(input, indices, ascending, null_precedence, stream);
} else {
faster_sort<T>(input, indices, ascending, stream);
}
}

template <typename T,
CUDF_ENABLE_IF(cudf::is_relationally_comparable<T, T>() and
not is_faster_sort_supported<T>())>
void operator()(column_view const& input,
mutable_column_view& indices,
bool ascending,
null_order null_precedence,
rmm::cuda_stream_view stream)
{
sorted_order<T>(input, indices, ascending, null_precedence, stream);
}

template <typename T, CUDF_ENABLE_IF(not cudf::is_relationally_comparable<T, T>())>
void operator()(column_view const&, mutable_column_view&, bool, null_order, rmm::cuda_stream_view)
{
CUDF_FAIL("Column type must be relationally comparable");
}
};

} // namespace detail
} // namespace cudf
72 changes: 6 additions & 66 deletions cpp/src/sort/stable_sort_column.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,76 +14,16 @@
* limitations under the License.
*/

#include <sort/sort_impl.cuh>
#include <sort/sort_column_impl.cuh>

#include <cudf/column/column_factories.hpp>
#include <cudf/column/column_view.hpp>
#include <cudf/utilities/type_dispatcher.hpp>

#include <thrust/sequence.h>
#include <thrust/sort.h>

namespace cudf {
namespace detail {
namespace {

struct column_stable_sorted_order_fn {
/**
* @brief Stable sort of fixed-width columns using a thrust sort with no comparator.
*
* @param input Column to sort
* @param indices Output sorted indices
* @param stream CUDA stream used for device memory operations and kernel launches
*/
template <typename T, std::enable_if_t<cudf::is_fixed_width<T>()>* = nullptr>
void faster_stable_sort(column_view const& input,
mutable_column_view& indices,
rmm::cuda_stream_view stream)
{
auto temp_col = column(input, stream);
auto d_col = temp_col.mutable_view();
thrust::stable_sort_by_key(
rmm::exec_policy(stream), d_col.begin<T>(), d_col.end<T>(), indices.begin<size_type>());
}
template <typename T, std::enable_if_t<!cudf::is_fixed_width<T>()>* = nullptr>
void faster_stable_sort(column_view const&, mutable_column_view&, rmm::cuda_stream_view)
{
CUDF_FAIL("Only fixed-width types are suitable for faster stable sorting");
}

/**
* @brief Stable sorts a single column with a relationally comparable type.
*
* This includes numeric, timestamp, duration, and string types.
*
* @param input Column to sort
* @param indices Output sorted indices
* @param ascending True if sort order is ascending
* @param null_precedence How null rows are to be ordered
* @param stream CUDA stream used for device memory operations and kernel launches
*/
template <typename T, std::enable_if_t<cudf::is_relationally_comparable<T, T>()>* = nullptr>
void operator()(column_view const& input,
mutable_column_view& indices,
bool ascending,
null_order null_precedence,
rmm::cuda_stream_view stream)
{
if (!ascending || input.has_nulls() || !cudf::is_fixed_width<T>()) {
auto keys = column_device_view::create(input, stream);
thrust::stable_sort(
rmm::exec_policy(stream),
indices.begin<size_type>(),
indices.end<size_type>(),
simple_comparator<T>{*keys, input.has_nulls(), ascending, null_precedence});
} else {
faster_stable_sort<T>(input, indices, stream);
}
}
template <typename T, std::enable_if_t<!cudf::is_relationally_comparable<T, T>()>* = nullptr>
void operator()(column_view const&, mutable_column_view&, bool, null_order, rmm::cuda_stream_view)
{
CUDF_FAIL("Column type must be relationally comparable");
}
};

} // namespace

/**
* @copydoc
Expand All @@ -102,7 +42,7 @@ std::unique_ptr<column> sorted_order<true>(column_view const& input,
thrust::sequence(
rmm::exec_policy(stream), indices_view.begin<size_type>(), indices_view.end<size_type>(), 0);
cudf::type_dispatcher<dispatch_storage_type>(input.type(),
column_stable_sorted_order_fn{},
column_sorted_order_fn<true>{},
input,
indices_view,
column_order == order::ASCENDING,
Expand Down
Loading