diff --git a/cpp/src/strings/filling/fill.cu b/cpp/src/strings/filling/fill.cu index b48d56a595c..878d0fe11ba 100644 --- a/cpp/src/strings/filling/fill.cu +++ b/cpp/src/strings/filling/fill.cu @@ -15,10 +15,8 @@ */ #include -#include #include -#include -#include +#include #include #include #include @@ -27,35 +25,34 @@ #include #include +#include namespace cudf { namespace strings { namespace detail { namespace { + struct fill_fn { column_device_view const d_strings; size_type const begin; size_type const end; - string_view const d_value; - size_type* d_offsets{}; - char* d_chars{}; - - __device__ string_view resolve_string_at(size_type idx) const - { - if ((begin <= idx) && (idx < end)) { return d_value; } - return d_strings.is_valid(idx) ? d_strings.element(idx) : string_view{}; - } + string_scalar_device_view const d_value; - __device__ void operator()(size_type idx) const + __device__ string_index_pair operator()(size_type idx) const { - auto const d_str = resolve_string_at(idx); - if (!d_chars) { - d_offsets[idx] = d_str.size_bytes(); + auto d_str = string_view(); + if ((begin <= idx) && (idx < end)) { + if (!d_value.is_valid()) { return string_index_pair{nullptr, 0}; } + d_str = d_value.value(); } else { - copy_string(d_chars + d_offsets[idx], d_str); + if (d_strings.is_null(idx)) { return string_index_pair{nullptr, 0}; } + d_str = d_strings.element(idx); } + return !d_str.empty() ? string_index_pair{d_str.data(), d_str.size_bytes()} + : string_index_pair{"", 0}; } }; + } // namespace std::unique_ptr fill(strings_column_view const& input, @@ -72,33 +69,18 @@ std::unique_ptr fill(strings_column_view const& input, CUDF_EXPECTS(begin <= end, "Parameters [begin,end) have invalid range values"); if (begin == end) { return std::make_unique(input.parent(), stream, mr); } - auto strings_column = column_device_view::create(input.parent(), stream); - auto const d_strings = *strings_column; - auto const is_valid = value.is_valid(stream); - - // create resulting null mask - auto [null_mask, null_count] = [begin, end, is_valid, d_strings, stream, mr] { - if (begin == 0 and end == d_strings.size() and is_valid) { - return std::pair(rmm::device_buffer{}, 0); - } - return cudf::detail::valid_if( - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(d_strings.size()), - [d_strings, begin, end, is_valid] __device__(size_type idx) { - return ((begin <= idx) && (idx < end)) ? is_valid : d_strings.is_valid(idx); - }, - stream, - mr); - }(); - - auto const d_value = const_cast(value); - auto const d_str = is_valid ? d_value.value(stream) : string_view{}; - auto fn = fill_fn{d_strings, begin, end, d_str}; + auto const d_strings = column_device_view::create(input.parent(), stream); + auto const d_value = cudf::get_scalar_device_view(const_cast(value)); - auto [offsets_column, chars] = make_strings_children(fn, strings_count, stream, mr); + auto fn = fill_fn{*d_strings, begin, end, d_value}; + rmm::device_uvector indices(strings_count, stream); + thrust::transform(rmm::exec_policy_nosync(stream), + thrust::counting_iterator(0), + thrust::counting_iterator(strings_count), + indices.begin(), + fn); - return make_strings_column( - strings_count, std::move(offsets_column), chars.release(), null_count, std::move(null_mask)); + return make_strings_column(indices.begin(), indices.end(), stream, mr); } } // namespace detail