diff --git a/cpp/src/stream_compaction/distinct_count.cu b/cpp/src/stream_compaction/distinct_count.cu index 7185dae77b7..f6037ea7785 100644 --- a/cpp/src/stream_compaction/distinct_count.cu +++ b/cpp/src/stream_compaction/distinct_count.cu @@ -150,15 +150,23 @@ cudf::size_type distinct_count(table_view const& keys, stream.value()}; auto const iter = thrust::counting_iterator(0); - // when nulls are equal, insert non-null rows only to improve efficiency + // when nulls are equal, only insert those rows that are not all null to improve efficiency if (nulls_equal == null_equality::EQUAL and has_nulls) { thrust::counting_iterator stencil(0); + // We must consider a row if any of its column entries is valid, + // hence OR together the validities of the columns. auto const [row_bitmask, null_count] = cudf::detail::bitmask_or(keys, stream, rmm::mr::get_current_device_resource()); - row_validity pred{static_cast(row_bitmask.data())}; - return key_set.insert_if(iter, iter + num_rows, stencil, pred, stream.value()) + - static_cast(null_count > 0); + // Unless all columns have a null mask, row_bitmask will be + // null, and null_count will be zero. Equally, unless there is + // some row which is null in all columns, null_count will be + // zero. So, it is only when null_count is not zero that we need + // to do a filtered insertion. + if (null_count > 0) { + row_validity pred{static_cast(row_bitmask.data())}; + return key_set.insert_if(iter, iter + num_rows, stencil, pred, stream.value()) + 1; + } } // otherwise, insert all return key_set.insert(iter, iter + num_rows, stream.value()); diff --git a/cpp/tests/stream_compaction/distinct_count_tests.cpp b/cpp/tests/stream_compaction/distinct_count_tests.cpp index e80244cee41..864ac8f84c6 100644 --- a/cpp/tests/stream_compaction/distinct_count_tests.cpp +++ b/cpp/tests/stream_compaction/distinct_count_tests.cpp @@ -274,6 +274,16 @@ TEST_F(DistinctCount, TableWithNull) EXPECT_EQ(10, cudf::distinct_count(input, null_equality::UNEQUAL)); } +TEST_F(DistinctCount, TableWithSomeNull) +{ + cudf::test::fixed_width_column_wrapper col1{{1, 2, 3, 4, 5, 6}, {1, 0, 1, 0, 1, 0}}; + cudf::test::fixed_width_column_wrapper col2{{1, 1, 1, 1, 1, 1}}; + cudf::table_view input{{col1, col2}}; + + EXPECT_EQ(4, cudf::distinct_count(input, null_equality::EQUAL)); + EXPECT_EQ(6, cudf::distinct_count(input, null_equality::UNEQUAL)); +} + TEST_F(DistinctCount, EmptyColumnedTable) { std::vector cols{};