diff --git a/cpp/include/cudf/detail/null_mask.cuh b/cpp/include/cudf/detail/null_mask.cuh index 6090477c28d..90abb8dbeae 100644 --- a/cpp/include/cudf/detail/null_mask.cuh +++ b/cpp/include/cudf/detail/null_mask.cuh @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -279,7 +280,8 @@ rmm::device_uvector segmented_count_bits(bitmask_type const* bitmask, OffsetIterator first_bit_indices_end, OffsetIterator last_bit_indices_begin, count_bits_policy count_bits, - rmm::cuda_stream_view stream) + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) { auto const num_ranges = static_cast(std::distance(first_bit_indices_begin, first_bit_indices_end)); @@ -329,14 +331,15 @@ rmm::device_uvector segmented_count_bits(bitmask_type const* bitmask, // set bits from the length of the segment. auto segments_begin = thrust::make_zip_iterator(first_bit_indices_begin, last_bit_indices_begin); - auto segments_size = thrust::transform_iterator(segments_begin, [] __device__(auto segment) { - auto const begin = thrust::get<0>(segment); - auto const end = thrust::get<1>(segment); - return end - begin; - }); + auto segment_length_iterator = + thrust::transform_iterator(segments_begin, [] __device__(auto const& segment) { + auto const begin = thrust::get<0>(segment); + auto const end = thrust::get<1>(segment); + return end - begin; + }); thrust::transform(rmm::exec_policy(stream), - segments_size, - segments_size + num_ranges, + segment_length_iterator, + segment_length_iterator + num_ranges, d_bit_counts.data(), d_bit_counts.data(), [] __device__(auto segment_size, auto segment_bit_count) { @@ -438,7 +441,8 @@ std::vector segmented_count_bits(bitmask_type const* bitmask, first_bit_indices_end, last_bit_indices_begin, count_bits, - stream); + stream, + rmm::mr::get_current_device_resource()); // Copy the results back to the host. return make_std_vector_sync(d_bit_counts, stream); @@ -501,6 +505,110 @@ std::vector segmented_null_count(bitmask_type const* bitmask, return detail::segmented_count_unset_bits(bitmask, indices_begin, indices_end, stream); } +/** + * @brief Reduce an input null mask using segments defined by offset indices + * into an output null mask. + * + * @tparam OffsetIterator Random-access input iterator type. + * @param bitmask Null mask residing in device memory whose segments will be + * reduced into a new mask. + * @param first_bit_indices_begin Random-access input iterator to the beginning + * of a sequence of indices of the first bit in each segment (inclusive). + * @param first_bit_indices_end Random-access input iterator to the end of a + * sequence of indices of the first bit in each segment (inclusive). + * @param last_bit_indices_begin Random-access input iterator to the beginning + * of a sequence of indices of the last bit in each segment (exclusive). + * @param null_handling If `INCLUDE`, all elements in a segment must be valid + * for the reduced value to be valid. If `EXCLUDE`, the reduction is valid if + * any element in the segment is valid. + * @param stream CUDA stream used for device memory operations and kernel launches. + * @param mr Device memory resource used to allocate the returned buffer's device memory. + * @return A pair containing the reduced null mask and number of nulls. + */ +template +std::pair segmented_null_mask_reduction( + bitmask_type const* bitmask, + OffsetIterator first_bit_indices_begin, + OffsetIterator first_bit_indices_end, + OffsetIterator last_bit_indices_begin, + null_policy null_handling, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + auto const segments_begin = + thrust::make_zip_iterator(first_bit_indices_begin, last_bit_indices_begin); + auto const segment_length_iterator = + thrust::make_transform_iterator(segments_begin, [] __device__(auto const& segment) { + auto const begin = thrust::get<0>(segment); + auto const end = thrust::get<1>(segment); + return end - begin; + }); + + // Empty segments are always null in the output mask + auto const num_segments = std::distance(first_bit_indices_begin, first_bit_indices_end); + auto [output_null_mask, output_null_count] = cudf::detail::valid_if( + segment_length_iterator, + segment_length_iterator + num_segments, + [] __device__(auto const& len) { return len > 0; }, + stream, + mr); + + if (bitmask != nullptr) { + [[maybe_unused]] auto const [null_policy_bitmask, _] = [&]() { + if (null_handling == null_policy::EXCLUDE) { + // Output null mask should be valid if any element in the segment is + // valid and the segment is non-empty. + auto const valid_counts = + cudf::detail::segmented_count_bits(bitmask, + first_bit_indices_begin, + first_bit_indices_end, + last_bit_indices_begin, + cudf::detail::count_bits_policy::SET_BITS, + stream, + rmm::mr::get_current_device_resource()); + return cudf::detail::valid_if( + valid_counts.begin(), + valid_counts.end(), + [] __device__(auto const valid_count) { return valid_count > 0; }, + stream); + } else { + // Output null mask should be valid if all elements in the segment are + // valid and the segment is non-empty. + auto const null_counts = + cudf::detail::segmented_count_bits(bitmask, + first_bit_indices_begin, + first_bit_indices_end, + last_bit_indices_begin, + cudf::detail::count_bits_policy::UNSET_BITS, + stream, + rmm::mr::get_current_device_resource()); + return cudf::detail::valid_if( + null_counts.begin(), + null_counts.end(), + [] __device__(auto const null_count) { return null_count == 0; }, + stream); + } + }(); + + std::vector masks{ + reinterpret_cast(output_null_mask.data()), + reinterpret_cast(null_policy_bitmask.data())}; + std::vector begin_bits{0, 0}; + cudf::detail::inplace_bitmask_and( + device_span(reinterpret_cast(output_null_mask.data()), + num_bitmask_words(num_segments)), + masks, + begin_bits, + num_segments, + stream, + mr); + + // TODO: inplace_bitmask_and should return its null count (PR 9904) + output_null_count = cudf::UNKNOWN_NULL_COUNT; + } + return std::make_pair(std::move(output_null_mask), output_null_count); +} + } // namespace detail } // namespace cudf diff --git a/cpp/include/cudf/reduction.hpp b/cpp/include/cudf/reduction.hpp index 978a42fe843..0795d90cd19 100644 --- a/cpp/include/cudf/reduction.hpp +++ b/cpp/include/cudf/reduction.hpp @@ -103,7 +103,9 @@ std::unique_ptr reduce( * @param agg Aggregation operator applied by the reduction * @param offsets Indices to segment boundaries * @param output_dtype The computation and output precision. - * @param null_handling `INCLUDE` + * @param null_handling If `INCLUDE`, all elements in a segment must be valid + * for the reduced value to be valid. If `EXCLUDE`, the reduction is valid if + * any element in the segment is valid. * @param mr Device memory resource used to allocate the returned scalar's device memory * @returns Output column with segment's reduce result. */ diff --git a/cpp/src/reductions/simple_segmented.cuh b/cpp/src/reductions/simple_segmented.cuh index 18aae0e7312..9829ecfc529 100644 --- a/cpp/src/reductions/simple_segmented.cuh +++ b/cpp/src/reductions/simple_segmented.cuh @@ -49,6 +49,9 @@ namespace detail { * @param col Input column of data to reduce * @param offsets Indices to segment boundaries + * @param null_handling If `INCLUDE`, all elements in a segment must be valid + * for the reduced value to be valid. If `EXCLUDE`, the reduction is valid if + * any element in the segment is valid. * @param stream Used for device memory operations and kernel launches. * @param mr Device memory resource used to allocate the returned column's device memory * @return Output column in device memory @@ -79,85 +82,20 @@ std::unique_ptr simple_segmented_reduction(column_view const& col, } }(); - // Compute output null mask - auto const bitmask = col.null_mask(); - - // Compute segment lengths to get the output null mask + // Compute the output null mask + auto const bitmask = col.null_mask(); auto const first_bit_indices_begin = offsets.begin(); auto const first_bit_indices_end = offsets.end() - 1; auto const last_bit_indices_begin = offsets.begin() + 1; - - // TODO: Investigate segment length iterator? Seems reusable. - auto const indices_start_end_pair_iterator = - thrust::make_zip_iterator(first_bit_indices_begin, last_bit_indices_begin); - auto const segment_length_iterator = - thrust::make_transform_iterator(indices_start_end_pair_iterator, [] __device__(auto const& p) { - auto const start = thrust::get<0>(p); - auto const end = thrust::get<1>(p); - return end - start; - }); - - [[maybe_unused]] auto [output_null_mask, _] = cudf::detail::valid_if( - segment_length_iterator, - segment_length_iterator + col.size(), - [] __device__(auto const& len) { return len > 0; }, - stream, - mr); - - if (bitmask != nullptr) { - [[maybe_unused]] auto const [null_policy_bitmask, _] = [&]() { - if (null_handling == null_policy::EXCLUDE) { - // Output null mask should be valid if any element in the segment is - // valid and the segment is non-empty. - - // TODO: This needs a nicer function wrapping segmented_count_bits on device - auto const valid_counts = - cudf::detail::segmented_count_bits(bitmask, - first_bit_indices_begin, - first_bit_indices_end, - last_bit_indices_begin, - cudf::detail::count_bits_policy::SET_BITS, - stream); - return cudf::detail::valid_if( - valid_counts.begin(), - valid_counts.end(), - [] __device__(auto const valid_count) { return valid_count > 0; }, - stream); - } else { - // Output null mask should be valid if all elements in the segment are - // valid and the segment is non-empty. - - // TODO: This needs a nicer function wrapping segmented_count_bits on device - auto const null_counts = - cudf::detail::segmented_count_bits(bitmask, - first_bit_indices_begin, - first_bit_indices_end, - last_bit_indices_begin, - cudf::detail::count_bits_policy::UNSET_BITS, - stream); - return cudf::detail::valid_if( - null_counts.begin(), - null_counts.end(), - [] __device__(auto const null_count) { return null_count == 0; }, - stream); - } - }(); - - // TODO: inplace_bitmask_and should return its null count (bdice working on PR) - std::vector masks{ - reinterpret_cast(output_null_mask.data()), - reinterpret_cast(null_policy_bitmask.data())}; - std::vector begin_bits{0, 0}; - cudf::detail::inplace_bitmask_and( - device_span(reinterpret_cast(output_null_mask.data()), - num_bitmask_words(col.size())), - masks, - begin_bits, - col.size(), - stream, - mr); - } - result->set_null_mask(output_null_mask, cudf::UNKNOWN_NULL_COUNT, stream); + auto const [output_null_mask, output_null_count] = + cudf::detail::segmented_null_mask_reduction(bitmask, + first_bit_indices_begin, + first_bit_indices_end, + last_bit_indices_begin, + null_handling, + stream, + mr); + result->set_null_mask(output_null_mask, output_null_count, stream); return result; }