Skip to content

Commit

Permalink
Large strings support for cudf::fill (#15555)
Browse files Browse the repository at this point in the history
Updates the `cudf::fill` strings specialization logic to use gather-based `make_strings_column` instead of the `make_strings_children` since the gather-based function already efficiently supports longs.

Authors:
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - Muhammad Haseeb (https://github.com/mhaseeb123)
  - Mark Harris (https://github.com/harrism)

URL: #15555
  • Loading branch information
davidwendt authored Apr 23, 2024
1 parent 7804ba7 commit 73306f1
Showing 1 changed file with 24 additions and 42 deletions.
66 changes: 24 additions & 42 deletions cpp/src/strings/filling/fill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@
*/

#include <cudf/column/column_device_view.cuh>
#include <cudf/detail/valid_if.cuh>
#include <cudf/strings/detail/fill.hpp>
#include <cudf/strings/detail/strings_children.cuh>
#include <cudf/strings/detail/utilities.cuh>
#include <cudf/strings/detail/strings_column_factories.cuh>
#include <cudf/strings/string_view.cuh>
#include <cudf/strings/strings_column_view.hpp>
#include <cudf/utilities/error.hpp>
Expand All @@ -27,35 +25,34 @@
#include <rmm/resource_ref.hpp>

#include <thrust/iterator/counting_iterator.h>
#include <thrust/transform.h>

namespace cudf {
namespace strings {
namespace detail {
namespace {

struct fill_fn {
column_device_view const d_strings;
size_type const begin;
size_type const end;
string_view const d_value;
size_type* d_offsets{};
char* d_chars{};

__device__ string_view resolve_string_at(size_type idx) const
{
if ((begin <= idx) && (idx < end)) { return d_value; }
return d_strings.is_valid(idx) ? d_strings.element<string_view>(idx) : string_view{};
}
string_scalar_device_view const d_value;

__device__ void operator()(size_type idx) const
__device__ string_index_pair operator()(size_type idx) const
{
auto const d_str = resolve_string_at(idx);
if (!d_chars) {
d_offsets[idx] = d_str.size_bytes();
auto d_str = string_view();
if ((begin <= idx) && (idx < end)) {
if (!d_value.is_valid()) { return string_index_pair{nullptr, 0}; }
d_str = d_value.value();
} else {
copy_string(d_chars + d_offsets[idx], d_str);
if (d_strings.is_null(idx)) { return string_index_pair{nullptr, 0}; }
d_str = d_strings.element<string_view>(idx);
}
return !d_str.empty() ? string_index_pair{d_str.data(), d_str.size_bytes()}
: string_index_pair{"", 0};
}
};

} // namespace

std::unique_ptr<column> fill(strings_column_view const& input,
Expand All @@ -72,33 +69,18 @@ std::unique_ptr<column> fill(strings_column_view const& input,
CUDF_EXPECTS(begin <= end, "Parameters [begin,end) have invalid range values");
if (begin == end) { return std::make_unique<column>(input.parent(), stream, mr); }

auto strings_column = column_device_view::create(input.parent(), stream);
auto const d_strings = *strings_column;
auto const is_valid = value.is_valid(stream);

// create resulting null mask
auto [null_mask, null_count] = [begin, end, is_valid, d_strings, stream, mr] {
if (begin == 0 and end == d_strings.size() and is_valid) {
return std::pair(rmm::device_buffer{}, 0);
}
return cudf::detail::valid_if(
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(d_strings.size()),
[d_strings, begin, end, is_valid] __device__(size_type idx) {
return ((begin <= idx) && (idx < end)) ? is_valid : d_strings.is_valid(idx);
},
stream,
mr);
}();

auto const d_value = const_cast<string_scalar&>(value);
auto const d_str = is_valid ? d_value.value(stream) : string_view{};
auto fn = fill_fn{d_strings, begin, end, d_str};
auto const d_strings = column_device_view::create(input.parent(), stream);
auto const d_value = cudf::get_scalar_device_view(const_cast<string_scalar&>(value));

auto [offsets_column, chars] = make_strings_children(fn, strings_count, stream, mr);
auto fn = fill_fn{*d_strings, begin, end, d_value};
rmm::device_uvector<string_index_pair> indices(strings_count, stream);
thrust::transform(rmm::exec_policy_nosync(stream),
thrust::counting_iterator<size_type>(0),
thrust::counting_iterator<size_type>(strings_count),
indices.begin(),
fn);

return make_strings_column(
strings_count, std::move(offsets_column), chars.release(), null_count, std::move(null_mask));
return make_strings_column(indices.begin(), indices.end(), stream, mr);
}

} // namespace detail
Expand Down

0 comments on commit 73306f1

Please sign in to comment.