Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve distinct by using cuco::static_map::retrieve_all #10916

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/cmake/thirdparty/get_cucollections.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function(find_and_configure_cucollections)
GLOBAL_TARGETS cuco::cuco
BUILD_EXPORT_SET cudf-exports
CPM_ARGS GITHUB_REPOSITORY NVIDIA/cuCollections
GIT_TAG 8b15f06f38d034e815bc72045ca3403787f75e07
GIT_TAG ebaba1ae378a5272116414b6d7ae5847e5cf5715
EXCLUDE_FROM_ALL ${BUILD_SHARED_LIBS}
OPTIONS "BUILD_TESTS OFF" "BUILD_BENCHMARKS OFF" "BUILD_EXAMPLES OFF"
)
Expand Down
19 changes: 5 additions & 14 deletions cpp/src/stream_compaction/distinct.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include <thrust/execution_policy.h>
#include <thrust/functional.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/discard_iterator.h>

#include <utility>
#include <vector>
Expand Down Expand Up @@ -79,23 +80,13 @@ std::unique_ptr<table> distinct(table_view const& input,
// insert distinct indices into the map.
key_map.insert(iter, iter + num_rows, hash_key, key_equal, stream.value());

auto counting_iter = thrust::make_counting_iterator<size_type>(0);
rmm::device_uvector<bool> index_exists_in_map(num_rows, stream, mr);
// enumerate all indices to check if they are present in the map.
key_map.contains(counting_iter, counting_iter + num_rows, index_exists_in_map.begin(), hash_key);

auto const output_size{key_map.get_size()};

// write distinct indices to a numeric column
auto distinct_indices = cudf::make_numeric_column(
data_type{type_id::INT32}, output_size, mask_state::UNALLOCATED, stream, mr);
auto mutable_view = mutable_column_device_view::create(*distinct_indices, stream);
thrust::copy_if(rmm::exec_policy(stream),
counting_iter,
counting_iter + num_rows,
index_exists_in_map.begin(),
mutable_view->begin<size_type>(),
thrust::identity<bool>{});
// write distinct indices to a numeric column
key_map.retrieve_all(distinct_indices->mutable_view().begin<cudf::size_type>(),
thrust::make_discard_iterator(),
stream.value());

// run gather operation to establish new order
return detail::gather(input,
Expand Down
9 changes: 6 additions & 3 deletions cpp/tests/stream_compaction/distinct_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,13 +420,15 @@ TEST_F(Distinct, SlicedStructsOfLists)
using lists_col = cudf::test::lists_column_wrapper<int32_t>;
using structs_col = cudf::test::structs_column_wrapper;

auto const idx =
cudf::test::fixed_width_column_wrapper<int>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
auto const structs = [] {
auto child = lists_col{
{0, 0}, {0, 0}, {1}, {1, 1}, {1}, {1, 2}, {2, 2}, {2}, {2}, {2, 1}, {2, 2}, {2, 2}, {5, 5}};
return structs_col{{child}};
}();

auto const input_original = cudf::table_view({structs});
auto const input_original = cudf::table_view({idx, structs});
auto const input = cudf::slice(input_original, {2, 12})[0];

auto const expected_structs = [] {
Expand All @@ -435,8 +437,9 @@ TEST_F(Distinct, SlicedStructsOfLists)
}();
auto const expected = cudf::table_view({expected_structs});

auto const result = cudf::distinct(input, {0});
CUDF_TEST_EXPECT_TABLES_EQUAL(expected, *result);
auto const result = cudf::distinct(input, {1});
auto const sorted_result = cudf::sort_by_key(*result, result->select({0}));
CUDF_TEST_EXPECT_TABLES_EQUAL(expected, cudf::table_view{{sorted_result->get_column(1)}});
}

TEST_F(Distinct, StructWithNullElement)
Expand Down