Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Large strings support for cudf::gather #15621

Merged
merged 1 commit into from
May 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions cpp/include/cudf/strings/detail/gather.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,19 @@
#include <cudf/column/column_device_view.cuh>
#include <cudf/column/column_factories.hpp>
#include <cudf/detail/offsets_iterator_factory.cuh>
#include <cudf/detail/sizes_to_offsets_iterator.cuh>
#include <cudf/detail/utilities/cuda.cuh>
#include <cudf/strings/detail/strings_children.cuh>
#include <cudf/strings/detail/utilities.hpp>
#include <cudf/strings/strings_column_view.hpp>
#include <cudf/utilities/span.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/exec_policy.hpp>
#include <rmm/resource_ref.hpp>

#include <cuda/functional>
#include <thrust/advance.h>
#include <thrust/binary_search.h>
#include <thrust/distance.h>
#include <thrust/execution_policy.h>
#include <thrust/functional.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>

namespace cudf {
Expand Down Expand Up @@ -226,7 +222,7 @@ rmm::device_uvector<char> gather_chars(StringIterator strings_begin,
MapIterator map_begin,
MapIterator map_end,
cudf::detail::input_offsetalator const offsets,
size_type chars_bytes,
int64_t chars_bytes,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
Expand All @@ -239,9 +235,9 @@ rmm::device_uvector<char> gather_chars(StringIterator strings_begin,
constexpr int warps_per_threadblock = 4;
// String parallel strategy will be used if average string length is above this threshold.
// Otherwise, char parallel strategy will be used.
constexpr size_type string_parallel_threshold = 32;
constexpr int64_t string_parallel_threshold = 32;

size_type average_string_length = chars_bytes / output_count;
int64_t const average_string_length = chars_bytes / output_count;

if (average_string_length > string_parallel_threshold) {
constexpr int max_threadblocks = 65536;
Expand Down Expand Up @@ -302,16 +298,16 @@ std::unique_ptr<cudf::column> gather(strings_column_view const& strings,
strings.is_empty() ? make_empty_column(type_id::INT32)->view() : strings.offsets(),
strings.offset());

auto offsets_itr = thrust::make_transform_iterator(
auto sizes_itr = thrust::make_transform_iterator(
begin,
cuda::proclaim_return_type<size_type>(
[d_strings = *d_strings, d_in_offsets] __device__(size_type idx) {
if (NullifyOutOfBounds && (idx < 0 || idx >= d_strings.size())) { return 0; }
if (not d_strings.is_valid(idx)) { return 0; }
return static_cast<size_type>(d_in_offsets[idx + 1] - d_in_offsets[idx]);
}));
auto [out_offsets_column, total_bytes] =
cudf::detail::make_offsets_child_column(offsets_itr, offsets_itr + output_count, stream, mr);
auto [out_offsets_column, total_bytes] = cudf::strings::detail::make_offsets_child_column(
sizes_itr, sizes_itr + output_count, stream, mr);

// build chars column
auto const offsets_view =
Expand Down
Loading