Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use offsetalator in cudf::strings::replace functions #14824

Merged
merged 44 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
064c824
Use offsetalator in strings:replace functions
davidwendt Jan 22, 2024
298766e
Merge branch 'branch-24.04' into replace-offsetalator2
davidwendt Jan 23, 2024
db54d67
Merge branch 'branch-24.04' into replace-offsetalator2
davidwendt Jan 24, 2024
f7b5da3
Merge branch 'branch-24.04' into replace-offsetalator2
davidwendt Jan 26, 2024
2ba1511
remove errant comment
davidwendt Jan 26, 2024
b755c0d
Merge branch 'branch-24.04' into replace-offsetalator2
davidwendt Jan 29, 2024
38c1f23
fix offsets call
davidwendt Jan 29, 2024
e3d7d60
Merge branch 'branch-24.04' into replace-offsetalator2
davidwendt Jan 29, 2024
9b9042c
Merge branch 'branch-24.04' into replace-offsetalator2
davidwendt Jan 31, 2024
5541317
fix char-parallel output offsets
davidwendt Jan 31, 2024
dead3ed
Merge branch 'branch-24.04' into replace-offsetalator2
davidwendt Feb 2, 2024
038e0d5
revert some int64 changes in favor of more offsetalators
davidwendt Feb 2, 2024
d12eb3f
fix copy_if_safe; rework count-if call
davidwendt Feb 5, 2024
47143c2
Merge branch 'branch-24.04' into replace-offsetalator2
davidwendt Feb 5, 2024
8bde743
add create_offsets_child_column utility
davidwendt Feb 7, 2024
0f85b60
split up and rewrite replace.cu
davidwendt Feb 9, 2024
56afcf4
fix merge conflict
davidwendt Feb 14, 2024
01ea163
rework detail interface
davidwendt Feb 14, 2024
2db71f6
Merge branch 'branch-24.04' into replace-offsetalator2
davidwendt Feb 14, 2024
6ef2490
Merge branch 'branch-24.04' into replace-offsetalator2
davidwendt Feb 15, 2024
bf9e0ef
fix merge conflicts
davidwendt Feb 22, 2024
f8bfea4
remove temp source file
davidwendt Feb 22, 2024
ac8bda4
use create_offsets_from_positions utility
davidwendt Feb 22, 2024
f9853f0
Merge branch 'branch-24.04' into replace-offsetalator2
davidwendt Feb 23, 2024
b53b6e6
Merge branch 'branch-24.04' into replace-offsetalator2
davidwendt Feb 25, 2024
145d6bf
Merge branch 'branch-24.04' into replace-offsetalator2
davidwendt Feb 27, 2024
77d735d
Merge branch 'branch-24.04' into replace-offsetalator2
davidwendt Feb 28, 2024
1f8fe47
Merge branch 'branch-24.04' into replace-offsetalator2
davidwendt Feb 29, 2024
a9d4957
remove unused utility function
davidwendt Feb 29, 2024
f26c68a
Merge branch 'branch-24.04' into replace-offsetalator2
davidwendt Feb 29, 2024
cf015ee
Merge branch 'branch-24.04' into replace-offsetalator2
davidwendt Feb 29, 2024
9279164
fix merge conflicts
davidwendt Mar 2, 2024
75d1381
Merge branch 'branch-24.04' into replace-offsetalator2
davidwendt Mar 4, 2024
3990a03
Merge branch 'branch-24.04' into replace-offsetalator2
davidwendt Mar 5, 2024
7117738
Merge branch 'branch-24.04' into replace-offsetalator2
davidwendt Mar 6, 2024
2317129
fix merge conflicts
davidwendt Mar 6, 2024
448f6dd
Merge branch 'branch-24.04' into replace-offsetalator2
davidwendt Mar 6, 2024
98c99f8
Merge branch 'branch-24.04' into replace-offsetalator2
davidwendt Mar 8, 2024
c554d64
Merge branch 'branch-24.04' into replace-offsetalator2
davidwendt Mar 11, 2024
a538b2c
Merge branch 'branch-24.04' into replace-offsetalator2
davidwendt Mar 12, 2024
c8bf50a
Merge branch 'branch-24.04' into replace-offsetalator2
davidwendt Mar 13, 2024
af75bb1
Merge branch 'branch-24.04' into replace-offsetalator2
davidwendt Mar 14, 2024
dda2525
fix comments, int32, brackets, exec-policy
davidwendt Mar 14, 2024
5ca01f8
Merge branch 'branch-24.06' into replace-offsetalator2
davidwendt Mar 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 119 additions & 113 deletions cpp/src/strings/replace/multi.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
* limitations under the License.
*/

#include "strings/split/split.cuh"

#include <cudf/column/column_device_view.cuh>
#include <cudf/column/column_factories.hpp>
#include <cudf/detail/get_value.cuh>
#include <cudf/detail/null_mask.hpp>
#include <cudf/detail/nvtx/ranges.hpp>
#include <cudf/detail/utilities/algorithm.cuh>
#include <cudf/detail/utilities/cuda.cuh>
#include <cudf/strings/detail/char_tables.hpp>
#include <cudf/strings/detail/replace.hpp>
#include <cudf/strings/detail/strings_children.cuh>
#include <cudf/strings/detail/strings_column_factories.cuh>
Expand All @@ -42,6 +43,7 @@
#include <thrust/execution_policy.h>
#include <thrust/for_each.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/zip_iterator.h>
#include <thrust/optional.h>
#include <thrust/scan.h>
#include <thrust/transform.h>
Expand All @@ -67,20 +69,14 @@ constexpr size_type AVG_CHAR_BYTES_THRESHOLD = 256;
* @brief Type used for holding the target position (first) and the
* target index (second).
*/
using target_pair = thrust::pair<size_type, size_type>;
using target_pair = thrust::tuple<int64_t, size_type>;

/**
* @brief Helper functions for performing character-parallel replace
*/
struct replace_multi_parallel_fn {
__device__ char const* get_base_ptr() const { return d_strings.head<char>(); }

__device__ size_type const* get_offsets_ptr() const
{
return d_strings.child(strings_column_view::offsets_column_index).data<size_type>() +
d_strings.offset();
}

__device__ string_view const get_string(size_type idx) const
{
return d_strings.element<string_view>(idx);
Expand All @@ -100,11 +96,12 @@ struct replace_multi_parallel_fn {
* @param idx Index of the byte position in the chars column
* @param chars_bytes Number of bytes in the chars column
*/
__device__ thrust::optional<size_type> has_target(size_type idx, size_type chars_bytes) const
__device__ size_type target_index(int64_t idx, int64_t chars_bytes) const
{
auto const d_offsets = get_offsets_ptr();
auto const d_offsets = d_strings_offsets;
auto const d_chars = get_base_ptr() + d_offsets[0] + idx;
size_type str_idx = -1;
string_view d_str{};
for (std::size_t t = 0; t < d_targets.size(); ++t) {
auto const d_tgt = d_targets[t];
if (!d_tgt.empty() && (idx + d_tgt.size_bytes() <= chars_bytes) &&
Expand All @@ -113,12 +110,24 @@ struct replace_multi_parallel_fn {
auto const idx_itr =
thrust::upper_bound(thrust::seq, d_offsets, d_offsets + d_strings.size(), idx);
str_idx = thrust::distance(d_offsets, idx_itr) - 1;
d_str = get_string(str_idx - d_offsets[0]);
}
auto const d_str = get_string(str_idx - d_offsets[0]);
if ((d_chars + d_tgt.size_bytes()) <= (d_str.data() + d_str.size_bytes())) { return t; }
}
}
return thrust::nullopt;
return -1;
}

__device__ bool has_target(int64_t idx, int64_t chars_bytes) const
{
auto const d_chars = get_base_ptr() + d_strings_offsets[0] + idx;
for (auto& d_tgt : d_targets) {
if (!d_tgt.empty() && (idx + d_tgt.size_bytes() <= chars_bytes) &&
(d_tgt.compare(d_chars, d_tgt.size_bytes()) == 0)) {
return true;
}
}
return false;
}

/**
Expand All @@ -133,28 +142,32 @@ struct replace_multi_parallel_fn {
* @return Number of substrings resulting from the replace operations on this row
*/
__device__ size_type count_strings(size_type idx,
target_pair const* d_positions,
size_type const* d_targets_offsets) const
int64_t const* d_positions,
size_type const* d_indices,
cudf::detail::input_offsetalator d_targets_offsets) const
{
if (!is_valid(idx)) { return 0; }

auto const d_str = get_string(idx);
auto const d_str_end = d_str.data() + d_str.size_bytes();
auto const base_ptr = get_base_ptr();
auto const targets_positions = cudf::device_span<target_pair const>(
d_positions + d_targets_offsets[idx], d_targets_offsets[idx + 1] - d_targets_offsets[idx]);
auto const d_str = get_string(idx);
auto const d_str_end = d_str.data() + d_str.size_bytes();
auto const base_ptr = get_base_ptr();

auto const target_offset = d_targets_offsets[idx];
auto const targets_size = static_cast<size_type>(d_targets_offsets[idx + 1] - target_offset);
auto const positions = d_positions + target_offset;
auto const indices = d_indices + target_offset;

size_type count = 1; // always at least one string
auto str_ptr = d_str.data();
for (auto d_pair : targets_positions) {
auto const d_pos = d_pair.first;
auto const d_tgt = d_targets[d_pair.second];
auto const tgt_ptr = base_ptr + d_pos;
for (std::size_t i = 0; i < targets_size; ++i) {
auto const tgt_idx = indices[i];
auto const d_tgt = d_targets[tgt_idx];
auto const tgt_ptr = base_ptr + positions[i];
if (str_ptr <= tgt_ptr && tgt_ptr < d_str_end) {
auto const keep_size = static_cast<size_type>(thrust::distance(str_ptr, tgt_ptr));
if (keep_size > 0) { count++; } // don't bother counting empty strings

auto const d_repl = get_replacement_string(d_pair.second);
auto const d_repl = get_replacement_string(tgt_idx);
if (!d_repl.empty()) { count++; }

str_ptr += keep_size + d_tgt.size_bytes();
Expand Down Expand Up @@ -182,9 +195,10 @@ struct replace_multi_parallel_fn {
* @return The size in bytes of the output string for this row
*/
__device__ size_type get_strings(size_type idx,
size_type const* d_offsets,
target_pair const* d_positions,
size_type const* d_targets_offsets,
cudf::detail::input_offsetalator const d_offsets,
int64_t const* d_positions,
size_type const* d_indices,
cudf::detail::input_offsetalator d_targets_offsets,
string_index_pair* d_all_strings) const
{
if (!is_valid(idx)) { return 0; }
Expand All @@ -194,22 +208,24 @@ struct replace_multi_parallel_fn {
auto const d_str_end = d_str.data() + d_str.size_bytes();
auto const base_ptr = get_base_ptr();

auto const targets_positions = cudf::device_span<target_pair const>(
d_positions + d_targets_offsets[idx], d_targets_offsets[idx + 1] - d_targets_offsets[idx]);
auto const target_offset = d_targets_offsets[idx];
auto const targets_size = static_cast<size_type>(d_targets_offsets[idx + 1] - target_offset);
auto const positions = d_positions + target_offset;
auto const indices = d_indices + target_offset;

size_type output_idx = 0;
size_type output_size = 0;
auto str_ptr = d_str.data();
for (auto d_pair : targets_positions) {
auto const d_pos = d_pair.first;
auto const d_tgt = d_targets[d_pair.second];
auto const tgt_ptr = base_ptr + d_pos;
for (std::size_t i = 0; i < targets_size; ++i) {
auto const tgt_idx = indices[i];
auto const d_tgt = d_targets[tgt_idx];
auto const tgt_ptr = base_ptr + positions[i];
if (str_ptr <= tgt_ptr && tgt_ptr < d_str_end) {
auto const keep_size = static_cast<size_type>(thrust::distance(str_ptr, tgt_ptr));
if (keep_size > 0) { d_output[output_idx++] = string_index_pair{str_ptr, keep_size}; }
output_size += keep_size;

auto const d_repl = get_replacement_string(d_pair.second);
auto const d_repl = get_replacement_string(tgt_idx);
if (!d_repl.empty()) {
d_output[output_idx++] = string_index_pair{d_repl.data(), d_repl.size_bytes()};
}
Expand All @@ -228,14 +244,19 @@ struct replace_multi_parallel_fn {
}

replace_multi_parallel_fn(column_device_view const& d_strings,
cudf::detail::input_offsetalator d_strings_offsets,
device_span<string_view const> d_targets,
device_span<string_view const> d_replacements)
: d_strings(d_strings), d_targets{d_targets}, d_replacements{d_replacements}
: d_strings(d_strings),
d_strings_offsets(d_strings_offsets),
d_targets{d_targets},
d_replacements{d_replacements}
{
}

protected:
column_device_view d_strings;
cudf::detail::input_offsetalator d_strings_offsets;
device_span<string_view const> d_targets;
device_span<string_view const> d_replacements;
};
Expand All @@ -247,17 +268,16 @@ struct replace_multi_parallel_fn {
* (this happens sometimes when passing device lambdas to thrust algorithms)
*/
struct pair_generator {
__device__ target_pair operator()(int idx) const
__device__ target_pair operator()(int64_t idx) const
{
auto pos = fn.has_target(idx, chars_bytes);
return target_pair{idx, pos.value_or(-1)};
return thrust::make_tuple(idx, fn.target_index(idx, chars_bytes));
}
replace_multi_parallel_fn fn;
size_type chars_bytes;
int64_t chars_bytes;
};

struct copy_if_fn {
__device__ bool operator()(target_pair pos) { return pos.second >= 0; }
__device__ bool operator()(target_pair pos) { return thrust::get<1>(pos) >= 0; }
};

std::unique_ptr<column> replace_character_parallel(strings_column_view const& input,
Expand All @@ -270,92 +290,73 @@ std::unique_ptr<column> replace_character_parallel(strings_column_view const& in

auto const strings_count = input.size();
auto const chars_bytes =
cudf::detail::get_value<size_type>(input.offsets(), input.offset() + strings_count, stream) -
cudf::detail::get_value<size_type>(input.offsets(), input.offset(), stream);
get_offset_value(input.offsets(), input.offset() + strings_count, stream) -
get_offset_value(input.offsets(), input.offset(), stream);

auto d_targets =
create_string_vector_from_column(targets, stream, rmm::mr::get_current_device_resource());
auto d_replacements =
create_string_vector_from_column(repls, stream, rmm::mr::get_current_device_resource());

replace_multi_parallel_fn fn{*d_strings, d_targets, d_replacements};
replace_multi_parallel_fn fn{
*d_strings,
cudf::detail::offsetalator_factory::make_input_iterator(input.offsets(), input.offset()),
d_targets,
d_replacements,
};

// Count the number of targets in the entire column.
// Note this may over-count in the case where a target spans adjacent strings.
auto target_count = thrust::count_if(
rmm::exec_policy(stream),
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
thrust::make_counting_iterator<int64_t>(0),
thrust::make_counting_iterator<int64_t>(chars_bytes),
[fn, chars_bytes] __device__(int64_t idx) { return fn.has_target(idx, chars_bytes); });

// count the number of targets in the entire column
auto const target_count = thrust::count_if(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(chars_bytes),
[fn, chars_bytes] __device__(size_type idx) {
return fn.has_target(idx, chars_bytes).has_value();
});
// Create a vector of every target position in the chars column.
// These may include overlapping targets which will be resolved later.
auto targets_positions = rmm::device_uvector<target_pair>(target_count, stream);
// These may also include overlapping targets which will be resolved later.
auto targets_positions = rmm::device_uvector<int64_t>(target_count, stream);
auto targets_indices = rmm::device_uvector<size_type>(target_count, stream);

// cudf::detail::make_counting_transform_iterator hardcodes size_type
auto const copy_itr = thrust::make_transform_iterator(thrust::counting_iterator<int64_t>(0),
pair_generator{fn, chars_bytes});
auto const out_itr = thrust::make_zip_iterator(
thrust::make_tuple(targets_positions.begin(), targets_indices.begin()));
auto const copy_end =
cudf::detail::copy_if_safe(copy_itr, copy_itr + chars_bytes, out_itr, copy_if_fn{}, stream);

// adjust target count since the copy-if may have eliminated some invalid targets
target_count = std::min(static_cast<int64_t>(std::distance(out_itr, copy_end)), target_count);
targets_positions.resize(target_count, stream);
targets_indices.resize(target_count, stream);
auto d_positions = targets_positions.data();

auto const copy_itr =
cudf::detail::make_counting_transform_iterator(0, pair_generator{fn, chars_bytes});
auto const copy_end = thrust::copy_if(
rmm::exec_policy(stream), copy_itr, copy_itr + chars_bytes, d_positions, copy_if_fn{});
auto d_targets_indices = targets_indices.data();

// create a vector of offsets to each string's set of target positions
auto const targets_offsets = [&] {
auto string_indices = rmm::device_uvector<size_type>(target_count, stream);

auto const pos_itr = cudf::detail::make_counting_transform_iterator(
0, cuda::proclaim_return_type<int64_t>([d_positions] __device__(auto idx) -> int64_t {
return d_positions[idx].first;
}));
auto pos_count = std::distance(d_positions, copy_end);

auto begin =
cudf::detail::offsetalator_factory::make_input_iterator(input.offsets(), input.offset());
auto end = begin + input.offsets().size();
thrust::upper_bound(
rmm::exec_policy(stream), begin, end, pos_itr, pos_itr + pos_count, string_indices.begin());

// compute offsets per string
auto targets_offsets = rmm::device_uvector<size_type>(strings_count + 1, stream);
auto d_targets_offsets = targets_offsets.data();

// memset to zero-out the target counts for any null-entries or strings with no targets
thrust::uninitialized_fill(
rmm::exec_policy(stream), targets_offsets.begin(), targets_offsets.end(), 0);

// next, count the number of targets per string
auto d_string_indices = string_indices.data();
thrust::for_each_n(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
target_count,
[d_string_indices, d_targets_offsets] __device__(size_type idx) {
auto const str_idx = d_string_indices[idx] - 1;
atomicAdd(d_targets_offsets + str_idx, 1);
});
// finally, convert the counts into offsets
thrust::exclusive_scan(rmm::exec_policy(stream),
targets_offsets.begin(),
targets_offsets.end(),
targets_offsets.begin());
return targets_offsets;
}();
auto const d_targets_offsets = targets_offsets.data();
auto const targets_offsets = create_offsets_from_positions(
input, targets_positions, stream, rmm::mr::get_current_device_resource());
auto const d_targets_offsets =
cudf::detail::offsetalator_factory::make_input_iterator(targets_offsets->view());

// compute the number of string segments produced by replace in each string
auto counts = rmm::device_uvector<size_type>(strings_count, stream);
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(strings_count),
thrust::counting_iterator<size_type>(0),
thrust::counting_iterator<size_type>(strings_count),
counts.begin(),
cuda::proclaim_return_type<size_type>(
[fn, d_positions, d_targets_offsets] __device__(size_type idx) -> size_type {
return fn.count_strings(idx, d_positions, d_targets_offsets);
[fn, d_positions, d_targets_indices, d_targets_offsets] __device__(
size_type idx) -> size_type {
return fn.count_strings(
idx, d_positions, d_targets_indices, d_targets_offsets);
}));

// create offsets from the counts
auto offsets =
std::get<0>(cudf::detail::make_offsets_child_column(counts.begin(), counts.end(), stream, mr));
auto const total_strings =
cudf::detail::get_value<size_type>(offsets->view(), strings_count, stream);
auto const d_strings_offsets = offsets->view().data<size_type>();
auto [offsets, total_strings] =
cudf::detail::make_offsets_child_column(counts.begin(), counts.end(), stream, mr);
auto const d_strings_offsets =
cudf::detail::offsetalator_factory::make_input_iterator(offsets->view());

// build a vector of all the positions for all the strings
auto indices = rmm::device_uvector<string_index_pair>(total_strings, stream);
Expand All @@ -365,19 +366,24 @@ std::unique_ptr<column> replace_character_parallel(strings_column_view const& in
rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
strings_count,
[fn, d_strings_offsets, d_positions, d_targets_offsets, d_indices, d_sizes] __device__(
size_type idx) {
d_sizes[idx] =
fn.get_strings(idx, d_strings_offsets, d_positions, d_targets_offsets, d_indices);
[fn,
d_strings_offsets,
d_positions,
d_targets_indices,
d_targets_offsets,
d_indices,
d_sizes] __device__(size_type idx) {
d_sizes[idx] = fn.get_strings(
idx, d_strings_offsets, d_positions, d_targets_indices, d_targets_offsets, d_indices);
});

// use this utility to gather the string parts into a contiguous chars column
auto chars = make_strings_column(indices.begin(), indices.end(), stream, mr);
auto chars_data = chars->release().data;

// create offsets from the sizes
offsets =
std::get<0>(cudf::detail::make_offsets_child_column(counts.begin(), counts.end(), stream, mr));
offsets = std::get<0>(
cudf::strings::detail::make_offsets_child_column(counts.begin(), counts.end(), stream, mr));

// build the strings columns from the chars and offsets
return make_strings_column(strings_count,
Expand Down
Loading
Loading