Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Segmented sort #7122

Merged
merged 35 commits into from
Feb 4, 2021
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
b2464b1
add segmented sort for lists_column_view
karthikeyann Jan 12, 2021
e7e2e4f
add unit test for segmented sort for lists_column_view
karthikeyann Jan 12, 2021
f16f924
add segmented_sort(table_view)
karthikeyann Jan 25, 2021
44ff02a
add unit test segmented_sort(table_view) all valid
karthikeyann Jan 25, 2021
4d5fbbf
fix interfaces
karthikeyann Jan 25, 2021
edfa549
add null list column test
karthikeyann Jan 25, 2021
9a8c9d5
Merge branch 'branch-0.18' of github.com:rapidsai/cudf into fea-segme…
karthikeyann Jan 25, 2021
54c0c5e
fix null_order example, code, tests for sort_lists segmented sort
karthikeyann Jan 27, 2021
4b6ae05
documentation update for segmented_sort_by_key
karthikeyann Jan 27, 2021
a936f4e
replace segmented_sort by segmented_sort_by_key
karthikeyann Jan 27, 2021
2baf424
Merge branch 'branch-0.18' of github.com:rapidsai/cudf into fea-segme…
karthikeyann Jan 27, 2021
d03f032
conda yml include new hpp
karthikeyann Jan 27, 2021
23c8f26
add more checks for input list table
karthikeyann Jan 28, 2021
0825984
add error tests for inputs
karthikeyann Jan 28, 2021
663662b
rename API segmented_sort to sort_lists
karthikeyann Jan 28, 2021
e842e1d
rename tests
karthikeyann Jan 28, 2021
bb0e7c1
update sort_lists to use segmented_sorted_order
karthikeyann Jan 28, 2021
e273dd8
reorder code to use segmented_sort in sort_lists
karthikeyann Jan 29, 2021
5e63c7a
add sort_lists key,value tests
karthikeyann Jan 29, 2021
d249c25
rename sortlists tests
karthikeyann Jan 29, 2021
837655e
rename tests sort_lists
karthikeyann Jan 29, 2021
2303fea
Apply suggestions from code review (codereport)
karthikeyann Jan 30, 2021
2a8ee21
style fix clang format
karthikeyann Feb 1, 2021
71e048a
fix vec<unq_ptr> constructor
karthikeyann Feb 1, 2021
0a7d565
add segmented_sort, detail::segmented_sorted_order
karthikeyann Feb 1, 2021
d3cadc3
add segmented sort tests
karthikeyann Feb 1, 2021
8f26011
segmented_sort unit tests
karthikeyann Feb 1, 2021
6cb960b
move list_size function to list_device_view.cuh
karthikeyann Feb 1, 2021
f6c058a
Apply suggestions from code review (jake)
karthikeyann Feb 1, 2021
df49138
address review comments (codereport)
karthikeyann Feb 2, 2021
95dda47
Merge branch 'branch-0.18' of github.com:rapidsai/cudf into fea-segme…
karthikeyann Feb 2, 2021
1141485
remove sort_lists(table_view)
karthikeyann Feb 3, 2021
476420d
add cudf::segmented_sorted_order in header
karthikeyann Feb 3, 2021
8732fbb
update copyright year
karthikeyann Feb 3, 2021
e0696ef
add more tests, enable sliced list column
karthikeyann Feb 3, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions conda/recipes/libcudf/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ test:
- test -f $PREFIX/include/cudf/lists/contains.hpp
- test -f $PREFIX/include/cudf/lists/gather.hpp
- test -f $PREFIX/include/cudf/lists/lists_column_view.hpp
- test -f $PREFIX/include/cudf/lists/sorting.hpp
- test -f $PREFIX/include/cudf/merge.hpp
- test -f $PREFIX/include/cudf/null_mask.hpp
- test -f $PREFIX/include/cudf/partitioning.hpp
Expand Down
13 changes: 13 additions & 0 deletions cpp/include/cudf/detail/sorting.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,18 @@ std::unique_ptr<table> sort_by_key(
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
* @copydoc cudf::sort_lists
*
* @param[in] stream CUDA stream used for device memory operations and kernel launches.
*/
std::unique_ptr<table> sort_lists(
table_view const& values,
table_view const& keys,
std::vector<order> const& column_order = {},
std::vector<null_order> const& null_precedence = {},
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

} // namespace detail
} // namespace cudf
59 changes: 59 additions & 0 deletions cpp/include/cudf/lists/sorting.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <cudf/column/column.hpp>
#include <cudf/column/column_view.hpp>
#include <cudf/lists/lists_column_view.hpp>

namespace cudf {
namespace lists {
/**
* @addtogroup lists_sort
* @{
* @file
*/

/**
* @brief Segmented sort of the elements within a list in each row of a list column.
*
* `source_column` with depth 1 is only supported.
*
* * @code{.pseudo}
* source_column : [{4, 2, 3, 1}, {1, 2, NULL, 4}, {-10, 10, 0}]
*
* Ascending, Null After : [{1, 2, 3, 4}, {1, 2, 4, NULL}, {-10, 0, 10}]
* Ascending, Null Before : [{1, 2, 3, 4}, {NULL, 1, 2, 4}, {-10, 0, 10}]
* Descending, Null After : [{4, 3, 2, 1}, {NULL, 4, 2, 1}, {10, 0, -10}]
* Descending, Null Before : [{4, 3, 2, 1}, {4, 2, 1, NULL}, {10, 0, -10}]
* @endcode
*
* @param source_column View of the list column of numeric types to sort
* @param column_order The desired sort order
* @param null_precedence The desired order of null compared to other elements in the list
* @param mr Device memory resource to allocate any returned objects
* @return list column with elements in each list sorted.
*
*/
std::unique_ptr<column> sort_lists(
lists_column_view const& source_column,
order column_order,
null_order null_precedence,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/** @} */ // end of group
} // namespace lists
} // namespace cudf
30 changes: 29 additions & 1 deletion cpp/include/cudf/sorting.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ std::unique_ptr<table> sort(
* @param values The table to reorder
* @param keys The table that determines the ordering
* @param column_order The desired order for each column in `keys`. Size must be
* equal to `input.num_columns()` or empty. If empty, all columns are sorted in
* equal to `keys.num_columns()` or empty. If empty, all columns are sorted in
* ascending order.
* @param null_precedence The desired order of a null element compared to other
* elements for each column in `keys`. Size must be equal to
Expand Down Expand Up @@ -184,5 +184,33 @@ std::unique_ptr<column> rank(
bool percentage,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
* @brief Performs a lexicographic segmented sort of the list in each row of a table of list columns
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this function. How do you do a lexicographic sort of a table of list columns?

Copy link
Contributor Author

@karthikeyann karthikeyann Feb 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this clear?
Performs a lexicographic sort of each list in a table.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So is it the same as calling sort independently on each column?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no. It will be similar to calling sort on each row of the table. Each row is a list.
The relative row order in the table remains same. In each row, the elements in list are sorted.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it's the same as calling sort_lists on each column in the table?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no. it's similar to treating each row as table, then sort it. then repeat it for all rows.

Copying the example from #6541

row# a b c
0 [21, 22, 23, 22] [13, 14, 12, 11] ["a", "b", "c", "d"]
1 [22, 21, 23, 22] [14, 13, 12, 11] ["a", "b", "c", "d"]

Here [...] is a list.

row# a b c
0
  • 21
  • 22
  • 23
  • 22
  • 13
  • 14
  • 12
  • 11
  • "a"
  • "b"
  • "c"
  • "d"
1
  • 22
  • 21
  • 23
  • 22
  • 14
  • 13
  • 12
  • 11
  • "a"
  • "b"
  • "c"
  • "d"

sort_lists(values={a,b,c}, keys={a,b}) output will be

row# a b c
0
  • 21
  • 22
  • 22
  • 23
  • 13
  • 11
  • 14
  • 12
  • "a"
  • "d"
  • "b"
  • "c"
1
  • 21
  • 22
  • 22
  • 23
  • 13
  • 11
  • 14
  • 12
  • "b"
  • "d"
  • "a"
  • "c"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that require all of the lists in a given row to have the same number of elements? What would the result be for this input?

row# a b c
0 [21, 22, 23, 22] [13, 14] ["a", "c", "d"]
1 [22, 21, 23, 22] [14] ["c", "d"]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.
For this input, it will throw a logic error.

*
* `keys` with list columns of depth 1 is only supported.
* @throws cudf::logic_error if `values.num_rows() != keys.num_rows()`.
* @throws cudf::logic_error if any list sizes of corresponding row in each column are not equal.
* @throws cudf::logic_error if any column of `keys` or `values` is not a list column.
*
* @param values The table with list columns to reorder
* @param keys The table with list coumns that determines the ordering of elements in each list
* @param column_order The desired order for each column in `keys`. Size must be
* equal to `keys.num_columns()` or empty. If empty, all columns are sorted in
* ascending order.
* @param null_precedence The desired order of a null element compared to other
* elements for each column in `keys`. Size must be equal to
* `keys.num_columns()` or empty. If empty, all columns will be sorted with
* `null_order::BEFORE`.
* @param mr Device memory resource to allocate any returned objects
* @return table with list columns with elements in each list sorted.
*
*/
std::unique_ptr<table> sort_lists(
table_view const& values,
table_view const& keys,
std::vector<order> const& column_order = {},
std::vector<null_order> const& null_precedence = {},
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/** @} */ // end of group
} // namespace cudf
246 changes: 246 additions & 0 deletions cpp/src/lists/segmented_sort.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cudf/column/column.hpp>
#include <cudf/column/column_device_view.cuh>
#include <cudf/column/column_factories.hpp>
#include <cudf/detail/copy.hpp>
#include <cudf/detail/gather.cuh>
#include <cudf/detail/iterator.cuh>
#include <cudf/detail/null_mask.hpp>
#include <cudf/detail/nvtx/ranges.hpp>
#include <cudf/lists/lists_column_view.hpp>
#include <cudf/lists/sorting.hpp>
#include <cudf/utilities/error.hpp>
#include <cudf/utilities/type_dispatcher.hpp>

#include <thrust/copy.h>
#include <cub/device/device_segmented_radix_sort.cuh>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_buffer.hpp>
#include <rmm/device_uvector.hpp>
#include <rmm/exec_policy.hpp>
karthikeyann marked this conversation as resolved.
Show resolved Hide resolved

namespace cudf {
namespace lists {
namespace detail {

struct SortPairs {
template <typename KeyT, typename ValueT, typename OffsetIteratorT>
void SortPairsAscending(KeyT const* keys_in,
KeyT* keys_out,
ValueT const* values_in,
ValueT* values_out,
int num_items,
int num_segments,
OffsetIteratorT begin_offsets,
OffsetIteratorT end_offsets,
rmm::cuda_stream_view stream)
{
rmm::device_buffer d_temp_storage;
size_t temp_storage_bytes = 0;
cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage.data(),
jrhemstad marked this conversation as resolved.
Show resolved Hide resolved
temp_storage_bytes,
keys_in,
keys_out,
values_in,
values_out,
num_items,
num_segments,
begin_offsets,
end_offsets,
0,
sizeof(KeyT) * 8,
stream.value());
d_temp_storage = rmm::device_buffer{temp_storage_bytes, stream};

cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage.data(),
temp_storage_bytes,
keys_in,
keys_out,
values_in,
values_out,
num_items,
num_segments,
begin_offsets,
end_offsets,
0,
sizeof(KeyT) * 8,
stream.value());
}

template <typename KeyT, typename ValueT, typename OffsetIteratorT>
void SortPairsDescending(KeyT const* keys_in,
KeyT* keys_out,
ValueT const* values_in,
ValueT* values_out,
int num_items,
int num_segments,
OffsetIteratorT begin_offsets,
OffsetIteratorT end_offsets,
rmm::cuda_stream_view stream)
{
rmm::device_buffer d_temp_storage;
size_t temp_storage_bytes = 0;
cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage.data(),
temp_storage_bytes,
keys_in,
keys_out,
values_in,
values_out,
num_items,
num_segments,
begin_offsets,
end_offsets,
0,
sizeof(KeyT) * 8,
stream.value());
d_temp_storage = rmm::device_buffer{temp_storage_bytes, stream};

cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage.data(),
temp_storage_bytes,
keys_in,
keys_out,
values_in,
values_out,
num_items,
num_segments,
begin_offsets,
end_offsets,
0,
sizeof(KeyT) * 8,
stream.value());
}
template <typename T>
std::enable_if_t<not is_numeric<T>(), std::unique_ptr<column>> operator()(
column_view const& child,
column_view const& offsets,
order column_order,
null_order null_precedence,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
CUDF_FAIL("segmented sort is not supported for non-numeric list types");
}
template <typename T>
std::enable_if_t<is_numeric<T>(), std::unique_ptr<column>> operator()(
column_view const& child,
column_view const& offsets,
order column_order,
null_order null_precedence,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto output =
cudf::detail::allocate_like(child, child.size(), mask_allocation_policy::NEVER, stream, mr);
mutable_column_view mutable_output_view = output->mutable_view();

auto keys = [&]() {
if (child.nullable()) {
rmm::device_uvector<T> keys(child.size(), stream);
auto null_replace_T = null_precedence == null_order::AFTER ? std::numeric_limits<T>::max()
: std::numeric_limits<T>::min();
karthikeyann marked this conversation as resolved.
Show resolved Hide resolved
auto device_child = column_device_view::create(child, stream);
auto keys_in =
cudf::detail::make_null_replacement_iterator<T>(*device_child, null_replace_T);
thrust::copy_n(rmm::exec_policy(stream), keys_in, child.size(), keys.begin());
return keys;
}
return rmm::device_uvector<T>{0, stream};
}();

std::unique_ptr<column> sorted_indices = cudf::make_numeric_column(
data_type(type_to_id<size_type>()), child.size(), mask_state::UNALLOCATED, stream, mr);
mutable_column_view mutable_indices_view = sorted_indices->mutable_view();
thrust::sequence(rmm::exec_policy(stream),
mutable_indices_view.begin<size_type>(),
mutable_indices_view.end<size_type>(),
0);

if (column_order == order::ASCENDING)
SortPairsAscending(child.nullable() ? keys.data() : child.begin<T>(),
mutable_output_view.begin<T>(),
mutable_indices_view.begin<size_type>(),
mutable_indices_view.begin<size_type>(),
child.size(),
offsets.size() - 1,
offsets.begin<size_type>(),
offsets.begin<size_type>() + 1,
stream);
else
SortPairsDescending(child.nullable() ? keys.data() : child.begin<T>(),
mutable_output_view.begin<T>(),
mutable_indices_view.begin<size_type>(),
mutable_indices_view.begin<size_type>(),
child.size(),
offsets.size() - 1,
offsets.begin<size_type>(),
offsets.begin<size_type>() + 1,
stream);
std::vector<std::unique_ptr<column>> output_cols;
output_cols.push_back(std::move(output));
codereport marked this conversation as resolved.
Show resolved Hide resolved
// rearrange the null_mask.
cudf::detail::gather_bitmask(cudf::table_view{{child}},
mutable_indices_view.begin<size_type>(),
output_cols,
cudf::detail::gather_bitmask_op::DONT_CHECK,
stream,
mr);
return std::move(output_cols.front());
}
};

std::unique_ptr<column> sort_lists(lists_column_view const& input,
order column_order,
null_order null_precedence,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
if (input.size() == 0) return {};
karthikeyann marked this conversation as resolved.
Show resolved Hide resolved

auto output_child = type_dispatcher(input.child().type(),
SortPairs{},
input.get_sliced_child(stream),
input.offsets(),
column_order,
null_precedence,
stream,
mr);

// Copy list offsets.
auto output_offset = std::make_unique<column>(input.offsets(), stream, mr);
auto null_mask = cudf::detail::copy_bitmask(input.parent(), stream, mr);

// Assemble list column & return
return make_lists_column(input.size(),
std::move(output_offset),
std::move(output_child),
input.null_count(),
std::move(null_mask));
}
} // namespace detail

std::unique_ptr<column> sort_lists(lists_column_view const& input,
order column_order,
null_order null_precedence,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
return detail::sort_lists(input, column_order, null_precedence, rmm::cuda_stream_default, mr);
}

} // namespace lists
} // namespace cudf
Loading