From b4674a1b1cab1b9c4338e4251cf940185de98aff Mon Sep 17 00:00:00 2001 From: Mike Wilson Date: Wed, 25 May 2022 14:48:40 -0400 Subject: [PATCH] String support for jcudf row to cudf column conversion (#10871) This PR adds support for string column creation from jcudf row data. It leverages the fixed-width data copy to convert the offsets and lengths stored inside the fixed-width data section and then uses that information to copy the string data itself from the jcudf row format into the cudf column. closes #10286 Authors: - Mike Wilson (https://github.com/hyperbolic2346) Approvers: - Ray Douglass (https://github.com/raydouglass) - Nghia Truong (https://github.com/ttnghia) - https://github.com/nvdbaranec --- java/src/main/native/src/row_conversion.cu | 374 +++++++++++++++------ 1 file changed, 270 insertions(+), 104 deletions(-) diff --git a/java/src/main/native/src/row_conversion.cu b/java/src/main/native/src/row_conversion.cu index 96ee95c476d..8fba7d27bce 100644 --- a/java/src/main/native/src/row_conversion.cu +++ b/java/src/main/native/src/row_conversion.cu @@ -69,7 +69,10 @@ constexpr auto NUM_VALIDITY_TILES_PER_KERNEL_LOADED = 2; constexpr auto MAX_BATCH_SIZE = std::numeric_limits::max(); -constexpr auto NUM_STRING_ROWS_PER_BLOCK = 16; +// Number of rows each block processes in the two kernels. Tuned via nsight +constexpr auto NUM_STRING_ROWS_PER_BLOCK_TO_ROWS = 1024; +constexpr auto NUM_STRING_ROWS_PER_BLOCK_FROM_ROWS = 64; +constexpr auto MIN_STRING_BLOCKS = 32; constexpr auto MAX_STRING_BLOCKS = MAX_BATCH_SIZE; constexpr auto NUM_THREADS = 256; @@ -269,14 +272,14 @@ build_string_row_offsets(table_view const &tbl, size_type fixed_width_and_validi * */ struct string_row_offset_functor { - string_row_offset_functor(device_span _d_row_offsets) - : d_row_offsets(_d_row_offsets){}; + string_row_offset_functor(device_span d_row_offsets) + : d_row_offsets(d_row_offsets){}; __device__ inline size_type operator()(int row_number, int) const { return d_row_offsets[row_number]; } - device_span d_row_offsets; + device_span d_row_offsets; }; /** @@ -907,7 +910,7 @@ __global__ void copy_strings_to_rows(size_type const num_rows, size_type const n size_type const **variable_col_offsets, size_type fixed_width_row_size, RowOffsetIter row_offsets, size_type const batch_row_offset, int8_t *output_data) { - // Each block will take a group of rows controlled by NUM_STRING_ROWS_PER_BLOCK. + // Each block will take a group of rows controlled by NUM_STRING_ROWS_PER_BLOCK_TO_ROWS. // Each warp will copy a row at a time. The base thread will first go through column data and // fill out offset/length information for the column. Then all threads of the warp will // participate in the memcpy of the string data. @@ -918,9 +921,9 @@ __global__ void copy_strings_to_rows(size_type const num_rows, size_type const n #endif auto const start_row = - blockIdx.x * NUM_STRING_ROWS_PER_BLOCK + my_tile.meta_group_rank() + batch_row_offset; + blockIdx.x * NUM_STRING_ROWS_PER_BLOCK_TO_ROWS + my_tile.meta_group_rank() + batch_row_offset; auto const end_row = - std::min(num_rows, static_cast(start_row + NUM_STRING_ROWS_PER_BLOCK)); + std::min(num_rows, static_cast(start_row + NUM_STRING_ROWS_PER_BLOCK_TO_ROWS)); for (int row = start_row; row < end_row; row += my_tile.meta_group_size()) { auto offset = fixed_width_row_size; // initial offset to variable-width data @@ -937,7 +940,14 @@ __global__ void copy_strings_to_rows(size_type const num_rows, size_type const n } auto string_output_dest = &output_data[base_row_offset + offset]; auto string_output_src = &variable_input_data[col][string_start_offset]; - MEMCPY(string_output_dest, string_output_src, string_length, block_barrier); +#ifdef ASYNC_MEMCPY_SUPPORTED + cuda::memcpy_async(my_tile, string_output_dest, string_output_src, string_length, + block_barrier); +#else + for (int c = my_tile.thread_rank(); c < string_length; c += my_tile.size()) { + string_output_dest[c] = string_output_src[c]; + } +#endif offset += string_length; } } @@ -1238,6 +1248,65 @@ copy_validity_from_rows(const size_type num_rows, const size_type num_columns, #endif // ASYNC_MEMCPY_SUPPORTED } +/** + * @brief copies string data from jcudf row format to cudf columns + * + * @tparam RowOffsetIter iterator for row offsets into the destination data + * @param row_offsets offsets for each row in input data + * @param string_row_offsets offset data into jcudf row data for each string + * @param string_lengths length of each incoming string in each column + * @param string_column_offsets offset column data for cudf column + * @param string_col_data output cudf string column data + * @param row_data jcudf row data + * @param num_rows number of rows in data + * @param num_string_columns number of string columns in the table + */ +template +__global__ void copy_strings_from_rows(RowOffsetIter row_offsets, int32_t **string_row_offsets, + int32_t **string_lengths, size_type **string_column_offsets, + char **string_col_data, int8_t const *row_data, + size_type const num_rows, + size_type const num_string_columns) { + // Each warp takes a tile, which is a single column and up to ROWS_PER_BLOCK rows. A tile + // will not wrap around the bottom of the table. The warp will copy the strings for each row + // in the tile. Traversing in row-major order to coalesce the offsets and size reads. + auto my_block = cooperative_groups::this_thread_block(); + auto my_partition = cooperative_groups::tiled_partition<32>(my_block); +#ifdef ASYNC_MEMCPY_SUPPORTED + cuda::barrier block_barrier; +#endif + + // workaround for not being able to take a reference to a constexpr host variable + auto const ROWS_PER_BLOCK = NUM_STRING_ROWS_PER_BLOCK_FROM_ROWS; + auto const tiles_per_col = util::div_rounding_up_unsafe(num_rows, ROWS_PER_BLOCK); + auto const starting_tile = + blockIdx.x * my_partition.meta_group_size() + my_partition.meta_group_rank(); + auto const num_tiles = tiles_per_col * num_string_columns; + auto const tile_stride = my_partition.meta_group_size() * gridDim.x; + // Each warp will copy strings in its tile. This is handled by all the threads of a warp passing + // the same parameters to async_memcpy and all threads in the warp participating in the copy. + for (auto my_tile = starting_tile; my_tile < num_tiles; my_tile += tile_stride) { + auto const starting_row = (my_tile % tiles_per_col) * ROWS_PER_BLOCK; + auto const col = my_tile / tiles_per_col; + auto const str_len = string_lengths[col]; + auto const str_row_off = string_row_offsets[col]; + auto const str_col_off = string_column_offsets[col]; + auto str_col_data = string_col_data[col]; + for (int row = starting_row; row < starting_row + ROWS_PER_BLOCK && row < num_rows; ++row) { + auto const src = &row_data[row_offsets(row, 0) + str_row_off[row]]; + auto dst = &str_col_data[str_col_off[row]]; + +#ifdef ASYNC_MEMCPY_SUPPORTED + cuda::memcpy_async(my_partition, dst, src, str_len[row], block_barrier); +#else + for (int c = my_partition.thread_rank(); c < str_len[row]; c += my_partition.size()) { + dst[c] = src[c]; + } +#endif + } + } +} + /** * @brief Calculate the dimensions of the kernel for fixed width only columns. * @@ -1374,9 +1443,9 @@ static inline int32_t compute_fixed_width_layout(std::vector const &s * @brief column sizes and column start offsets for a table */ struct column_info_s { - size_type fixed_width_size_per_row; - std::vector fixed_width_column_starts; - std::vector fixed_width_column_sizes; + size_type size_per_row; + std::vector column_starts; + std::vector column_sizes; std::vector variable_width_column_starts; column_info_s &operator=(column_info_s const &other) = delete; @@ -1395,42 +1464,43 @@ struct column_info_s { */ template column_info_s compute_column_information(iterator begin, iterator end) { - size_type fixed_width_size_per_row = 0; - std::vector fixed_width_column_starts; - std::vector fixed_width_column_sizes; + size_type size_per_row = 0; + std::vector column_starts; + std::vector column_sizes; std::vector variable_width_column_starts; - for (auto cv = begin; cv != end; ++cv) { - auto col_type = std::get<0>(*cv); - bool const compound_type = is_compound(col_type); + column_starts.reserve(std::distance(begin, end) + 1); + column_sizes.reserve(std::distance(begin, end)); + + for (auto col_type = begin; col_type != end; ++col_type) { + bool const compound_type = is_compound(*col_type); // a list or string column will write a single uint64 // of data here for offset/length - auto const col_size = compound_type ? sizeof(uint32_t) + sizeof(uint32_t) : size_of(col_type); + auto const col_size = compound_type ? sizeof(uint32_t) + sizeof(uint32_t) : size_of(*col_type); // align size for this type - They are the same for fixed width types and 4 bytes for variable // width length/offset combos size_type const alignment_needed = compound_type ? __alignof(uint32_t) : col_size; - fixed_width_size_per_row = util::round_up_unsafe(fixed_width_size_per_row, alignment_needed); + size_per_row = util::round_up_unsafe(size_per_row, alignment_needed); if (compound_type) { - variable_width_column_starts.push_back(fixed_width_size_per_row); - } else { - fixed_width_column_starts.push_back(fixed_width_size_per_row); - fixed_width_column_sizes.push_back(col_size); + variable_width_column_starts.push_back(size_per_row); } - fixed_width_size_per_row += col_size; + column_starts.push_back(size_per_row); + column_sizes.push_back(col_size); + size_per_row += col_size; } // add validity offset to the end of fixed_width offsets - auto validity_offset = fixed_width_size_per_row; - fixed_width_column_starts.push_back(validity_offset); + auto validity_offset = size_per_row; + column_starts.push_back(validity_offset); // validity is byte-aligned in the JCUDF format - fixed_width_size_per_row += + size_per_row += util::div_rounding_up_safe(static_cast(std::distance(begin, end)), CHAR_BIT); - return {fixed_width_size_per_row, std::move(fixed_width_column_starts), - std::move(fixed_width_column_sizes), std::move(variable_width_column_starts)}; + return {size_per_row, std::move(column_starts), std::move(column_sizes), + std::move(variable_width_column_starts)}; } /** @@ -1790,23 +1860,17 @@ std::vector> convert_to_rows( return table_view(cols); }; - // build fixed_width table view with only fixed-width columns - auto fixed_width_table = - fixed_width_only ? tbl : - select_columns(tbl, [](auto col) { return !is_compound(col.type()); }); - - auto const num_fixed_width_columns = fixed_width_table.num_columns(); - auto dev_col_sizes = make_device_uvector_async(column_info.fixed_width_column_sizes, stream); - auto dev_col_starts = make_device_uvector_async(column_info.fixed_width_column_starts, stream); + auto dev_col_sizes = make_device_uvector_async(column_info.column_sizes, stream); + auto dev_col_starts = make_device_uvector_async(column_info.column_starts, stream); // Get the pointers to the input columnar data ready - auto data_begin = thrust::make_transform_iterator(fixed_width_table.begin(), [](auto const &c) { + auto const data_begin = thrust::make_transform_iterator(tbl.begin(), [](auto const &c) { return is_compound(c.type()) ? nullptr : c.template data(); }); - std::vector input_data(data_begin, data_begin + num_fixed_width_columns); + std::vector input_data(data_begin, data_begin + tbl.num_columns()); // validity code handles variable and fixed-width data, so give it everything - auto nm_begin = + auto const nm_begin = thrust::make_transform_iterator(tbl.begin(), [](auto const &c) { return c.null_mask(); }); std::vector input_nm(nm_begin, nm_begin + tbl.num_columns()); @@ -1831,8 +1895,8 @@ std::vector> convert_to_rows( int info_count = 0; detail::determine_tiles( - column_info.fixed_width_column_sizes, column_info.fixed_width_column_starts, first_batch_size, - num_rows, shmem_limit_per_tile, + column_info.column_sizes, column_info.column_starts, first_batch_size, num_rows, + shmem_limit_per_tile, [&gpu_batch_row_boundaries = batch_info.d_batch_row_boundaries, &info_count, &stream](int const start_col, int const end_col, int const tile_height) { int i = detail::compute_tile_counts(gpu_batch_row_boundaries, tile_height, stream); @@ -1844,8 +1908,8 @@ std::vector> convert_to_rows( int tile_offset = 0; detail::determine_tiles( - column_info.fixed_width_column_sizes, column_info.fixed_width_column_starts, first_batch_size, - num_rows, shmem_limit_per_tile, + column_info.column_sizes, column_info.column_starts, first_batch_size, num_rows, + shmem_limit_per_tile, [&gpu_batch_row_boundaries = batch_info.d_batch_row_boundaries, &gpu_tile_infos, num_rows, &tile_offset, stream](int const start_col, int const end_col, int const tile_height) { tile_offset += detail::build_tiles( @@ -1854,24 +1918,25 @@ std::vector> convert_to_rows( }); // blast through the entire table and convert it - dim3 blocks(util::div_rounding_up_unsafe(gpu_tile_infos.size(), NUM_TILES_PER_KERNEL_TO_ROWS)); - dim3 threads(NUM_THREADS); + dim3 const blocks( + util::div_rounding_up_unsafe(gpu_tile_infos.size(), NUM_TILES_PER_KERNEL_TO_ROWS)); + dim3 const threads(NUM_THREADS); // build validity tiles for ALL columns, variable and fixed width. auto validity_tile_infos = detail::build_validity_tile_infos( tbl.num_columns(), num_rows, shmem_limit_per_tile, batch_info.row_batches); auto dev_validity_tile_infos = make_device_uvector_async(validity_tile_infos, stream); - dim3 validity_blocks( + dim3 const validity_blocks( util::div_rounding_up_unsafe(validity_tile_infos.size(), NUM_VALIDITY_TILES_PER_KERNEL)); - dim3 validity_threads( + dim3 const validity_threads( std::min(validity_tile_infos.size() * NUM_VALIDITY_THREADS_PER_TILE, 128lu)); - auto const validity_offset = column_info.fixed_width_column_starts.back(); + auto const validity_offset = column_info.column_starts.back(); detail::copy_to_rows<<>>( - num_rows, num_fixed_width_columns, shmem_limit_per_tile, gpu_tile_infos, - dev_input_data.data(), dev_col_sizes.data(), dev_col_starts.data(), offset_functor, + num_rows, tbl.num_columns(), shmem_limit_per_tile, gpu_tile_infos, dev_input_data.data(), + dev_col_sizes.data(), dev_col_starts.data(), offset_functor, batch_info.d_batch_row_boundaries.data(), reinterpret_cast(dev_output_data.data())); @@ -1884,15 +1949,15 @@ std::vector> convert_to_rows( if (!fixed_width_only) { // build table view for variable-width data only - auto variable_width_table = + auto const variable_width_table = select_columns(tbl, [](auto col) { return is_compound(col.type()); }); CUDF_EXPECTS(!variable_width_table.is_empty(), "No variable-width columns when expected!"); CUDF_EXPECTS(variable_width_offsets.has_value(), "No variable width offset data!"); - auto variable_data_begin = + auto const variable_data_begin = thrust::make_transform_iterator(variable_width_table.begin(), [](auto const &c) { - strings_column_view scv{c}; + strings_column_view const scv{c}; return is_compound(c.type()) ? scv.chars().template data() : nullptr; }); std::vector variable_width_input_data( @@ -1902,19 +1967,19 @@ std::vector> convert_to_rows( auto dev_variable_col_output_offsets = make_device_uvector_async(column_info.variable_width_column_starts, stream); - dim3 string_threads(NUM_THREADS); + dim3 const string_threads(NUM_THREADS); for (uint i = 0; i < batch_info.row_batches.size(); i++) { auto const batch_row_offset = batch_info.batch_row_boundaries[i]; auto const batch_num_rows = batch_info.row_batches[i].row_count; - dim3 string_blocks( - std::min(MAX_STRING_BLOCKS, - util::div_rounding_up_unsafe(batch_num_rows, NUM_STRING_ROWS_PER_BLOCK))); + dim3 const string_blocks(std::min( + MAX_STRING_BLOCKS, + util::div_rounding_up_unsafe(batch_num_rows, NUM_STRING_ROWS_PER_BLOCK_TO_ROWS))); detail::copy_strings_to_rows<<>>( batch_num_rows, variable_width_table.num_columns(), dev_variable_input_data.data(), dev_variable_col_output_offsets.data(), variable_width_offsets->data(), - column_info.fixed_width_size_per_row, offset_functor, batch_row_offset, + column_info.size_per_row, offset_functor, batch_row_offset, reinterpret_cast(output_data[i])); } } @@ -1922,6 +1987,7 @@ std::vector> convert_to_rows( // split up the output buffer into multiple buffers based on row batch sizes // and create list of byte columns std::vector> ret; + ret.reserve(batch_info.row_batches.size()); auto counting_iter = thrust::make_counting_iterator(0); std::transform(counting_iter, counting_iter + batch_info.row_batches.size(), std::back_inserter(ret), [&](auto batch) { @@ -1975,29 +2041,27 @@ std::vector> convert_to_rows(table_view const &tbl, // to that point. These are row batches and they are decided first before building the // tiles so the tiles can be properly cut around them. - auto schema_column_iter = thrust::make_transform_iterator( - thrust::make_counting_iterator(0), [&tbl](auto i) -> std::pair { - return {tbl.column(i).type(), tbl.column(i)}; - }); + auto schema_column_iter = + thrust::make_transform_iterator(tbl.begin(), [](auto const &i) { return i.type(); }); auto column_info = detail::compute_column_information(schema_column_iter, schema_column_iter + num_columns); - auto const fixed_width_size_per_row = column_info.fixed_width_size_per_row; + auto const size_per_row = column_info.size_per_row; if (fixed_width_only) { // total encoded row size. This includes fixed-width data and validity only. It does not include // variable-width data since it isn't copied with the fixed-width and validity kernel. auto row_size_iter = thrust::make_constant_iterator( - util::round_up_unsafe(fixed_width_size_per_row, JCUDF_ROW_ALIGNMENT)); + util::round_up_unsafe(size_per_row, JCUDF_ROW_ALIGNMENT)); auto batch_info = detail::build_batches(num_rows, row_size_iter, fixed_width_only, stream, mr); detail::fixed_width_row_offset_functor offset_functor( - util::round_up_unsafe(fixed_width_size_per_row, JCUDF_ROW_ALIGNMENT)); + util::round_up_unsafe(size_per_row, JCUDF_ROW_ALIGNMENT)); return detail::convert_to_rows(tbl, batch_info, offset_functor, std::move(column_info), std::nullopt, stream, mr); } else { - auto offset_data = detail::build_string_row_offsets(tbl, fixed_width_size_per_row, stream); + auto offset_data = detail::build_string_row_offsets(tbl, size_per_row, stream); auto &row_sizes = std::get<0>(offset_data); auto row_size_iter = cudf::detail::make_counting_transform_iterator( @@ -2093,7 +2157,21 @@ std::unique_ptr convert_from_rows(lists_column_view const &input, CUDF_EXPECTS(list_type == type_id::INT8 || list_type == type_id::UINT8, "Only a list of bytes is supported as input"); - auto const num_columns = schema.size(); + // convert any strings in the schema to two int32 columns + // This allows us to leverage the fixed-width copy code to fill in our offset and string length + // data. + std::vector string_schema; + string_schema.reserve(schema.size()); + for (auto i : schema) { + if (i.id() == type_id::STRING) { + string_schema.push_back(data_type(type_id::INT32)); + string_schema.push_back(data_type(type_id::INT32)); + } else { + string_schema.push_back(i); + } + } + + auto const num_columns = string_schema.size(); auto const num_rows = input.parent().size(); int device_id; @@ -2110,33 +2188,57 @@ std::unique_ptr
convert_from_rows(lists_column_view const &input, int shmem_limit_per_tile = total_shmem_in_bytes / NUM_TILES_PER_KERNEL_LOADED; - auto iter = thrust::make_transform_iterator(thrust::make_counting_iterator(0), [&schema](auto i) { - return std::make_tuple(schema[i], nullptr); - }); - auto column_info = detail::compute_column_information(iter, iter + num_columns); - auto const fixed_width_size_per_row = - util::round_up_unsafe(column_info.fixed_width_size_per_row, JCUDF_ROW_ALIGNMENT); + auto column_info = detail::compute_column_information(string_schema.begin(), string_schema.end()); + auto const size_per_row = util::round_up_unsafe(column_info.size_per_row, JCUDF_ROW_ALIGNMENT); // Ideally we would check that the offsets are all the same, etc. but for now // this is probably fine - CUDF_EXPECTS(fixed_width_size_per_row * num_rows == child.size(), - "The layout of the data appears to be off"); - auto dev_col_starts = make_device_uvector_async(column_info.fixed_width_column_starts, stream); - auto dev_col_sizes = make_device_uvector_async(column_info.fixed_width_column_sizes, stream); + CUDF_EXPECTS(size_per_row * num_rows <= child.size(), "The layout of the data appears to be off"); + auto dev_col_starts = make_device_uvector_async(column_info.column_starts, stream); + auto dev_col_sizes = make_device_uvector_async(column_info.column_sizes, stream); // Allocate the columns we are going to write into std::vector> output_columns; + std::vector> string_row_offset_columns; + std::vector> string_length_columns; std::vector output_data; std::vector output_nm; - for (int i = 0; i < static_cast(num_columns); i++) { - auto column = - make_fixed_width_column(schema[i], num_rows, mask_state::UNINITIALIZED, stream, mr); - auto mut = column->mutable_view(); - output_data.emplace_back(mut.data()); - output_nm.emplace_back(mut.null_mask()); - output_columns.emplace_back(std::move(column)); + std::vector string_row_offsets; + std::vector string_lengths; + for (auto i : schema) { + auto make_col = [&output_data, &output_nm](data_type type, size_type num_rows, bool include_nm, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource *mr) { + auto column = make_fixed_width_column( + type, num_rows, include_nm ? mask_state::UNINITIALIZED : mask_state::UNALLOCATED, stream, + mr); + auto mut = column->mutable_view(); + output_data.emplace_back(mut.data()); + if (include_nm) { + output_nm.emplace_back(mut.null_mask()); + } + return column; + }; + if (i.id() == type_id::STRING) { + auto const int32type = data_type(type_id::INT32); + auto offset_col = + make_col(int32type, num_rows, true, stream, rmm::mr::get_current_device_resource()); + string_row_offsets.push_back(offset_col->mutable_view().data()); + string_row_offset_columns.emplace_back(std::move(offset_col)); + auto length_col = + make_col(int32type, num_rows, false, stream, rmm::mr::get_current_device_resource()); + string_lengths.push_back(length_col->mutable_view().data()); + string_length_columns.emplace_back(std::move(length_col)); + // placeholder + output_columns.emplace_back(make_empty_column(type_id::STRING)); + } else { + output_columns.emplace_back(make_col(i, num_rows, true, stream, mr)); + } } + auto dev_string_row_offsets = make_device_uvector_async(string_row_offsets, stream); + auto dev_string_lengths = make_device_uvector_async(string_lengths, stream); + // build the row_batches from the passed in list column std::vector row_batches; row_batches.push_back( @@ -2156,8 +2258,7 @@ std::unique_ptr
convert_from_rows(lists_column_view const &input, int info_count = 0; detail::determine_tiles( - column_info.fixed_width_column_sizes, column_info.fixed_width_column_starts, num_rows, - num_rows, shmem_limit_per_tile, + column_info.column_sizes, column_info.column_starts, num_rows, num_rows, shmem_limit_per_tile, [&gpu_batch_row_boundaries, &info_count, &stream](int const start_col, int const end_col, int const tile_height) { info_count += detail::compute_tile_counts(gpu_batch_row_boundaries, tile_height, stream); @@ -2168,8 +2269,7 @@ std::unique_ptr
convert_from_rows(lists_column_view const &input, int tile_offset = 0; detail::determine_tiles( - column_info.fixed_width_column_sizes, column_info.fixed_width_column_starts, num_rows, - num_rows, shmem_limit_per_tile, + column_info.column_sizes, column_info.column_starts, num_rows, num_rows, shmem_limit_per_tile, [&gpu_batch_row_boundaries, &gpu_tile_infos, num_rows, &tile_offset, stream](int const start_col, int const end_col, int const tile_height) { tile_offset += detail::build_tiles( @@ -2177,32 +2277,98 @@ std::unique_ptr
convert_from_rows(lists_column_view const &input, gpu_batch_row_boundaries, start_col, end_col, tile_height, num_rows, stream); }); - dim3 blocks(util::div_rounding_up_unsafe(gpu_tile_infos.size(), NUM_TILES_PER_KERNEL_FROM_ROWS)); - dim3 threads(NUM_THREADS); + dim3 const blocks( + util::div_rounding_up_unsafe(gpu_tile_infos.size(), NUM_TILES_PER_KERNEL_FROM_ROWS)); + dim3 const threads(NUM_THREADS); + // validity needs to be calculated based on the actual number of final table columns auto validity_tile_infos = - detail::build_validity_tile_infos(num_columns, num_rows, shmem_limit_per_tile, row_batches); + detail::build_validity_tile_infos(schema.size(), num_rows, shmem_limit_per_tile, row_batches); auto dev_validity_tile_infos = make_device_uvector_async(validity_tile_infos, stream); - dim3 validity_blocks( + dim3 const validity_blocks( util::div_rounding_up_unsafe(validity_tile_infos.size(), NUM_VALIDITY_TILES_PER_KERNEL)); - dim3 validity_threads( + dim3 const validity_threads( std::min(validity_tile_infos.size() * NUM_VALIDITY_THREADS_PER_TILE, 128lu)); - detail::fixed_width_row_offset_functor offset_functor(fixed_width_size_per_row); + if (dev_string_row_offsets.size() == 0) { + detail::fixed_width_row_offset_functor offset_functor(size_per_row); + + detail::copy_from_rows<<>>( + num_rows, num_columns, shmem_limit_per_tile, offset_functor, + gpu_batch_row_boundaries.data(), dev_output_data.data(), dev_col_sizes.data(), + dev_col_starts.data(), gpu_tile_infos, child.data()); - detail::copy_from_rows<<>>( - num_rows, num_columns, shmem_limit_per_tile, offset_functor, gpu_batch_row_boundaries.data(), - dev_output_data.data(), dev_col_sizes.data(), dev_col_starts.data(), gpu_tile_infos, - child.data()); + detail::copy_validity_from_rows<<>>( + num_rows, num_columns, shmem_limit_per_tile, offset_functor, + gpu_batch_row_boundaries.data(), dev_output_nm.data(), column_info.column_starts.back(), + dev_validity_tile_infos, child.data()); - detail::copy_validity_from_rows<<>>( - num_rows, num_columns, shmem_limit_per_tile, offset_functor, gpu_batch_row_boundaries.data(), - dev_output_nm.data(), column_info.fixed_width_column_starts.back(), dev_validity_tile_infos, - child.data()); + } else { + detail::string_row_offset_functor offset_functor(device_span{input.offsets()}); + detail::copy_from_rows<<>>( + num_rows, num_columns, shmem_limit_per_tile, offset_functor, + gpu_batch_row_boundaries.data(), dev_output_data.data(), dev_col_sizes.data(), + dev_col_starts.data(), gpu_tile_infos, child.data()); + + detail::copy_validity_from_rows<<>>( + num_rows, num_columns, shmem_limit_per_tile, offset_functor, + gpu_batch_row_boundaries.data(), dev_output_nm.data(), column_info.column_starts.back(), + dev_validity_tile_infos, child.data()); + + std::vector> string_col_offsets; + std::vector> string_data_cols; + std::vector string_col_offset_ptrs; + std::vector string_data_col_ptrs; + for (auto &col_string_lengths : string_lengths) { + device_uvector output_string_offsets(num_rows + 1, stream, mr); + auto tmp = [num_rows, col_string_lengths] __device__(auto const &i) { + return i < num_rows ? col_string_lengths[i] : 0; + }; + auto bounded_iter = cudf::detail::make_counting_transform_iterator(0, tmp); + thrust::exclusive_scan(rmm::exec_policy(stream), bounded_iter, bounded_iter + num_rows + 1, + output_string_offsets.begin()); + + // allocate destination string column + rmm::device_uvector string_data(output_string_offsets.element(num_rows, stream), stream, + mr); + + string_col_offset_ptrs.push_back(output_string_offsets.data()); + string_data_col_ptrs.push_back(string_data.data()); + string_col_offsets.push_back(std::move(output_string_offsets)); + string_data_cols.push_back(std::move(string_data)); + } + auto dev_string_col_offsets = make_device_uvector_async(string_col_offset_ptrs, stream); + auto dev_string_data_cols = make_device_uvector_async(string_data_col_ptrs, stream); + + dim3 const string_blocks( + std::min(std::max(MIN_STRING_BLOCKS, num_rows / NUM_STRING_ROWS_PER_BLOCK_FROM_ROWS), + MAX_STRING_BLOCKS)); + dim3 const string_threads(NUM_THREADS); + + detail::copy_strings_from_rows<<>>( + offset_functor, dev_string_row_offsets.data(), dev_string_lengths.data(), + dev_string_col_offsets.data(), dev_string_data_cols.data(), child.data(), num_rows, + static_cast(string_col_offsets.size())); + + // merge strings back into output_columns + int string_idx = 0; + for (int i = 0; i < static_cast(schema.size()); ++i) { + if (schema[i].id() == type_id::STRING) { + // stuff real string column + auto string_data = string_row_offset_columns[string_idx].release()->release(); + output_columns[i] = make_strings_column(num_rows, std::move(string_col_offsets[string_idx]), + std::move(string_data_cols[string_idx]), + std::move(*string_data.null_mask.release()), + cudf::UNKNOWN_NULL_COUNT); + string_idx++; + } + } + } return std::make_unique
(std::move(output_columns)); }