diff --git a/cpp/src/stream_compaction/distinct_count.cu b/cpp/src/stream_compaction/distinct_count.cu index 8c50f8d29e8..7185dae77b7 100644 --- a/cpp/src/stream_compaction/distinct_count.cu +++ b/cpp/src/stream_compaction/distinct_count.cu @@ -34,6 +34,8 @@ #include #include +#include + #include #include #include @@ -127,27 +129,27 @@ cudf::size_type distinct_count(table_view const& keys, null_equality nulls_equal, rmm::cuda_stream_view stream) { - auto const num_rows = keys.num_rows(); + auto const num_rows = keys.num_rows(); + if (num_rows == 0) { return 0; } // early exit for empty input auto const has_nulls = nullate::DYNAMIC{cudf::has_nested_nulls(keys)}; - hash_map_type key_map{compute_hash_table_size(num_rows), - cuco::empty_key{COMPACTION_EMPTY_KEY_SENTINEL}, - cuco::empty_value{COMPACTION_EMPTY_VALUE_SENTINEL}, - detail::hash_table_allocator_type{default_allocator{}, stream}, - stream.value()}; - auto const preprocessed_input = cudf::experimental::row::hash::preprocessed_table::create(keys, stream); - auto const row_hasher = cudf::experimental::row::hash::row_hasher(preprocessed_input); auto const hash_key = experimental::compaction_hash(row_hasher.device_hasher(has_nulls)); - - auto const row_comp = cudf::experimental::row::equality::self_comparator(preprocessed_input); - - auto iter = cudf::detail::make_counting_transform_iterator( - 0, [] __device__(size_type i) { return cuco::make_pair(i, i); }); + auto const row_comp = cudf::experimental::row::equality::self_comparator(preprocessed_input); auto const comparator_helper = [&](auto const row_equal) { + using hasher_type = decltype(hash_key); + auto key_set = cuco::experimental::static_set{ + cuco::experimental::extent{compute_hash_table_size(num_rows)}, + cuco::empty_key{COMPACTION_EMPTY_KEY_SENTINEL}, + row_equal, + cuco::experimental::linear_probing<1, hasher_type>{hash_key}, + detail::hash_table_allocator_type{default_allocator{}, stream}, + stream.value()}; + + auto const iter = thrust::counting_iterator(0); // when nulls are equal, insert non-null rows only to improve efficiency if (nulls_equal == null_equality::EQUAL and has_nulls) { thrust::counting_iterator stencil(0); @@ -155,12 +157,11 @@ cudf::size_type distinct_count(table_view const& keys, cudf::detail::bitmask_or(keys, stream, rmm::mr::get_current_device_resource()); row_validity pred{static_cast(row_bitmask.data())}; - key_map.insert_if(iter, iter + num_rows, stencil, pred, hash_key, row_equal, stream.value()); - return key_map.get_size() + static_cast(null_count > 0); + return key_set.insert_if(iter, iter + num_rows, stencil, pred, stream.value()) + + static_cast(null_count > 0); } // otherwise, insert all - key_map.insert(iter, iter + num_rows, hash_key, row_equal, stream.value()); - return key_map.get_size(); + return key_set.insert(iter, iter + num_rows, stream.value()); }; if (cudf::detail::has_nested_columns(keys)) {