diff --git a/cpp/src/io/comp/uncomp.cpp b/cpp/src/io/comp/uncomp.cpp index 1af45b41d8e..d4d6f46b99a 100644 --- a/cpp/src/io/comp/uncomp.cpp +++ b/cpp/src/io/comp/uncomp.cpp @@ -538,8 +538,10 @@ size_t decompress_zstd(host_span src, CUDF_EXPECTS(hd_stats[0].status == compression_status::SUCCESS, "ZSTD decompression failed"); // Copy temporary output to `dst` - CUDF_CUDA_TRY(cudaMemcpyAsync( - dst.data(), d_dst.data(), hd_stats[0].bytes_written, cudaMemcpyDefault, stream.value())); + cudf::detail::cuda_memcpy_async( + dst.subspan(0, hd_stats[0].bytes_written), + device_span{d_dst.data(), hd_stats[0].bytes_written}, + stream); return hd_stats[0].bytes_written; } diff --git a/cpp/src/io/csv/reader_impl.cu b/cpp/src/io/csv/reader_impl.cu index 8c32fc85f78..72fca75c56b 100644 --- a/cpp/src/io/csv/reader_impl.cu +++ b/cpp/src/io/csv/reader_impl.cu @@ -21,6 +21,7 @@ #include "csv_common.hpp" #include "csv_gpu.hpp" +#include "cudf/detail/utilities/cuda_memcpy.hpp" #include "io/comp/io_uncomp.hpp" #include "io/utilities/column_buffer.hpp" #include "io/utilities/hostdevice_vector.hpp" @@ -275,11 +276,10 @@ std::pair, selected_rows_offsets> load_data_and_gather auto const read_offset = byte_range_offset + input_pos + previous_data_size; auto const read_size = target_pos - input_pos - previous_data_size; if (data.has_value()) { - CUDF_CUDA_TRY(cudaMemcpyAsync(d_data.data() + previous_data_size, - data->data() + read_offset, - target_pos - input_pos - previous_data_size, - cudaMemcpyDefault, - stream.value())); + cudf::detail::cuda_memcpy_async( + device_span{d_data.data() + previous_data_size, read_size}, + data->subspan(read_offset, read_size), + stream); } else { if (source->is_device_read_preferred(read_size)) { source->device_read(read_offset, @@ -288,12 +288,11 @@ std::pair, selected_rows_offsets> load_data_and_gather stream); } else { auto const buffer = source->host_read(read_offset, read_size); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_data.data() + previous_data_size, - buffer->data(), - buffer->size(), - cudaMemcpyDefault, - stream.value())); - stream.synchronize(); // To prevent buffer going out of scope before we copy the data. + // Use sync version to prevent buffer going out of scope before we copy the data. + cudf::detail::cuda_memcpy( + device_span{d_data.data() + previous_data_size, read_size}, + host_span{reinterpret_cast(buffer->data()), buffer->size()}, + stream); } } @@ -311,12 +310,10 @@ std::pair, selected_rows_offsets> load_data_and_gather range_end, skip_rows, stream); - CUDF_CUDA_TRY(cudaMemcpyAsync(row_ctx.host_ptr(), - row_ctx.device_ptr(), - num_blocks * sizeof(uint64_t), - cudaMemcpyDefault, - stream.value())); - stream.synchronize(); + + cudf::detail::cuda_memcpy(host_span{row_ctx}.subspan(0, num_blocks), + device_span{row_ctx}.subspan(0, num_blocks), + stream); // Sum up the rows in each character block, selecting the row count that // corresponds to the current input context. Also stores the now known input @@ -331,11 +328,9 @@ std::pair, selected_rows_offsets> load_data_and_gather // At least one row in range in this batch all_row_offsets.resize(total_rows - skip_rows, stream); - CUDF_CUDA_TRY(cudaMemcpyAsync(row_ctx.device_ptr(), - row_ctx.host_ptr(), - num_blocks * sizeof(uint64_t), - cudaMemcpyDefault, - stream.value())); + cudf::detail::cuda_memcpy_async(device_span{row_ctx}.subspan(0, num_blocks), + host_span{row_ctx}.subspan(0, num_blocks), + stream); // Pass 2: Output row offsets cudf::io::csv::gpu::gather_row_offsets(parse_opts.view(), @@ -352,12 +347,9 @@ std::pair, selected_rows_offsets> load_data_and_gather stream); // With byte range, we want to keep only one row out of the specified range if (range_end < data_size) { - CUDF_CUDA_TRY(cudaMemcpyAsync(row_ctx.host_ptr(), - row_ctx.device_ptr(), - num_blocks * sizeof(uint64_t), - cudaMemcpyDefault, - stream.value())); - stream.synchronize(); + cudf::detail::cuda_memcpy(host_span{row_ctx}.subspan(0, num_blocks), + device_span{row_ctx}.subspan(0, num_blocks), + stream); size_t rows_out_of_range = 0; for (uint32_t i = 0; i < num_blocks; i++) { @@ -401,12 +393,9 @@ std::pair, selected_rows_offsets> load_data_and_gather // Remove header rows and extract header auto const header_row_index = std::max(header_rows, 1) - 1; if (header_row_index + 1 < row_offsets.size()) { - CUDF_CUDA_TRY(cudaMemcpyAsync(row_ctx.host_ptr(), - row_offsets.data() + header_row_index, - 2 * sizeof(uint64_t), - cudaMemcpyDefault, - stream.value())); - stream.synchronize(); + cudf::detail::cuda_memcpy(host_span{row_ctx}.subspan(0, 2), + device_span{row_offsets.data() + header_row_index, 2}, + stream); auto const header_start = input_pos + row_ctx[0]; auto const header_end = input_pos + row_ctx[1]; diff --git a/cpp/src/io/csv/writer_impl.cu b/cpp/src/io/csv/writer_impl.cu index b84446b5f3e..2bbe05ced84 100644 --- a/cpp/src/io/csv/writer_impl.cu +++ b/cpp/src/io/csv/writer_impl.cu @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -405,13 +406,8 @@ void write_chunked(data_sink* out_sink, out_sink->device_write(ptr_all_bytes, total_num_bytes, stream); } else { // copy the bytes to host to write them out - thrust::host_vector h_bytes(total_num_bytes); - CUDF_CUDA_TRY(cudaMemcpyAsync(h_bytes.data(), - ptr_all_bytes, - total_num_bytes * sizeof(char), - cudaMemcpyDefault, - stream.value())); - stream.synchronize(); + auto const h_bytes = cudf::detail::make_host_vector_sync( + device_span{ptr_all_bytes, total_num_bytes}, stream); out_sink->host_write(h_bytes.data(), total_num_bytes); } diff --git a/cpp/src/io/json/host_tree_algorithms.cu b/cpp/src/io/json/host_tree_algorithms.cu index 7ee652e0239..d06338c6f69 100644 --- a/cpp/src/io/json/host_tree_algorithms.cu +++ b/cpp/src/io/json/host_tree_algorithms.cu @@ -198,17 +198,13 @@ NodeIndexT get_row_array_parent_col_id(device_span col_ids, bool is_enabled_lines, rmm::cuda_stream_view stream) { - NodeIndexT value = parent_node_sentinel; - if (!col_ids.empty()) { - auto const list_node_index = is_enabled_lines ? 0 : 1; - CUDF_CUDA_TRY(cudaMemcpyAsync(&value, - col_ids.data() + list_node_index, - sizeof(NodeIndexT), - cudaMemcpyDefault, - stream.value())); - stream.synchronize(); - } - return value; + if (col_ids.empty()) { return parent_node_sentinel; } + + auto const list_node_index = is_enabled_lines ? 0 : 1; + auto const value = cudf::detail::make_host_vector_sync( + device_span{col_ids.data() + list_node_index, 1}, stream); + + return value[0]; } /** * @brief Holds member data pointers of `d_json_column` @@ -818,11 +814,7 @@ std::pair, hashmap_of_device_columns> build_tree column_categories.cbegin(), expected_types.begin(), [](auto exp, auto cat) { return exp == NUM_NODE_CLASSES ? cat : exp; }); - cudaMemcpyAsync(d_column_tree.node_categories.begin(), - expected_types.data(), - expected_types.size() * sizeof(column_categories[0]), - cudaMemcpyDefault, - stream.value()); + cudf::detail::cuda_memcpy_async(d_column_tree.node_categories, expected_types, stream); return {is_pruned, columns}; } diff --git a/cpp/src/io/json/json_column.cu b/cpp/src/io/json/json_column.cu index 4584f71775f..7e4d975e431 100644 --- a/cpp/src/io/json/json_column.cu +++ b/cpp/src/io/json/json_column.cu @@ -513,16 +513,14 @@ table_with_metadata device_parse_nested_json(device_span d_input, #endif bool const is_array_of_arrays = [&]() { - std::array h_node_categories = {NC_ERR, NC_ERR}; - auto const size_to_copy = std::min(size_t{2}, gpu_tree.node_categories.size()); - CUDF_CUDA_TRY(cudaMemcpyAsync(h_node_categories.data(), - gpu_tree.node_categories.data(), - sizeof(node_t) * size_to_copy, - cudaMemcpyDefault, - stream.value())); - stream.synchronize(); + auto const size_to_copy = std::min(size_t{2}, gpu_tree.node_categories.size()); + if (size_to_copy == 0) return false; + auto const h_node_categories = cudf::detail::make_host_vector_sync( + device_span{gpu_tree.node_categories.data(), size_to_copy}, stream); + if (options.is_enabled_lines()) return h_node_categories[0] == NC_LIST; - return h_node_categories[0] == NC_LIST and h_node_categories[1] == NC_LIST; + return h_node_categories.size() >= 2 and h_node_categories[0] == NC_LIST and + h_node_categories[1] == NC_LIST; }(); auto [gpu_col_id, gpu_row_offsets] = diff --git a/cpp/src/io/json/json_tree.cu b/cpp/src/io/json/json_tree.cu index d949635c1cc..e2fe926ea19 100644 --- a/cpp/src/io/json/json_tree.cu +++ b/cpp/src/io/json/json_tree.cu @@ -264,16 +264,13 @@ tree_meta_t get_tree_representation(device_span tokens, error_count > 0) { auto const error_location = thrust::find(rmm::exec_policy(stream), tokens.begin(), tokens.end(), token_t::ErrorBegin); - SymbolOffsetT error_index; - CUDF_CUDA_TRY( - cudaMemcpyAsync(&error_index, - token_indices.data() + thrust::distance(tokens.begin(), error_location), - sizeof(SymbolOffsetT), - cudaMemcpyDefault, - stream.value())); - stream.synchronize(); + auto error_index = cudf::detail::make_host_vector_sync( + device_span{ + token_indices.data() + thrust::distance(tokens.begin(), error_location), 1}, + stream); + CUDF_FAIL("JSON Parser encountered an invalid format at location " + - std::to_string(error_index)); + std::to_string(error_index[0])); } auto const num_tokens = tokens.size(); diff --git a/cpp/src/io/json/read_json.cu b/cpp/src/io/json/read_json.cu index c424d2b3b62..8a740ae17ef 100644 --- a/cpp/src/io/json/read_json.cu +++ b/cpp/src/io/json/read_json.cu @@ -315,13 +315,12 @@ device_span ingest_raw_input(device_span buffer, // Reading to host because decompression of a single block is much faster on the CPU sources[0]->host_read(range_offset, remaining_bytes_to_read, hbuffer.data()); auto uncomp_data = decompress(compression, hbuffer); - CUDF_CUDA_TRY(cudaMemcpyAsync(buffer.data(), - reinterpret_cast(uncomp_data.data()), - uncomp_data.size() * sizeof(char), - cudaMemcpyHostToDevice, - stream.value())); - stream.synchronize(); - return buffer.first(uncomp_data.size()); + auto ret_buffer = buffer.first(uncomp_data.size()); + cudf::detail::cuda_memcpy( + ret_buffer, + host_span{reinterpret_cast(uncomp_data.data()), uncomp_data.size()}, + stream); + return ret_buffer; } table_with_metadata read_json(host_span> sources, diff --git a/cpp/src/io/orc/reader_impl_chunking.cu b/cpp/src/io/orc/reader_impl_chunking.cu index 9cc77e8e738..fcaee9c548e 100644 --- a/cpp/src/io/orc/reader_impl_chunking.cu +++ b/cpp/src/io/orc/reader_impl_chunking.cu @@ -516,10 +516,10 @@ void reader_impl::load_next_stripe_data(read_mode mode) _stream.synchronize(); stream_synchronized = true; } - device_read_tasks.push_back( - std::pair(source_ptr->device_read_async( - read_info.offset, read_info.length, dst_base + read_info.dst_pos, _stream), - read_info.length)); + device_read_tasks.emplace_back( + source_ptr->device_read_async( + read_info.offset, read_info.length, dst_base + read_info.dst_pos, _stream), + read_info.length); } else { auto buffer = source_ptr->host_read(read_info.offset, read_info.length); diff --git a/cpp/src/io/orc/writer_impl.cu b/cpp/src/io/orc/writer_impl.cu index 03020eb649f..d432deb8e79 100644 --- a/cpp/src/io/orc/writer_impl.cu +++ b/cpp/src/io/orc/writer_impl.cu @@ -19,6 +19,7 @@ * @brief cuDF-IO ORC writer class implementation */ +#include "cudf/detail/utilities/cuda_memcpy.hpp" #include "io/comp/nvcomp_adapter.hpp" #include "io/orc/orc_gpu.hpp" #include "io/statistics/column_statistics.cuh" @@ -1408,7 +1409,8 @@ encoded_footer_statistics finish_statistic_blobs(Footer const& footer, num_entries_seen += stripes_per_col; } - std::vector file_stats_merge(num_file_blobs); + auto file_stats_merge = + cudf::detail::make_host_vector(num_file_blobs, stream); for (auto i = 0u; i < num_file_blobs; ++i) { auto col_stats = &file_stats_merge[i]; col_stats->col_dtype = per_chunk_stats.col_types[i]; @@ -1418,11 +1420,10 @@ encoded_footer_statistics finish_statistic_blobs(Footer const& footer, } auto d_file_stats_merge = stats_merge.device_ptr(num_stripe_blobs); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_file_stats_merge, - file_stats_merge.data(), - num_file_blobs * sizeof(statistics_merge_group), - cudaMemcpyDefault, - stream.value())); + cudf::detail::cuda_memcpy_async( + device_span{stats_merge.device_ptr(num_stripe_blobs), num_file_blobs}, + file_stats_merge, + stream); auto file_stat_chunks = stat_chunks.data() + num_stripe_blobs; detail::merge_group_statistics( @@ -1573,7 +1574,7 @@ void write_index_stream(int32_t stripe_id, * @param[in] strm_desc Stream's descriptor * @param[in] enc_stream Chunk's streams * @param[in] compressed_data Compressed stream data - * @param[in,out] stream_out Temporary host output buffer + * @param[in,out] bounce_buffer Pinned memory bounce buffer for D2H data transfer * @param[in,out] stripe Stream's parent stripe * @param[in,out] streams List of all streams * @param[in] compression_kind The compression kind @@ -1584,7 +1585,7 @@ void write_index_stream(int32_t stripe_id, std::future write_data_stream(gpu::StripeStream const& strm_desc, gpu::encoder_chunk_streams const& enc_stream, uint8_t const* compressed_data, - uint8_t* stream_out, + host_span bounce_buffer, StripeInformation* stripe, orc_streams* streams, CompressionKind compression_kind, @@ -1604,11 +1605,10 @@ std::future write_data_stream(gpu::StripeStream const& strm_desc, if (out_sink->is_device_write_preferred(length)) { return out_sink->device_write_async(stream_in, length, stream); } else { - CUDF_CUDA_TRY( - cudaMemcpyAsync(stream_out, stream_in, length, cudaMemcpyDefault, stream.value())); - stream.synchronize(); + cudf::detail::cuda_memcpy( + bounce_buffer.subspan(0, length), device_span{stream_in, length}, stream); - out_sink->host_write(stream_out, length); + out_sink->host_write(bounce_buffer.data(), length); return std::async(std::launch::deferred, [] {}); } }(); @@ -2616,7 +2616,7 @@ void writer::impl::write_orc_data_to_sink(encoded_data const& enc_data, strm_desc, enc_data.streams[strm_desc.column_id][segmentation.stripes[stripe_id].first], compressed_data.data(), - bounce_buffer.data(), + bounce_buffer, &stripe, &streams, _compression_kind, diff --git a/cpp/src/io/parquet/predicate_pushdown.cpp b/cpp/src/io/parquet/predicate_pushdown.cpp index f0a0bc0b51b..32e922b04bb 100644 --- a/cpp/src/io/parquet/predicate_pushdown.cpp +++ b/cpp/src/io/parquet/predicate_pushdown.cpp @@ -454,15 +454,18 @@ std::optional>> aggregate_reader_metadata::fi CUDF_EXPECTS(predicate.type().id() == cudf::type_id::BOOL8, "Filter expression must return a boolean column"); - auto num_bitmasks = num_bitmask_words(predicate.size()); - std::vector host_bitmask(num_bitmasks, ~bitmask_type{0}); - if (predicate.nullable()) { - CUDF_CUDA_TRY(cudaMemcpyAsync(host_bitmask.data(), - predicate.null_mask(), - num_bitmasks * sizeof(bitmask_type), - cudaMemcpyDefault, - stream.value())); - } + auto const host_bitmask = [&] { + auto const num_bitmasks = num_bitmask_words(predicate.size()); + if (predicate.nullable()) { + return cudf::detail::make_host_vector_sync( + device_span(predicate.null_mask(), num_bitmasks), stream); + } else { + auto bitmask = cudf::detail::make_host_vector(num_bitmasks, stream); + std::fill(bitmask.begin(), bitmask.end(), ~bitmask_type{0}); + return bitmask; + } + }(); + auto validity_it = cudf::detail::make_counting_transform_iterator( 0, [bitmask = host_bitmask.data()](auto bit_index) { return bit_is_set(bitmask, bit_index); }); diff --git a/cpp/src/io/parquet/reader_impl.cpp b/cpp/src/io/parquet/reader_impl.cpp index f0865c715bc..0705ff6f5cc 100644 --- a/cpp/src/io/parquet/reader_impl.cpp +++ b/cpp/src/io/parquet/reader_impl.cpp @@ -78,7 +78,7 @@ void reader::impl::decode_page_data(read_mode mode, size_t skip_rows, size_t num // TODO: This step is somewhat redundant if size info has already been calculated (nested schema, // chunked reader). auto const has_strings = (kernel_mask & STRINGS_MASK) != 0; - std::vector col_string_sizes(_input_columns.size(), 0L); + auto col_string_sizes = cudf::detail::make_host_vector(_input_columns.size(), _stream); if (has_strings) { // need to compute pages bounds/sizes if we lack page indexes or are using custom bounds // TODO: we could probably dummy up size stats for FLBA data since we know the width diff --git a/cpp/src/io/parquet/reader_impl.hpp b/cpp/src/io/parquet/reader_impl.hpp index 62ffc4d3077..3aa9b94ed6b 100644 --- a/cpp/src/io/parquet/reader_impl.hpp +++ b/cpp/src/io/parquet/reader_impl.hpp @@ -284,7 +284,7 @@ class reader::impl { * * @return Vector of total string data sizes for each column */ - std::vector calculate_page_string_offsets(); + cudf::detail::host_vector calculate_page_string_offsets(); /** * @brief Converts the page data and outputs to columns. diff --git a/cpp/src/io/parquet/reader_impl_chunking.hpp b/cpp/src/io/parquet/reader_impl_chunking.hpp index 3a3cdd34a58..a0c2dbd3e44 100644 --- a/cpp/src/io/parquet/reader_impl_chunking.hpp +++ b/cpp/src/io/parquet/reader_impl_chunking.hpp @@ -107,7 +107,7 @@ struct subpass_intermediate_data { * rowgroups may represent less than all of the rowgroups to be read for the file. */ struct pass_intermediate_data { - std::vector> raw_page_data; + std::vector raw_page_data; // rowgroup, chunk and page information for the current pass. bool has_compressed_data{false}; diff --git a/cpp/src/io/parquet/reader_impl_preprocess.cu b/cpp/src/io/parquet/reader_impl_preprocess.cu index 60e35465793..f03f1214b9a 100644 --- a/cpp/src/io/parquet/reader_impl_preprocess.cu +++ b/cpp/src/io/parquet/reader_impl_preprocess.cu @@ -218,7 +218,7 @@ void generate_depth_remappings( */ [[nodiscard]] std::future read_column_chunks_async( std::vector> const& sources, - std::vector>& page_data, + cudf::host_span page_data, cudf::detail::hostdevice_vector& chunks, size_t begin_chunk, size_t end_chunk, @@ -251,23 +251,24 @@ void generate_depth_remappings( if (source->is_device_read_preferred(io_size)) { // Buffer needs to be padded. // Required by `gpuDecodePageData`. - auto buffer = + page_data[chunk] = rmm::device_buffer(cudf::util::round_up_safe(io_size, BUFFER_PADDING_MULTIPLE), stream); auto fut_read_size = source->device_read_async( - io_offset, io_size, static_cast(buffer.data()), stream); + io_offset, io_size, static_cast(page_data[chunk].data()), stream); read_tasks.emplace_back(std::move(fut_read_size)); - page_data[chunk] = datasource::buffer::create(std::move(buffer)); } else { auto const read_buffer = source->host_read(io_offset, io_size); // Buffer needs to be padded. // Required by `gpuDecodePageData`. - auto tmp_buffer = rmm::device_buffer( + page_data[chunk] = rmm::device_buffer( cudf::util::round_up_safe(read_buffer->size(), BUFFER_PADDING_MULTIPLE), stream); - CUDF_CUDA_TRY(cudaMemcpyAsync( - tmp_buffer.data(), read_buffer->data(), read_buffer->size(), cudaMemcpyDefault, stream)); - page_data[chunk] = datasource::buffer::create(std::move(tmp_buffer)); + CUDF_CUDA_TRY(cudaMemcpyAsync(page_data[chunk].data(), + read_buffer->data(), + read_buffer->size(), + cudaMemcpyDefault, + stream)); } - auto d_compdata = page_data[chunk]->data(); + auto d_compdata = static_cast(page_data[chunk].data()); do { chunks[chunk].compressed_data = d_compdata; d_compdata += chunks[chunk].compressed_size; @@ -980,7 +981,7 @@ std::pair> reader::impl::read_column_chunks() std::vector chunk_source_map(num_chunks); // Tracker for eventually deallocating compressed and uncompressed data - raw_page_data = std::vector>(num_chunks); + raw_page_data = std::vector(num_chunks); // Keep track of column chunk file offsets std::vector column_chunk_offsets(num_chunks); @@ -1695,15 +1696,14 @@ void reader::impl::allocate_columns(read_mode mode, size_t skip_rows, size_t num nullmask_bufs, std::numeric_limits::max(), _stream); } -std::vector reader::impl::calculate_page_string_offsets() +cudf::detail::host_vector reader::impl::calculate_page_string_offsets() { auto& pass = *_pass_itm_data; auto& subpass = *pass.subpass; auto page_keys = make_page_key_iterator(subpass.pages); - std::vector col_sizes(_input_columns.size(), 0L); - rmm::device_uvector d_col_sizes(col_sizes.size(), _stream); + rmm::device_uvector d_col_sizes(_input_columns.size(), _stream); // use page_index to fetch page string sizes in the proper order auto val_iter = thrust::make_transform_iterator(subpass.pages.device_begin(), @@ -1717,7 +1717,7 @@ std::vector reader::impl::calculate_page_string_offsets() page_offset_output_iter{subpass.pages.device_ptr()}); // now sum up page sizes - rmm::device_uvector reduce_keys(col_sizes.size(), _stream); + rmm::device_uvector reduce_keys(d_col_sizes.size(), _stream); thrust::reduce_by_key(rmm::exec_policy_nosync(_stream), page_keys, page_keys + subpass.pages.size(), @@ -1725,14 +1725,7 @@ std::vector reader::impl::calculate_page_string_offsets() reduce_keys.begin(), d_col_sizes.begin()); - cudaMemcpyAsync(col_sizes.data(), - d_col_sizes.data(), - sizeof(size_t) * col_sizes.size(), - cudaMemcpyDeviceToHost, - _stream); - _stream.synchronize(); - - return col_sizes; + return cudf::detail::make_host_vector_sync(d_col_sizes, _stream); } } // namespace cudf::io::parquet::detail diff --git a/cpp/src/io/parquet/writer_impl.cu b/cpp/src/io/parquet/writer_impl.cu index 190f13eb688..f865c9a7643 100644 --- a/cpp/src/io/parquet/writer_impl.cu +++ b/cpp/src/io/parquet/writer_impl.cu @@ -183,7 +183,7 @@ struct aggregate_writer_metadata { std::vector row_groups; std::vector key_value_metadata; std::vector offset_indexes; - std::vector> column_indexes; + std::vector> column_indexes; }; std::vector files; std::optional> column_orders = std::nullopt; @@ -1543,12 +1543,7 @@ void encode_pages(hostdevice_2dvector& chunks, d_chunks.flat_view(), {column_stats, pages.size()}, column_index_truncate_length, stream); } - auto h_chunks = chunks.host_view(); - CUDF_CUDA_TRY(cudaMemcpyAsync(h_chunks.data(), - d_chunks.data(), - d_chunks.flat_view().size_bytes(), - cudaMemcpyDefault, - stream.value())); + chunks.device_to_host_async(stream); if (comp_stats.has_value()) { comp_stats.value() += collect_compression_statistics(comp_in, comp_res, stream); @@ -2559,12 +2554,11 @@ void writer::impl::write_parquet_data_to_sink( } else { CUDF_EXPECTS(bounce_buffer.size() >= ck.compressed_size, "Bounce buffer was not properly initialized."); - CUDF_CUDA_TRY(cudaMemcpyAsync(bounce_buffer.data(), - dev_bfr + ck.ck_stat_size, - ck.compressed_size, - cudaMemcpyDefault, - _stream.value())); - _stream.synchronize(); + cudf::detail::cuda_memcpy( + host_span{bounce_buffer}.subspan(0, ck.compressed_size), + device_span{dev_bfr + ck.ck_stat_size, ck.compressed_size}, + _stream); + _out_sink[p]->host_write(bounce_buffer.data(), ck.compressed_size); } @@ -2600,13 +2594,8 @@ void writer::impl::write_parquet_data_to_sink( auto const& column_chunk_meta = row_group.columns[i].meta_data; // start transfer of the column index - std::vector column_idx; - column_idx.resize(ck.column_index_size); - CUDF_CUDA_TRY(cudaMemcpyAsync(column_idx.data(), - ck.column_index_blob, - ck.column_index_size, - cudaMemcpyDefault, - _stream.value())); + auto column_idx = cudf::detail::make_host_vector_async( + device_span{ck.column_index_blob, ck.column_index_size}, _stream); // calculate offsets while the column index is transferring int64_t curr_pg_offset = column_chunk_meta.data_page_offset; diff --git a/cpp/src/io/text/bgzip_data_chunk_source.cu b/cpp/src/io/text/bgzip_data_chunk_source.cu index badcd3f58f9..06069630685 100644 --- a/cpp/src/io/text/bgzip_data_chunk_source.cu +++ b/cpp/src/io/text/bgzip_data_chunk_source.cu @@ -74,8 +74,8 @@ class bgzip_data_chunk_reader : public data_chunk_reader { // Buffer needs to be padded. // Required by `inflate_kernel`. device.resize(cudf::util::round_up_safe(host.size(), BUFFER_PADDING_MULTIPLE), stream); - CUDF_CUDA_TRY(cudaMemcpyAsync( - device.data(), host.data(), host.size() * sizeof(T), cudaMemcpyDefault, stream.value())); + cudf::detail::cuda_memcpy_async( + device_span{device}.subspan(0, host.size()), host, stream); } struct decompression_blocks { diff --git a/cpp/src/io/text/data_chunk_source_factories.cpp b/cpp/src/io/text/data_chunk_source_factories.cpp index 58faa0ebfe4..4baea8655e0 100644 --- a/cpp/src/io/text/data_chunk_source_factories.cpp +++ b/cpp/src/io/text/data_chunk_source_factories.cpp @@ -87,8 +87,10 @@ class datasource_chunk_reader : public data_chunk_reader { _source->host_read(_offset, read_size, reinterpret_cast(h_ticket.buffer.data())); // copy the host-pinned data on to device - CUDF_CUDA_TRY(cudaMemcpyAsync( - chunk.data(), h_ticket.buffer.data(), read_size, cudaMemcpyDefault, stream.value())); + cudf::detail::cuda_memcpy_async( + device_span{chunk}.subspan(0, read_size), + host_span{h_ticket.buffer}.subspan(0, read_size), + stream); // record the host-to-device copy. CUDF_CUDA_TRY(cudaEventRecord(h_ticket.event, stream.value())); @@ -153,8 +155,10 @@ class istream_data_chunk_reader : public data_chunk_reader { auto chunk = rmm::device_uvector(read_size, stream); // copy the host-pinned data on to device - CUDF_CUDA_TRY(cudaMemcpyAsync( - chunk.data(), h_ticket.buffer.data(), read_size, cudaMemcpyDefault, stream.value())); + cudf::detail::cuda_memcpy_async( + device_span{chunk}.subspan(0, read_size), + host_span{h_ticket.buffer}.subspan(0, read_size), + stream); // record the host-to-device copy. CUDF_CUDA_TRY(cudaEventRecord(h_ticket.event, stream.value())); @@ -193,12 +197,10 @@ class host_span_data_chunk_reader : public data_chunk_reader { auto chunk = rmm::device_uvector(read_size, stream); // copy the host data to device - CUDF_CUDA_TRY(cudaMemcpyAsync( // - chunk.data(), - _data.data() + _position, - read_size, - cudaMemcpyDefault, - stream.value())); + cudf::detail::cuda_memcpy_async( + cudf::device_span{chunk}.subspan(0, read_size), + cudf::host_span{_data}.subspan(_position, read_size), + stream); _position += read_size; diff --git a/cpp/src/io/utilities/datasource.cpp b/cpp/src/io/utilities/datasource.cpp index 2daaecadca6..a63c13b01de 100644 --- a/cpp/src/io/utilities/datasource.cpp +++ b/cpp/src/io/utilities/datasource.cpp @@ -18,6 +18,7 @@ #include "getenv_or.hpp" #include +#include #include #include #include @@ -246,17 +247,18 @@ class device_buffer_source final : public datasource { size_t host_read(size_t offset, size_t size, uint8_t* dst) override { auto const count = std::min(size, this->size() - offset); - auto const stream = cudf::get_default_stream(); - CUDF_CUDA_TRY( - cudaMemcpyAsync(dst, _d_buffer.data() + offset, count, cudaMemcpyDefault, stream.value())); - stream.synchronize(); + auto const stream = cudf::detail::global_cuda_stream_pool().get_stream(); + cudf::detail::cuda_memcpy(host_span{dst, count}, + device_span{ + reinterpret_cast(_d_buffer.data() + offset), count}, + stream); return count; } std::unique_ptr host_read(size_t offset, size_t size) override { auto const count = std::min(size, this->size() - offset); - auto const stream = cudf::get_default_stream(); + auto const stream = cudf::detail::global_cuda_stream_pool().get_stream(); auto h_data = cudf::detail::make_host_vector_async( cudf::device_span{_d_buffer.data() + offset, count}, stream); stream.synchronize();