From d8eaa024c2e8a6eb689da3e6a650e441147d81ca Mon Sep 17 00:00:00 2001 From: David Wendt Date: Tue, 13 Feb 2024 17:18:13 -0500 Subject: [PATCH] Use appropriate make_offsets_child_column for building lists columns --- cpp/src/strings/extract/extract_all.cu | 29 ++++++++++++-------------- cpp/src/strings/search/findall.cu | 8 +++---- 2 files changed, 17 insertions(+), 20 deletions(-) diff --git a/cpp/src/strings/extract/extract_all.cu b/cpp/src/strings/extract/extract_all.cu index 0c0d4ae4fbf..d49addf1324 100644 --- a/cpp/src/strings/extract/extract_all.cu +++ b/cpp/src/strings/extract/extract_all.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -118,12 +118,12 @@ std::unique_ptr extract_all_record(strings_column_view const& input, // Get the match counts for each string. // This column will become the output lists child offsets column. - auto offsets = count_matches(*d_strings, *d_prog, strings_count + 1, stream, mr); - auto d_offsets = offsets->mutable_view().data(); + auto counts = count_matches(*d_strings, *d_prog, strings_count, stream, mr); + auto d_counts = counts->mutable_view().data(); // Compute null output rows auto [null_mask, null_count] = cudf::detail::valid_if( - d_offsets, d_offsets + strings_count, [] __device__(auto v) { return v > 0; }, stream, mr); + d_counts, d_counts + strings_count, [] __device__(auto v) { return v > 0; }, stream, mr); // Return an empty lists column if there are no valid rows if (strings_count == null_count) { @@ -132,18 +132,15 @@ std::unique_ptr extract_all_record(strings_column_view const& input, // Convert counts into offsets. // Multiply each count by the number of groups. - thrust::transform_exclusive_scan( - rmm::exec_policy(stream), - d_offsets, - d_offsets + strings_count + 1, - d_offsets, - [groups] __device__(auto v) { return v * groups; }, - size_type{0}, - thrust::plus{}); - auto const total_groups = - cudf::detail::get_value(offsets->view(), strings_count, stream); - - rmm::device_uvector indices(total_groups, stream); + auto sizes_itr = cudf::detail::make_counting_transform_iterator( + 0, cuda::proclaim_return_type([d_counts, groups] __device__(auto idx) { + return d_counts[idx] * groups; + })); + auto [offsets, total_strings] = + cudf::detail::make_offsets_child_column(sizes_itr, sizes_itr + strings_count, stream, mr); + auto d_offsets = offsets->view().data(); + + rmm::device_uvector indices(total_strings, stream); launch_for_each_kernel( extract_fn{*d_strings, d_offsets, indices.data()}, *d_prog, strings_count, stream); diff --git a/cpp/src/strings/search/findall.cu b/cpp/src/strings/search/findall.cu index 8df1a67d56d..578d0eac3c8 100644 --- a/cpp/src/strings/search/findall.cu +++ b/cpp/src/strings/search/findall.cu @@ -48,7 +48,7 @@ namespace { */ struct findall_fn { column_device_view const d_strings; - cudf::detail::input_offsetalator const d_offsets; + size_type const* d_offsets; string_index_pair* d_indices; __device__ void operator()(size_type const idx, reprog_device const prog, int32_t const prog_idx) @@ -76,7 +76,7 @@ struct findall_fn { std::unique_ptr findall_util(column_device_view const& d_strings, reprog_device& d_prog, int64_t total_matches, - cudf::detail::input_offsetalator const d_offsets, + size_type const* d_offsets, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { @@ -104,9 +104,9 @@ std::unique_ptr findall(strings_column_view const& input, // Create lists offsets column auto const sizes = count_matches(*d_strings, *d_prog, strings_count, stream, mr); - auto [offsets, total_matches] = cudf::strings::detail::make_offsets_child_column( + auto [offsets, total_matches] = cudf::detail::make_offsets_child_column( sizes->view().begin(), sizes->view().end(), stream, mr); - auto const d_offsets = cudf::detail::offsetalator_factory::make_input_iterator(offsets->view()); + auto const d_offsets = offsets->view().data(); // Build strings column of the matches auto strings_output = findall_util(*d_strings, *d_prog, total_matches, d_offsets, stream, mr);