diff --git a/cpp/src/lists/combine/concatenate_list_elements.cu b/cpp/src/lists/combine/concatenate_list_elements.cu index c5a28a8ec5f..fb6bff3f129 100644 --- a/cpp/src/lists/combine/concatenate_list_elements.cu +++ b/cpp/src/lists/combine/concatenate_list_elements.cu @@ -51,13 +51,9 @@ std::unique_ptr concatenate_lists_ignore_null(column_view const& input, auto out_offsets = make_numeric_column( data_type{type_id::INT32}, num_rows + 1, mask_state::UNALLOCATED, stream, mr); - // The array of int8_t stores validities for the output list elements. - auto validities = rmm::device_uvector(build_null_mask ? num_rows : 0, stream); - auto const d_out_offsets = out_offsets->mutable_view().template begin(); auto const d_row_offsets = lists_column_view(input).offsets_begin(); auto const d_list_offsets = lists_column_view(lists_column_view(input).child()).offsets_begin(); - auto const lists_dv_ptr = column_device_view::create(lists_column_view(input).child()); // Concatenating the lists at the same row by converting the entry offsets from the child column // into row offsets of the root column. Those entry offsets are subtracted by the first entry @@ -67,22 +63,7 @@ std::unique_ptr concatenate_lists_ignore_null(column_view const& input, iter, iter + num_rows + 1, d_out_offsets, - [d_row_offsets, - d_list_offsets, - lists_dv = *lists_dv_ptr, - d_validities = validities.begin(), - build_null_mask, - iter] __device__(auto const idx) { - if (build_null_mask) { - // The output row will be null only if all lists on the input row are null. - auto const is_valid = thrust::any_of(thrust::seq, - iter + d_row_offsets[idx], - iter + d_row_offsets[idx + 1], - [&] __device__(auto const list_idx) { - return lists_dv.is_valid(list_idx); - }); - d_validities[idx] = static_cast(is_valid); - } + [d_row_offsets, d_list_offsets] __device__(auto const idx) { auto const start_offset = d_list_offsets[d_row_offsets[0]]; return d_list_offsets[d_row_offsets[idx]] - start_offset; }); @@ -92,10 +73,23 @@ std::unique_ptr concatenate_lists_ignore_null(column_view const& input, lists_column_view(lists_column_view(input).get_sliced_child(stream)).get_sliced_child(stream)); auto [null_mask, null_count] = [&] { - return build_null_mask - ? cudf::detail::valid_if( - validities.begin(), validities.end(), thrust::identity{}, stream, mr) - : std::make_pair(cudf::detail::copy_bitmask(input, stream, mr), input.null_count()); + if (!build_null_mask) + return std::make_pair(cudf::detail::copy_bitmask(input, stream, mr), input.null_count()); + + // The output row will be null only if all lists on the input row are null. + auto const lists_dv_ptr = column_device_view::create(lists_column_view(input).child(), stream); + return cudf::detail::valid_if( + iter, + iter + num_rows, + [d_row_offsets, lists_dv = *lists_dv_ptr, iter] __device__(auto const idx) { + return thrust::any_of( + thrust::seq, + iter + d_row_offsets[idx], + iter + d_row_offsets[idx + 1], + [&] __device__(auto const list_idx) { return lists_dv.is_valid(list_idx); }); + }, + stream, + mr); }(); return make_lists_column(num_rows,