Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Large strings support for cudf::fill #15555

Merged
merged 2 commits into from
Apr 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 24 additions & 42 deletions cpp/src/strings/filling/fill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@
*/

#include <cudf/column/column_device_view.cuh>
#include <cudf/detail/valid_if.cuh>
#include <cudf/strings/detail/fill.hpp>
#include <cudf/strings/detail/strings_children.cuh>
#include <cudf/strings/detail/utilities.cuh>
#include <cudf/strings/detail/strings_column_factories.cuh>
#include <cudf/strings/string_view.cuh>
#include <cudf/strings/strings_column_view.hpp>
#include <cudf/utilities/error.hpp>
Expand All @@ -27,35 +25,34 @@
#include <rmm/resource_ref.hpp>

#include <thrust/iterator/counting_iterator.h>
#include <thrust/transform.h>

namespace cudf {
namespace strings {
namespace detail {
namespace {

struct fill_fn {
column_device_view const d_strings;
size_type const begin;
size_type const end;
string_view const d_value;
size_type* d_offsets{};
char* d_chars{};

__device__ string_view resolve_string_at(size_type idx) const
{
if ((begin <= idx) && (idx < end)) { return d_value; }
return d_strings.is_valid(idx) ? d_strings.element<string_view>(idx) : string_view{};
}
string_scalar_device_view const d_value;

__device__ void operator()(size_type idx) const
__device__ string_index_pair operator()(size_type idx) const
{
auto const d_str = resolve_string_at(idx);
if (!d_chars) {
d_offsets[idx] = d_str.size_bytes();
auto d_str = string_view();
if ((begin <= idx) && (idx < end)) {
if (!d_value.is_valid()) { return string_index_pair{nullptr, 0}; }
mhaseeb123 marked this conversation as resolved.
Show resolved Hide resolved
d_str = d_value.value();
} else {
copy_string(d_chars + d_offsets[idx], d_str);
if (d_strings.is_null(idx)) { return string_index_pair{nullptr, 0}; }
d_str = d_strings.element<string_view>(idx);
}
return !d_str.empty() ? string_index_pair{d_str.data(), d_str.size_bytes()}
: string_index_pair{"", 0};
}
};

} // namespace

std::unique_ptr<column> fill(strings_column_view const& input,
Expand All @@ -72,33 +69,18 @@ std::unique_ptr<column> fill(strings_column_view const& input,
CUDF_EXPECTS(begin <= end, "Parameters [begin,end) have invalid range values");
if (begin == end) { return std::make_unique<column>(input.parent(), stream, mr); }

auto strings_column = column_device_view::create(input.parent(), stream);
auto const d_strings = *strings_column;
auto const is_valid = value.is_valid(stream);

// create resulting null mask
auto [null_mask, null_count] = [begin, end, is_valid, d_strings, stream, mr] {
if (begin == 0 and end == d_strings.size() and is_valid) {
return std::pair(rmm::device_buffer{}, 0);
}
return cudf::detail::valid_if(
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(d_strings.size()),
[d_strings, begin, end, is_valid] __device__(size_type idx) {
return ((begin <= idx) && (idx < end)) ? is_valid : d_strings.is_valid(idx);
},
stream,
mr);
}();

auto const d_value = const_cast<string_scalar&>(value);
auto const d_str = is_valid ? d_value.value(stream) : string_view{};
auto fn = fill_fn{d_strings, begin, end, d_str};
auto const d_strings = column_device_view::create(input.parent(), stream);
auto const d_value = cudf::get_scalar_device_view(const_cast<string_scalar&>(value));

auto [offsets_column, chars] = make_strings_children(fn, strings_count, stream, mr);
auto fn = fill_fn{*d_strings, begin, end, d_value};
rmm::device_uvector<string_index_pair> indices(strings_count, stream);
thrust::transform(rmm::exec_policy_nosync(stream),
thrust::counting_iterator<size_type>(0),
thrust::counting_iterator<size_type>(strings_count),
indices.begin(),
fn);

return make_strings_column(
strings_count, std::move(offsets_column), chars.release(), null_count, std::move(null_mask));
return make_strings_column(indices.begin(), indices.end(), stream, mr);
}

} // namespace detail
Expand Down
Loading