Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Free temp memory no longer needed in multibyte_split processing #16091

Merged
merged 29 commits into from
Jul 9, 2024
Merged
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
3e941bc
Add stream parameter to cudf::io::text::multibyte_split
davidwendt Jun 14, 2024
85f3d5e
Merge branch 'branch-24.08' into stream-mutlibyte-split
davidwendt Jun 24, 2024
925aa96
fix doxygen
davidwendt Jun 24, 2024
ed1f883
Merge branch 'branch-24.08' into stream-mutlibyte-split
davidwendt Jun 25, 2024
6e69dc3
Merge branch 'branch-24.08' into stream-mutlibyte-split
davidwendt Jun 25, 2024
055e6b4
Free temp memory no longer needed in multibyte_split processing
davidwendt Jun 25, 2024
8580a58
empty commit to trigger CI
davidwendt Jun 25, 2024
9056352
Merge branch 'stream-mutlibyte-split' into mbs-free-temp-mem
davidwendt Jun 25, 2024
f09feb8
Merge branch 'stream-mutlibyte-split' of github.com:davidwendt/cudf i…
davidwendt Jun 26, 2024
d19ad46
Merge branch 'branch-24.08' into mbs-free-temp-mem
davidwendt Jun 26, 2024
31fe8fb
Merge branch 'branch-24.08' into stream-mutlibyte-split
davidwendt Jun 26, 2024
6533c49
Merge branch 'stream-mutlibyte-split' into mbs-free-temp-mem
davidwendt Jun 26, 2024
f390836
Merge branch 'branch-24.08' into stream-mutlibyte-split
davidwendt Jun 27, 2024
3e19f09
Merge branch 'stream-mutlibyte-split' into mbs-free-temp-mem
davidwendt Jun 27, 2024
8f71bcf
Merge branch 'branch-24.08' into stream-mutlibyte-split
davidwendt Jun 27, 2024
c51028c
Merge branch 'branch-24.08' into stream-mutlibyte-split
davidwendt Jun 28, 2024
1e73e6a
add multibyte_split to io_readers doxygen group
davidwendt Jun 28, 2024
a770b46
add data-chunk and parse-options to io_readers
davidwendt Jun 28, 2024
dd2b09b
add io::text to conf.py
davidwendt Jun 28, 2024
9f88ad2
Merge branch 'stream-mutlibyte-split' of github.com:davidwendt/cudf i…
davidwendt Jun 28, 2024
da11adc
Merge branch 'branch-24.08' into stream-mutlibyte-split
davidwendt Jun 28, 2024
70576b6
Merge branch 'stream-mutlibyte-split' into mbs-free-temp-mem
davidwendt Jun 28, 2024
76ab655
fix merge conflict
davidwendt Jul 1, 2024
c6d9bd8
Merge branch 'mbs-free-temp-mem' of github.com:davidwendt/cudf into m…
davidwendt Jul 1, 2024
052ae9c
Merge branch 'branch-24.08' into mbs-free-temp-mem
davidwendt Jul 1, 2024
a289f0f
Merge branch 'branch-24.08' into mbs-free-temp-mem
davidwendt Jul 2, 2024
021b469
Merge branch 'branch-24.08' into mbs-free-temp-mem
davidwendt Jul 3, 2024
50aedd5
Merge branch 'branch-24.08' into mbs-free-temp-mem
davidwendt Jul 8, 2024
1c871a4
update namespace decls
davidwendt Jul 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
324 changes: 162 additions & 162 deletions cpp/src/io/text/multibyte_split.cu
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
#include <numeric>
#include <optional>

namespace cudf::io::text {
namespace detail {
namespace {

using cudf::io::text::detail::multistate;
Expand Down Expand Up @@ -299,11 +301,6 @@ CUDF_KERNEL __launch_bounds__(THREADS_PER_TILE) void byte_split_kernel(

} // namespace

namespace cudf {
namespace io {
namespace text {
namespace detail {

std::unique_ptr<cudf::column> multibyte_split(cudf::io::text::data_chunk_source const& source,
std::string const& delimiter,
byte_range_info byte_range,
Expand Down Expand Up @@ -336,173 +333,181 @@ std::unique_ptr<cudf::column> multibyte_split(cudf::io::text::data_chunk_source
CUDF_EXPECTS(delimiter.size() < multistate::max_segment_value,
"delimiter contains too many total tokens to produce a deterministic result.");

auto const concurrency = 2;

// must be at least 32 when using warp-reduce on partials
// must be at least 1 more than max possible concurrent tiles
// best when at least 32 more than max possible concurrent tiles, due to rolling `invalid`s
auto num_tile_states = std::max(32, TILES_PER_CHUNK * concurrency + 32);
auto tile_multistates =
scan_tile_state<multistate>(num_tile_states, stream, rmm::mr::get_current_device_resource());
auto tile_offsets =
scan_tile_state<output_offset>(num_tile_states, stream, rmm::mr::get_current_device_resource());

multibyte_split_init_kernel<<<TILES_PER_CHUNK,
THREADS_PER_TILE,
0,
stream.value()>>>( //
-TILES_PER_CHUNK,
TILES_PER_CHUNK,
tile_multistates,
tile_offsets,
cudf::io::text::detail::scan_tile_status::oob);

auto multistate_seed = multistate();
multistate_seed.enqueue(0, 0); // this represents the first state in the pattern.

// Seeding the tile state with an identity value allows the 0th tile to follow the same logic as
// the Nth tile, assuming it can look up an inclusive prefix. Without this seed, the 0th block
// would have to follow separate logic.
cudf::detail::device_single_thread(
[tm = scan_tile_state_view<multistate>(tile_multistates),
to = scan_tile_state_view<output_offset>(tile_offsets),
multistate_seed] __device__() mutable {
tm.set_inclusive_prefix(-1, multistate_seed);
to.set_inclusive_prefix(-1, 0);
},
stream);

auto reader = source.create_reader();
auto chunk_offset = std::max<byte_offset>(0, byte_range.offset() - delimiter.size());
auto const byte_range_end = byte_range.offset() + byte_range.size();
reader->skip_bytes(chunk_offset);
// amortize output chunk allocations over 8 worst-case outputs. This limits the overallocation
constexpr auto max_growth = 8;
output_builder<byte_offset> row_offset_storage(ITEMS_PER_CHUNK, max_growth, stream);
output_builder<char> char_storage(ITEMS_PER_CHUNK, max_growth, stream);

auto streams = cudf::detail::fork_streams(stream, concurrency);

cudaEvent_t last_launch_event;
CUDF_CUDA_TRY(cudaEventCreate(&last_launch_event));

auto& read_stream = streams[0];
auto& scan_stream = streams[1];
auto chunk = reader->get_next_chunk(ITEMS_PER_CHUNK, read_stream);
int64_t base_tile_idx = 0;
auto chunk_offset = std::max<byte_offset>(0, byte_range.offset() - delimiter.size());
std::optional<byte_offset> first_row_offset;
std::optional<byte_offset> last_row_offset;
bool found_last_offset = false;
if (byte_range.offset() == 0) { first_row_offset = 0; }
std::swap(read_stream, scan_stream);

while (chunk->size() > 0) {
// if we found the last delimiter, or didn't find delimiters inside the byte range at all: abort
if (last_row_offset.has_value() or
(not first_row_offset.has_value() and chunk_offset >= byte_range_end)) {
break;
}

auto tiles_in_launch =
cudf::util::div_rounding_up_safe(chunk->size(), static_cast<std::size_t>(ITEMS_PER_TILE));

auto row_offsets = row_offset_storage.next_output(scan_stream);
std::optional<byte_offset> last_row_offset;

// reset the next chunk of tile state
multibyte_split_init_kernel<<<tiles_in_launch,
auto [global_offsets, chars] = [&] {
// must be at least 32 when using warp-reduce on partials
// must be at least 1 more than max possible concurrent tiles
// best when at least 32 more than max possible concurrent tiles, due to rolling `invalid`s
auto const concurrency = 2;
auto num_tile_states = std::max(32, TILES_PER_CHUNK * concurrency + 32);
auto tile_multistates =
scan_tile_state<multistate>(num_tile_states, stream, rmm::mr::get_current_device_resource());
auto tile_offsets = scan_tile_state<output_offset>(
num_tile_states, stream, rmm::mr::get_current_device_resource());

multibyte_split_init_kernel<<<TILES_PER_CHUNK,
THREADS_PER_TILE,
0,
scan_stream.value()>>>( //
base_tile_idx,
tiles_in_launch,
stream.value()>>>( //
-TILES_PER_CHUNK,
TILES_PER_CHUNK,
tile_multistates,
tile_offsets);
tile_offsets,
cudf::io::text::detail::scan_tile_status::oob);

CUDF_CUDA_TRY(cudaStreamWaitEvent(scan_stream.value(), last_launch_event));
auto multistate_seed = multistate();
multistate_seed.enqueue(0, 0); // this represents the first state in the pattern.

if (delimiter.size() == 1) {
// the single-byte case allows for a much more efficient kernel, so we special-case it
byte_split_kernel<<<tiles_in_launch,
THREADS_PER_TILE,
0,
scan_stream.value()>>>( //
base_tile_idx,
chunk_offset,
row_offset_storage.size(),
tile_offsets,
delimiter[0],
*chunk,
row_offsets);
} else {
multibyte_split_kernel<<<tiles_in_launch,
THREADS_PER_TILE,
0,
scan_stream.value()>>>( //
// Seeding the tile state with an identity value allows the 0th tile to follow the same logic as
// the Nth tile, assuming it can look up an inclusive prefix. Without this seed, the 0th block
// would have to follow separate logic.
cudf::detail::device_single_thread(
[tm = scan_tile_state_view<multistate>(tile_multistates),
to = scan_tile_state_view<output_offset>(tile_offsets),
multistate_seed] __device__() mutable {
tm.set_inclusive_prefix(-1, multistate_seed);
to.set_inclusive_prefix(-1, 0);
},
stream);

auto reader = source.create_reader();
auto const byte_range_end = byte_range.offset() + byte_range.size();
reader->skip_bytes(chunk_offset);
// amortize output chunk allocations over 8 worst-case outputs. This limits the overallocation
constexpr auto max_growth = 8;
output_builder<byte_offset> row_offset_storage(ITEMS_PER_CHUNK, max_growth, stream);
output_builder<char> char_storage(ITEMS_PER_CHUNK, max_growth, stream);

auto streams = cudf::detail::fork_streams(stream, concurrency);

cudaEvent_t last_launch_event;
CUDF_CUDA_TRY(cudaEventCreate(&last_launch_event));

auto& read_stream = streams[0];
auto& scan_stream = streams[1];
auto chunk = reader->get_next_chunk(ITEMS_PER_CHUNK, read_stream);
int64_t base_tile_idx = 0;
bool found_last_offset = false;
std::swap(read_stream, scan_stream);

while (chunk->size() > 0) {
// if we found the last delimiter, or didn't find delimiters inside the byte range at all:
// abort
if (last_row_offset.has_value() or
(not first_row_offset.has_value() and chunk_offset >= byte_range_end)) {
break;
}

auto tiles_in_launch =
cudf::util::div_rounding_up_safe(chunk->size(), static_cast<std::size_t>(ITEMS_PER_TILE));

auto row_offsets = row_offset_storage.next_output(scan_stream);

// reset the next chunk of tile state
multibyte_split_init_kernel<<<tiles_in_launch,
THREADS_PER_TILE,
0,
scan_stream.value()>>>( //
base_tile_idx,
chunk_offset,
row_offset_storage.size(),
tiles_in_launch,
tile_multistates,
tile_offsets,
{device_delim.data(), static_cast<std::size_t>(device_delim.size())},
*chunk,
row_offsets);
}
tile_offsets);

CUDF_CUDA_TRY(cudaStreamWaitEvent(scan_stream.value(), last_launch_event));

if (delimiter.size() == 1) {
// the single-byte case allows for a much more efficient kernel, so we special-case it
byte_split_kernel<<<tiles_in_launch,
THREADS_PER_TILE,
0,
scan_stream.value()>>>( //
base_tile_idx,
chunk_offset,
row_offset_storage.size(),
tile_offsets,
delimiter[0],
*chunk,
row_offsets);
} else {
multibyte_split_kernel<<<tiles_in_launch,
THREADS_PER_TILE,
0,
scan_stream.value()>>>( //
base_tile_idx,
chunk_offset,
row_offset_storage.size(),
tile_multistates,
tile_offsets,
{device_delim.data(), static_cast<std::size_t>(device_delim.size())},
*chunk,
row_offsets);
}

// load the next chunk
auto next_chunk = reader->get_next_chunk(ITEMS_PER_CHUNK, read_stream);
// while that is running, determine how many offsets we output (synchronizes)
auto const new_offsets = [&] {
auto const new_offsets_unclamped =
tile_offsets.get_inclusive_prefix(base_tile_idx + tiles_in_launch - 1, scan_stream) -
static_cast<output_offset>(row_offset_storage.size());
// if we are not in the last chunk, we can use all offsets
if (chunk_offset + static_cast<output_offset>(chunk->size()) < byte_range_end) {
return new_offsets_unclamped;
// load the next chunk
auto next_chunk = reader->get_next_chunk(ITEMS_PER_CHUNK, read_stream);
// while that is running, determine how many offsets we output (synchronizes)
auto const new_offsets = [&] {
auto const new_offsets_unclamped =
tile_offsets.get_inclusive_prefix(base_tile_idx + tiles_in_launch - 1, scan_stream) -
static_cast<output_offset>(row_offset_storage.size());
// if we are not in the last chunk, we can use all offsets
if (chunk_offset + static_cast<output_offset>(chunk->size()) < byte_range_end) {
return new_offsets_unclamped;
}
// if we are in the last chunk, we need to find the first out-of-bounds offset
auto const it = thrust::make_counting_iterator(output_offset{});
auto const end_loc =
*thrust::find_if(rmm::exec_policy_nosync(scan_stream),
it,
it + new_offsets_unclamped,
[row_offsets, byte_range_end] __device__(output_offset i) {
return row_offsets[i] >= byte_range_end;
});
// if we had no out-of-bounds offset, we copy all offsets
if (end_loc == new_offsets_unclamped) { return end_loc; }
// otherwise we copy only up to (including) the first out-of-bounds delimiter
found_last_offset = true;
return end_loc + 1;
}();
row_offset_storage.advance_output(new_offsets, scan_stream);
// determine if we found the first or last field offset for the byte range
if (new_offsets > 0 and not first_row_offset) {
first_row_offset = row_offset_storage.front_element(scan_stream);
}
if (found_last_offset) { last_row_offset = row_offset_storage.back_element(scan_stream); }
// copy over the characters we need, if we already encountered the first field delimiter
if (first_row_offset.has_value()) {
auto const begin =
chunk->data() + std::max<byte_offset>(0, *first_row_offset - chunk_offset);
auto const sentinel = last_row_offset.value_or(std::numeric_limits<byte_offset>::max());
auto const end =
chunk->data() + std::min<byte_offset>(sentinel - chunk_offset, chunk->size());
auto const output_size = end - begin;
auto char_output = char_storage.next_output(scan_stream);
thrust::copy(rmm::exec_policy_nosync(scan_stream), begin, end, char_output.begin());
char_storage.advance_output(output_size, scan_stream);
}
// if we are in the last chunk, we need to find the first out-of-bounds offset
auto const it = thrust::make_counting_iterator(output_offset{});
auto const end_loc =
*thrust::find_if(rmm::exec_policy_nosync(scan_stream),
it,
it + new_offsets_unclamped,
[row_offsets, byte_range_end] __device__(output_offset i) {
return row_offsets[i] >= byte_range_end;
});
// if we had no out-of-bounds offset, we copy all offsets
if (end_loc == new_offsets_unclamped) { return end_loc; }
// otherwise we copy only up to (including) the first out-of-bounds delimiter
found_last_offset = true;
return end_loc + 1;
}();
row_offset_storage.advance_output(new_offsets, scan_stream);
// determine if we found the first or last field offset for the byte range
if (new_offsets > 0 and not first_row_offset) {
first_row_offset = row_offset_storage.front_element(scan_stream);
}
if (found_last_offset) { last_row_offset = row_offset_storage.back_element(scan_stream); }
// copy over the characters we need, if we already encountered the first field delimiter
if (first_row_offset.has_value()) {
auto const begin = chunk->data() + std::max<byte_offset>(0, *first_row_offset - chunk_offset);
auto const sentinel = last_row_offset.value_or(std::numeric_limits<byte_offset>::max());
auto const end =
chunk->data() + std::min<byte_offset>(sentinel - chunk_offset, chunk->size());
auto const output_size = end - begin;
auto char_output = char_storage.next_output(scan_stream);
thrust::copy(rmm::exec_policy_nosync(scan_stream), begin, end, char_output.begin());
char_storage.advance_output(output_size, scan_stream);
}

CUDF_CUDA_TRY(cudaEventRecord(last_launch_event, scan_stream.value()));
CUDF_CUDA_TRY(cudaEventRecord(last_launch_event, scan_stream.value()));

std::swap(read_stream, scan_stream);
base_tile_idx += tiles_in_launch;
chunk_offset += chunk->size();
chunk = std::move(next_chunk);
}
std::swap(read_stream, scan_stream);
base_tile_idx += tiles_in_launch;
chunk_offset += chunk->size();
chunk = std::move(next_chunk);
}

CUDF_CUDA_TRY(cudaEventDestroy(last_launch_event));

CUDF_CUDA_TRY(cudaEventDestroy(last_launch_event));
cudf::detail::join_streams(streams, stream);

cudf::detail::join_streams(streams, stream);
auto chars = char_storage.gather(stream, mr);
auto global_offsets = row_offset_storage.gather(stream, mr);
return std::pair{std::move(global_offsets), std::move(chars)};
}();

// if the input was empty, we didn't find a delimiter at all,
// or the first delimiter was also the last: empty output
Expand All @@ -511,9 +516,6 @@ std::unique_ptr<cudf::column> multibyte_split(cudf::io::text::data_chunk_source
return make_empty_column(type_id::STRING);
}

auto chars = char_storage.gather(stream, mr);
auto global_offsets = row_offset_storage.gather(stream, mr);

// insert an offset at the beginning if we started at the beginning of the input
bool const insert_begin = first_row_offset.value_or(0) == 0;
// insert an offset at the end if we have not terminated the last row
Expand Down Expand Up @@ -591,6 +593,4 @@ std::unique_ptr<cudf::column> multibyte_split(cudf::io::text::data_chunk_source
return result;
}

} // namespace text
} // namespace io
} // namespace cudf
} // namespace cudf::io::text
Loading