Skip to content

Commit

Permalink
Replace device_vector with device_uvector in sort_impl (#7925)
Browse files Browse the repository at this point in the history
Reference #7287

This replaces usages of `rmm::device_vector` with `rmm::device_uvector` in `cpp/src/sort/sort_impl.cuh` which is used internally by the `cudf::sort` APIs.
The `make_device_uvector_async` utility are called to convert a `std::vector` to a temporary `rmm::device_uvector`.

Also updated `vector_factories.hpp` utility to add the missing `include <vector>`.

Authors:
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - https://github.com/nvdbaranec
  - Nghia Truong (https://github.com/ttnghia)
  - Christopher Harris (https://github.com/cwharris)

URL: #7925
  • Loading branch information
davidwendt authored Apr 13, 2021
1 parent d6479a2 commit 18964ff
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
2 changes: 2 additions & 0 deletions cpp/include/cudf/detail/utilities/vector_factories.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>

#include <vector>

namespace cudf {
namespace detail {

Expand Down
22 changes: 14 additions & 8 deletions cpp/src/sort/sort_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@

#include <cudf/column/column_factories.hpp>
#include <cudf/detail/gather.hpp>
#include <cudf/detail/utilities/vector_factories.hpp>
#include <cudf/table/row_operators.cuh>
#include <cudf/table/table_device_view.cuh>
#include <cudf/utilities/error.hpp>
#include <cudf/utilities/traits.hpp>

#include <structs/utilities.hpp>

#include <rmm/cuda_stream_view.hpp>
Expand Down Expand Up @@ -123,14 +125,14 @@ std::unique_ptr<column> sorted_order(table_view input,
}

auto flattened = structs::detail::flatten_nested_columns(input, column_order, null_precedence);
auto& input_flattened = std::get<0>(flattened);
auto device_table = table_device_view::create(input_flattened, stream);
rmm::device_vector<order> d_column_order(std::get<1>(flattened));
auto& input_flattened = std::get<0>(flattened);
auto device_table = table_device_view::create(input_flattened, stream);
auto const d_column_order = make_device_uvector_async(std::get<1>(flattened), stream);

if (has_nulls(input_flattened)) {
rmm::device_vector<null_order> d_null_precedence(std::get<2>(flattened));
auto comparator = row_lexicographic_comparator<true>(
*device_table, *device_table, d_column_order.data().get(), d_null_precedence.data().get());
auto const d_null_precedence = make_device_uvector_async(std::get<2>(flattened), stream);
auto const comparator = row_lexicographic_comparator<true>(
*device_table, *device_table, d_column_order.data(), d_null_precedence.data());
if (stable) {
thrust::stable_sort(rmm::exec_policy(stream),
mutable_indices_view.begin<size_type>(),
Expand All @@ -142,9 +144,11 @@ std::unique_ptr<column> sorted_order(table_view input,
mutable_indices_view.end<size_type>(),
comparator);
}
// protection for temporary d_column_order and d_null_precedence
stream.synchronize();
} else {
auto comparator = row_lexicographic_comparator<false>(
*device_table, *device_table, d_column_order.data().get());
auto const comparator =
row_lexicographic_comparator<false>(*device_table, *device_table, d_column_order.data());
if (stable) {
thrust::stable_sort(rmm::exec_policy(stream),
mutable_indices_view.begin<size_type>(),
Expand All @@ -156,6 +160,8 @@ std::unique_ptr<column> sorted_order(table_view input,
mutable_indices_view.end<size_type>(),
comparator);
}
// protection for temporary d_column_order
stream.synchronize();
}

return sorted_indices;
Expand Down

0 comments on commit 18964ff

Please sign in to comment.