Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix valid count computation in offset_bitmask_binop kernel #13489

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions cpp/include/cudf/detail/null_mask.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,11 @@ __global__ void offset_bitmask_binop(Binop op,
size_type source_size_bits,
size_type* count_ptr)
{
constexpr auto const word_size{detail::size_in_bits<bitmask_type>()};
auto const tid = threadIdx.x + blockIdx.x * blockDim.x;

auto const last_bit_index = source_size_bits - 1;
auto const last_word_index = cudf::word_index(last_bit_index);

size_type thread_count = 0;

for (size_type destination_word_index = tid; destination_word_index < destination.size();
Expand All @@ -86,20 +88,19 @@ __global__ void offset_bitmask_binop(Binop op,
source_begin_bits[i] + source_size_bits));
}

if (destination_word_index == last_word_index) {
bdice marked this conversation as resolved.
Show resolved Hide resolved
// mask out any bits not part of this word
auto const num_bits_in_last_word = intra_word_index(last_bit_index);
if (num_bits_in_last_word <
static_cast<size_type>(detail::size_in_bits<bitmask_type>() - 1)) {
destination_word &= set_least_significant_bits(num_bits_in_last_word + 1);
}
}

destination[destination_word_index] = destination_word;
thread_count += __popc(destination_word);
}

// Subtract any slack bits from the last word
if (tid == 0) {
size_type const last_bit_index = source_size_bits - 1;
size_type const num_slack_bits = word_size - (last_bit_index % word_size) - 1;
if (num_slack_bits > 0) {
size_type const word_index = cudf::word_index(last_bit_index);
thread_count -= __popc(destination[word_index] & set_most_significant_bits(num_slack_bits));
}
}

using BlockReduce = cub::BlockReduce<size_type, block_size>;
__shared__ typename BlockReduce::TempStorage temp_storage;
size_type block_count = BlockReduce(temp_storage).Sum(thread_count);
Expand Down