From 8a8afaa436ec657828ef5052ec985b797e50a825 Mon Sep 17 00:00:00 2001 From: Mike Wilson Date: Fri, 11 Feb 2022 03:53:50 +0000 Subject: [PATCH] Support for Pascal GPUs (which lack memcpy_async) * first pass * pascal tomfoolery * Mithun fixed it * Figured out conditional compile. ** Must happen in __device__ context. * Experiments with __CUDA_ARCH__: ** Got it working with __global__, __device__, and thrust. * Initial stab at ASYNC_MEMCPY_SUPPORTED: ** 1. Found out that __host__ code does not have __CUDA_ARCH__ set. Everywhere else, this can be used reliably. ** 2. Replaced all the __CUDA_ARCH__ checks with ASYNC_MEMCPY_SUPPORTED. This is correct for all sites, EXCEPT convert_to/from_rows(), because those are __host__. ** 3. Running out of memory on Ampere box, for some reason. * Completed changes for __host__ code: ** 1. Changed convert_to_rows(), convert_from_rows() to use ifndef __CUDA_ARCH__. ** 2. Added comments for barrier initialization. * Reduced scope of ASYNC_MEMCPY_SUPPORTED in some if statements. * Formatting. * Updated JNI/row_conversion.cu. --- cpp/CMakeLists.txt | 1 + cpp/benchmarks/CMakeLists.txt | 4 + .../row_conversion/row_conversion.cpp | 181 ++ cpp/include/cudf/row_conversion.hpp | 51 + cpp/src/row_conversion/row_conversion.cu | 2301 +++++++++++++++++ cpp/src/row_conversion/row_conversion.cu.bk | 2234 ++++++++++++++++ cpp/tests/CMakeLists.txt | 2 + cpp/tests/row_conversion/row_conversion.cu | 1048 ++++++++ java/src/main/native/src/row_conversion.cu | 168 +- 9 files changed, 5929 insertions(+), 61 deletions(-) create mode 100644 cpp/benchmarks/row_conversion/row_conversion.cpp create mode 100644 cpp/include/cudf/row_conversion.hpp create mode 100644 cpp/src/row_conversion/row_conversion.cu create mode 100644 cpp/src/row_conversion/row_conversion.cu.bk create mode 100644 cpp/tests/row_conversion/row_conversion.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 407e1f9a858..fa6e0e21165 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -388,6 +388,7 @@ add_library( src/rolling/rolling.cu src/rolling/rolling_collect_list.cu src/round/round.cu + src/row_conversion/row_conversion.cu src/scalar/scalar.cpp src/scalar/scalar_factories.cpp src/search/search.cu diff --git a/cpp/benchmarks/CMakeLists.txt b/cpp/benchmarks/CMakeLists.txt index 3bc6dc10fdf..8ee43e6ffc5 100644 --- a/cpp/benchmarks/CMakeLists.txt +++ b/cpp/benchmarks/CMakeLists.txt @@ -276,6 +276,10 @@ ConfigureBench(JSON_BENCH string/json.cpp) # * io benchmark --------------------------------------------------------------------- ConfigureBench(MULTIBYTE_SPLIT_BENCHMARK io/text/multibyte_split.cpp) +################################################################################################### +# - row conversion benchmark --------------------------------------------------------- +ConfigureBench(ROW_CONVERSION_BENCH row_conversion/row_conversion.cpp) + add_custom_target( run_benchmarks DEPENDS CUDF_BENCHMARKS diff --git a/cpp/benchmarks/row_conversion/row_conversion.cpp b/cpp/benchmarks/row_conversion/row_conversion.cpp new file mode 100644 index 00000000000..1e5a1866376 --- /dev/null +++ b/cpp/benchmarks/row_conversion/row_conversion.cpp @@ -0,0 +1,181 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include +#include +#include + +class RowConversion : public cudf::benchmark { +}; + +static void BM_old_to_row(benchmark::State& state) +{ + cudf::size_type const n_rows{(cudf::size_type)state.range(0)}; + auto const table = create_random_table({cudf::type_id::INT8, + cudf::type_id::INT32, + cudf::type_id::INT16, + cudf::type_id::INT64, + cudf::type_id::INT32, + cudf::type_id::BOOL8, + cudf::type_id::UINT16, + cudf::type_id::UINT8, + cudf::type_id::UINT64}, + 212, + row_count{n_rows}); + + cudf::size_type total_bytes = 0; + for (int i = 0; i < table->num_columns(); ++i) { + auto t = table->get_column(i).type(); + total_bytes += cudf::size_of(t); + } + + for (auto _ : state) { + cuda_event_timer raii(state, true, rmm::cuda_stream_default); + + auto rows = cudf::convert_to_rows_fixed_width_optimized(table->view()); + } + + state.SetBytesProcessed(state.iterations() * total_bytes * 2 * table->num_rows()); +} + +static void BM_new_to_row(benchmark::State& state) +{ + cudf::size_type const n_rows{(cudf::size_type)state.range(0)}; + auto const table = create_random_table({cudf::type_id::INT8, + cudf::type_id::INT32, + cudf::type_id::INT16, + cudf::type_id::INT64, + cudf::type_id::INT32, + cudf::type_id::BOOL8, + cudf::type_id::UINT16, + cudf::type_id::UINT8, + cudf::type_id::UINT64}, + 212, + row_count{n_rows}); + + cudf::size_type total_bytes = 0; + for (int i = 0; i < table->num_columns(); ++i) { + auto t = table->get_column(i).type(); + total_bytes += cudf::size_of(t); + } + + for (auto _ : state) { + cuda_event_timer raii(state, true, rmm::cuda_stream_default); + + auto new_rows = cudf::convert_to_rows(table->view()); + } + + state.SetBytesProcessed(state.iterations() * total_bytes * 2 * table->num_rows()); +} + +static void BM_old_from_row(benchmark::State& state) +{ + cudf::size_type const n_rows{(cudf::size_type)state.range(0)}; + auto const table = create_random_table({cudf::type_id::INT8, + cudf::type_id::INT32, + cudf::type_id::INT16, + cudf::type_id::INT64, + cudf::type_id::INT32, + cudf::type_id::BOOL8, + cudf::type_id::UINT16, + cudf::type_id::UINT8, + cudf::type_id::UINT64}, + 256, + row_count{n_rows}); + + std::vector schema; + cudf::size_type total_bytes = 0; + for (int i = 0; i < table->num_columns(); ++i) { + auto t = table->get_column(i).type(); + schema.push_back(t); + total_bytes += cudf::size_of(t); + } + + auto rows = cudf::convert_to_rows_fixed_width_optimized(table->view()); + cudf::lists_column_view const first_list(rows.front()->view()); + + for (auto _ : state) { + cuda_event_timer raii(state, true, rmm::cuda_stream_default); + + auto out = cudf::convert_from_rows_fixed_width_optimized(first_list, schema); + } + + state.SetBytesProcessed(state.iterations() * total_bytes * 2 * table->num_rows()); +} + +static void BM_new_from_row(benchmark::State& state) +{ + cudf::size_type const n_rows{(cudf::size_type)state.range(0)}; + auto const table = create_random_table({cudf::type_id::INT8, + cudf::type_id::INT32, + cudf::type_id::INT16, + cudf::type_id::INT64, + cudf::type_id::INT32, + cudf::type_id::BOOL8, + cudf::type_id::UINT16, + cudf::type_id::UINT8, + cudf::type_id::UINT64}, + 256, + row_count{n_rows}); + + std::vector schema; + cudf::size_type total_bytes = 0; + for (int i = 0; i < table->num_columns(); ++i) { + auto t = table->get_column(i).type(); + schema.push_back(t); + total_bytes += cudf::size_of(t); + } + + auto rows = cudf::convert_to_rows_fixed_width_optimized(table->view()); + cudf::lists_column_view const first_list(rows.front()->view()); + + for (auto _ : state) { + cuda_event_timer raii(state, true, rmm::cuda_stream_default); + + auto out = cudf::convert_from_rows(first_list, schema); + } + + state.SetBytesProcessed(state.iterations() * total_bytes * 2 * table->num_rows()); +} + +#define TO_ROW_CONVERSION_BENCHMARK_DEFINE(name, f) \ + BENCHMARK_DEFINE_F(RowConversion, name) \ + (::benchmark::State & st) { f(st); } \ + BENCHMARK_REGISTER_F(RowConversion, name) \ + ->RangeMultiplier(8) \ + ->Ranges({{1 << 20, 1 << 20}}) \ + ->UseManualTime() \ + ->Unit(benchmark::kMillisecond); + +TO_ROW_CONVERSION_BENCHMARK_DEFINE(old_to_row_conversion, BM_old_to_row) +TO_ROW_CONVERSION_BENCHMARK_DEFINE(new_to_row_conversion, BM_new_to_row) + +#define FROM_ROW_CONVERSION_BENCHMARK_DEFINE(name, f) \ + BENCHMARK_DEFINE_F(RowConversion, name) \ + (::benchmark::State & st) { f(st); } \ + BENCHMARK_REGISTER_F(RowConversion, name) \ + ->RangeMultiplier(8) \ + ->Ranges({{1 << 20, 1 << 20}}) \ + ->UseManualTime() \ + ->Unit(benchmark::kMillisecond); + +FROM_ROW_CONVERSION_BENCHMARK_DEFINE(old_from_row_conversion, BM_old_from_row) +FROM_ROW_CONVERSION_BENCHMARK_DEFINE(new_from_row_conversion, BM_new_from_row) diff --git a/cpp/include/cudf/row_conversion.hpp b/cpp/include/cudf/row_conversion.hpp new file mode 100644 index 00000000000..5d799f4c596 --- /dev/null +++ b/cpp/include/cudf/row_conversion.hpp @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include +#include + +namespace cudf { + +std::vector> convert_to_rows_fixed_width_optimized( + cudf::table_view const& tbl, + // TODO need something for validity + rmm::cuda_stream_view stream = rmm::cuda_stream_default, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + +std::vector> convert_to_rows( + cudf::table_view const& tbl, + // TODO need something for validity + rmm::cuda_stream_view stream = rmm::cuda_stream_default, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + +std::unique_ptr convert_from_rows_fixed_width_optimized( + cudf::lists_column_view const& input, + std::vector const& schema, + rmm::cuda_stream_view stream = rmm::cuda_stream_default, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + +std::unique_ptr convert_from_rows( + cudf::lists_column_view const& input, + std::vector const& schema, + rmm::cuda_stream_view stream = rmm::cuda_stream_default, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + +} // namespace cudf diff --git a/cpp/src/row_conversion/row_conversion.cu b/cpp/src/row_conversion/row_conversion.cu new file mode 100644 index 00000000000..f316d88848e --- /dev/null +++ b/cpp/src/row_conversion/row_conversion.cu @@ -0,0 +1,2301 @@ +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 +#define ASYNC_MEMCPY_SUPPORTED +#endif + +#if !defined(__CUDA_ARCH__) || defined(ASYNC_MEMCPY_SUPPORTED) +#include +#endif // #if !defined(__CUDA_ARCH__) || defined(ASYNC_MEMCPY_SUPPORTED) + +#include +#include +#include +#include +#include +#include +#include + +constexpr auto JCUDF_ROW_ALIGNMENT = 8; + +constexpr auto NUM_TILES_PER_KERNEL_FROM_ROWS = 2; +constexpr auto NUM_TILES_PER_KERNEL_TO_ROWS = 2; +constexpr auto NUM_TILES_PER_KERNEL_LOADED = 2; +constexpr auto NUM_VALIDITY_TILES_PER_KERNEL = 8; +constexpr auto NUM_VALIDITY_TILES_PER_KERNEL_LOADED = 2; + +constexpr auto MAX_BATCH_SIZE = std::numeric_limits::max(); + +using namespace cudf; +using detail::make_device_uvector_async; +using rmm::device_uvector; + +#ifdef ASYNC_MEMCPY_SUPPORTED +using cuda::aligned_size_t; +#else +template +using aligned_size_t = size_t; // Local stub for cuda::aligned_size_t. +#endif // ASYNC_MEMCPY_SUPPORTED + +namespace cudf { +// namespace jni { +namespace detail { + +/************************************************************************ + * This module converts data from row-major to column-major and from column-major + * to row-major. It is a transpose of the data of sorts, but there are a few + * complicating factors. They are spelled out below: + * + * Row Batches: + * The row data has to fit inside a + * cuDF column, which limits it to 2 gigs currently. The calling code attempts + * to keep the data size under 2 gigs, but due to padding this isn't always + * the case, so being able to break this up into multiple columns is necessary. + * Internally, this is referred to as the row batch, which is a group of rows + * that will fit into this 2 gig space requirement. There are typically 1 of + * these batches, but there can be 2. + * + * Async Memcpy: + * The CUDA blocks are using memcpy_async, which allows for the device to + * schedule memcpy operations and then wait on them to complete at a later + * time with a barrier. The recommendation is to double-buffer the work + * so that processing can occur while a copy operation is being completed. + * On Ampere or later hardware there is dedicated hardware to do this copy + * and on pre-Ampere it should generate the same code that a hand-rolled + * loop would generate, so performance should be the same or better than + * a hand-rolled kernel. + * + * Tile Info: + * Each CUDA block will work on NUM_TILES_PER_KERNEL_*_ROWS tile infos + * before exiting. It will have enough shared memory available to load + * NUM_TILES_PER_KERNEL_LOADED tiles at one time. The block will load + * as many tiles as it can fit into shared memory and then wait on the + * first tile to completely load before processing. Processing in this + * case means copying the data from shared memory back out to device + * memory via memcpy_async. This kernel is completely memory bound. + * + * Batch Data: + * This structure contains all the row batches and some book-keeping + * data necessary for the batches such as row numbers for the batches. + * + * Tiles: + * The tile info describes a tile of data to process. In a GPU with + * 48KB of shared memory each tile uses approximately 24KB of memory + * which equates to about 144 bytes in each direction. The tiles are + * kept as square as possible to attempt to coalesce memory operations. + * The taller a tile is the better coalescing of columns, but row + * coalescing suffers. The wider a tile is the better the row coalescing, + * but columns coalescing suffers. The code attempts to produce a square + * tile to balance the coalescing. It starts by figuring out the optimal + * byte length and then adding columns to the data until the tile is too + * large. Since rows are different width with different alignment + * requirements, this isn't typically exact. Once a width is found the + * tiles are generated vertically with that width and height and then + * the process repeats. This means all the tiles will be the same + * height, but will have different widths based on what columns they + * encompass. Tiles in a vertical row will all have the same dimensions. + * + * -------------------------------- + * | 4 5.0f || True 8 3 1 | + * | 3 6.0f || False 3 1 1 | + * | 2 7.0f || True 7 4 1 | + * | 1 8.0f || False 2 5 1 | + * -------------------------------- + * | 0 9.0f || True 6 7 1 | + * ... + ************************************************************************/ + +/** + * @brief The CUDA blocks work on one or more tile_info structs of data. + * This structure defines the workspaces for the blocks. + * + */ +struct tile_info { + int start_col; + int start_row; + int end_col; + int end_row; + int batch_number; + + __device__ inline size_type get_shared_row_size(size_type const* const col_offsets, + size_type const* const col_sizes) const + { + return util::round_up_unsafe(col_offsets[end_col] + col_sizes[end_col] - col_offsets[start_col], + JCUDF_ROW_ALIGNMENT); + } + + __device__ inline size_type num_cols() const { return end_col - start_col + 1; } + + __device__ inline size_type num_rows() const { return end_row - start_row + 1; } +}; + +/** + * @brief Returning rows is done in a byte cudf column. This is limited in size by + * `size_type` and so output is broken into batches of rows that fit inside + * this limit. + * + */ +struct row_batch { + size_type num_bytes; // number of bytes in this batch + size_type row_count; // number of rows in the batch + device_uvector row_offsets; // offsets column of output cudf column +}; + +/** + * @brief Holds information about the batches of data to be processed + * + */ +struct batch_data { + device_uvector batch_row_offsets; // offset column of returned cudf column + device_uvector d_batch_row_boundaries; // row numbers for the start of each batch + std::vector + batch_row_boundaries; // row numbers for the start of each batch: 0, 1500, 2700 + std::vector row_batches; // information about each batch such as byte count +}; + +/** + * @brief builds row size information for tables that contain strings + * + * @param tbl table from which to compute row size information + * @param fixed_width_and_validity_size size of fixed-width and validity data in this table + * @param stream cuda stream on which to operate + * @return device vector of size_types of the row sizes of the table + */ +rmm::device_uvector build_string_row_sizes(table_view const& tbl, + size_type fixed_width_and_validity_size, + rmm::cuda_stream_view stream) +{ + auto const num_rows = tbl.num_rows(); + rmm::device_uvector d_row_sizes(num_rows, stream); + thrust::uninitialized_fill(rmm::exec_policy(stream), d_row_sizes.begin(), d_row_sizes.end(), 0); + + auto d_offsets_iterators = [&]() { + std::vector offsets_iterators; + auto offsets_iter = thrust::make_transform_iterator( + tbl.begin(), [](auto const& col) -> strings_column_view::offset_iterator { + if (!is_fixed_width(col.type())) { + CUDF_EXPECTS(col.type().id() == type_id::STRING, "only string columns are supported!"); + return strings_column_view(col).offsets_begin(); + } else { + return nullptr; + } + }); + std::copy_if(offsets_iter, + offsets_iter + tbl.num_columns(), + std::back_inserter(offsets_iterators), + [](auto const& offset_ptr) { return offset_ptr != nullptr; }); + return make_device_uvector_async(offsets_iterators, stream); + }(); + + auto const num_columns = static_cast(d_offsets_iterators.size()); + + thrust::for_each(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_columns * num_rows), + [d_offsets_iterators = d_offsets_iterators.data(), + num_columns, + num_rows, + d_row_sizes = d_row_sizes.data()] __device__(auto element_idx) { + auto const row = element_idx % num_rows; + auto const col = element_idx / num_rows; + auto const val = + d_offsets_iterators[col][row + 1] - d_offsets_iterators[col][row]; + atomicAdd(&d_row_sizes[row], val); + }); + + // transform the row sizes to include fixed width size and alignment + thrust::transform(rmm::exec_policy(stream), + d_row_sizes.begin(), + d_row_sizes.end(), + d_row_sizes.begin(), + [fixed_width_and_validity_size] __device__(auto row_size) { + return util::round_up_unsafe(fixed_width_and_validity_size + row_size, + JCUDF_ROW_ALIGNMENT); + }); + + return d_row_sizes; +} + +/** + * @brief functor to return the offset of a row in a table with string columns + * + */ +struct string_row_offset_functor { + string_row_offset_functor(device_span _d_row_offsets) + : d_row_offsets(_d_row_offsets){}; + + __device__ inline size_type operator()(int row_number, int) const + { + return d_row_offsets[row_number]; + } + + device_span d_row_offsets; +}; + +/** + * @brief functor to return the offset of a row in a table with only fixed-width columns + * + */ +struct fixed_width_row_offset_functor { + fixed_width_row_offset_functor(size_type fixed_width_only_row_size) + : _fixed_width_only_row_size(fixed_width_only_row_size){}; + + __device__ inline size_type operator()(int row_number, int tile_row_start) const + { + return (row_number - tile_row_start) * _fixed_width_only_row_size; + } + + size_type _fixed_width_only_row_size; +}; + +/** + * @brief Copies data from row-based JCUDF format to column-based cudf format. + * + * This optimized version of the conversion is faster for fixed-width tables + * that do not have more than 100 columns. + * + * @param num_rows number of rows in the incoming table + * @param num_columns number of columns in the incoming table + * @param row_size length in bytes of each row + * @param input_offset_in_row offset to each row of data + * @param num_bytes total number of bytes in the incoming data + * @param output_data array of pointers to the output data + * @param output_nm array of pointers to the output null masks + * @param input_data pointing to the incoming row data + */ +__global__ void copy_from_rows_fixed_width_optimized(const size_type num_rows, + const size_type num_columns, + const size_type row_size, + const size_type* input_offset_in_row, + const size_type* num_bytes, + int8_t** output_data, + bitmask_type** output_nm, + const int8_t* input_data) +{ + // We are going to copy the data in two passes. + // The first pass copies a chunk of data into shared memory. + // The second pass copies that chunk from shared memory out to the final location. + + // Because shared memory is limited we copy a subset of the rows at a time. + // For simplicity we will refer to this as a row_group + + // In practice we have found writing more than 4 columns of data per thread + // results in performance loss. As such we are using a 2 dimensional + // kernel in terms of threads, but not in terms of blocks. Columns are + // controlled by the y dimension (there is no y dimension in blocks). Rows + // are controlled by the x dimension (there are multiple blocks in the x + // dimension). + + size_type const rows_per_group = blockDim.x; + size_type const row_group_start = blockIdx.x; + size_type const row_group_stride = gridDim.x; + size_type const row_group_end = (num_rows + rows_per_group - 1) / rows_per_group + 1; + + extern __shared__ int8_t shared_data[]; + + // Because we are copying fixed width only data and we stride the rows + // this thread will always start copying from shared data in the same place + int8_t* row_tmp = &shared_data[row_size * threadIdx.x]; + int8_t* row_vld_tmp = &row_tmp[input_offset_in_row[num_columns - 1] + num_bytes[num_columns - 1]]; + + for (auto row_group_index = row_group_start; row_group_index < row_group_end; + row_group_index += row_group_stride) { + // Step 1: Copy the data into shared memory + // We know row_size is always aligned with and a multiple of int64_t; + int64_t* long_shared = reinterpret_cast(shared_data); + int64_t const* long_input = reinterpret_cast(input_data); + + auto const shared_output_index = threadIdx.x + (threadIdx.y * blockDim.x); + auto const shared_output_stride = blockDim.x * blockDim.y; + auto const row_index_end = std::min(num_rows, ((row_group_index + 1) * rows_per_group)); + auto const num_rows_in_group = row_index_end - (row_group_index * rows_per_group); + auto const shared_length = row_size * num_rows_in_group; + + size_type const shared_output_end = shared_length / sizeof(int64_t); + + auto const start_input_index = (row_size * row_group_index * rows_per_group) / sizeof(int64_t); + + for (size_type shared_index = shared_output_index; shared_index < shared_output_end; + shared_index += shared_output_stride) { + long_shared[shared_index] = long_input[start_input_index + shared_index]; + } + // Wait for all of the data to be in shared memory + __syncthreads(); + + // Step 2 copy the data back out + + // Within the row group there should be 1 thread for each row. This is a + // requirement for launching the kernel + auto const row_index = (row_group_index * rows_per_group) + threadIdx.x; + // But we might not use all of the threads if the number of rows does not go + // evenly into the thread count. We don't want those threads to exit yet + // because we may need them to copy data in for the next row group. + uint32_t active_mask = __ballot_sync(0xffffffff, row_index < num_rows); + if (row_index < num_rows) { + auto const col_index_start = threadIdx.y; + auto const col_index_stride = blockDim.y; + for (auto col_index = col_index_start; col_index < num_columns; + col_index += col_index_stride) { + auto const col_size = num_bytes[col_index]; + int8_t const* col_tmp = &(row_tmp[input_offset_in_row[col_index]]); + int8_t* col_output = output_data[col_index]; + switch (col_size) { + case 1: { + col_output[row_index] = *col_tmp; + break; + } + case 2: { + int16_t* short_col_output = reinterpret_cast(col_output); + short_col_output[row_index] = *reinterpret_cast(col_tmp); + break; + } + case 4: { + int32_t* int_col_output = reinterpret_cast(col_output); + int_col_output[row_index] = *reinterpret_cast(col_tmp); + break; + } + case 8: { + int64_t* long_col_output = reinterpret_cast(col_output); + long_col_output[row_index] = *reinterpret_cast(col_tmp); + break; + } + default: { + auto const output_offset = col_size * row_index; + // TODO this should just not be supported for fixed width columns, but just in case... + for (auto b = 0; b < col_size; b++) { + col_output[b + output_offset] = col_tmp[b]; + } + break; + } + } + + bitmask_type* nm = output_nm[col_index]; + int8_t* valid_byte = &row_vld_tmp[col_index / 8]; + size_type byte_bit_offset = col_index % 8; + int predicate = *valid_byte & (1 << byte_bit_offset); + uint32_t bitmask = __ballot_sync(active_mask, predicate); + if (row_index % 32 == 0) { nm[word_index(row_index)] = bitmask; } + } // end column loop + } // end row copy + // wait for the row_group to be totally copied before starting on the next row group + __syncthreads(); + } +} + +__global__ void copy_to_rows_fixed_width_optimized(const size_type start_row, + const size_type num_rows, + const size_type num_columns, + const size_type row_size, + const size_type* output_offset_in_row, + const size_type* num_bytes, + const int8_t** input_data, + const bitmask_type** input_nm, + int8_t* output_data) +{ + // We are going to copy the data in two passes. + // The first pass copies a chunk of data into shared memory. + // The second pass copies that chunk from shared memory out to the final location. + + // Because shared memory is limited we copy a subset of the rows at a time. + // We do not support copying a subset of the columns in a row yet, so we don't + // currently support a row that is wider than shared memory. + // For simplicity we will refer to this as a row_group + + // In practice we have found reading more than 4 columns of data per thread + // results in performance loss. As such we are using a 2 dimensional + // kernel in terms of threads, but not in terms of blocks. Columns are + // controlled by the y dimension (there is no y dimension in blocks). Rows + // are controlled by the x dimension (there are multiple blocks in the x + // dimension). + + size_type rows_per_group = blockDim.x; + size_type row_group_start = blockIdx.x; + size_type row_group_stride = gridDim.x; + size_type row_group_end = (num_rows + rows_per_group - 1) / rows_per_group + 1; + + extern __shared__ int8_t shared_data[]; + + // Because we are copying fixed width only data and we stride the rows + // this thread will always start copying to shared data in the same place + int8_t* row_tmp = &shared_data[row_size * threadIdx.x]; + int8_t* row_vld_tmp = + &row_tmp[output_offset_in_row[num_columns - 1] + num_bytes[num_columns - 1]]; + + for (size_type row_group_index = row_group_start; row_group_index < row_group_end; + row_group_index += row_group_stride) { + // Within the row group there should be 1 thread for each row. This is a + // requirement for launching the kernel + size_type row_index = start_row + (row_group_index * rows_per_group) + threadIdx.x; + // But we might not use all of the threads if the number of rows does not go + // evenly into the thread count. We don't want those threads to exit yet + // because we may need them to copy data back out. + if (row_index < (start_row + num_rows)) { + size_type col_index_start = threadIdx.y; + size_type col_index_stride = blockDim.y; + for (size_type col_index = col_index_start; col_index < num_columns; + col_index += col_index_stride) { + size_type col_size = num_bytes[col_index]; + int8_t* col_tmp = &(row_tmp[output_offset_in_row[col_index]]); + const int8_t* col_input = input_data[col_index]; + switch (col_size) { + case 1: { + *col_tmp = col_input[row_index]; + break; + } + case 2: { + const int16_t* short_col_input = reinterpret_cast(col_input); + *reinterpret_cast(col_tmp) = short_col_input[row_index]; + break; + } + case 4: { + const int32_t* int_col_input = reinterpret_cast(col_input); + *reinterpret_cast(col_tmp) = int_col_input[row_index]; + break; + } + case 8: { + const int64_t* long_col_input = reinterpret_cast(col_input); + *reinterpret_cast(col_tmp) = long_col_input[row_index]; + break; + } + default: { + size_type input_offset = col_size * row_index; + // TODO this should just not be supported for fixed width columns, but just in case... + for (size_type b = 0; b < col_size; b++) { + col_tmp[b] = col_input[b + input_offset]; + } + break; + } + } + // atomicOr only works on 32 bit or 64 bit aligned values, and not byte aligned + // so we have to rewrite the addresses to make sure that it is 4 byte aligned + int8_t* valid_byte = &row_vld_tmp[col_index / 8]; + size_type byte_bit_offset = col_index % 8; + uint64_t fixup_bytes = reinterpret_cast(valid_byte) % 4; + int32_t* valid_int = reinterpret_cast(valid_byte - fixup_bytes); + size_type int_bit_offset = byte_bit_offset + (fixup_bytes * 8); + // Now copy validity for the column + if (input_nm[col_index]) { + if (bit_is_set(input_nm[col_index], row_index)) { + atomicOr_block(valid_int, 1 << int_bit_offset); + } else { + atomicAnd_block(valid_int, ~(1 << int_bit_offset)); + } + } else { + // It is valid so just set the bit + atomicOr_block(valid_int, 1 << int_bit_offset); + } + } // end column loop + } // end row copy + // wait for the row_group to be totally copied into shared memory + __syncthreads(); + + // Step 2: Copy the data back out + // We know row_size is always aligned with and a multiple of int64_t; + int64_t* long_shared = reinterpret_cast(shared_data); + int64_t* long_output = reinterpret_cast(output_data); + + size_type shared_input_index = threadIdx.x + (threadIdx.y * blockDim.x); + size_type shared_input_stride = blockDim.x * blockDim.y; + size_type row_index_end = ((row_group_index + 1) * rows_per_group); + if (row_index_end > num_rows) { row_index_end = num_rows; } + size_type num_rows_in_group = row_index_end - (row_group_index * rows_per_group); + size_type shared_length = row_size * num_rows_in_group; + + size_type shared_input_end = shared_length / sizeof(int64_t); + + size_type start_output_index = (row_size * row_group_index * rows_per_group) / sizeof(int64_t); + + for (size_type shared_index = shared_input_index; shared_index < shared_input_end; + shared_index += shared_input_stride) { + long_output[start_output_index + shared_index] = long_shared[shared_index]; + } + __syncthreads(); + // Go for the next round + } +} + +#ifdef ASYNC_MEMCPY_SUPPORTED +#define MEMCPY(dst, src, size, barrier) cuda::memcpy_async(dst, src, size, barrier) +#else +#define MEMCPY(dst, src, size, barrier) memcpy(dst, src, size) +#endif // ASYNC_MEMCPY_SUPPORTED + +/** + * @brief copy data from cudf columns into JCUDF format, which is row-based + * + * @tparam RowOffsetIter iterator that gives the size of a specific row of the table. + * @param num_rows total number of rows in the table + * @param num_columns total number of columns in the table + * @param shmem_used_per_tile shared memory amount each `tile_info` is using + * @param tile_infos span of `tile_info` structs the define the work + * @param input_data pointer to raw table data + * @param col_sizes array of sizes for each element in a column - one per column + * @param col_offsets offset into input data row for each column's start + * @param row_offsets offset to a specific row in the output data + * @param batch_row_boundaries row numbers for batch starts + * @param output_data pointer to output data + * + */ +template +__global__ void copy_to_rows(const size_type num_rows, + const size_type num_columns, + const size_type shmem_used_per_tile, + device_span tile_infos, + const int8_t** input_data, + const size_type* col_sizes, + const size_type* col_offsets, + RowOffsetIter row_offsets, + size_type const* batch_row_boundaries, + int8_t** output_data) +{ + // We are going to copy the data in two passes. + // The first pass copies a chunk of data into shared memory. + // The second pass copies that chunk from shared memory out to the final location. + + // Because shared memory is limited we copy a subset of the rows at a time. + // This has been broken up for us in the tile_info struct, so we don't have + // any calculation to do here, but it is important to note. + + constexpr unsigned stages_count = NUM_TILES_PER_KERNEL_LOADED; + auto group = cooperative_groups::this_thread_block(); + extern __shared__ int8_t shared_data[]; + int8_t* shared[stages_count] = {shared_data, shared_data + shmem_used_per_tile}; + +#ifdef ASYNC_MEMCPY_SUPPORTED + __shared__ cuda::barrier tile_barrier[NUM_TILES_PER_KERNEL_LOADED]; + if (group.thread_rank() == 0) { + for (int i = 0; i < NUM_TILES_PER_KERNEL_LOADED; ++i) { + init(&tile_barrier[i], group.size()); + } + } + group.sync(); +#endif // ASYNC_MEMCPY_SUPPORTED + + auto const tiles_remaining = + std::min(static_cast(tile_infos.size()) - blockIdx.x * NUM_TILES_PER_KERNEL_TO_ROWS, + static_cast(NUM_TILES_PER_KERNEL_TO_ROWS)); + + size_t fetch_index; //< tile we are currently fetching + size_t processing_index; //< tile we are currently processing + for (processing_index = fetch_index = 0; processing_index < tiles_remaining; ++processing_index) { + // Fetch ahead up to NUM_TILES_PER_KERNEL_LOADED + for (; fetch_index < tiles_remaining && fetch_index < (processing_index + stages_count); + ++fetch_index) { + auto const fetch_tile = tile_infos[blockIdx.x * NUM_TILES_PER_KERNEL_TO_ROWS + fetch_index]; + auto const num_fetch_cols = fetch_tile.num_cols(); + auto const num_fetch_rows = fetch_tile.num_rows(); + auto const num_elements_in_tile = num_fetch_cols * num_fetch_rows; + auto const fetch_tile_row_size = fetch_tile.get_shared_row_size(col_offsets, col_sizes); + auto const starting_column_offset = col_offsets[fetch_tile.start_col]; +#ifdef ASYNC_MEMCPY_SUPPORTED + auto& fetch_barrier = tile_barrier[fetch_index % NUM_TILES_PER_KERNEL_LOADED]; + // wait for the last use of the memory to be completed + if (fetch_index >= NUM_TILES_PER_KERNEL_LOADED) { fetch_barrier.arrive_and_wait(); } +#else + // wait for the last use of the memory to be completed + if (fetch_index >= NUM_TILES_PER_KERNEL_LOADED) { group.sync(); } +#endif // ASYNC_MEMCPY_SUPPORTED + + // to do the copy we need to do n column copies followed by m element copies OR + // we have to do m element copies followed by r row copies. When going from column + // to row it is much easier to copy by elements first otherwise we would need a running + // total of the column sizes for our tile, which isn't readily available. This makes it + // more appealing to copy element-wise from input data into shared matching the end layout + // and do row-based memcopies out. + + auto const shared_buffer_base = shared[fetch_index % stages_count]; + for (auto el = static_cast(threadIdx.x); el < num_elements_in_tile; el += blockDim.x) { + auto const relative_col = el / num_fetch_rows; + auto const relative_row = el % num_fetch_rows; + auto const absolute_col = relative_col + fetch_tile.start_col; + if (input_data[absolute_col] == nullptr) { + // variable-width data + continue; + } + auto const absolute_row = relative_row + fetch_tile.start_row; + auto const col_size = col_sizes[absolute_col]; + auto const col_offset = col_offsets[absolute_col]; + auto const relative_col_offset = col_offset - starting_column_offset; + + auto const shared_offset = relative_row * fetch_tile_row_size + relative_col_offset; + auto const input_src = input_data[absolute_col] + col_size * absolute_row; + + // copy the element from global memory + switch (col_size) { + case 2: + MEMCPY(&shared_buffer_base[shared_offset], + input_src, + aligned_size_t<2>(col_size), + fetch_barrier); + break; + case 4: + MEMCPY(&shared_buffer_base[shared_offset], + input_src, + aligned_size_t<4>(col_size), + fetch_barrier); + break; + case 8: + MEMCPY(&shared_buffer_base[shared_offset], + input_src, + aligned_size_t<8>(col_size), + fetch_barrier); + break; + default: + MEMCPY(&shared_buffer_base[shared_offset], input_src, col_size, fetch_barrier); + break; + } + } + } + +#ifdef ASYNC_MEMCPY_SUPPORTED + auto& processing_barrier = tile_barrier[processing_index % NUM_TILES_PER_KERNEL_LOADED]; + processing_barrier.arrive_and_wait(); +#else + group.sync(); +#endif // ASYNC_MEMCPY_SUPPORTED + + auto const tile = tile_infos[blockIdx.x * NUM_TILES_PER_KERNEL_TO_ROWS + processing_index]; + auto const tile_row_size = tile.get_shared_row_size(col_offsets, col_sizes); + auto const column_offset = col_offsets[tile.start_col]; + auto const tile_output_buffer = output_data[tile.batch_number]; + auto const row_batch_start = + tile.batch_number == 0 ? 0 : batch_row_boundaries[tile.batch_number]; + + // copy entire row 8 bytes at a time + constexpr auto bytes_per_chunk = 8; + auto const chunks_per_row = util::div_rounding_up_unsafe(tile_row_size, bytes_per_chunk); + auto const total_chunks = chunks_per_row * tile.num_rows(); + + for (auto i = threadIdx.x; i < total_chunks; i += blockDim.x) { + // determine source address of my chunk + auto const relative_row = i / chunks_per_row; + auto const relative_chunk_offset = (i % chunks_per_row) * bytes_per_chunk; + auto const output_dest = tile_output_buffer + + row_offsets(relative_row + tile.start_row, row_batch_start) + + column_offset + relative_chunk_offset; + auto const input_src = &shared[processing_index % stages_count] + [tile_row_size * relative_row + relative_chunk_offset]; + + MEMCPY(output_dest, + input_src, + aligned_size_t{bytes_per_chunk}, + processing_barrier); + } + } + +#ifdef ASYNC_MEMCPY_SUPPORTED + // wait on the last copies to complete + for (uint i = 0; i < std::min(stages_count, tiles_remaining); ++i) { + tile_barrier[i].arrive_and_wait(); + } +#else + group.sync(); +#endif // ASYNC_MEMCPY_SUPPORTED +} + +/** + * @brief copy data from row-based format to cudf columns + * + * @tparam RowOffsetIter iterator that gives the size of a specific row of the table. + * @param num_rows total number of rows in the table + * @param num_columns total number of columns in the table + * @param shmem_used_per_tile amount of shared memory that is used by a tile + * @param row_offsets offset to a specific row in the output data + * @param batch_row_boundaries row numbers for batch starts + * @param output_data pointer to output data, partitioned by data size + * @param validity_offsets offset into input data row for validity data + * @param tile_infos information about the tiles of work + * @param input_nm pointer to input data + * + */ +template +__global__ void copy_validity_to_rows(const size_type num_rows, + const size_type num_columns, + const size_type shmem_used_per_tile, + RowOffsetIter row_offsets, + size_type const* batch_row_boundaries, + int8_t** output_data, + const size_type validity_offset, + device_span tile_infos, + const bitmask_type** input_nm) +{ + extern __shared__ int8_t shared_data[]; + int8_t* shared_tiles[NUM_VALIDITY_TILES_PER_KERNEL_LOADED] = { + shared_data, shared_data + shmem_used_per_tile / 2}; + + using cudf::detail::warp_size; + + // each thread of warp reads a single int32 of validity - so we read 128 bytes + // then ballot_sync the bits and write the result to shmem + // after we fill shared mem memcpy it out in a blob. + // probably need knobs for number of rows vs columns to balance read/write + auto group = cooperative_groups::this_thread_block(); + + int const tiles_remaining = + std::min(static_cast(tile_infos.size()) - blockIdx.x * NUM_VALIDITY_TILES_PER_KERNEL, + static_cast(NUM_VALIDITY_TILES_PER_KERNEL)); + +#ifdef ASYNC_MEMCPY_SUPPORTED + // Initialize cuda barriers for each tile. + __shared__ cuda::barrier + shared_tile_barriers[NUM_VALIDITY_TILES_PER_KERNEL_LOADED]; + if (group.thread_rank() == 0) { + for (int i = 0; i < NUM_VALIDITY_TILES_PER_KERNEL_LOADED; ++i) { + init(&shared_tile_barriers[i], group.size()); + } + } + group.sync(); +#endif // ASYNC_MEMCPY_SUPPORTED + + for (int validity_tile = 0; validity_tile < tiles_remaining; ++validity_tile) { + if (validity_tile >= NUM_VALIDITY_TILES_PER_KERNEL_LOADED) { +#ifdef ASYNC_MEMCPY_SUPPORTED + shared_tile_barriers[validity_tile % NUM_VALIDITY_TILES_PER_KERNEL_LOADED].arrive_and_wait(); +#else + group.sync(); +#endif // ASYNC_MEMCPY_SUPPORTED + } + int8_t* this_shared_tile = shared_tiles[validity_tile % NUM_VALIDITY_TILES_PER_KERNEL_LOADED]; + auto tile = tile_infos[blockIdx.x * NUM_VALIDITY_TILES_PER_KERNEL + validity_tile]; + + auto const num_tile_cols = tile.num_cols(); + auto const num_tile_rows = tile.num_rows(); + + auto const num_sections_x = util::div_rounding_up_unsafe(num_tile_cols, 32); + auto const num_sections_y = util::div_rounding_up_unsafe(num_tile_rows, 32); + auto const validity_data_row_length = util::round_up_unsafe( + util::div_rounding_up_unsafe(num_tile_cols, CHAR_BIT), JCUDF_ROW_ALIGNMENT); + auto const total_sections = num_sections_x * num_sections_y; + + int const warp_id = threadIdx.x / warp_size; + int const lane_id = threadIdx.x % warp_size; + auto const warps_per_tile = std::max(1u, blockDim.x / warp_size); + + // the tile is divided into sections. A warp operates on a section at a time. + for (int my_section_idx = warp_id; my_section_idx < total_sections; + my_section_idx += warps_per_tile) { + // convert to rows and cols + auto const section_x = my_section_idx % num_sections_x; + auto const section_y = my_section_idx / num_sections_x; + auto const relative_col = section_x * 32 + lane_id; + auto const relative_row = section_y * 32; + auto const absolute_col = relative_col + tile.start_col; + auto const absolute_row = relative_row + tile.start_row; + auto const participation_mask = __ballot_sync(0xFFFFFFFF, absolute_col < num_columns); + + if (absolute_col < num_columns) { + auto my_data = input_nm[absolute_col] != nullptr ? input_nm[absolute_col][absolute_row / 32] + : std::numeric_limits::max(); + + // every thread that is participating in the warp has 4 bytes, but it's column-based + // data and we need it in row-based. So we shuffle the bits around with ballot_sync to + // make the bytes we actually write. + bitmask_type dw_mask = 1; + for (int i = 0; i < 32 && relative_row + i < num_rows; ++i, dw_mask <<= 1) { + auto validity_data = __ballot_sync(participation_mask, my_data & dw_mask); + // lead thread in each warp writes data + auto const validity_write_offset = + validity_data_row_length * (relative_row + i) + relative_col / CHAR_BIT; + if (threadIdx.x % warp_size == 0) { + *reinterpret_cast(&this_shared_tile[validity_write_offset]) = validity_data; + } + } + } + } + + // make sure entire tile has finished copy + group.sync(); + + auto const output_data_base = + output_data[tile.batch_number] + validity_offset + tile.start_col / CHAR_BIT; + + // now async memcpy the shared memory out to the final destination 4 bytes at a time since we do + // 32-row chunks + constexpr auto bytes_per_chunk = 8; + auto const row_bytes = util::div_rounding_up_unsafe(num_tile_cols, CHAR_BIT); + auto const chunks_per_row = util::div_rounding_up_unsafe(row_bytes, bytes_per_chunk); + auto const total_chunks = chunks_per_row * tile.num_rows(); +#ifdef ASYNC_MEMCPY_SUPPORTED + auto& processing_barrier = + shared_tile_barriers[validity_tile % NUM_VALIDITY_TILES_PER_KERNEL_LOADED]; +#endif // ASYNC_MEMCPY_SUPPORTED + auto const tail_bytes = row_bytes % bytes_per_chunk; + auto const row_batch_start = + tile.batch_number == 0 ? 0 : batch_row_boundaries[tile.batch_number]; + + for (auto i = threadIdx.x; i < total_chunks; i += blockDim.x) { + // determine source address of my chunk + auto const relative_row = i / chunks_per_row; + auto const col_chunk = i % chunks_per_row; + auto const relative_chunk_offset = col_chunk * bytes_per_chunk; + auto const output_dest = output_data_base + + row_offsets(relative_row + tile.start_row, row_batch_start) + + relative_chunk_offset; + auto const input_src = + &this_shared_tile[validity_data_row_length * relative_row + relative_chunk_offset]; + + if (tail_bytes > 0 && col_chunk == chunks_per_row - 1) + MEMCPY(output_dest, input_src, tail_bytes, processing_barrier); + else + MEMCPY(output_dest, + input_src, + aligned_size_t(bytes_per_chunk), + processing_barrier); + } + } + +#ifdef ASYNC_MEMCPY_SUPPORTED + // wait for last tiles of data to arrive + for (int validity_tile = 0; + validity_tile < tiles_remaining % NUM_VALIDITY_TILES_PER_KERNEL_LOADED; + ++validity_tile) { + shared_tile_barriers[validity_tile].arrive_and_wait(); + } +#else + group.sync(); +#endif // ASYNC_MEMCPY_SUPPORTED +} + +/** + * @brief copy data from row-based format to cudf columns + * + * @tparam RowOffsetIter iterator that gives the size of a specific row of the table. + * @param num_rows total number of rows in the table + * @param num_columns total number of columns in the table + * @param shmem_used_per_tile amount of shared memory that is used by a tile + * @param row_offsets offset to a specific row in the input data + * @param batch_row_boundaries row numbers for batch starts + * @param output_data pointers to column data + * @param col_sizes array of sizes for each element in a column - one per column + * @param col_offsets offset into input data row for each column's start + * @param tile_infos information about the tiles of work + * @param input_data pointer to input data + * + */ +template +__global__ void copy_from_rows(const size_type num_rows, + const size_type num_columns, + const size_type shmem_used_per_tile, + RowOffsetIter row_offsets, + size_type const* batch_row_boundaries, + int8_t** output_data, + const size_type* col_sizes, + const size_type* col_offsets, + device_span tile_infos, + const int8_t* input_data) +{ + // We are going to copy the data in two passes. + // The first pass copies a chunk of data into shared memory. + // The second pass copies that chunk from shared memory out to the final location. + + // Because shared memory is limited we copy a subset of the rows at a time. + // This has been broken up for us in the tile_info struct, so we don't have + // any calculation to do here, but it is important to note. + + // to speed up some of the random access memory we do, we copy col_sizes and col_offsets + // to shared memory for each of the tiles that we work on + + constexpr unsigned stages_count = NUM_TILES_PER_KERNEL_LOADED; + auto group = cooperative_groups::this_thread_block(); + extern __shared__ int8_t shared_data[]; + int8_t* shared[stages_count] = {shared_data, shared_data + shmem_used_per_tile}; + +#ifdef ASYNC_MEMCPY_SUPPORTED + // Initialize cuda barriers for each tile. + __shared__ cuda::barrier tile_barrier[NUM_TILES_PER_KERNEL_LOADED]; + if (group.thread_rank() == 0) { + for (int i = 0; i < NUM_TILES_PER_KERNEL_LOADED; ++i) { + init(&tile_barrier[i], group.size()); + } + } + group.sync(); +#endif // ASYNC_MEMCPY_SUPPORTED + + auto tiles_remaining = + std::min(static_cast(tile_infos.size()) - blockIdx.x * NUM_TILES_PER_KERNEL_FROM_ROWS, + static_cast(NUM_TILES_PER_KERNEL_FROM_ROWS)); + + size_t fetch_index; + size_t processing_index; + for (processing_index = fetch_index = 0; processing_index < tiles_remaining; ++processing_index) { + // Fetch ahead up to stages_count groups + for (; fetch_index < static_cast(tiles_remaining) && + fetch_index < (processing_index + stages_count); + ++fetch_index) { + auto const fetch_tile = tile_infos[blockIdx.x * NUM_TILES_PER_KERNEL_FROM_ROWS + fetch_index]; + auto const fetch_tile_start_row = fetch_tile.start_row; + auto const starting_col_offset = col_offsets[fetch_tile.start_col]; + auto const fetch_tile_row_size = fetch_tile.get_shared_row_size(col_offsets, col_sizes); + auto const row_batch_start = + fetch_tile.batch_number == 0 ? 0 : batch_row_boundaries[fetch_tile.batch_number]; +#ifdef ASYNC_MEMCPY_SUPPORTED + auto& fetch_barrier = tile_barrier[fetch_index % NUM_TILES_PER_KERNEL_LOADED]; + // if we have fetched all buffers, we need to wait for processing + // to complete on them before we can use them again + if (fetch_index > NUM_TILES_PER_KERNEL_LOADED) { fetch_barrier.arrive_and_wait(); } +#else + if (fetch_index >= NUM_TILES_PER_KERNEL_LOADED) { group.sync(); } +#endif // ASYNC_MEMCPY_SUPPORTED + + for (auto row = fetch_tile_start_row + static_cast(threadIdx.x); + row <= fetch_tile.end_row; + row += blockDim.x) { + auto shared_offset = (row - fetch_tile_start_row) * fetch_tile_row_size; + // copy the data + MEMCPY(&shared[fetch_index % stages_count][shared_offset], + &input_data[row_offsets(row, row_batch_start) + starting_col_offset], + fetch_tile_row_size, + fetch_barrier); + } + } + +#ifdef ASYNC_MEMCPY_SUPPORTED + auto& processing_barrier = tile_barrier[processing_index % NUM_TILES_PER_KERNEL_LOADED]; + // ensure our data is ready + processing_barrier.arrive_and_wait(); +#else + group.sync(); +#endif // ASYNC_MEMCPY_SUPPORTED + + auto const tile = tile_infos[blockIdx.x * NUM_TILES_PER_KERNEL_FROM_ROWS + processing_index]; + auto const rows_in_tile = tile.num_rows(); + auto const cols_in_tile = tile.num_cols(); + auto const tile_row_size = tile.get_shared_row_size(col_offsets, col_sizes); + + // now we copy from shared memory to final destination. + // the data is laid out in rows in shared memory, so the reads + // for a column will be "vertical". Because of this and the different + // sizes for each column, this portion is handled on row/column basis. + // to prevent each thread working on a single row and also to ensure + // that all threads can do work in the case of more threads than rows, + // we do a global index instead of a double for loop with col/row. + for (int index = threadIdx.x; index < rows_in_tile * cols_in_tile; index += blockDim.x) { + auto const relative_col = index % cols_in_tile; + auto const relative_row = index / cols_in_tile; + auto const absolute_col = relative_col + tile.start_col; + auto const absolute_row = relative_row + tile.start_row; + + auto const shared_memory_row_offset = tile_row_size * relative_row; + auto const shared_memory_offset = + col_offsets[absolute_col] - col_offsets[tile.start_col] + shared_memory_row_offset; + auto const column_size = col_sizes[absolute_col]; + + int8_t* shmem_src = &shared[processing_index % stages_count][shared_memory_offset]; + int8_t* dst = &output_data[absolute_col][absolute_row * column_size]; + + MEMCPY(dst, shmem_src, column_size, processing_barrier); + } + group.sync(); + } + +#ifdef ASYNC_MEMCPY_SUPPORTED + // wait on the last copies to complete + for (uint i = 0; i < std::min(stages_count, tiles_remaining); ++i) { + tile_barrier[i].arrive_and_wait(); + } +#else + group.sync(); +#endif // ASYNC_MEMCPY_SUPPORTED +} + +/** + * @brief copy data from row-based format to cudf columns + * + * @tparam RowOffsetIter iterator that gives the size of a specific row of the table. + * @param num_rows total number of rows in the table + * @param num_columns total number of columns in the table + * @param shmem_used_per_tile amount of shared memory that is used by a tile + * @param row_offsets offset to a specific row in the input data + * @param batch_row_boundaries row numbers for batch starts + * @param output_nm pointers to null masks for columns + * @param validity_offsets offset into input data row for validity data + * @param tile_infos information about the tiles of work + * @param input_data pointer to input data + * + */ +template +__global__ void copy_validity_from_rows(const size_type num_rows, + const size_type num_columns, + const size_type shmem_used_per_tile, + RowOffsetIter row_offsets, + size_type const* batch_row_boundaries, + bitmask_type** output_nm, + const size_type validity_offset, + device_span tile_infos, + const int8_t* input_data) +{ + extern __shared__ int8_t shared_data[]; + int8_t* shared_tiles[NUM_VALIDITY_TILES_PER_KERNEL_LOADED] = { + shared_data, shared_data + shmem_used_per_tile / 2}; + + using cudf::detail::warp_size; + + // each thread of warp reads a single byte of validity - so we read 32 bytes + // then ballot_sync the bits and write the result to shmem + // after we fill shared mem memcpy it out in a blob. + // probably need knobs for number of rows vs columns to balance read/write + auto group = cooperative_groups::this_thread_block(); + + int const tiles_remaining = + std::min(static_cast(tile_infos.size()) - blockIdx.x * NUM_VALIDITY_TILES_PER_KERNEL, + static_cast(NUM_VALIDITY_TILES_PER_KERNEL)); + +#ifdef ASYNC_MEMCPY_SUPPORTED + // Initialize cuda barriers for each tile. + __shared__ cuda::barrier + shared_tile_barriers[NUM_VALIDITY_TILES_PER_KERNEL_LOADED]; + if (group.thread_rank() == 0) { + for (int i = 0; i < NUM_VALIDITY_TILES_PER_KERNEL_LOADED; ++i) { + init(&shared_tile_barriers[i], group.size()); + } + } + group.sync(); +#endif // ASYNC_MEMCPY_SUPPORTED + + for (int validity_tile = 0; validity_tile < tiles_remaining; ++validity_tile) { + if (validity_tile >= NUM_VALIDITY_TILES_PER_KERNEL_LOADED) { +#ifdef ASYNC_MEMCPY_SUPPORTED + auto const validity_index = validity_tile % NUM_VALIDITY_TILES_PER_KERNEL_LOADED; + shared_tile_barriers[validity_index].arrive_and_wait(); +#else + group.sync(); +#endif // ASYNC_MEMCPY_SUPPORTED + } + int8_t* this_shared_tile = shared_tiles[validity_tile % 2]; + auto const tile = tile_infos[blockIdx.x * NUM_VALIDITY_TILES_PER_KERNEL + validity_tile]; + auto const tile_start_col = tile.start_col; + auto const tile_start_row = tile.start_row; + auto const num_tile_cols = tile.num_cols(); + auto const num_tile_rows = tile.num_rows(); + constexpr auto rows_per_read = 32; + auto const num_sections_x = util::div_rounding_up_safe(num_tile_cols, CHAR_BIT); + auto const num_sections_y = util::div_rounding_up_safe(num_tile_rows, rows_per_read); + auto const validity_data_col_length = num_sections_y * 4; // words to bytes + auto const total_sections = num_sections_x * num_sections_y; + int const warp_id = threadIdx.x / warp_size; + int const lane_id = threadIdx.x % warp_size; + auto const warps_per_tile = std::max(1u, blockDim.x / warp_size); + + // the tile is divided into sections. A warp operates on a section at a time. + for (int my_section_idx = warp_id; my_section_idx < total_sections; + my_section_idx += warps_per_tile) { + // convert section to row and col + auto const section_x = my_section_idx % num_sections_x; + auto const section_y = my_section_idx / num_sections_x; + auto const relative_col = section_x * CHAR_BIT; + auto const relative_row = section_y * rows_per_read + lane_id; + auto const absolute_col = relative_col + tile_start_col; + auto const absolute_row = relative_row + tile_start_row; + auto const row_batch_start = + tile.batch_number == 0 ? 0 : batch_row_boundaries[tile.batch_number]; + + auto const participation_mask = __ballot_sync(0xFFFFFFFF, absolute_row < num_rows); + + if (absolute_row < num_rows) { + auto const my_byte = input_data[row_offsets(absolute_row, row_batch_start) + + validity_offset + absolute_col / CHAR_BIT]; + + // so every thread that is participating in the warp has a byte, but it's row-based + // data and we need it in column-based. So we shuffle the bits around to make + // the bytes we actually write. + for (int i = 0, byte_mask = 1; i < CHAR_BIT && relative_col + i < num_columns; + ++i, byte_mask <<= 1) { + auto validity_data = __ballot_sync(participation_mask, my_byte & byte_mask); + // lead thread in each warp writes data + if (threadIdx.x % warp_size == 0) { + auto const validity_write_offset = + validity_data_col_length * (relative_col + i) + relative_row / CHAR_BIT; + + *reinterpret_cast(&this_shared_tile[validity_write_offset]) = validity_data; + } + } + } + } + + // make sure entire tile has finished copy + group.sync(); + + // now async memcpy the shared memory out to the final destination 8 bytes at a time + constexpr auto bytes_per_chunk = 8; + auto const col_bytes = util::div_rounding_up_unsafe(num_tile_rows, CHAR_BIT); + auto const chunks_per_col = util::div_rounding_up_unsafe(col_bytes, bytes_per_chunk); + auto const total_chunks = chunks_per_col * num_tile_cols; + auto const tail_bytes = col_bytes % bytes_per_chunk; +#ifdef ASYNC_MEMCPY_SUPPORTED + auto& processing_barrier = + shared_tile_barriers[validity_tile % NUM_VALIDITY_TILES_PER_KERNEL_LOADED]; +#endif // ASYNC_MEMCPY_SUPPORTED + + for (auto i = threadIdx.x; i < total_chunks; i += blockDim.x) { + // determine source address of my chunk + auto const relative_col = i / chunks_per_col; + auto const row_chunk = i % chunks_per_col; + auto const absolute_col = relative_col + tile_start_col; + auto const relative_chunk_byte_offset = row_chunk * bytes_per_chunk; + auto output_dest = reinterpret_cast(output_nm[absolute_col] + + word_index(tile_start_row) + row_chunk * 2); + auto const input_src = + &this_shared_tile[validity_data_col_length * relative_col + relative_chunk_byte_offset]; + + if (tail_bytes > 0 && row_chunk == chunks_per_col - 1) { + MEMCPY(output_dest, input_src, tail_bytes, processing_barrier); + } else { + MEMCPY(output_dest, + input_src, + aligned_size_t(bytes_per_chunk), + processing_barrier); + } + } + } + +#ifdef ASYNC_MEMCPY_SUPPORTED + // wait for last tiles of data to arrive + auto const num_tiles_to_wait = tiles_remaining > NUM_VALIDITY_TILES_PER_KERNEL_LOADED + ? NUM_VALIDITY_TILES_PER_KERNEL_LOADED + : tiles_remaining; + for (int validity_tile = 0; validity_tile < num_tiles_to_wait; ++validity_tile) { + shared_tile_barriers[validity_tile].arrive_and_wait(); + } +#else + group.sync(); +#endif // ASYNC_MEMCPY_SUPPORTED +} + +/** + * @brief Calculate the dimensions of the kernel for fixed width only columns. + * + * @param [in] num_columns the number of columns being copied. + * @param [in] num_rows the number of rows being copied. + * @param [in] size_per_row the size each row takes up when padded. + * @param [out] blocks the size of the blocks for the kernel + * @param [out] threads the size of the threads for the kernel + * @return the size in bytes of shared memory needed for each block. + */ +static int calc_fixed_width_kernel_dims(const size_type num_columns, + const size_type num_rows, + const size_type size_per_row, + dim3& blocks, + dim3& threads) +{ + // We have found speed degrades when a thread handles more than 4 columns. + // Each block is 2 dimensional. The y dimension indicates the columns. + // We limit this to 32 threads in the y dimension so we can still + // have at least 32 threads in the x dimension (1 warp) which should + // result in better coalescing of memory operations. We also + // want to guarantee that we are processing a multiple of 32 threads + // in the x dimension because we use atomic operations at the block + // level when writing validity data out to main memory, and that would + // need to change if we split a word of validity data between blocks. + int const y_block_size = min(util::div_rounding_up_safe(num_columns, 4), 32); + int const x_possible_block_size = 1024 / y_block_size; + // 48KB is the default setting for shared memory per block according to the cuda tutorials + // If someone configures the GPU to only have 16 KB this might not work. + int const max_shared_size = 48 * 1024; + // If we don't have enough shared memory there is no point in having more threads + // per block that will just sit idle + auto const max_block_size = std::min(x_possible_block_size, max_shared_size / size_per_row); + // Make sure that the x dimension is a multiple of 32 this not only helps + // coalesce memory access it also lets us do a ballot sync for validity to write + // the data back out the warp level. If x is a multiple of 32 then each thread in the y + // dimension is associated with one or more warps, that should correspond to the validity + // words directly. + int const block_size = (max_block_size / 32) * 32; + CUDF_EXPECTS(block_size != 0, "Row size is too large to fit in shared memory"); + + // The maximum number of blocks supported in the x dimension is 2 ^ 31 - 1 + // but in practice haveing too many can cause some overhead that I don't totally + // understand. Playing around with this haveing as little as 600 blocks appears + // to be able to saturate memory on V100, so this is an order of magnitude higher + // to try and future proof this a bit. + int const num_blocks = std::clamp((num_rows + block_size - 1) / block_size, 1, 10240); + + blocks.x = num_blocks; + blocks.y = 1; + blocks.z = 1; + threads.x = block_size; + threads.y = y_block_size; + threads.z = 1; + return size_per_row * block_size; +} + +/** + * When converting to rows it is possible that the size of the table was too big to fit + * in a single column. This creates an output column for a subset of the rows in a table + * going from start row and containing the next num_rows. Most of the parameters passed + * into this function are common between runs and should be calculated once. + */ +static std::unique_ptr fixed_width_convert_to_rows( + const size_type start_row, + const size_type num_rows, + const size_type num_columns, + const size_type size_per_row, + rmm::device_uvector& column_start, + rmm::device_uvector& column_size, + rmm::device_uvector& input_data, + rmm::device_uvector& input_nm, + const scalar& zero, + const scalar& scalar_size_per_row, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + int64_t const total_allocation = size_per_row * num_rows; + // We made a mistake in the split somehow + CUDF_EXPECTS(total_allocation < std::numeric_limits::max(), + "Table is too large to fit!"); + + // Allocate and set the offsets row for the byte array + std::unique_ptr offsets = + cudf::detail::sequence(num_rows + 1, zero, scalar_size_per_row, stream); + + std::unique_ptr data = make_numeric_column(data_type(type_id::INT8), + static_cast(total_allocation), + mask_state::UNALLOCATED, + stream, + mr); + + dim3 blocks; + dim3 threads; + int shared_size = + detail::calc_fixed_width_kernel_dims(num_columns, num_rows, size_per_row, blocks, threads); + + copy_to_rows_fixed_width_optimized<<>>( + start_row, + num_rows, + num_columns, + size_per_row, + column_start.data(), + column_size.data(), + input_data.data(), + input_nm.data(), + data->mutable_view().data()); + + return make_lists_column(num_rows, + std::move(offsets), + std::move(data), + 0, + rmm::device_buffer{0, rmm::cuda_stream_default, mr}, + stream, + mr); +} + +static inline bool are_all_fixed_width(std::vector const& schema) +{ + return std::all_of( + schema.begin(), schema.end(), [](const data_type& t) { return is_fixed_width(t); }); +} + +/** + * @brief Given a set of fixed width columns, calculate how the data will be laid out in memory. + * + * @param [in] schema the types of columns that need to be laid out. + * @param [out] column_start the byte offset where each column starts in the row. + * @param [out] column_size the size in bytes of the data for each columns in the row. + * @return the size in bytes each row needs. + */ +static inline int32_t compute_fixed_width_layout(std::vector const& schema, + std::vector& column_start, + std::vector& column_size) +{ + // We guarantee that the start of each column is 64-bit aligned so anything can go + // there, but to make the code simple we will still do an alignment for it. + int32_t at_offset = 0; + for (auto col = schema.begin(); col < schema.end(); col++) { + size_type s = size_of(*col); + column_size.emplace_back(s); + std::size_t allocation_needed = s; + std::size_t alignment_needed = allocation_needed; // They are the same for fixed width types + at_offset = util::round_up_unsafe(at_offset, static_cast(alignment_needed)); + column_start.emplace_back(at_offset); + at_offset += allocation_needed; + } + + // Now we need to add in space for validity + // Eventually we can think about nullable vs not nullable, but for now we will just always add + // it in + int32_t const validity_bytes_needed = + util::div_rounding_up_safe(schema.size(), CHAR_BIT); + // validity comes at the end and is byte aligned so we can pack more in. + at_offset += validity_bytes_needed; + // Now we need to pad the end so all rows are 64 bit aligned + return util::round_up_unsafe(at_offset, JCUDF_ROW_ALIGNMENT); +} + +/** + * @brief Compute information about a table such as bytes per row and offsets. + * + * @tparam iterator iterator of column schema data + * @param begin starting iterator of column schema + * @param end ending iterator of column schema + * @param column_starts column start offsets + * @param column_sizes size in bytes of each column + * @return size of the fixed_width data portion of a row. + */ +template +static size_type compute_column_information(iterator begin, + iterator end, + std::vector& column_starts, + std::vector& column_sizes) +{ + size_type fixed_width_size_per_row = 0; + for (auto cv = begin; cv != end; ++cv) { + auto col_type = std::get<0>(*cv); + bool nested_type = is_compound(col_type); + + // a list or string column will write a single uint64 + // of data here for offset/length + auto col_size = nested_type ? 8 : size_of(col_type); + + // align size for this type + size_type const alignment_needed = col_size; // They are the same for fixed width types + fixed_width_size_per_row = util::round_up_unsafe(fixed_width_size_per_row, alignment_needed); + column_starts.push_back(fixed_width_size_per_row); + column_sizes.push_back(col_size); + fixed_width_size_per_row += col_size; + } + + auto validity_offset = fixed_width_size_per_row; + column_starts.push_back(validity_offset); + + return fixed_width_size_per_row + + util::div_rounding_up_safe(static_cast(std::distance(begin, end)), CHAR_BIT); +} + +/** + * @brief Build `tile_info` for the validity data to break up the work. + * + * @param num_columns number of columns in the table + * @param num_rows number of rows in the table + * @param shmem_limit_per_tile size of shared memory available to a single gpu tile + * @param row_batches batched row information for multiple output locations + * @return vector of `tile_info` structs for validity data + */ +std::vector build_validity_tile_infos(size_type const& num_columns, + size_type const& num_rows, + size_type const& shmem_limit_per_tile, + std::vector const& row_batches) +{ + auto const desired_rows_and_columns = static_cast(sqrt(shmem_limit_per_tile)); + auto const column_stride = util::round_up_unsafe( + [&]() { + if (desired_rows_and_columns > num_columns) { + // not many columns, group it into 8s and ship it off + return std::min(CHAR_BIT, num_columns); + } else { + return util::round_down_safe(desired_rows_and_columns, CHAR_BIT); + } + }(), + JCUDF_ROW_ALIGNMENT); + + // we fit as much as we can given the column stride + // note that an element in the table takes just 1 bit, but a row with a single + // element still takes 8 bytes! + auto const bytes_per_row = + util::round_up_safe(util::div_rounding_up_unsafe(column_stride, CHAR_BIT), JCUDF_ROW_ALIGNMENT); + auto const row_stride = + std::min(num_rows, util::round_down_safe(shmem_limit_per_tile / bytes_per_row, 64)); + + std::vector validity_tile_infos; + validity_tile_infos.reserve(num_columns / column_stride * num_rows / row_stride); + for (int col = 0; col < num_columns; col += column_stride) { + int current_tile_row_batch = 0; + int rows_left_in_batch = row_batches[current_tile_row_batch].row_count; + int row = 0; + while (row < num_rows) { + if (rows_left_in_batch == 0) { + current_tile_row_batch++; + rows_left_in_batch = row_batches[current_tile_row_batch].row_count; + } + int const tile_height = std::min(row_stride, rows_left_in_batch); + + validity_tile_infos.emplace_back(detail::tile_info{ + col, row, std::min(col + column_stride - 1, num_columns - 1), row + tile_height - 1}); + row += tile_height; + rows_left_in_batch -= tile_height; + } + } + + return validity_tile_infos; +} + +/** + * @brief functor that returns the size of a row or 0 is row is greater than the number of rows in + * the table + * + * @tparam RowSize iterator that returns the size of a specific row + */ +template +struct row_size_functor { + row_size_functor(size_type row_end, RowSize row_sizes, size_type last_row_end) + : _row_end(row_end), _row_sizes(row_sizes), _last_row_end(last_row_end) + { + } + + __device__ inline uint64_t operator()(int i) const + { + return i >= _row_end ? 0 : _row_sizes[i + _last_row_end]; + } + + size_type _row_end; + RowSize _row_sizes; + size_type _last_row_end; +}; + +/** + * @brief Builds batches of rows that will fit in the size limit of a column. + * + * @tparam RowSize iterator that gives the size of a specific row of the table. + * @param num_rows Total number of rows in the table + * @param row_sizes iterator that gives the size of a specific row of the table. + * @param all_fixed_width bool indicating all data in this table is fixed width + * @param stream stream to operate on for this work + * @param mr memory resource used to allocate any returned data + * @returns vector of size_type's that indicate row numbers for batch boundaries and a + * device_uvector of row offsets + */ +template +batch_data build_batches(size_type num_rows, + RowSize row_sizes, + bool all_fixed_width, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + auto const total_size = thrust::reduce(rmm::exec_policy(stream), row_sizes, row_sizes + num_rows); + auto const num_batches = static_cast( + util::div_rounding_up_safe(total_size, static_cast(MAX_BATCH_SIZE))); + auto const num_offsets = num_batches + 1; + std::vector row_batches; + std::vector batch_row_boundaries; + device_uvector batch_row_offsets(all_fixed_width ? 0 : num_rows, stream); + + // at most max gpu memory / 2GB iterations. + batch_row_boundaries.reserve(num_offsets); + batch_row_boundaries.push_back(0); + size_type last_row_end = 0; + device_uvector cumulative_row_sizes(num_rows, stream); + thrust::inclusive_scan( + rmm::exec_policy(stream), row_sizes, row_sizes + num_rows, cumulative_row_sizes.begin()); + + while (static_cast(batch_row_boundaries.size()) < num_offsets) { + // find the next MAX_BATCH_SIZE boundary + size_type const row_end = + ((thrust::lower_bound(rmm::exec_policy(stream), + cumulative_row_sizes.begin(), + cumulative_row_sizes.begin() + (num_rows - last_row_end), + MAX_BATCH_SIZE) - + cumulative_row_sizes.begin()) + + last_row_end); + + // build offset list for each row in this batch + auto const num_rows_in_batch = row_end - last_row_end; + + // build offset list for each row in this batch + auto const num_entries = row_end - last_row_end + 1; + device_uvector output_batch_row_offsets(num_entries, stream, mr); + + auto row_size_iter_bounded = cudf::detail::make_counting_transform_iterator( + 0, row_size_functor(row_end, row_sizes, last_row_end)); + + thrust::exclusive_scan(rmm::exec_policy(stream), + row_size_iter_bounded, + row_size_iter_bounded + num_entries, + output_batch_row_offsets.begin()); + + auto const batch_bytes = output_batch_row_offsets.element(num_rows_in_batch, stream); + + // The output_batch_row_offsets vector is used as the offset column of the returned data. This + // needs to be individually allocated, but the kernel needs a contiguous array of offsets or + // more global lookups are necessary. + if (!all_fixed_width) { + cudaMemcpy(batch_row_offsets.data() + last_row_end, + output_batch_row_offsets.data(), + num_rows_in_batch * sizeof(size_type), + cudaMemcpyDeviceToDevice); + } + + batch_row_boundaries.push_back(row_end); + row_batches.push_back({batch_bytes, num_rows_in_batch, std::move(output_batch_row_offsets)}); + + last_row_end = row_end; + } + + return {std::move(batch_row_offsets), + make_device_uvector_async(batch_row_boundaries, stream), + std::move(batch_row_boundaries), + std::move(row_batches)}; +} + +/** + * @brief Computes the number of tiles necessary given a tile height and batch offsets + * + * @param batch_row_boundaries row boundaries for each batch + * @param desired_tile_height height of each tile in the table + * @param stream stream to use + * @return number of tiles necessary + */ +int compute_tile_counts(device_span const& batch_row_boundaries, + int desired_tile_height, + rmm::cuda_stream_view stream) +{ + size_type const num_batches = batch_row_boundaries.size() - 1; + device_uvector num_tiles(num_batches, stream); + auto iter = thrust::make_counting_iterator(0); + thrust::transform( + rmm::exec_policy(stream), + iter, + iter + num_batches, + num_tiles.begin(), + [desired_tile_height, + batch_row_boundaries = batch_row_boundaries.data()] __device__(auto batch_index) -> size_type { + return util::div_rounding_up_unsafe( + batch_row_boundaries[batch_index + 1] - batch_row_boundaries[batch_index], + desired_tile_height); + }); + return thrust::reduce(rmm::exec_policy(stream), num_tiles.begin(), num_tiles.end()); +} + +/** + * @brief Builds the `tile_info` structs for a given table. + * + * @param tiles span of tiles to populate + * @param batch_row_boundaries boundary to row batches + * @param column_start starting column of the tile + * @param column_end ending column of the tile + * @param desired_tile_height height of the tile + * @param total_number_of_rows total number of rows in the table + * @param stream stream to use + * @return number of tiles created + */ +size_type build_tiles( + device_span tiles, + device_uvector const& batch_row_boundaries, // comes from build_batches + int column_start, + int column_end, + int desired_tile_height, + int total_number_of_rows, + rmm::cuda_stream_view stream) +{ + size_type const num_batches = batch_row_boundaries.size() - 1; + device_uvector num_tiles(num_batches, stream); + auto iter = thrust::make_counting_iterator(0); + thrust::transform( + rmm::exec_policy(stream), + iter, + iter + num_batches, + num_tiles.begin(), + [desired_tile_height, + batch_row_boundaries = batch_row_boundaries.data()] __device__(auto batch_index) -> size_type { + return util::div_rounding_up_unsafe( + batch_row_boundaries[batch_index + 1] - batch_row_boundaries[batch_index], + desired_tile_height); + }); + + size_type const total_tiles = + thrust::reduce(rmm::exec_policy(stream), num_tiles.begin(), num_tiles.end()); + + device_uvector tile_starts(num_batches + 1, stream); + auto tile_iter = cudf::detail::make_counting_transform_iterator( + 0, [num_tiles = num_tiles.data(), num_batches] __device__(auto i) { + return (i < num_batches) ? num_tiles[i] : 0; + }); + thrust::exclusive_scan(rmm::exec_policy(stream), + tile_iter, + tile_iter + num_batches + 1, + tile_starts.begin()); // in tiles + + thrust::transform( + rmm::exec_policy(stream), + iter, + iter + total_tiles, + tiles.begin(), + [ =, + tile_starts = tile_starts.data(), + batch_row_boundaries = batch_row_boundaries.data()] __device__(size_type tile_index) { + // what batch this tile falls in + auto const batch_index_iter = + thrust::upper_bound(thrust::seq, tile_starts, tile_starts + num_batches, tile_index); + auto const batch_index = std::distance(tile_starts, batch_index_iter) - 1; + // local index within the tile + int const local_tile_index = tile_index - tile_starts[batch_index]; + // the start row for this batch. + int const batch_row_start = batch_row_boundaries[batch_index]; + // the start row for this tile + int const tile_row_start = batch_row_start + (local_tile_index * desired_tile_height); + // the end row for this tile + int const max_row = std::min(total_number_of_rows - 1, + batch_index + 1 > num_batches + ? std::numeric_limits::max() + : static_cast(batch_row_boundaries[batch_index + 1]) - 1); + int const tile_row_end = + std::min(batch_row_start + ((local_tile_index + 1) * desired_tile_height) - 1, max_row); + + // stuff the tile + return tile_info{ + column_start, tile_row_start, column_end, tile_row_end, static_cast(batch_index)}; + }); + + return total_tiles; +} + +/** + * @brief Determines what data should be operated on by each tile for the incoming table. + * + * @tparam TileCallback Callback that receives the start and end columns of tiles + * @param column_sizes vector of the size of each column + * @param column_starts vector of the offset of each column + * @param first_row_batch_size size of the first row batch to limit max tile size since a tile + * is unable to span batches + * @param total_number_of_rows total number of rows in the table + * @param shmem_limit_per_tile shared memory allowed per tile + * @param f callback function called when building a tile + */ +template +void determine_tiles(std::vector const& column_sizes, + std::vector const& column_starts, + size_type const first_row_batch_size, + size_type const total_number_of_rows, + size_type const& shmem_limit_per_tile, + TileCallback f) +{ + // tile infos are organized with the tile going "down" the columns + // this provides the most coalescing of memory access + int current_tile_width = 0; + int current_tile_start_col = 0; + + // the ideal tile height has lots of 8-byte reads and 8-byte writes. The optimal read/write + // would be memory cache line sized access, but since other tiles will read/write the edges + // this may not turn out to be overly important. For now, we will attempt to build a square + // tile as far as byte sizes. x * y = shared_mem_size. Which translates to x^2 = + // shared_mem_size since we want them equal, so height and width are sqrt(shared_mem_size). The + // trick is that it's in bytes, not rows or columns. + auto const optimal_square_len = static_cast(sqrt(shmem_limit_per_tile)); + auto const tile_height = + std::clamp(util::round_up_safe( + std::min(optimal_square_len / column_sizes[0], total_number_of_rows), 32), + 1, + first_row_batch_size); + + int row_size = 0; + + // march each column and build the tiles of appropriate sizes + for (uint col = 0; col < column_sizes.size(); ++col) { + auto const col_size = column_sizes[col]; + + // align size for this type + auto const alignment_needed = col_size; // They are the same for fixed width types + auto const row_size_aligned = util::round_up_unsafe(row_size, alignment_needed); + auto const row_size_with_this_col = row_size_aligned + col_size; + auto const row_size_with_end_pad = + util::round_up_unsafe(row_size_with_this_col, JCUDF_ROW_ALIGNMENT); + + if (row_size_with_end_pad * tile_height > shmem_limit_per_tile) { + // too large, close this tile, generate vertical tiles and restart + f(current_tile_start_col, col == 0 ? col : col - 1, tile_height); + + row_size = + util::round_up_unsafe((column_starts[col] + column_sizes[col]) & 7, alignment_needed); + row_size += col_size; // alignment required for shared memory tile boundary to match + // alignment of output row + current_tile_start_col = col; + current_tile_width = 0; + } else { + row_size = row_size_with_this_col; + current_tile_width++; + } + } + + // build last set of tiles + if (current_tile_width > 0) { + f(current_tile_start_col, static_cast(column_sizes.size()) - 1, tile_height); + } +} + +/** + * @brief convert cudf table into JCUDF row format + * + * @tparam offsetFunctor functor type for offset functor + * @param tbl table to convert to JCUDF row format + * @param batch_info information about the batches of data + * @param offset_functor functor that returns the starting offset of each row + * @param column_starts starting offset of a column in a row + * @param column_sizes size of each element in a column + * @param fixed_width_size_per_row size of fixed-width data in a row of this table + * @param stream stream used + * @param mr selected memory resource for returned data + * @return vector of list columns containing byte columns of the JCUDF row data + */ +template +std::vector> convert_to_rows(table_view const& tbl, + batch_data& batch_info, + offsetFunctor offset_functor, + std::vector const& column_starts, + std::vector const& column_sizes, + size_type const fixed_width_size_per_row, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + int device_id; + CUDA_TRY(cudaGetDevice(&device_id)); + int total_shmem_in_bytes; + CUDA_TRY( + cudaDeviceGetAttribute(&total_shmem_in_bytes, cudaDevAttrMaxSharedMemoryPerBlock, device_id)); + +#ifndef __CUDA_ARCH__ // __host__ code. + // Need to reduce total shmem available by the size of barriers in the kernel's shared memory + total_shmem_in_bytes -= + sizeof(cuda::barrier) * NUM_TILES_PER_KERNEL_LOADED; +#endif // __CUDA_ARCH__ + + auto const shmem_limit_per_tile = total_shmem_in_bytes / NUM_TILES_PER_KERNEL_LOADED; + + auto const num_rows = tbl.num_rows(); + auto const num_columns = tbl.num_columns(); + auto dev_col_sizes = make_device_uvector_async(column_sizes, stream, mr); + auto dev_col_starts = make_device_uvector_async(column_starts, stream, mr); + + // Get the pointers to the input columnar data ready + auto data_begin = thrust::make_transform_iterator( + tbl.begin(), [](auto const& c) { return c.template data(); }); + std::vector input_data(data_begin, data_begin + tbl.num_columns()); + + auto nm_begin = + thrust::make_transform_iterator(tbl.begin(), [](auto const& c) { return c.null_mask(); }); + std::vector input_nm(nm_begin, nm_begin + tbl.num_columns()); + + auto dev_input_data = make_device_uvector_async(input_data, stream, mr); + auto dev_input_nm = make_device_uvector_async(input_nm, stream, mr); + + // the first batch always exists unless we were sent an empty table + auto const first_batch_size = batch_info.row_batches[0].row_count; + + std::vector output_buffers; + std::vector output_data; + output_data.reserve(batch_info.row_batches.size()); + output_buffers.reserve(batch_info.row_batches.size()); + std::transform( + batch_info.row_batches.begin(), + batch_info.row_batches.end(), + std::back_inserter(output_buffers), + [&](auto const& batch) { return rmm::device_buffer(batch.num_bytes, stream, mr); }); + std::transform( + output_buffers.begin(), output_buffers.end(), std::back_inserter(output_data), [](auto& buf) { + return static_cast(buf.data()); + }); + + auto dev_output_data = make_device_uvector_async(output_data, stream, mr); + + int info_count = 0; + detail::determine_tiles( + column_sizes, + column_starts, + first_batch_size, + num_rows, + shmem_limit_per_tile, + [&gpu_batch_row_boundaries = batch_info.d_batch_row_boundaries, &info_count, &stream]( + int const start_col, int const end_col, int const tile_height) { + int i = detail::compute_tile_counts(gpu_batch_row_boundaries, tile_height, stream); + info_count += i; + }); + + // allocate space for tiles + device_uvector gpu_tile_infos(info_count, stream); + int tile_offset = 0; + + detail::determine_tiles( + column_sizes, + column_starts, + first_batch_size, + num_rows, + shmem_limit_per_tile, + [&gpu_batch_row_boundaries = batch_info.d_batch_row_boundaries, + &gpu_tile_infos, + num_rows, + &tile_offset, + stream](int const start_col, int const end_col, int const tile_height) { + tile_offset += detail::build_tiles( + {gpu_tile_infos.data() + tile_offset, gpu_tile_infos.size() - tile_offset}, + gpu_batch_row_boundaries, + start_col, + end_col, + tile_height, + num_rows, + stream); + }); + + // blast through the entire table and convert it + dim3 blocks(util::div_rounding_up_unsafe(gpu_tile_infos.size(), NUM_TILES_PER_KERNEL_TO_ROWS)); + dim3 threads(256); + + auto validity_tile_infos = detail::build_validity_tile_infos( + num_columns, num_rows, shmem_limit_per_tile, batch_info.row_batches); + + auto dev_validity_tile_infos = make_device_uvector_async(validity_tile_infos, stream); + dim3 validity_blocks( + util::div_rounding_up_unsafe(validity_tile_infos.size(), NUM_VALIDITY_TILES_PER_KERNEL)); + dim3 validity_threads(std::min(validity_tile_infos.size() * 32, 128lu)); + + detail::copy_to_rows<<>>( + num_rows, + num_columns, + shmem_limit_per_tile, + gpu_tile_infos, + dev_input_data.data(), + dev_col_sizes.data(), + dev_col_starts.data(), + offset_functor, + batch_info.d_batch_row_boundaries.data(), + reinterpret_cast(dev_output_data.data())); + + detail::copy_validity_to_rows<<>>(num_rows, + num_columns, + shmem_limit_per_tile, + offset_functor, + batch_info.d_batch_row_boundaries.data(), + dev_output_data.data(), + column_starts.back(), + dev_validity_tile_infos, + dev_input_nm.data()); + + // split up the output buffer into multiple buffers based on row batch sizes + // and create list of byte columns + std::vector> ret; + auto counting_iter = thrust::make_counting_iterator(0); + std::transform(counting_iter, + counting_iter + batch_info.row_batches.size(), + std::back_inserter(ret), + [&](auto batch) { + auto const offset_count = batch_info.row_batches[batch].row_offsets.size(); + auto offsets = + std::make_unique(data_type{type_id::INT32}, + (size_type)offset_count, + batch_info.row_batches[batch].row_offsets.release()); + auto data = std::make_unique(data_type{type_id::INT8}, + batch_info.row_batches[batch].num_bytes, + std::move(output_buffers[batch])); + + return make_lists_column(batch_info.row_batches[batch].row_count, + std::move(offsets), + std::move(data), + 0, + rmm::device_buffer{0, rmm::cuda_stream_default, mr}, + stream, + mr); + }); + + return ret; +} + +} // namespace detail + +std::vector> convert_to_rows(table_view const& tbl, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + auto const num_columns = tbl.num_columns(); + auto const num_rows = tbl.num_rows(); + + auto const fixed_width_only = std::all_of( + tbl.begin(), tbl.end(), [](column_view const& c) { return is_fixed_width(c.type()); }); + + // break up the work into tiles, which are a starting and ending row/col #. + // this tile size is calculated based on the shared memory size available + // we want a single tile to fill up the entire shared memory space available + // for the transpose-like conversion. + + // There are two different processes going on here. The GPU conversion of the data + // and the writing of the data into the list of byte columns that are a maximum of + // 2 gigs each due to offset maximum size. The GPU conversion portion has to understand + // this limitation because the column must own the data inside and as a result it must be + // a distinct allocation for that column. Copying the data into these final buffers would + // be prohibitively expensive, so care is taken to ensure the GPU writes to the proper buffer. + // The tiles are broken at the boundaries of specific rows based on the row sizes up + // to that point. These are row batches and they are decided first before building the + // tiles so the tiles can be properly cut around them. + + std::vector column_sizes; // byte size of each column + std::vector column_starts; // offset of column inside a row including alignment + column_sizes.reserve(num_columns); + column_starts.reserve(num_columns + 1); // we add a final offset for validity data start + + auto schema_column_iter = thrust::make_transform_iterator( + thrust::make_counting_iterator(0), [&tbl](auto i) -> std::tuple { + return {tbl.column(i).type(), tbl.column(i)}; + }); + + auto const fixed_width_size_per_row = detail::compute_column_information( + schema_column_iter, schema_column_iter + num_columns, column_starts, column_sizes); + if (fixed_width_only) { + // total encoded row size. This includes fixed-width data and validity only. It does not include + // variable-width data since it isn't copied with the fixed-width and validity kernel. + auto row_size_iter = thrust::make_constant_iterator( + util::round_up_unsafe(fixed_width_size_per_row, JCUDF_ROW_ALIGNMENT)); + + auto batch_info = detail::build_batches(num_rows, row_size_iter, fixed_width_only, stream, mr); + + detail::fixed_width_row_offset_functor offset_functor( + util::round_up_unsafe(fixed_width_size_per_row, JCUDF_ROW_ALIGNMENT)); + + return detail::convert_to_rows(tbl, + batch_info, + offset_functor, + column_starts, + column_sizes, + fixed_width_size_per_row, + stream, + mr); + } else { + auto row_sizes = detail::build_string_row_sizes(tbl, fixed_width_size_per_row, stream); + + auto row_size_iter = cudf::detail::make_counting_transform_iterator( + 0, detail::row_size_functor(num_rows, row_sizes.data(), 0)); + + auto batch_info = detail::build_batches(num_rows, row_size_iter, fixed_width_only, stream, mr); + + detail::string_row_offset_functor offset_functor(batch_info.batch_row_offsets); + + return detail::convert_to_rows(tbl, + batch_info, + offset_functor, + column_starts, + column_sizes, + fixed_width_size_per_row, + stream, + mr); + } +} + +std::vector> convert_to_rows_fixed_width_optimized( + table_view const& tbl, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) +{ + auto const num_columns = tbl.num_columns(); + + std::vector schema; + schema.resize(num_columns); + std::transform( + tbl.begin(), tbl.end(), schema.begin(), [](auto i) -> data_type { return i.type(); }); + + if (detail::are_all_fixed_width(schema)) { + std::vector column_start; + std::vector column_size; + + int32_t const size_per_row = + detail::compute_fixed_width_layout(schema, column_start, column_size); + auto dev_column_start = make_device_uvector_async(column_start, stream, mr); + auto dev_column_size = make_device_uvector_async(column_size, stream, mr); + + // Make the number of rows per batch a multiple of 32 so we don't have to worry about + // splitting validity at a specific row offset. This might change in the future. + auto const max_rows_per_batch = + util::round_down_safe(std::numeric_limits::max() / size_per_row, 32); + + auto const num_rows = tbl.num_rows(); + + // Get the pointers to the input columnar data ready + std::vector input_data; + std::vector input_nm; + for (size_type column_number = 0; column_number < num_columns; column_number++) { + column_view cv = tbl.column(column_number); + input_data.emplace_back(cv.data()); + input_nm.emplace_back(cv.null_mask()); + } + auto dev_input_data = make_device_uvector_async(input_data, stream, mr); + auto dev_input_nm = make_device_uvector_async(input_nm, stream, mr); + + using ScalarType = scalar_type_t; + auto zero = make_numeric_scalar(data_type(type_id::INT32), stream.value()); + zero->set_valid_async(true, stream); + static_cast(zero.get())->set_value(0, stream); + + auto step = make_numeric_scalar(data_type(type_id::INT32), stream.value()); + step->set_valid_async(true, stream); + static_cast(step.get())->set_value(static_cast(size_per_row), stream); + + std::vector> ret; + for (size_type row_start = 0; row_start < num_rows; row_start += max_rows_per_batch) { + size_type row_count = num_rows - row_start; + row_count = row_count > max_rows_per_batch ? max_rows_per_batch : row_count; + ret.emplace_back(detail::fixed_width_convert_to_rows(row_start, + row_count, + num_columns, + size_per_row, + dev_column_start, + dev_column_size, + dev_input_data, + dev_input_nm, + *zero, + *step, + stream, + mr)); + } + + return ret; + } else { + CUDF_FAIL("Only fixed width types are currently supported"); + } +} + +std::unique_ptr convert_from_rows(lists_column_view const& input, + std::vector const& schema, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + // verify that the types are what we expect + column_view child = input.child(); + auto const list_type = child.type().id(); + CUDF_EXPECTS(list_type == type_id::INT8 || list_type == type_id::UINT8, + "Only a list of bytes is supported as input"); + + auto const num_columns = schema.size(); + auto const num_rows = input.parent().size(); + + int device_id; + CUDA_TRY(cudaGetDevice(&device_id)); + int total_shmem_in_bytes; + CUDA_TRY( + cudaDeviceGetAttribute(&total_shmem_in_bytes, cudaDevAttrMaxSharedMemoryPerBlock, device_id)); + +#ifndef __CUDA_ARCH__ // __host__ code. + // Need to reduce total shmem available by the size of barriers in the kernel's shared memory + total_shmem_in_bytes -= + sizeof(cuda::barrier) * NUM_TILES_PER_KERNEL_LOADED; +#endif // __CUDA_ARCH__ + + int shmem_limit_per_tile = total_shmem_in_bytes / NUM_TILES_PER_KERNEL_LOADED; + + std::vector column_starts; + std::vector column_sizes; + + auto iter = thrust::make_transform_iterator(thrust::make_counting_iterator(0), [&schema](auto i) { + return std::make_tuple(schema[i], nullptr); + }); + auto const fixed_width_size_per_row = util::round_up_unsafe( + detail::compute_column_information(iter, iter + num_columns, column_starts, column_sizes), + JCUDF_ROW_ALIGNMENT); + + // Ideally we would check that the offsets are all the same, etc. but for now + // this is probably fine + CUDF_EXPECTS(fixed_width_size_per_row * num_rows == child.size(), + "The layout of the data appears to be off"); + auto dev_col_starts = make_device_uvector_async(column_starts, stream, mr); + auto dev_col_sizes = make_device_uvector_async(column_sizes, stream, mr); + + // Allocate the columns we are going to write into + std::vector> output_columns; + std::vector output_data; + std::vector output_nm; + for (int i = 0; i < static_cast(num_columns); i++) { + auto column = + make_fixed_width_column(schema[i], num_rows, mask_state::UNINITIALIZED, stream, mr); + auto mut = column->mutable_view(); + output_data.emplace_back(mut.data()); + output_nm.emplace_back(mut.null_mask()); + output_columns.emplace_back(std::move(column)); + } + + // build the row_batches from the passed in list column + std::vector row_batches; + row_batches.push_back( + {detail::row_batch{child.size(), num_rows, device_uvector(0, stream)}}); + + auto dev_output_data = make_device_uvector_async(output_data, stream, mr); + auto dev_output_nm = make_device_uvector_async(output_nm, stream, mr); + + // only ever get a single batch when going from rows, so boundaries + // are 0, num_rows + constexpr auto num_batches = 2; + device_uvector gpu_batch_row_boundaries(num_batches, stream); + + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_batches), + gpu_batch_row_boundaries.begin(), + [num_rows] __device__(auto i) { return i == 0 ? 0 : num_rows; }); + + int info_count = 0; + detail::determine_tiles(column_sizes, + column_starts, + num_rows, + num_rows, + shmem_limit_per_tile, + [&gpu_batch_row_boundaries, &info_count, &stream]( + int const start_col, int const end_col, int const tile_height) { + info_count += detail::compute_tile_counts( + gpu_batch_row_boundaries, tile_height, stream); + }); + + // allocate space for tiles + device_uvector gpu_tile_infos(info_count, stream); + + int tile_offset = 0; + detail::determine_tiles( + column_sizes, + column_starts, + num_rows, + num_rows, + shmem_limit_per_tile, + [&gpu_batch_row_boundaries, &gpu_tile_infos, num_rows, &tile_offset, stream]( + int const start_col, int const end_col, int const tile_height) { + tile_offset += detail::build_tiles( + {gpu_tile_infos.data() + tile_offset, gpu_tile_infos.size() - tile_offset}, + gpu_batch_row_boundaries, + start_col, + end_col, + tile_height, + num_rows, + stream); + }); + + dim3 blocks(util::div_rounding_up_unsafe(gpu_tile_infos.size(), NUM_TILES_PER_KERNEL_FROM_ROWS)); + dim3 threads(std::min(std::min(256, shmem_limit_per_tile / 8), static_cast(child.size()))); + + auto validity_tile_infos = + detail::build_validity_tile_infos(num_columns, num_rows, shmem_limit_per_tile, row_batches); + + auto dev_validity_tile_infos = make_device_uvector_async(validity_tile_infos, stream); + + dim3 validity_blocks( + util::div_rounding_up_unsafe(validity_tile_infos.size(), NUM_VALIDITY_TILES_PER_KERNEL)); + + dim3 validity_threads(std::min(validity_tile_infos.size() * 32, 128lu)); + + detail::fixed_width_row_offset_functor offset_functor(fixed_width_size_per_row); + + detail::copy_from_rows<<>>( + num_rows, + num_columns, + shmem_limit_per_tile, + offset_functor, + gpu_batch_row_boundaries.data(), + dev_output_data.data(), + dev_col_sizes.data(), + dev_col_starts.data(), + gpu_tile_infos, + child.data()); + + detail::copy_validity_from_rows<<>>(num_rows, + num_columns, + shmem_limit_per_tile, + offset_functor, + gpu_batch_row_boundaries.data(), + dev_output_nm.data(), + column_starts.back(), + dev_validity_tile_infos, + child.data()); + + return std::make_unique
(std::move(output_columns)); +} + +std::unique_ptr
convert_from_rows_fixed_width_optimized(lists_column_view const& input, + std::vector const& schema, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + // verify that the types are what we expect + column_view child = input.child(); + auto const list_type = child.type().id(); + CUDF_EXPECTS(list_type == type_id::INT8 || list_type == type_id::UINT8, + "Only a list of bytes is supported as input"); + + auto const num_columns = schema.size(); + + if (detail::are_all_fixed_width(schema)) { + std::vector column_start; + std::vector column_size; + + auto const num_rows = input.parent().size(); + auto const size_per_row = detail::compute_fixed_width_layout(schema, column_start, column_size); + + // Ideally we would check that the offsets are all the same, etc. but for now + // this is probably fine + CUDF_EXPECTS(size_per_row * num_rows == child.size(), + "The layout of the data appears to be off"); + auto dev_column_start = make_device_uvector_async(column_start, stream); + auto dev_column_size = make_device_uvector_async(column_size, stream); + + // Allocate the columns we are going to write into + std::vector> output_columns; + std::vector output_data; + std::vector output_nm; + for (int i = 0; i < static_cast(num_columns); i++) { + auto column = + make_fixed_width_column(schema[i], num_rows, mask_state::UNINITIALIZED, stream, mr); + auto mut = column->mutable_view(); + output_data.emplace_back(mut.data()); + output_nm.emplace_back(mut.null_mask()); + output_columns.emplace_back(std::move(column)); + } + + auto dev_output_data = make_device_uvector_async(output_data, stream, mr); + auto dev_output_nm = make_device_uvector_async(output_nm, stream, mr); + + dim3 blocks; + dim3 threads; + int shared_size = + detail::calc_fixed_width_kernel_dims(num_columns, num_rows, size_per_row, blocks, threads); + + detail::copy_from_rows_fixed_width_optimized<<>>( + num_rows, + num_columns, + size_per_row, + dev_column_start.data(), + dev_column_size.data(), + dev_output_data.data(), + dev_output_nm.data(), + child.data()); + + return std::make_unique
(std::move(output_columns)); + } else { + CUDF_FAIL("Only fixed width types are currently supported"); + } +} + +//} // namespace jni + +} // namespace cudf diff --git a/cpp/src/row_conversion/row_conversion.cu.bk b/cpp/src/row_conversion/row_conversion.cu.bk new file mode 100644 index 00000000000..5f7c9d65aac --- /dev/null +++ b/cpp/src/row_conversion/row_conversion.cu.bk @@ -0,0 +1,2234 @@ +/* + * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include "thrust/scan.h" + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700 +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +class cuda_event_timer { + public: + /** + * @brief This c'tor clears the L2$ by cudaMemset'ing a buffer of L2$ size + * and starts the timer. + * + * @param[in,out] state This is the benchmark::State whose timer we are going + * to update. + * @param[in] flush_l2_cache_ whether or not to flush the L2 cache before + * every iteration. + * @param[in] stream_ The CUDA stream we are measuring time on. + */ + + void start_timer(const char* name, + bool flush_l2_cache, + rmm::cuda_stream_view stream = rmm::cuda_stream_default); + + // The user must provide a benchmark::State object to set + // the timer so we disable the default c'tor. + void stop_timer(); + + private: + const char* name; + cudaEvent_t start; + cudaEvent_t stop; + rmm::cuda_stream_view stream; +}; + +#include + +#include +#include + +void cuda_event_timer::start_timer(const char* name, + bool flush_l2_cache, + rmm::cuda_stream_view stream) +{ + this->name = name; + this->stream = stream; + + // flush all of L2$ + if (flush_l2_cache) { + int current_device = 0; + CUDA_TRY(cudaGetDevice(¤t_device)); + + int l2_cache_bytes = 0; + CUDA_TRY(cudaDeviceGetAttribute(&l2_cache_bytes, cudaDevAttrL2CacheSize, current_device)); + + if (l2_cache_bytes > 0) { + const int memset_value = 0; + rmm::device_buffer l2_cache_buffer(l2_cache_bytes, stream); + CUDA_TRY( + cudaMemsetAsync(l2_cache_buffer.data(), memset_value, l2_cache_bytes, stream.value())); + } + } + + CUDA_TRY(cudaEventCreate(&start)); + CUDA_TRY(cudaEventCreate(&stop)); + CUDA_TRY(cudaEventRecord(start, stream.value())); +} + +void cuda_event_timer::stop_timer() +{ + CUDA_TRY(cudaEventRecord(stop, stream.value())); + CUDA_TRY(cudaEventSynchronize(stop)); + + float milliseconds = 0.0f; + CUDA_TRY(cudaEventElapsedTime(&milliseconds, start, stop)); + printf("%s took %.2fms\n", name, milliseconds); + CUDA_TRY(cudaEventDestroy(start)); + CUDA_TRY(cudaEventDestroy(stop)); +} + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700 +constexpr auto NUM_BLOCKS_PER_KERNEL_FROM_ROWS = 2; +constexpr auto NUM_BLOCKS_PER_KERNEL_TO_ROWS = 2; +constexpr auto NUM_BLOCKS_PER_KERNEL_LOADED = 2; +constexpr auto NUM_VALIDITY_BLOCKS_PER_KERNEL = 8; +constexpr auto NUM_VALIDITY_BLOCKS_PER_KERNEL_LOADED = 2; + +// needed to suppress warning about cuda::barrier +#pragma diag_suppress static_var_with_dynamic_init +#endif + +using cudf::detail::make_device_uvector_async; +using rmm::device_uvector; +namespace cudf { +// namespace java { +namespace detail { + +static inline __host__ __device__ int32_t align_offset(int32_t offset, std::size_t alignment) +{ + return (offset + alignment - 1) & ~(alignment - 1); +} + +__global__ void copy_from_rows_fixed_width_optimized(const cudf::size_type num_rows, + const cudf::size_type num_columns, + const cudf::size_type row_size, + const cudf::size_type* input_offset_in_row, + const cudf::size_type* num_bytes, + int8_t** output_data, + cudf::bitmask_type** output_nm, + const int8_t* input_data) +{ + // We are going to copy the data in two passes. + // The first pass copies a chunk of data into shared memory. + // The second pass copies that chunk from shared memory out to the final location. + + // Because shared memory is limited we copy a subset of the rows at a time. + // For simplicity we will refer to this as a row_group + + // In practice we have found writing more than 4 columns of data per thread + // results in performance loss. As such we are using a 2 dimensional + // kernel in terms of threads, but not in terms of blocks. Columns are + // controlled by the y dimension (there is no y dimension in blocks). Rows + // are controlled by the x dimension (there are multiple blocks in the x + // dimension). + + cudf::size_type rows_per_group = blockDim.x; + cudf::size_type row_group_start = blockIdx.x; + cudf::size_type row_group_stride = gridDim.x; + cudf::size_type row_group_end = (num_rows + rows_per_group - 1) / rows_per_group + 1; + + extern __shared__ int8_t shared_data[]; + + // Because we are copying fixed width only data and we stride the rows + // this thread will always start copying from shared data in the same place + int8_t* row_tmp = &shared_data[row_size * threadIdx.x]; + int8_t* row_vld_tmp = &row_tmp[input_offset_in_row[num_columns - 1] + num_bytes[num_columns - 1]]; + + for (cudf::size_type row_group_index = row_group_start; row_group_index < row_group_end; + row_group_index += row_group_stride) { + // Step 1: Copy the data into shared memory + // We know row_size is always aligned with and a multiple of int64_t; + int64_t* long_shared = reinterpret_cast(shared_data); + const int64_t* long_input = reinterpret_cast(input_data); + + cudf::size_type shared_output_index = threadIdx.x + (threadIdx.y * blockDim.x); + cudf::size_type shared_output_stride = blockDim.x * blockDim.y; + cudf::size_type row_index_end = ((row_group_index + 1) * rows_per_group); + if (row_index_end > num_rows) { row_index_end = num_rows; } + cudf::size_type num_rows_in_group = row_index_end - (row_group_index * rows_per_group); + cudf::size_type shared_length = row_size * num_rows_in_group; + + cudf::size_type shared_output_end = shared_length / sizeof(int64_t); + + cudf::size_type start_input_index = + (row_size * row_group_index * rows_per_group) / sizeof(int64_t); + + for (cudf::size_type shared_index = shared_output_index; shared_index < shared_output_end; + shared_index += shared_output_stride) { + long_shared[shared_index] = long_input[start_input_index + shared_index]; + } + // Wait for all of the data to be in shared memory + __syncthreads(); + + // Step 2 copy the data back out + + // Within the row group there should be 1 thread for each row. This is a + // requirement for launching the kernel + cudf::size_type row_index = (row_group_index * rows_per_group) + threadIdx.x; + // But we might not use all of the threads if the number of rows does not go + // evenly into the thread count. We don't want those threads to exit yet + // because we may need them to copy data in for the next row group. + uint32_t active_mask = __ballot_sync(0xffffffff, row_index < num_rows); + if (row_index < num_rows) { + cudf::size_type col_index_start = threadIdx.y; + cudf::size_type col_index_stride = blockDim.y; + for (cudf::size_type col_index = col_index_start; col_index < num_columns; + col_index += col_index_stride) { + cudf::size_type col_size = num_bytes[col_index]; + const int8_t* col_tmp = &(row_tmp[input_offset_in_row[col_index]]); + int8_t* col_output = output_data[col_index]; + switch (col_size) { + case 1: { + col_output[row_index] = *col_tmp; + break; + } + case 2: { + int16_t* short_col_output = reinterpret_cast(col_output); + short_col_output[row_index] = *reinterpret_cast(col_tmp); + break; + } + case 4: { + int32_t* int_col_output = reinterpret_cast(col_output); + int_col_output[row_index] = *reinterpret_cast(col_tmp); + break; + } + case 8: { + int64_t* long_col_output = reinterpret_cast(col_output); + long_col_output[row_index] = *reinterpret_cast(col_tmp); + break; + } + default: { + cudf::size_type output_offset = col_size * row_index; + // TODO this should just not be supported for fixed width columns, but just in case... + for (cudf::size_type b = 0; b < col_size; b++) { + col_output[b + output_offset] = col_tmp[b]; + } + break; + } + } + + cudf::bitmask_type* nm = output_nm[col_index]; + int8_t* valid_byte = &row_vld_tmp[col_index / 8]; + cudf::size_type byte_bit_offset = col_index % 8; + int predicate = *valid_byte & (1 << byte_bit_offset); + uint32_t bitmask = __ballot_sync(active_mask, predicate); + if (row_index % 32 == 0) { nm[word_index(row_index)] = bitmask; } + } // end column loop + } // end row copy + // wait for the row_group to be totally copied before starting on the next row group + __syncthreads(); + } +} + +__global__ void copy_to_rows_fixed_width_optimized(const cudf::size_type start_row, + const cudf::size_type num_rows, + const cudf::size_type num_columns, + const cudf::size_type row_size, + const cudf::size_type* output_offset_in_row, + const cudf::size_type* num_bytes, + const int8_t** input_data, + const cudf::bitmask_type** input_nm, + int8_t* output_data) +{ + // We are going to copy the data in two passes. + // The first pass copies a chunk of data into shared memory. + // The second pass copies that chunk from shared memory out to the final location. + + // Because shared memory is limited we copy a subset of the rows at a time. + // We do not support copying a subset of the columns in a row yet, so we don't + // currently support a row that is wider than shared memory. + // For simplicity we will refer to this as a row_group + + // In practice we have found reading more than 4 columns of data per thread + // results in performance loss. As such we are using a 2 dimensional + // kernel in terms of threads, but not in terms of blocks. Columns are + // controlled by the y dimension (there is no y dimension in blocks). Rows + // are controlled by the x dimension (there are multiple blocks in the x + // dimension). + + cudf::size_type rows_per_group = blockDim.x; + cudf::size_type row_group_start = blockIdx.x; + cudf::size_type row_group_stride = gridDim.x; + cudf::size_type row_group_end = (num_rows + rows_per_group - 1) / rows_per_group + 1; + + extern __shared__ int8_t shared_data[]; + + // Because we are copying fixed width only data and we stride the rows + // this thread will always start copying to shared data in the same place + int8_t* row_tmp = &shared_data[row_size * threadIdx.x]; + int8_t* row_vld_tmp = + &row_tmp[output_offset_in_row[num_columns - 1] + num_bytes[num_columns - 1]]; + + for (cudf::size_type row_group_index = row_group_start; row_group_index < row_group_end; + row_group_index += row_group_stride) { + // Within the row group there should be 1 thread for each row. This is a + // requirement for launching the kernel + cudf::size_type row_index = start_row + (row_group_index * rows_per_group) + threadIdx.x; + // But we might not use all of the threads if the number of rows does not go + // evenly into the thread count. We don't want those threads to exit yet + // because we may need them to copy data back out. + if (row_index < (start_row + num_rows)) { + cudf::size_type col_index_start = threadIdx.y; + cudf::size_type col_index_stride = blockDim.y; + for (cudf::size_type col_index = col_index_start; col_index < num_columns; + col_index += col_index_stride) { + cudf::size_type col_size = num_bytes[col_index]; + int8_t* col_tmp = &(row_tmp[output_offset_in_row[col_index]]); + const int8_t* col_input = input_data[col_index]; + switch (col_size) { + case 1: { + *col_tmp = col_input[row_index]; + break; + } + case 2: { + const int16_t* short_col_input = reinterpret_cast(col_input); + *reinterpret_cast(col_tmp) = short_col_input[row_index]; + break; + } + case 4: { + const int32_t* int_col_input = reinterpret_cast(col_input); + *reinterpret_cast(col_tmp) = int_col_input[row_index]; + break; + } + case 8: { + const int64_t* long_col_input = reinterpret_cast(col_input); + *reinterpret_cast(col_tmp) = long_col_input[row_index]; + break; + } + default: { + cudf::size_type input_offset = col_size * row_index; + // TODO this should just not be supported for fixed width columns, but just in case... + for (cudf::size_type b = 0; b < col_size; b++) { + col_tmp[b] = col_input[b + input_offset]; + } + break; + } + } + // atomicOr only works on 32 bit or 64 bit aligned values, and not byte aligned + // so we have to rewrite the addresses to make sure that it is 4 byte aligned + int8_t* valid_byte = &row_vld_tmp[col_index / 8]; + cudf::size_type byte_bit_offset = col_index % 8; + uint64_t fixup_bytes = reinterpret_cast(valid_byte) % 4; + int32_t* valid_int = reinterpret_cast(valid_byte - fixup_bytes); + cudf::size_type int_bit_offset = byte_bit_offset + (fixup_bytes * 8); + // Now copy validity for the column + if (input_nm[col_index]) { + if (bit_is_set(input_nm[col_index], row_index)) { + atomicOr_block(valid_int, 1 << int_bit_offset); + } else { + atomicAnd_block(valid_int, ~(1 << int_bit_offset)); + } + } else { + // It is valid so just set the bit + atomicOr_block(valid_int, 1 << int_bit_offset); + } + } // end column loop + } // end row copy + // wait for the row_group to be totally copied into shared memory + __syncthreads(); + + // Step 2: Copy the data back out + // We know row_size is always aligned with and a multiple of int64_t; + int64_t* long_shared = reinterpret_cast(shared_data); + int64_t* long_output = reinterpret_cast(output_data); + + cudf::size_type shared_input_index = threadIdx.x + (threadIdx.y * blockDim.x); + cudf::size_type shared_input_stride = blockDim.x * blockDim.y; + cudf::size_type row_index_end = ((row_group_index + 1) * rows_per_group); + if (row_index_end > num_rows) { row_index_end = num_rows; } + cudf::size_type num_rows_in_group = row_index_end - (row_group_index * rows_per_group); + cudf::size_type shared_length = row_size * num_rows_in_group; + + cudf::size_type shared_input_end = shared_length / sizeof(int64_t); + + cudf::size_type start_output_index = + (row_size * row_group_index * rows_per_group) / sizeof(int64_t); + + for (cudf::size_type shared_index = shared_input_index; shared_index < shared_input_end; + shared_index += shared_input_stride) { + long_output[start_output_index + shared_index] = long_shared[shared_index]; + } + __syncthreads(); + // Go for the next round + } +} + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700 + +/** + * @brief The GPU blocks work on one or more block_info structs of data. + * This structure defined the workspace for the block. + * + */ +struct block_info { + int start_col; + int start_row; + int end_col; + int end_row; + int batch_number; + + __host__ __device__ size_type get_shared_row_size(size_type const* const col_offsets, + size_type const* const col_sizes) const + { + auto ret = align_offset(col_offsets[end_col] + col_sizes[end_col] - col_offsets[start_col], 8); + /* if (ret % 32 == 0) { + // bank collision, adjust size to pad. + ret += 1; + }*/ + return ret; + } + __host__ __device__ size_type num_cols() const { return end_col - start_col + 1; } + + __host__ __device__ size_type num_rows() const { return end_row - start_row + 1; } +}; + +/** + * @brief Returning rows is done in a byte cudf column. This is limited in size by + * `size_type` and so output is broken into batches of rows that fit inside + * this limit. + * + */ +struct row_batch { + size_type num_bytes; + size_type row_count; + device_uvector row_offsets; +}; + +__device__ void memcpy_dumb(int8_t* dest, int8_t* src, size_t bytes) +{ + for (uint i = 0; i < bytes; ++i) { + dest[i] = src[i]; + } +} + +/** + * @brief copy data from cudf columns into JCUDF format, which is row-based + * + * @param num_rows total number of rows in the table + * @param num_columns total number of columns in the table + * @param shmem_used_per_block shared memory amount each `block_info` is using + * @param block_infos span of `block_info` structs the define the work + * @param input_data pointer to raw table data + * @param col_sizes array of sizes for each element in a column - one per column + * @param col_offsets offset into input data row for each column's start + * @param row_offsets offset to a specific row in the output data + * @param output_data pointer to output data + * + */ +__global__ void copy_to_rows(const size_type num_rows, + const size_type num_columns, + const size_type shmem_used_per_block, + device_span block_infos, + const int8_t** input_data, + const size_type* col_sizes, + const size_type* col_offsets, + const size_type* row_offsets, + int8_t** output_data) +{ + // We are going to copy the data in two passes. + // The first pass copies a chunk of data into shared memory. + // The second pass copies that chunk from shared memory out to the final location. + + // Because shared memory is limited we copy a subset of the rows at a time. + // This has been broken up for us in the block_info struct, so we don't have + // any calculation to do here, but it is important to note. + + constexpr unsigned stages_count = NUM_BLOCKS_PER_KERNEL_LOADED; + auto group = cooperative_groups::this_thread_block(); + extern __shared__ int8_t shared_data[]; + int8_t* shared[stages_count] = {shared_data, shared_data + shmem_used_per_block}; + + __shared__ cuda::barrier block_barrier[NUM_BLOCKS_PER_KERNEL_LOADED]; + if (group.thread_rank() == 0) { + for (int i = 0; i < NUM_BLOCKS_PER_KERNEL_LOADED; ++i) { + init(&block_barrier[i], group.size()); + } + } + + group.sync(); + + auto const blocks_remaining = + std::min((uint)block_infos.size() - blockIdx.x * NUM_BLOCKS_PER_KERNEL_TO_ROWS, + (uint)NUM_BLOCKS_PER_KERNEL_TO_ROWS); + + size_t fetch; + size_t subset; + for (subset = fetch = 0; subset < blocks_remaining; ++subset) { + // Fetch ahead up to stages_count subsets + for (; fetch < blocks_remaining && fetch < (subset + stages_count); ++fetch) { + auto const fetch_block = block_infos[blockIdx.x * NUM_BLOCKS_PER_KERNEL_TO_ROWS + fetch]; + auto const num_fetch_cols = fetch_block.num_cols(); + auto const num_fetch_rows = fetch_block.num_rows(); + auto const num_elements_in_block = num_fetch_cols * num_fetch_rows; + auto const fetch_block_row_size = fetch_block.get_shared_row_size(col_offsets, col_sizes); + auto const fetch_block_start_col = fetch_block.start_col; + auto const fetch_block_start_row = fetch_block.start_row; + auto const starting_column_offset = col_offsets[fetch_block.start_col]; + auto& fetch_barrier = block_barrier[fetch % NUM_BLOCKS_PER_KERNEL_LOADED]; + + // wait for the last use of the memory to be completed + if (fetch >= NUM_BLOCKS_PER_KERNEL_LOADED) { fetch_barrier.arrive_and_wait(); } + + // to do the copy we need to do n column copies followed by m element copies OR + // we have to do m element copies followed by r row copies. When going from column + // to row it is much easier to copy by elements first otherwise we would need a running + // total of the column sizes for our block, which isn't readily available. This makes it + // more appealing to copy element-wise from input data into shared matching the end layout + // and do row-based memcopies out. + + // offset for alignment shim in order to match shared memory with final dest + uint8_t const dest_shim_offset = + (row_offsets[fetch_block.start_row] + col_offsets[fetch_block.start_col]) & 7; + auto const shared_buffer_base = shared[fetch % stages_count]; + for (auto el = (int)threadIdx.x; el < num_elements_in_block; el += blockDim.x) { + auto const relative_col = el / num_fetch_rows; + auto const relative_row = el % num_fetch_rows; + auto const absolute_col = relative_col + fetch_block_start_col; + auto const absolute_row = relative_row + fetch_block_start_row; + auto const col_size = col_sizes[absolute_col]; + auto const col_offset = col_offsets[absolute_col]; + auto const relative_col_offset = col_offset - starting_column_offset; + + auto const shared_offset = + relative_row * fetch_block_row_size + relative_col_offset + dest_shim_offset; + auto const input_src = input_data[absolute_col] + col_size * absolute_row; + + // copy the element from global memory + switch (col_size) { + case 2: + *(int16_t*)(&shared_buffer_base[shared_offset]) = *(int16_t*)input_src; + /* cuda::memcpy_async(&shared_buffer_base[shared_offset], + input_src, + cuda::aligned_size_t<2>(col_size), + fetch_barrier);*/ + break; + case 4: + *(int32_t*)(&shared_buffer_base[shared_offset]) = *(int32_t*)input_src; + /* cuda::memcpy_async(&shared_buffer_base[shared_offset], + input_src, + cuda::aligned_size_t<4>(col_size), + fetch_barrier);*/ + break; + case 8: + *(int64_t*)(&shared_buffer_base[shared_offset]) = *(int64_t*)input_src; + /* cuda::memcpy_async(&shared_buffer_base[shared_offset], + input_src, + cuda::aligned_size_t<8>(col_size), + fetch_barrier);*/ + break; + default: + *(int8_t*)(&shared_buffer_base[shared_offset]) = *(int8_t*)input_src; + /* cuda::memcpy_async( + &shared_buffer_base[shared_offset], input_src, col_size, fetch_barrier);*/ + break; + } + } + } + + auto& subset_barrier = block_barrier[subset % NUM_BLOCKS_PER_KERNEL_LOADED]; + subset_barrier.arrive_and_wait(); + + auto const block = block_infos[blockIdx.x * NUM_BLOCKS_PER_KERNEL_TO_ROWS + subset]; + auto const block_row_size = block.get_shared_row_size(col_offsets, col_sizes); + auto const block_num_cols = block.num_cols(); + auto const column_offset = col_offsets[block.start_col]; + auto const block_output_buffer = output_data[block.batch_number]; + uint8_t const dest_shim_offset = (row_offsets[block.start_row] + column_offset) & 7; + + // single byte copy + /* for (int element=threadIdx.x; element 0 && byte_offset % block_row_size == 0) { + // first byte with leading pad + auto const num_single_bytes = 8 - dest_shim_offset; + for (auto i = 0; i < num_single_bytes; ++i) { + output_ptr[i] = input_ptr[i + dest_shim_offset]; + } + } else if ((byte_offset + 8) % block_row_size == 0 && + (block_row_size + dest_shim_offset) % 8 > 0) { + // last bytes of a row + auto const num_single_bytes = (block_row_size + dest_shim_offset) % 8; + for (auto i = 0; i < num_single_bytes; ++i) { + output_ptr[i] = input_ptr[i + dest_shim_offset]; + } + } else { + // copy 8 bytes aligned + const int64_t* long_col_input = reinterpret_cast(input_ptr); + *reinterpret_cast(output_ptr) = *long_col_input; + } + } + /* for (auto absolute_row = block.start_row; absolute_row <= block.end_row; + absolute_row++) { + auto const relative_row = absolute_row - block.start_row; + auto const output_dest = block_output_buffer + row_offsets[absolute_row] + column_offset; + auto const shared_offset = block_row_size * relative_row; + + cuda::memcpy_async(group, output_dest, + &shared[subset % stages_count][shared_offset], + block_row_size, + subset_barrier); + }*/ + } + + // wait on the last copies to complete + for (uint i = 0; i < std::min(stages_count, blocks_remaining); ++i) { + block_barrier[i].arrive_and_wait(); + } +} + +/** + * @brief copy data from row-based format to cudf columns + * + * @param num_rows total number of rows in the table + * @param num_columns total number of columns in the table + * @param shmem_used_per_block amount of shared memory that is used by a block + * @param row_offsets offset to a specific row in the output data + * @param output_data pointer to output data, partitioned by data size + * @param validity_offsets offset into input data row for validity data + * @param block_infos information about the blocks of work + * @param input_data pointer to input data + * + */ +__global__ void copy_validity_to_rows(const size_type num_rows, + const size_type num_columns, + const size_type shmem_used_per_block, + const size_type* row_offsets, + int8_t** output_data, + const size_type validity_offset, + device_span block_infos, + const bitmask_type** input_nm) +{ + extern __shared__ int8_t shared_data[]; + int8_t* shared_blocks[NUM_VALIDITY_BLOCKS_PER_KERNEL_LOADED] = { + shared_data, shared_data + shmem_used_per_block / 2}; + + using cudf::detail::warp_size; + + // each thread of warp reads a single int32 of validity - so we read 128 bytes + // then ballot_sync the bits and write the result to shmem + // after we fill shared mem memcpy it out in a blob. + // probably need knobs for number of rows vs columns to balance read/write + auto group = cooperative_groups::this_thread_block(); + + int const blocks_remaining = + std::min((uint)block_infos.size() - blockIdx.x * NUM_VALIDITY_BLOCKS_PER_KERNEL, + (uint)NUM_VALIDITY_BLOCKS_PER_KERNEL); + + __shared__ cuda::barrier + shared_block_barriers[NUM_VALIDITY_BLOCKS_PER_KERNEL_LOADED]; + if (group.thread_rank() == 0) { + for (int i = 0; i < NUM_VALIDITY_BLOCKS_PER_KERNEL_LOADED; ++i) { + init(&shared_block_barriers[i], group.size()); + } + } + + group.sync(); + + for (int validity_block = 0; validity_block < blocks_remaining; ++validity_block) { + if (validity_block >= NUM_VALIDITY_BLOCKS_PER_KERNEL_LOADED) { + shared_block_barriers[validity_block % NUM_VALIDITY_BLOCKS_PER_KERNEL_LOADED] + .arrive_and_wait(); + } + int8_t* this_shared_block = shared_blocks[validity_block % 2]; + auto block = block_infos[blockIdx.x * NUM_VALIDITY_BLOCKS_PER_KERNEL + validity_block]; + + auto const num_block_cols = block.num_cols(); + auto const num_block_rows = block.num_rows(); + + auto const num_sections_x = util::div_rounding_up_unsafe(num_block_cols, 32); + auto const num_sections_y = util::div_rounding_up_unsafe(num_block_rows, 32); + auto const validity_data_row_length = + align_offset(util::div_rounding_up_unsafe(num_block_cols, 8), 8); + auto const total_sections = num_sections_x * num_sections_y; + + int const warp_id = threadIdx.x / warp_size; + int const lane_id = threadIdx.x % warp_size; + auto const warps_per_block = std::max(1u, blockDim.x / warp_size); + + // the block is divided into sections. A warp operates on a section at a time. + for (int my_section_idx = warp_id; my_section_idx < total_sections; + my_section_idx += warps_per_block) { + // convert to rows and cols + auto const section_x = my_section_idx % num_sections_x; + auto const section_y = my_section_idx / num_sections_x; + auto const relative_col = section_x * 32 + lane_id; + auto const relative_row = section_y * 32; + auto const absolute_col = relative_col + block.start_col; + auto const absolute_row = relative_row + block.start_row; + auto const cols_left = num_columns - absolute_col; + auto const participation_mask = __ballot_sync(0xFFFFFFFF, absolute_col < num_columns); + + if (absolute_col < num_columns) { + auto my_data = input_nm[absolute_col] != nullptr ? input_nm[absolute_col][absolute_row / 32] + : std::numeric_limits::max(); + + // every thread that is participating in the warp has 4 bytes, but it's column-based + // data and we need it in row-based. So we shuffle the bits around with ballot_sync to + // make the bytes we actually write. + bitmask_type dw_mask = 1; + for (int i = 0; i < 32 && relative_row + i < num_rows; ++i, dw_mask <<= 1) { + auto validity_data = __ballot_sync(participation_mask, my_data & dw_mask); + // lead thread in each warp writes data + auto const validity_write_offset = + validity_data_row_length * (relative_row + i) + relative_col / 8; + if (threadIdx.x % warp_size == 0) { + if (cols_left <= 8) { + // write byte + this_shared_block[validity_write_offset] = validity_data & 0xFF; + } else if (cols_left <= 16) { + // write int16 + *reinterpret_cast(&this_shared_block[validity_write_offset]) = + validity_data & 0xFFFF; + } else if (cols_left <= 24) { + // write int16 and then int8 + *reinterpret_cast(&this_shared_block[validity_write_offset]) = + validity_data & 0xFFFF; + shared_data[validity_write_offset + 2] = (validity_data >> 16) & 0xFF; + } else { + // write int32 + *reinterpret_cast(&this_shared_block[validity_write_offset]) = + validity_data; + } + } + } + } + } + + // make sure entire block has finished copy + group.sync(); + + auto const output_data_base = + output_data[block.batch_number] + validity_offset + block.start_col / 8; + + auto const num_row_bytes = block.num_cols() / 8; + for (int element = threadIdx.x; element < block.num_rows() * num_row_bytes; + element += gridDim.x) { + auto const relative_row = element / num_row_bytes; + auto const absolute_row = block.start_row + relative_row; + auto const relative_byte = element % num_row_bytes; + auto const output_ptr = output_data_base + row_offsets[absolute_row] + relative_byte; + auto const input_ptr = + &this_shared_block[validity_data_row_length * relative_row + relative_byte]; + *output_ptr = *input_ptr; + } + + // now async memcpy the shared memory out to the final destination + /* for (int row = block.start_row + threadIdx.x; row <= block.end_row; row += blockDim.x) { + auto const relative_row = row - block.start_row; + auto const output_ptr = output_data_base + row_offsets[row]; + auto const input_ptr = &this_shared_block[validity_data_row_length * relative_row]; + auto const num_bytes = util::div_rounding_up_unsafe(num_block_cols, 8); + + cuda::memcpy_async( + output_ptr, + input_ptr, + num_bytes, + shared_block_barriers[validity_block % NUM_VALIDITY_BLOCKS_PER_KERNEL_LOADED]); + }*/ + } + + // wait for last blocks of data to arrive + for (int validity_block = 0; + validity_block < blocks_remaining % NUM_VALIDITY_BLOCKS_PER_KERNEL_LOADED; + ++validity_block) { + shared_block_barriers[validity_block].arrive_and_wait(); + } +} + +/** + * @brief Admin data is data stored in shared memory that isn't actual column data + * + * @param col_size_size size of the column size data. + * @param col_offset_size size of the column offset data. + * @param num_cols number of columns in the block. + * @return tuple of the size of column and offset admin data. + */ +static __device__ std::tuple get_admin_data_sizes(size_t col_size_size, + size_t col_offset_size, + int const num_cols) +{ + auto const col_size_bytes = num_cols * col_size_size; + auto const col_offset_bytes = num_cols * col_offset_size; + + return {col_size_bytes, col_offset_bytes}; +} + +/** + * @brief copy data from row-based format to cudf columns + * + * @param num_rows total number of rows in the table + * @param num_columns total number of columns in the table + * @param shmem_used_per_block amount of shared memory that is used by a block + * @param row_offsets offset to a specific row in the input data + * @param output_data pointers to column data + * @param col_sizes array of sizes for each element in a column - one per column + * @param col_offsets offset into input data row for each column's start + * @param block_infos information about the blocks of work + * @param input_data pointer to input data + * + */ +__global__ void copy_from_rows(const size_type num_rows, + const size_type num_columns, + const size_type shmem_used_per_block, + const size_type* row_offsets, + int8_t** output_data, + const size_type* _col_sizes, + const size_type* _col_offsets, + device_span block_infos, + const int8_t* input_data) +{ + // We are going to copy the data in two passes. + // The first pass copies a chunk of data into shared memory. + // The second pass copies that chunk from shared memory out to the final location. + + // Because shared memory is limited we copy a subset of the rows at a time. + // This has been broken up for us in the block_info struct, so we don't have + // any calculation to do here, but it is important to note. + + // to speed up some of the random access memory we do, we copy col_sizes and col_offsets + // to shared memory for each of the blocks that we work on + + constexpr unsigned stages_count = NUM_BLOCKS_PER_KERNEL_LOADED; + auto group = cooperative_groups::this_thread_block(); + extern __shared__ int8_t shared_data[]; + int8_t* shared[stages_count] = {shared_data, shared_data + shmem_used_per_block}; + + __shared__ cuda::barrier block_barrier[NUM_BLOCKS_PER_KERNEL_LOADED]; + if (group.thread_rank() == 0) { + for (int i = 0; i < NUM_BLOCKS_PER_KERNEL_LOADED; ++i) { + init(&block_barrier[i], group.size()); + } + } + + group.sync(); + + auto blocks_remaining = + std::min((uint)block_infos.size() - blockIdx.x * NUM_BLOCKS_PER_KERNEL_FROM_ROWS, + (uint)NUM_BLOCKS_PER_KERNEL_FROM_ROWS); + + size_t fetch_index; + size_t processing_index; + for (processing_index = fetch_index = 0; processing_index < blocks_remaining; + ++processing_index) { + // Fetch ahead up to stages_count groups + for (; fetch_index < static_cast(blocks_remaining) && + fetch_index < (processing_index + stages_count); + ++fetch_index) { + auto const fetch_block = + block_infos[blockIdx.x * NUM_BLOCKS_PER_KERNEL_FROM_ROWS + fetch_index]; + auto const fetch_block_start_row = fetch_block.start_row; + auto const fetch_block_end_row = fetch_block.end_row; + auto const starting_col_offset = _col_offsets[fetch_block.start_col]; + auto const fetch_block_row_size = fetch_block.get_shared_row_size(_col_offsets, _col_sizes); + auto const num_fetch_cols = fetch_block.num_cols(); + auto [col_size_bytes, col_offset_bytes] = get_admin_data_sizes( + sizeof(decltype(*_col_sizes)), sizeof(decltype(*_col_offsets)), num_fetch_cols); + auto& fetch_barrier = block_barrier[fetch_index % NUM_BLOCKS_PER_KERNEL_LOADED]; + + // if we have fetched all buffers, we need to wait for processing + // to complete on them before we can use them again + if (fetch_index > NUM_BLOCKS_PER_KERNEL_LOADED) { fetch_barrier.arrive_and_wait(); } + + auto shared_row_offset = 0; + // copy the data for column sizes + cuda::memcpy_async(group, + &shared[fetch_index % stages_count][shared_row_offset], + &_col_sizes[fetch_block.start_col], + col_size_bytes, + fetch_barrier); + shared_row_offset += col_size_bytes; + // copy the data for column offsets + cuda::memcpy_async(group, + &shared[fetch_index % stages_count][shared_row_offset], + &_col_offsets[fetch_block.start_col], + col_offset_bytes, + fetch_barrier); + shared_row_offset += col_offset_bytes; + shared_row_offset = align_offset(shared_row_offset, 8); + + for (auto row = fetch_block_start_row + static_cast(threadIdx.x); + row <= fetch_block_end_row; + row += blockDim.x) { + auto shared_offset = + (row - fetch_block_start_row) * fetch_block_row_size + shared_row_offset; + // copy the main + cuda::memcpy_async(&shared[fetch_index % stages_count][shared_offset], + &input_data[row_offsets[row] + starting_col_offset], + fetch_block_row_size, + fetch_barrier); + } + } + + auto& processing_barrier = block_barrier[processing_index % NUM_BLOCKS_PER_KERNEL_LOADED]; + + // ensure our data is ready + processing_barrier.arrive_and_wait(); + + auto block = block_infos[blockIdx.x * NUM_BLOCKS_PER_KERNEL_FROM_ROWS + processing_index]; + auto const rows_in_block = block.num_rows(); + auto const cols_in_block = block.num_cols(); + + auto [col_size_bytes, col_offset_bytes] = get_admin_data_sizes( + sizeof(decltype(*_col_sizes)), sizeof(decltype(*_col_offsets)), cols_in_block); + auto shared_col_sizes = reinterpret_cast(shared[processing_index % stages_count]); + auto shared_col_offsets = + reinterpret_cast(&shared[processing_index % stages_count][col_size_bytes]); + + auto const shared_row_offset = align_offset(col_size_bytes + col_offset_bytes, 8); + + auto block_row_size = block.get_shared_row_size(_col_offsets, _col_sizes); + + // now we copy from shared memory to final destination. + // the data is laid out in rows in shared memory, so the reads + // for a column will be "vertical". Because of this and the different + // sizes for each column, this portion is handled on row/column basis. + // to prevent each thread working on a single row and also to ensure + // that all threads can do work in the case of more threads than rows, + // we do a global index instead of a double for loop with col/row. + for (int index = threadIdx.x; index < rows_in_block * cols_in_block; index += blockDim.x) { + auto const relative_col = index % cols_in_block; + auto const relative_row = index / cols_in_block; + auto const absolute_col = relative_col + block.start_col; + auto const absolute_row = relative_row + block.start_row; + + auto const shared_memory_row_offset = block_row_size * relative_row; + auto const shared_memory_offset = shared_col_offsets[relative_col] - shared_col_offsets[0] + + shared_memory_row_offset + shared_row_offset; + auto const column_size = shared_col_sizes[relative_col]; + + int8_t* shmem_src = &shared[processing_index % stages_count][shared_memory_offset]; + int8_t* dst = &output_data[absolute_col][absolute_row * column_size]; + + cuda::memcpy_async(dst, shmem_src, column_size, processing_barrier); + } + group.sync(); + } + + // wait on the last copies to complete + for (uint i = 0; i < std::min(stages_count, blocks_remaining); ++i) { + block_barrier[i].arrive_and_wait(); + } +} + +/** + * @brief copy data from row-based format to cudf columns + * + * @param num_rows total number of rows in the table + * @param num_columns total number of columns in the table + * @param shmem_used_per_block amount of shared memory that is used by a block + * @param row_offsets offset to a specific row in the input data + * @param output_nm pointers to null masks for columns + * @param validity_offsets offset into input data row for validity data + * @param block_infos information about the blocks of work + * @param input_data pointer to input data + * + */ +__global__ void copy_validity_from_rows(const size_type num_rows, + const size_type num_columns, + const size_type shmem_used_per_block, + const size_type* row_offsets, + cudf::bitmask_type** output_nm, + const size_type validity_offset, + device_span block_infos, + const int8_t* input_data) +{ + extern __shared__ int8_t shared_data[]; + int8_t* shared_blocks[NUM_VALIDITY_BLOCKS_PER_KERNEL_LOADED] = { + shared_data, shared_data + shmem_used_per_block / 2}; + + using cudf::detail::warp_size; + + // each thread of warp reads a single byte of validity - so we read 32 bytes + // then ballot_sync the bits and write the result to shmem + // after we fill shared mem memcpy it out in a blob. + // probably need knobs for number of rows vs columns to balance read/write + auto group = cooperative_groups::this_thread_block(); + + int const blocks_remaining = + std::min((uint)block_infos.size() - blockIdx.x * NUM_VALIDITY_BLOCKS_PER_KERNEL, + (uint)NUM_VALIDITY_BLOCKS_PER_KERNEL); + + __shared__ cuda::barrier + shared_block_barriers[NUM_VALIDITY_BLOCKS_PER_KERNEL_LOADED]; + if (group.thread_rank() == 0) { + for (int i = 0; i < NUM_VALIDITY_BLOCKS_PER_KERNEL_LOADED; ++i) { + init(&shared_block_barriers[i], group.size()); + } + } + + group.sync(); + + for (int validity_block = 0; validity_block < blocks_remaining; ++validity_block) { + if (validity_block >= NUM_VALIDITY_BLOCKS_PER_KERNEL_LOADED) { + auto const validity_index = validity_block % NUM_VALIDITY_BLOCKS_PER_KERNEL_LOADED; + shared_block_barriers[validity_index].arrive_and_wait(); + } + int8_t* this_shared_block = shared_blocks[validity_block % 2]; + auto const block = block_infos[blockIdx.x * NUM_VALIDITY_BLOCKS_PER_KERNEL + validity_block]; + auto const block_start_col = block.start_col; + auto const block_start_row = block.start_row; + auto const num_block_cols = block.num_cols(); + auto const num_block_rows = block.num_rows(); + auto const num_sections_x = (num_block_cols + 7) / 8; + auto const num_sections_y = (num_block_rows + 31) / 32; + auto const validity_data_col_length = num_sections_y * 4; // words to bytes + auto const total_sections = num_sections_x * num_sections_y; + int const warp_id = threadIdx.x / warp_size; + int const lane_id = threadIdx.x % warp_size; + auto const warps_per_block = std::max(1u, blockDim.x / warp_size); + + // the block is divided into sections. A warp operates on a section at a time. + for (int my_section_idx = warp_id; my_section_idx < total_sections; + my_section_idx += warps_per_block) { + // convert section to row and col + auto const section_x = my_section_idx % num_sections_x; + auto const section_y = my_section_idx / num_sections_x; + auto const relative_col = section_x * 8; + auto const relative_row = section_y * 32 + lane_id; + auto const absolute_col = relative_col + block_start_col; + auto const absolute_row = relative_row + block_start_row; + auto const rows_left = num_rows - absolute_row; + + auto const participation_mask = __ballot_sync(0xFFFFFFFF, absolute_row < num_rows); + + if (absolute_row < num_rows) { + auto const my_byte = + input_data[row_offsets[absolute_row] + validity_offset + absolute_col / 8]; + + // so every thread that is participating in the warp has a byte, but it's row-based + // data and we need it in column-based. So we shuffle the bits around to make + // the bytes we actually write. + for (int i = 0, byte_mask = 1; i < 8 && relative_col + i < num_columns; + ++i, byte_mask <<= 1) { + auto validity_data = __ballot_sync(participation_mask, my_byte & byte_mask); + // lead thread in each warp writes data + if (threadIdx.x % warp_size == 0) { + auto const validity_write_offset = + validity_data_col_length * (relative_col + i) + relative_row / 8; + + if (rows_left <= 8) { + // write byte + this_shared_block[validity_write_offset] = validity_data & 0xFF; + } else if (rows_left <= 16) { + // write int16 + *reinterpret_cast(&this_shared_block[validity_write_offset]) = + validity_data & 0xFFFF; + } else if (rows_left <= 24) { + // write int16 and then int8 + *reinterpret_cast(&this_shared_block[validity_write_offset]) = + validity_data & 0xFFFF; + shared_data[validity_write_offset + 2] = (validity_data >> 16) & 0xFF; + } else { + // write int32 + *reinterpret_cast(&this_shared_block[validity_write_offset]) = + validity_data; + } + } + } + } + } + + // make sure entire block has finished copy + group.sync(); + + // now async memcpy the shared + for (int col = block.start_col + threadIdx.x; col <= block.end_col; col += blockDim.x) { + auto const relative_col = col - block.start_col; + auto const starting_address = output_nm[col] + word_index(block_start_row); + + cuda::memcpy_async( + starting_address, + &this_shared_block[validity_data_col_length * relative_col], + util::div_rounding_up_unsafe(num_block_rows, 8), + shared_block_barriers[validity_block % NUM_VALIDITY_BLOCKS_PER_KERNEL_LOADED]); + } + } + + // wait for last blocks of data to arrive + auto const num_blocks_to_wait = blocks_remaining > NUM_VALIDITY_BLOCKS_PER_KERNEL_LOADED + ? NUM_VALIDITY_BLOCKS_PER_KERNEL_LOADED + : blocks_remaining; + for (int validity_block = 0; validity_block < num_blocks_to_wait; ++validity_block) { + shared_block_barriers[validity_block].arrive_and_wait(); + } +} + +#endif // !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700 + +/** + * @brief Calculate the dimensions of the kernel for fixed width only columns. + * + * @param [in] num_columns the number of columns being copied. + * @param [in] num_rows the number of rows being copied. + * @param [in] size_per_row the size each row takes up when padded. + * @param [out] blocks the size of the blocks for the kernel + * @param [out] threads the size of the threads for the kernel + * @return the size in bytes of shared memory needed for each block. + */ +static int calc_fixed_width_kernel_dims(const cudf::size_type num_columns, + const cudf::size_type num_rows, + const cudf::size_type size_per_row, + dim3& blocks, + dim3& threads) +{ + // We have found speed degrades when a thread handles more than 4 columns. + // Each block is 2 dimensional. The y dimension indicates the columns. + // We limit this to 32 threads in the y dimension so we can still + // have at least 32 threads in the x dimension (1 warp) which should + // result in better coalescing of memory operations. We also + // want to guarantee that we are processing a multiple of 32 threads + // in the x dimension because we use atomic operations at the block + // level when writing validity data out to main memory, and that would + // need to change if we split a word of validity data between blocks. + int y_block_size = (num_columns + 3) / 4; // cudf::util::div_rounding_up_safe(num_columns, 4); + if (y_block_size > 32) { y_block_size = 32; } + int x_possible_block_size = 1024 / y_block_size; + // 48KB is the default setting for shared memory per block according to the cuda tutorials + // If someone configures the GPU to only have 16 KB this might not work. + int max_shared_size = 48 * 1024; + int max_block_size = max_shared_size / size_per_row; + // If we don't have enough shared memory there is no point in having more threads + // per block that will just sit idle + max_block_size = max_block_size > x_possible_block_size ? x_possible_block_size : max_block_size; + // Make sure that the x dimension is a multiple of 32 this not only helps + // coalesce memory access it also lets us do a ballot sync for validity to write + // the data back out the warp level. If x is a multiple of 32 then each thread in the y + // dimension is associated with one or more warps, that should correspond to the validity + // words directly. + int block_size = (max_block_size / 32) * 32; + CUDF_EXPECTS(block_size != 0, "Row size is too large to fit in shared memory"); + + int num_blocks = (num_rows + block_size - 1) / block_size; + if (num_blocks < 1) { + num_blocks = 1; + } else if (num_blocks > 10240) { + // The maximum number of blocks supported in the x dimension is 2 ^ 31 - 1 + // but in practice haveing too many can cause some overhead that I don't totally + // understand. Playing around with this haveing as little as 600 blocks appears + // to be able to saturate memory on V100, so this is an order of magnitude higher + // to try and future proof this a bit. + num_blocks = 10240; + } + blocks.x = num_blocks; + blocks.y = 1; + blocks.z = 1; + threads.x = block_size; + threads.y = y_block_size; + threads.z = 1; + return size_per_row * block_size; +} + +/** + * When converting to rows it is possible that the size of the table was too big to fit + * in a single column. This creates an output column for a subset of the rows in a table + * going from start row and containing the next num_rows. Most of the parameters passed + * into this function are common between runs and should be calculated once. + */ +static std::unique_ptr fixed_width_convert_to_rows( + const cudf::size_type start_row, + const cudf::size_type num_rows, + const cudf::size_type num_columns, + const cudf::size_type size_per_row, + rmm::device_uvector& column_start, + rmm::device_uvector& column_size, + rmm::device_uvector& input_data, + rmm::device_uvector& input_nm, + const cudf::scalar& zero, + const cudf::scalar& scalar_size_per_row, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + int64_t const total_allocation = size_per_row * num_rows; + // We made a mistake in the split somehow + CUDF_EXPECTS(total_allocation < std::numeric_limits::max(), "Table is too large to fit!"); + + // Allocate and set the offsets row for the byte array + std::unique_ptr offsets = + cudf::detail::sequence(num_rows + 1, zero, scalar_size_per_row, stream); + + std::unique_ptr data = + cudf::make_numeric_column(cudf::data_type(cudf::type_id::INT8), + static_cast(total_allocation), + cudf::mask_state::UNALLOCATED, + stream, + mr); + + dim3 blocks; + dim3 threads; + int shared_size = + detail::calc_fixed_width_kernel_dims(num_columns, num_rows, size_per_row, blocks, threads); + + copy_to_rows_fixed_width_optimized<<>>( + start_row, + num_rows, + num_columns, + size_per_row, + column_start.data(), + column_size.data(), + input_data.data(), + input_nm.data(), + data->mutable_view().data()); + + return cudf::make_lists_column(num_rows, + std::move(offsets), + std::move(data), + 0, + rmm::device_buffer{0, rmm::cuda_stream_default, mr}, + stream, + mr); +} + +static inline bool are_all_fixed_width(std::vector const& schema) +{ + return std::all_of( + schema.begin(), schema.end(), [](const cudf::data_type& t) { return cudf::is_fixed_width(t); }); +} + +/** + * @brief Given a set of fixed width columns, calculate how the data will be laid out in memory. + * + * @param [in] schema the types of columns that need to be laid out. + * @param [out] column_start the byte offset where each column starts in the row. + * @param [out] column_size the size in bytes of the data for each columns in the row. + * @return the size in bytes each row needs. + */ +static inline int32_t compute_fixed_width_layout(std::vector const& schema, + std::vector& column_start, + std::vector& column_size) +{ + // We guarantee that the start of each column is 64-bit aligned so anything can go + // there, but to make the code simple we will still do an alignment for it. + int32_t at_offset = 0; + for (auto col = schema.begin(); col < schema.end(); col++) { + cudf::size_type s = cudf::size_of(*col); + column_size.emplace_back(s); + std::size_t allocation_needed = s; + std::size_t alignment_needed = allocation_needed; // They are the same for fixed width types + at_offset = align_offset(at_offset, alignment_needed); + column_start.emplace_back(at_offset); + at_offset += allocation_needed; + } + + // Now we need to add in space for validity + // Eventually we can think about nullable vs not nullable, but for now we will just always add + // it in + int32_t validity_bytes_needed = + (schema.size() + 7) / 8; // cudf::util::div_rounding_up_safe(schema.size(), 8); + // validity comes at the end and is byte aligned so we can pack more in. + at_offset += validity_bytes_needed; + // Now we need to pad the end so all rows are 64 bit aligned + return align_offset(at_offset, 8); // 8 bytes (64 bits) +} + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700 + +/** + * @brief Compute information about a table such as bytes per row and offsets. + * + * @tparam iterator iterator of column schema data + * @param begin starting iterator of column schema + * @param end ending iterator of column schema + * @param column_starts column start offsets + * @param column_sizes size in bytes of each column + * @return size of the fixed_width data portion of a row. + */ +template +static size_type compute_column_information(iterator begin, + iterator end, + std::vector& column_starts, + std::vector& column_sizes) +{ + size_type fixed_width_size_per_row = 0; + for (auto cv = begin; cv != end; ++cv) { + auto col_type = std::get<0>(*cv); + bool nested_type = col_type.id() == type_id::LIST || col_type.id() == type_id::STRING; + + // a list or string column will write a single uint64 + // of data here for offset/length + auto col_size = nested_type ? 8 : size_of(col_type); + + // align size for this type + std::size_t const alignment_needed = col_size; // They are the same for fixed width types + fixed_width_size_per_row = detail::align_offset(fixed_width_size_per_row, alignment_needed); + column_starts.push_back(fixed_width_size_per_row); + column_sizes.push_back(col_size); + fixed_width_size_per_row += col_size; + } + + auto validity_offset = fixed_width_size_per_row; + column_starts.push_back(validity_offset); + + return fixed_width_size_per_row; +} + +/** + * @brief Build `block_info` for the validity data to break up the work. + * + * @param num_columns number of columns in the table + * @param num_rows number of rows in the table + * @param shmem_limit_per_block size of shared memory available to a single gpu block + * @param row_batches batched row information for multiple output locations + * @return vector of `block_info` structs for validity data + */ +std::vector build_validity_block_infos( + size_type const& num_columns, + size_type const& num_rows, + size_type const& shmem_limit_per_block, + std::vector const& row_batches) +{ + auto const desired_rows_and_columns = (int)sqrt(shmem_limit_per_block); + auto const column_stride = align_offset( + [&]() { + if (desired_rows_and_columns > num_columns) { + // not many columns, group it into 8s and ship it off + return std::min(8, num_columns); + } else { + return util::round_down_safe(desired_rows_and_columns, 8); + } + }(), + 8); + // we fit as much as we can given the column stride + // note that an element in the table takes just 1 bit, but a row with a single + // element still takes 8 bytes! + auto const bytes_per_row = align_offset(util::div_rounding_up_unsafe(column_stride, 8), 8); + auto const row_stride = std::min(num_rows, shmem_limit_per_block / bytes_per_row); + + std::vector validity_block_infos; + for (int col = 0; col < num_columns; col += column_stride) { + int current_window_row_batch = 0; + int rows_left_in_batch = row_batches[current_window_row_batch].row_count; + int row = 0; + while (row < num_rows) { + if (rows_left_in_batch == 0) { + current_window_row_batch++; + rows_left_in_batch = row_batches[current_window_row_batch].row_count; + } + int const window_height = std::min(row_stride, rows_left_in_batch); + + validity_block_infos.emplace_back(detail::block_info{ + col, row, std::min(col + column_stride - 1, num_columns - 1), row + window_height - 1}); + row += window_height; + rows_left_in_batch -= window_height; + } + } + + return validity_block_infos; +} + +constexpr size_type max_batch_size = std::numeric_limits::max(); + +/** + * @brief Holds information about the batches of data to be processed + * + */ +struct batch_data { + device_uvector batch_row_offsets; + std::vector batch_row_boundaries; + std::vector row_batches; +}; + +template +struct row_size_functor { + RowSize _row_sizes; + size_type _num_rows; + row_size_functor(RowSize row_sizes) : _row_sizes(row_sizes){}; + + CUDA_DEVICE_CALLABLE + uint64_t operator()(int row_index) { return static_cast(_row_sizes[row_index]); } +}; + +/** + * @brief Builds batches of rows that will fit in the size limit of a column. + * + * @tparam RowSize iterator that gives the size of a specific row of the table. + * @param num_rows Total number of rows in the table + * @param row_sizes iterator that gives the size of a specific row of the table. + * @param stream stream to operate on for this work + * @param mr memory resource used to allocate any returned data + * @returns vector of size_type's that indicate row numbers for batch boundaries and a + * device_uvector of row offsets + */ + +template +batch_data build_batches(size_type num_rows, + RowSize row_sizes, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + auto uint64_row_sizes = + cudf::detail::make_counting_transform_iterator(0, row_size_functor(row_sizes)); + auto const total_size = + thrust::reduce(rmm::exec_policy(stream), uint64_row_sizes, uint64_row_sizes + num_rows); + auto const num_batches = static_cast( + util::div_rounding_up_safe(total_size, static_cast(max_batch_size))); + auto const num_offsets = num_batches + 1; + std::vector row_batches; + std::vector batch_row_boundaries; + device_uvector batch_row_offsets(num_rows, stream); + + // at most max gpu memory / 2GB iterations. + batch_row_boundaries.reserve(num_offsets); + batch_row_boundaries.push_back(0); + size_type last_row_end = 0; + device_uvector cumulative_row_sizes(num_rows, stream); + thrust::inclusive_scan(rmm::exec_policy(stream), + uint64_row_sizes, + uint64_row_sizes + num_rows, + cumulative_row_sizes.begin()); + + while ((int)batch_row_boundaries.size() < num_offsets) { + // find the next max_batch_size boundary + size_type const row_end = + ((thrust::lower_bound(rmm::exec_policy(stream), + cumulative_row_sizes.begin(), + cumulative_row_sizes.begin() + (num_rows - last_row_end), + max_batch_size) - + cumulative_row_sizes.begin()) + + last_row_end); + + // build offset list for each row in this batch + auto const num_rows_in_batch = row_end - last_row_end; + + // build offset list for each row in this batch + auto const num_entries = row_end - last_row_end + 1; + device_uvector output_batch_row_offsets(num_entries, stream, mr); + + auto row_size_iter_bounded = thrust::make_transform_iterator( + thrust::make_counting_iterator(0), [row_end, row_sizes, last_row_end] __device__(auto i) { + return i >= row_end ? 0 : row_sizes[i + last_row_end]; + }); + + thrust::exclusive_scan(rmm::exec_policy(stream), + row_size_iter_bounded, + row_size_iter_bounded + num_entries, + output_batch_row_offsets.begin()); + + auto const batch_bytes = output_batch_row_offsets.element(num_rows_in_batch, stream); + + // The output_batch_row_offsets vector is used as the offset column of the returned data. This + // needs to be individually allocated, but the kernel needs a contiguous array of offsets or + // more global lookups are necessary. + cudaMemcpy(batch_row_offsets.data() + last_row_end, + output_batch_row_offsets.data(), + num_rows_in_batch * sizeof(size_type), + cudaMemcpyDeviceToDevice); + + batch_row_boundaries.push_back(row_end); + row_batches.push_back({batch_bytes, num_rows_in_batch, std::move(output_batch_row_offsets)}); + + last_row_end = row_end; + } + + return {std::move(batch_row_offsets), batch_row_boundaries, std::move(row_batches)}; +} + +/** + * @brief Computes the number of blocks necessary given a window height and batch offsets + * + * @param batch_row_boundaries row boundaries for each batch + * @param desired_window_height height of each window in the table + * @param stream stream to use + * @return number of windows necessary + */ +int compute_block_counts(device_span const& batch_row_boundaries, + int desired_window_height, + rmm::cuda_stream_view stream) +{ + size_type const num_batches = batch_row_boundaries.size() - 1; + device_uvector num_blocks(num_batches, stream); + auto iter = thrust::make_counting_iterator(0); + thrust::transform( + rmm::exec_policy(stream), + iter, + iter + num_batches, + num_blocks.begin(), + [desired_window_height, + batch_row_boundaries = batch_row_boundaries.data()] __device__(auto batch_index) -> size_type { + return util::div_rounding_up_unsafe( + batch_row_boundaries[batch_index + 1] - batch_row_boundaries[batch_index], + desired_window_height); + }); + return thrust::reduce(rmm::exec_policy(stream), num_blocks.begin(), num_blocks.end()); +} + +/** + * @brief Builds the `block_info` structs for a given table. + * + * @param blocks span of blocks to populate + * @param batch_row_boundaries boundary to row batches + * @param column_start starting column of the window + * @param column_end ending column of the window + * @param desired_window_height height of the window + * @param total_number_of_rows total number of rows in the table + * @param stream stream to use + * @return number of windows created + */ +size_type build_blocks( + device_span blocks, + device_uvector const& batch_row_boundaries, // comes from build_batches + int column_start, + int column_end, + int desired_window_height, + int total_number_of_rows, + rmm::cuda_stream_view stream) +{ + size_type const num_batches = batch_row_boundaries.size() - 1; + device_uvector num_blocks(num_batches, stream); + auto iter = thrust::make_counting_iterator(0); + thrust::transform( + rmm::exec_policy(stream), + iter, + iter + num_batches, + num_blocks.begin(), + [desired_window_height, + batch_row_boundaries = batch_row_boundaries.data()] __device__(auto batch_index) -> size_type { + return util::div_rounding_up_unsafe( + batch_row_boundaries[batch_index + 1] - batch_row_boundaries[batch_index], + desired_window_height); + }); + + size_type const total_blocks = + thrust::reduce(rmm::exec_policy(stream), num_blocks.begin(), num_blocks.end()); + + device_uvector block_starts(num_batches + 1, stream); + auto block_iter = cudf::detail::make_counting_transform_iterator( + 0, [num_blocks = num_blocks.data(), num_batches] __device__(auto i) { + return (i < num_batches) ? num_blocks[i] : 0; + }); + thrust::exclusive_scan(rmm::exec_policy(stream), + block_iter, + block_iter + num_batches + 1, + block_starts.begin()); // in blocks + + thrust::transform( + rmm::exec_policy(stream), + iter, + iter + total_blocks, + blocks.begin(), + [ =, + block_starts = block_starts.data(), + batch_row_boundaries = batch_row_boundaries.data()] __device__(size_type block_index) { + // what batch this block falls in + auto const batch_index_iter = + thrust::upper_bound(thrust::seq, block_starts, block_starts + num_batches, block_index); + auto const batch_index = std::distance(block_starts, batch_index_iter) - 1; + // local index within the block + int const local_block_index = block_index - block_starts[batch_index]; + // the start row for this batch. + int const batch_row_start = batch_row_boundaries[batch_index]; + // the start row for this block + int const block_row_start = batch_row_start + (local_block_index * desired_window_height); + // the end row for this block + int const max_row = std::min(total_number_of_rows - 1, + batch_index + 1 > num_batches + ? std::numeric_limits::max() + : static_cast(batch_row_boundaries[batch_index + 1]) - 1); + int const block_row_end = + std::min(batch_row_start + ((local_block_index + 1) * desired_window_height) - 1, max_row); + + // stuff the block + return block_info{ + column_start, block_row_start, column_end, block_row_end, static_cast(batch_index)}; + }); + + return total_blocks; +} + +/** + * @brief Determines what data should be operated on by each block for the incoming table. + * + * @tparam WindowCallback Callback that receives the start and end columns of windows + * @param column_sizes vector of the size of each column + * @param column_starts vector of the offset of each column + * @param first_row_batch_size size of the first row batch to limit max window size since a window + * is unable to span batches + * @param total_number_of_rows total number of rows in the table + * @param shmem_limit_per_block shared memory allowed per block + * @param f callback function called when building a window + */ +template +void determine_windows(std::vector const& column_sizes, + std::vector const& column_starts, + size_type const first_row_batch_size, + size_type const total_number_of_rows, + size_type const& shmem_limit_per_block, + WindowCallback f) +{ + // block infos are organized with the windows going "down" the columns + // this provides the most coalescing of memory access + int current_window_width = 0; + int current_window_start_col = 0; + + // the ideal window height has lots of 8-byte reads and 8-byte writes. The optimal read/write + // would be memory cache line sized access, but since other blocks will read/write the edges + // this may not turn out to be overly important. For now, we will attempt to build a square + // window as far as byte sizes. x * y = shared_mem_size. Which translates to x^2 = + // shared_mem_size since we want them equal, so height and width are sqrt(shared_mem_size). The + // trick is that it's in bytes, not rows or columns. + size_type const optimal_square_len = 32; // size_type(sqrt(shmem_limit_per_block)); + int const window_height = + std::clamp(util::round_up_safe( + std::min(optimal_square_len / column_sizes[0], total_number_of_rows), 32), + 1, + first_row_batch_size); + + auto calc_admin_data_size = [](int num_cols) -> size_type { + // admin data is the column sizes and column start information. + // this is copied to shared memory as well and needs to be accounted for + // in the window calculation. + return num_cols * sizeof(size_type) + num_cols * sizeof(size_type); + }; + + int row_size = 0; + + // march each column and build the blocks of appropriate sizes + for (unsigned int col = 0; col < column_sizes.size(); ++col) { + auto const col_size = column_sizes[col]; + + // align size for this type + std::size_t alignment_needed = col_size; // They are the same for fixed width types + auto row_size_aligned = detail::align_offset(row_size, alignment_needed); + auto row_size_with_this_col = row_size_aligned + col_size; + auto row_size_with_end_pad = detail::align_offset(row_size_with_this_col, 8); + + if (row_size_with_end_pad * window_height + + calc_admin_data_size(col - current_window_start_col) > + shmem_limit_per_block) { + // too large, close this window, generate vertical blocks and restart + f(current_window_start_col, col == 0 ? col : col - 1, window_height); + + row_size = + detail::align_offset((column_starts[col] + column_sizes[col]) & 7, alignment_needed); + row_size += col_size; // alignment required for shared memory window boundary to match + // alignment of output row + current_window_start_col = col; + current_window_width = 0; + } else { + row_size = row_size_with_this_col; + current_window_width++; + } + } + + // build last set of blocks + if (current_window_width > 0) { + f(current_window_start_col, (int)column_sizes.size() - 1, window_height); + } +} + +#endif // #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700 + +} // namespace detail + +std::vector> convert_to_rows(cudf::table_view const& tbl, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700 + const size_type num_columns = tbl.num_columns(); + const size_type num_rows = tbl.num_rows(); + + int device_id; + CUDA_TRY(cudaGetDevice(&device_id)); + int total_shmem; + CUDA_TRY(cudaDeviceGetAttribute(&total_shmem, cudaDevAttrMaxSharedMemoryPerBlock, device_id)); + + cuda_event_timer setup_time, copy_kernel_time, validity_kernel_time, post_time; + + setup_time.start_timer("setup time", false); + + // TODO: why is this needed. kernel fails to launch if all memory is requested. + total_shmem -= 1024; + int shmem_limit_per_block = total_shmem / NUM_BLOCKS_PER_KERNEL_LOADED; + + // break up the work into blocks, which are a starting and ending row/col #. + // this window size is calculated based on the shared memory size available + // we want a single block to fill up the entire shared memory space available + // for the transpose-like conversion. + + // There are two different processes going on here. The GPU conversion of the data + // and the writing of the data into the list of byte columns that are a maximum of + // 2 gigs each due to offset maximum size. The GPU conversion portion has to understand + // this limitation because the column must own the data inside and as a result it must be + // a distinct allocation for that column. Copying the data into these final buffers would + // be prohibitively expensive, so care is taken to ensure the GPU writes to the proper buffer. + // The windows are broken at the boundaries of specific rows based on the row sizes up + // to that point. These are row batches and they are decided first before building the + // windows so the windows can be properly cut around them. + + // Get the pointers to the input columnar data ready + std::vector input_data; + std::vector input_nm; + input_data.reserve(num_columns); + input_nm.reserve(num_columns); + std::transform( + tbl.begin(), + tbl.end(), + std::back_inserter(input_data), + [](cudf::column_view const& c) -> int8_t const* { return c.template data(); }); + std::transform( + tbl.begin(), tbl.end(), std::back_inserter(input_nm), [](auto c) { return c.null_mask(); }); + + auto dev_input_data = make_device_uvector_async(input_data, stream, mr); + auto dev_input_nm = make_device_uvector_async(input_nm, stream, mr); + + std::vector column_sizes; // byte size of each column + std::vector column_starts; // offset of column inside a row including alignment + column_sizes.reserve(num_columns); + column_starts.reserve(num_columns + 1); // we add a final offset for validity data start + + auto schema_column_iter = thrust::make_transform_iterator( + thrust::make_counting_iterator(0), [&tbl](auto i) -> std::tuple { + return std::make_tuple(tbl.column(i).type(), tbl.column(i)); + }); + + size_type fixed_width_size_per_row = detail::compute_column_information( + schema_column_iter, schema_column_iter + num_columns, column_starts, column_sizes); + + auto dev_col_sizes = make_device_uvector_async(column_sizes, stream, mr); + auto dev_col_starts = make_device_uvector_async(column_starts, stream, mr); + + // total encoded row size. This includes fixed-width data, validity, and variable-width data. + auto row_size_iter = cudf::detail::make_counting_transform_iterator( + 0, [fixed_width_size_per_row, num_columns] __device__(auto i) { + auto const bytes_needed = + fixed_width_size_per_row + util::div_rounding_up_safe(num_columns, 8); + return detail::align_offset(bytes_needed, 8); + }); + + // fixed_width_size_per_row is the size of the fixed-width portion of a row. We need to then + // calculate the size of each row's variable-width data and validity as well. + auto validity_size = num_bitmask_words(num_columns) * 4; + + auto batch_info = detail::build_batches(num_rows, row_size_iter, stream, mr); + auto gpu_batch_row_boundaries = + make_device_uvector_async(batch_info.batch_row_boundaries, stream); + + // the first batch always exists unless we were sent an empty table + auto const first_batch_size = batch_info.row_batches[0].row_count; + + std::vector output_buffers; + std::vector output_data; + output_data.reserve(batch_info.row_batches.size()); + output_buffers.reserve(batch_info.row_batches.size()); + std::transform( + batch_info.row_batches.begin(), + batch_info.row_batches.end(), + std::back_inserter(output_buffers), + [&](auto const& batch) { return rmm::device_buffer(batch.num_bytes, stream, mr); }); + std::transform( + output_buffers.begin(), output_buffers.end(), std::back_inserter(output_data), [](auto& buf) { + return static_cast(buf.data()); + }); + + auto dev_output_data = make_device_uvector_async(output_data, stream, mr); + + int info_count = 0; + detail::determine_windows(column_sizes, + column_starts, + first_batch_size, + num_rows, + shmem_limit_per_block, + [&gpu_batch_row_boundaries, &info_count, &stream]( + int const start_col, int const end_col, int const window_height) { + int i = detail::compute_block_counts( + gpu_batch_row_boundaries, window_height, stream); + info_count += i; + }); + + // allocate space for blocks + device_uvector gpu_block_infos(info_count, stream); + int block_offset = 0; + + detail::determine_windows( + column_sizes, + column_starts, + first_batch_size, + num_rows, + shmem_limit_per_block, + [&gpu_batch_row_boundaries, &gpu_block_infos, num_rows, &block_offset, stream]( + int const start_col, int const end_col, int const window_height) { + block_offset += detail::build_blocks( + {gpu_block_infos.data() + block_offset, gpu_block_infos.size() - block_offset}, + gpu_batch_row_boundaries, + start_col, + end_col, + window_height, + num_rows, + stream); + }); + + // blast through the entire table and convert it + dim3 blocks(util::div_rounding_up_unsafe(gpu_block_infos.size(), NUM_BLOCKS_PER_KERNEL_TO_ROWS)); + dim3 threads(256); + + auto validity_block_infos = detail::build_validity_block_infos( + num_columns, num_rows, shmem_limit_per_block, batch_info.row_batches); + + auto dev_validity_block_infos = make_device_uvector_async(validity_block_infos, stream); + dim3 validity_blocks( + util::div_rounding_up_unsafe(validity_block_infos.size(), NUM_VALIDITY_BLOCKS_PER_KERNEL)); + dim3 validity_threads(std::min(validity_block_infos.size() * 32, 128lu)); + + cudaDeviceSynchronize(); + setup_time.stop_timer(); + + copy_kernel_time.start_timer("copy kernel", false); + detail::copy_to_rows<<>>( + num_rows, + num_columns, + shmem_limit_per_block, + gpu_block_infos, + dev_input_data.data(), + dev_col_sizes.data(), + dev_col_starts.data(), + batch_info.batch_row_offsets + .data(), // needs to be row offsets per batch, not overall JUST for output. + reinterpret_cast(dev_output_data.data())); + + cudaDeviceSynchronize(); + copy_kernel_time.stop_timer(); + + validity_kernel_time.start_timer("validity_kernel", false); + detail::copy_validity_to_rows<<>>( + num_rows, + num_columns, + shmem_limit_per_block, + batch_info.batch_row_offsets.data(), + dev_output_data.data(), + column_starts.back(), + dev_validity_block_infos, + dev_input_nm.data()); + + cudaDeviceSynchronize(); + validity_kernel_time.stop_timer(); + + post_time.start_timer("post time", false); + + // split up the output buffer into multiple buffers based on row batch sizes + // and create list of byte columns + std::vector> ret; + for (int batch = 0; batch < (int)batch_info.row_batches.size(); ++batch) { + auto const offset_count = batch_info.row_batches[batch].row_offsets.size(); + auto offsets = std::make_unique(data_type{type_id::INT32}, + (size_type)offset_count, + batch_info.row_batches[batch].row_offsets.release()); + auto data = std::make_unique(data_type{type_id::INT8}, + batch_info.row_batches[batch].num_bytes, + std::move(output_buffers[batch])); + + ret.push_back(cudf::make_lists_column(batch_info.row_batches[batch].row_count, + std::move(offsets), + std::move(data), + 0, + rmm::device_buffer{0, rmm::cuda_stream_default, mr}, + stream, + mr)); + } + + post_time.stop_timer(); + + return ret; +#else + CUDF_FAIL("Column to row conversion optimization requires volta or later hardware."); + return {}; +#endif // #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700 +} + +std::vector> convert_to_rows_fixed_width_optimized( + cudf::table_view const& tbl, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) +{ + const cudf::size_type num_columns = tbl.num_columns(); + + std::vector schema; + schema.resize(num_columns); + std::transform( + tbl.begin(), tbl.end(), schema.begin(), [](auto i) -> cudf::data_type { return i.type(); }); + + if (detail::are_all_fixed_width(schema)) { + std::vector column_start; + std::vector column_size; + + int32_t size_per_row = detail::compute_fixed_width_layout(schema, column_start, column_size); + auto dev_column_start = make_device_uvector_async(column_start, stream, mr); + auto dev_column_size = make_device_uvector_async(column_size, stream, mr); + + int32_t max_rows_per_batch = std::numeric_limits::max() / size_per_row; + // Make the number of rows per batch a multiple of 32 so we don't have to worry about + // splitting validity at a specific row offset. This might change in the future. + max_rows_per_batch = (max_rows_per_batch / 32) * 32; + + cudf::size_type num_rows = tbl.num_rows(); + + // Get the pointers to the input columnar data ready + std::vector input_data; + std::vector input_nm; + for (cudf::size_type column_number = 0; column_number < num_columns; column_number++) { + cudf::column_view cv = tbl.column(column_number); + input_data.emplace_back(cv.data()); + input_nm.emplace_back(cv.null_mask()); + } + auto dev_input_data = make_device_uvector_async(input_data, stream, mr); + auto dev_input_nm = make_device_uvector_async(input_nm, stream, mr); + + using ScalarType = cudf::scalar_type_t; + auto zero = cudf::make_numeric_scalar(cudf::data_type(cudf::type_id::INT32), stream.value()); + zero->set_valid_async(true, stream); + static_cast(zero.get())->set_value(0, stream); + + auto step = cudf::make_numeric_scalar(cudf::data_type(cudf::type_id::INT32), stream.value()); + step->set_valid_async(true, stream); + static_cast(step.get()) + ->set_value(static_cast(size_per_row), stream); + + std::vector> ret; + for (cudf::size_type row_start = 0; row_start < num_rows; row_start += max_rows_per_batch) { + cudf::size_type row_count = num_rows - row_start; + row_count = row_count > max_rows_per_batch ? max_rows_per_batch : row_count; + ret.emplace_back(detail::fixed_width_convert_to_rows(row_start, + row_count, + num_columns, + size_per_row, + dev_column_start, + dev_column_size, + dev_input_data, + dev_input_nm, + *zero, + *step, + stream, + mr)); + } + + return ret; + } else { + CUDF_FAIL("Only fixed width types are currently supported"); + } +} + +std::unique_ptr convert_from_rows(cudf::lists_column_view const& input, + std::vector const& schema, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700 + // verify that the types are what we expect + cudf::column_view child = input.child(); + cudf::type_id list_type = child.type().id(); + CUDF_EXPECTS(list_type == cudf::type_id::INT8 || list_type == cudf::type_id::UINT8, + "Only a list of bytes is supported as input"); + + cudf::size_type num_columns = schema.size(); + cudf::size_type num_rows = input.parent().size(); + + int device_id; + CUDA_TRY(cudaGetDevice(&device_id)); + int total_shmem; + CUDA_TRY(cudaDeviceGetAttribute(&total_shmem, cudaDevAttrMaxSharedMemoryPerBlock, device_id)); + + cuda_event_timer setup_time, copy_kernel_time, validity_kernel_time; + + setup_time.start_timer("setup time", false); + + // TODO: why is this needed. kernel fails to launch if all memory is requested. + total_shmem -= 1024; + int shmem_limit_per_block = total_shmem / NUM_BLOCKS_PER_KERNEL_LOADED; + + std::vector column_starts; + std::vector column_sizes; + + auto iter = thrust::make_transform_iterator(thrust::make_counting_iterator(0), [&schema](auto i) { + return std::make_tuple(schema[i], nullptr); + }); + + size_type fixed_width_size_per_row = + detail::compute_column_information(iter, iter + num_columns, column_starts, column_sizes); + + size_type validity_size = num_bitmask_words(num_columns) * 4; + + size_type row_size = detail::align_offset(fixed_width_size_per_row + validity_size, 8); + + // Ideally we would check that the offsets are all the same, etc. but for now + // this is probably fine + CUDF_EXPECTS(row_size * num_rows == child.size(), "The layout of the data appears to be off"); + auto dev_col_starts = make_device_uvector_async(column_starts, stream, mr); + auto dev_col_sizes = make_device_uvector_async(column_sizes, stream, mr); + + // Allocate the columns we are going to write into + std::vector> output_columns; + std::vector output_data; + std::vector output_nm; + for (cudf::size_type i = 0; i < num_columns; i++) { + auto column = cudf::make_fixed_width_column( + schema[i], num_rows, cudf::mask_state::UNINITIALIZED, stream, mr); + auto mut = column->mutable_view(); + output_data.emplace_back(mut.data()); + output_nm.emplace_back(mut.null_mask()); + output_columns.emplace_back(std::move(column)); + } + + // build the row_batches from the passed in list column + std::vector row_batches; + row_batches.push_back( + {detail::row_batch{child.size(), num_rows, device_uvector(0, stream)}}); + + auto dev_output_data = make_device_uvector_async(output_data, stream, mr); + auto dev_output_nm = make_device_uvector_async(output_nm, stream, mr); + + // only ever get a single batch when going from rows, so boundaries + // are 0, num_rows + device_uvector gpu_batch_row_boundaries(2, stream); + + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(2), + gpu_batch_row_boundaries.begin(), + [num_rows] __device__(auto i) { return i == 0 ? 0 : num_rows; }); + + int info_count = 0; + detail::determine_windows(column_sizes, + column_starts, + num_rows, + num_rows, + shmem_limit_per_block, + [&gpu_batch_row_boundaries, &info_count, &stream]( + int const start_col, int const end_col, int const window_height) { + info_count += detail::compute_block_counts( + gpu_batch_row_boundaries, window_height, stream); + }); + + // allocate space for blocks + device_uvector gpu_block_infos(info_count, stream); + + int block_offset = 0; + detail::determine_windows( + column_sizes, + column_starts, + num_rows, + num_rows, + shmem_limit_per_block, + [&gpu_batch_row_boundaries, &gpu_block_infos, num_rows, &block_offset, stream]( + int const start_col, int const end_col, int const window_height) { + block_offset += detail::build_blocks( + {gpu_block_infos.data() + block_offset, gpu_block_infos.size() - block_offset}, + gpu_batch_row_boundaries, + start_col, + end_col, + window_height, + num_rows, + stream); + }); + + dim3 blocks( + util::div_rounding_up_unsafe(gpu_block_infos.size(), NUM_BLOCKS_PER_KERNEL_FROM_ROWS)); + dim3 threads(std::min(std::min(256, shmem_limit_per_block / 8), (int)child.size())); + + auto validity_block_infos = + detail::build_validity_block_infos(num_columns, num_rows, shmem_limit_per_block, row_batches); + + auto dev_validity_block_infos = make_device_uvector_async(validity_block_infos, stream); + + dim3 validity_blocks( + util::div_rounding_up_unsafe(validity_block_infos.size(), NUM_VALIDITY_BLOCKS_PER_KERNEL)); + + dim3 validity_threads(std::min(validity_block_infos.size() * 32, 128lu)); + + cudaDeviceSynchronize(); + setup_time.stop_timer(); + + copy_kernel_time.start_timer("copy kernel", false); + + detail::copy_from_rows<<>>( + num_rows, + num_columns, + shmem_limit_per_block, + input.offsets().data(), + dev_output_data.data(), + dev_col_sizes.data(), + dev_col_starts.data(), + gpu_block_infos, + child.data()); + + cudaDeviceSynchronize(); + copy_kernel_time.stop_timer(); + + validity_kernel_time.start_timer("validity kernel", false); + + detail:: + copy_validity_from_rows<<>>( + num_rows, + num_columns, + shmem_limit_per_block, + input.offsets().data(), + dev_output_nm.data(), + column_starts.back(), + dev_validity_block_infos, + child.data()); + + cudaDeviceSynchronize(); + validity_kernel_time.stop_timer(); + + return std::make_unique(std::move(output_columns)); +#else + CUDF_FAIL("Row to column conversion optimization requires volta or later hardware."); + return {}; +#endif // #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700 +} + +std::unique_ptr convert_from_rows_fixed_width_optimized( + cudf::lists_column_view const& input, + std::vector const& schema, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + // verify that the types are what we expect + cudf::column_view child = input.child(); + cudf::type_id list_type = child.type().id(); + CUDF_EXPECTS(list_type == cudf::type_id::INT8 || list_type == cudf::type_id::UINT8, + "Only a list of bytes is supported as input"); + + cudf::size_type num_columns = schema.size(); + + if (detail::are_all_fixed_width(schema)) { + std::vector column_start; + std::vector column_size; + + cudf::size_type num_rows = input.parent().size(); + int32_t size_per_row = detail::compute_fixed_width_layout(schema, column_start, column_size); + + // Ideally we would check that the offsets are all the same, etc. but for now + // this is probably fine + CUDF_EXPECTS(size_per_row * num_rows == child.size(), + "The layout of the data appears to be off"); + auto dev_column_start = make_device_uvector_async(column_start, stream); + auto dev_column_size = make_device_uvector_async(column_size, stream); + + // Allocate the columns we are going to write into + std::vector> output_columns; + std::vector output_data; + std::vector output_nm; + for (cudf::size_type i = 0; i < num_columns; i++) { + auto column = cudf::make_fixed_width_column( + schema[i], num_rows, cudf::mask_state::UNINITIALIZED, stream, mr); + auto mut = column->mutable_view(); + output_data.emplace_back(mut.data()); + output_nm.emplace_back(mut.null_mask()); + output_columns.emplace_back(std::move(column)); + } + + auto dev_output_data = make_device_uvector_async(output_data, stream, mr); + auto dev_output_nm = make_device_uvector_async(output_nm, stream, mr); + + dim3 blocks; + dim3 threads; + int shared_size = + detail::calc_fixed_width_kernel_dims(num_columns, num_rows, size_per_row, blocks, threads); + + detail::copy_from_rows_fixed_width_optimized<<>>( + num_rows, + num_columns, + size_per_row, + dev_column_start.data(), + dev_column_size.data(), + dev_output_data.data(), + dev_output_nm.data(), + child.data()); + + return std::make_unique(std::move(output_columns)); + } else { + CUDF_FAIL("Only fixed width types are currently supported"); + } +} + +// } // namespace java + +} // namespace cudf diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 27dd472b3f5..4ff703c4ac5 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -459,6 +459,8 @@ ConfigureTest( # * bin tests ---------------------------------------------------------------------------------- ConfigureTest(LABEL_BINS_TEST labeling/label_bins_tests.cpp) +ConfigureTest(ROW_CONVERSION row_conversion/row_conversion.cu) + # ################################################################################################## # enable testing ################################################################################ # ################################################################################################## diff --git a/cpp/tests/row_conversion/row_conversion.cu b/cpp/tests/row_conversion/row_conversion.cu new file mode 100644 index 00000000000..a317817976d --- /dev/null +++ b/cpp/tests/row_conversion/row_conversion.cu @@ -0,0 +1,1048 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include + +struct ColumnToRowTests : public cudf::test::BaseFixture { +}; +struct RowToColumnTests : public cudf::test::BaseFixture { +}; + +TEST_F(ColumnToRowTests, Single) +{ + cudf::test::fixed_width_column_wrapper a({-1}); + cudf::table_view in(std::vector{a}); + std::vector schema = {cudf::data_type{cudf::type_id::INT32}}; + + auto old_rows = cudf::convert_to_rows_fixed_width_optimized(in); + auto new_rows = cudf::convert_to_rows(in); + + EXPECT_EQ(old_rows.size(), new_rows.size()); + for (uint i = 0; i < old_rows.size(); ++i) { + auto old_tbl = + cudf::convert_from_rows_fixed_width_optimized(cudf::lists_column_view(*old_rows[i]), schema); + auto new_tbl = cudf::convert_from_rows(cudf::lists_column_view(*old_rows[i]), schema); + + CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*old_tbl, *new_tbl); + } +} + +TEST_F(ColumnToRowTests, SimpleString) +{ + cudf::test::fixed_width_column_wrapper a({-1, 0, 1, 0, -1}); + cudf::test::strings_column_wrapper b( + {"hello", "world", "this is a really long string to generate a longer row", "dlrow", "olleh"}); + cudf::table_view in(std::vector{a, b}); + std::vector schema = {cudf::data_type{cudf::type_id::INT32}}; + + auto new_rows = cudf::convert_to_rows(in); + + EXPECT_EQ(new_rows[0]->size(), 5); + cudf::test::print(*new_rows[0]); +} + +TEST_F(ColumnToRowTests, DoubleString) +{ + cudf::test::strings_column_wrapper a( + {"hello", "world", "this is a really long string to generate a longer row", "dlrow", "olleh"}); + cudf::test::fixed_width_column_wrapper b({0, 1, 2, 3, 4}); + cudf::test::strings_column_wrapper c({"world", + "hello", + "this string isn't as long", + "this one isn't so short though when you think about it", + "dlrow"}); + cudf::table_view in(std::vector{a, b, c}); + + auto new_rows = cudf::convert_to_rows(in); + + EXPECT_EQ(new_rows[0]->size(), 5); + cudf::test::print(*new_rows[0]); +} + +void print_rows(cudf::lists_column_view const& lcv, std::vector const& schema) +{ + for (auto s : schema) { + printf("%c", s.id() == cudf::type_id::STRING ? 'S' : 'I'); + } + printf("\n"); + + auto offsets = lcv.offsets(); + auto data = lcv.child(); + // auto zero_iter = thrust::make_counting_iterator(0); + auto const num_rows = lcv.size(); + + auto h_offsets = cudf::test::to_host(offsets); + auto h_data = cudf::test::to_host(data); + + auto print_row = [&](int r) { + printf("row %d(%d-%d): ", r, h_offsets.first[r], h_offsets.first[r + 1]); + auto this_schema = schema.begin(); + int next_grouping = this_schema->id() == cudf::type_id::INT32 ? 4 : 8; + int idx = h_offsets.first[r]; + bool handled_validity = false; + while (idx < h_offsets.first[r + 1]) { + if (this_schema == schema.end() && !handled_validity) { + int validity_bytes = + cudf::util::div_rounding_up_unsafe(std::distance(schema.begin(), schema.end()), 8); + printf("V:"); + for (int i = 0; i < validity_bytes; ++i) { + printf("%02x", h_data.first[idx++]); + } + printf(" "); + handled_validity = true; + } else if (this_schema == schema.end() && handled_validity) { + // just print the data + int8_t byte = h_data.first[idx]; + idx++; + if (byte >= 'a' && byte <= 'z') { + printf("%c", byte); + } else { + printf("%02x", byte); + } + } else if (this_schema->id() == cudf::type_id::INT32) { + int32_t fb = *reinterpret_cast(&h_data.first[idx]); + printf("%08x ", fb); + idx += 4; + this_schema++; + } else if (this_schema->id() == cudf::type_id::STRING) { + int32_t of = *reinterpret_cast(&h_data.first[idx]); + idx += 4; + printf("%08x-", of); + int32_t ln = *reinterpret_cast(&h_data.first[idx]); + idx += 4; + printf("%08x ", ln); + this_schema++; + } + } + printf("\n"); + }; + + if (num_rows > 20) { + for (auto r = 0; r < 10; ++r) { + print_row(r); + } + for (auto r = num_rows - 10; r < num_rows; ++r) { + print_row(r); + } + } else { + for (auto r = 0; r < num_rows; ++r) { + print_row(r); + } + } +} + +TEST_F(ColumnToRowTests, BigStrings) +{ + char const* TEST_STRINGS[] = { + "These", + "are", + "the", + "test", + "strings", + "that", + "we", + "have", + "some are really long", + "and some are kinda short", + "They are all over on purpose with different sizes for the strings in order to test the code " + "on all different lengths of strings", + "a", + "good test", + "is required to produce reasonable confidence that this is working"}; + auto num_generator = + cudf::detail::make_counting_transform_iterator(0, [](auto i) -> int32_t { return rand(); }); + auto string_generator = + cudf::detail::make_counting_transform_iterator(0, [&](auto i) -> char const* { + return TEST_STRINGS[rand() % (sizeof(TEST_STRINGS) / sizeof(TEST_STRINGS[0]))]; + }); + + auto const num_rows = 50; + auto const num_cols = 50; + std::vector schema; + + std::vector cols; + std::vector views; + + for (auto col = 0; col < num_cols; ++col) { + if (rand() % 2) { + cols.emplace_back( + cudf::test::fixed_width_column_wrapper(num_generator, num_generator + num_rows)); + views.push_back(cols.back()); + schema.emplace_back(cudf::data_type{cudf::type_id::INT32}); + } else { + cols.emplace_back( + cudf::test::strings_column_wrapper(string_generator, string_generator + num_rows)); + views.push_back(cols.back()); + schema.emplace_back(cudf::type_id::STRING); + } + } + + cudf::table_view in(views); + auto new_rows = cudf::convert_to_rows(in); + + EXPECT_EQ(new_rows[0]->size(), num_rows); + + // pick it apart + print_rows(new_rows[0]->view(), schema); + // cudf::test::print(*new_rows[0]); +} + +TEST_F(ColumnToRowTests, ManyStrings) +{ + char const* TEST_STRINGS[] = { + "These", + "are", + "the", + "test", + "strings", + "that", + "we", + "have", + "some are really long", + "and some are kinda short", + "They are all over on purpose with different sizes for the strings in order to test the code " + "on all different lengths of strings", + "a", + "good test", + "is required to produce reasonable confidence that this is working", + "some strings", + "are split into multiple strings", + "some strings have all their data", + "lots of choices of strings and sizes is sure to test the offset calculation code to ensure " + "that even a really long string ends up in the correct spot for the final destination allowing " + "for even crazy run-on sentences to be inserted into the data"}; + auto num_generator = + cudf::detail::make_counting_transform_iterator(0, [](auto i) -> int32_t { return rand(); }); + auto string_generator = + cudf::detail::make_counting_transform_iterator(0, [&](auto i) -> char const* { + return TEST_STRINGS[rand() % (sizeof(TEST_STRINGS) / sizeof(TEST_STRINGS[0]))]; + }); + + auto const num_rows = 1000000; + auto const num_cols = 50; + std::vector schema; + + std::vector cols; + std::vector views; + + for (auto col = 0; col < num_cols; ++col) { + if (rand() % 2) { + cols.emplace_back( + cudf::test::fixed_width_column_wrapper(num_generator, num_generator + num_rows)); + views.push_back(cols.back()); + schema.emplace_back(cudf::data_type{cudf::type_id::INT32}); + } else { + cols.emplace_back( + cudf::test::strings_column_wrapper(string_generator, string_generator + num_rows)); + views.push_back(cols.back()); + schema.emplace_back(cudf::type_id::STRING); + } + } + + cudf::table_view in(views); + auto new_rows = cudf::convert_to_rows(in); + + EXPECT_EQ(new_rows[0]->size(), num_rows); + + // pick it apart + print_rows(new_rows[0]->view(), schema); + // cudf::test::print(*new_rows[0]); +} + +TEST_F(ColumnToRowTests, Simple) +{ + cudf::test::fixed_width_column_wrapper a({-1, 0, 1}); + cudf::table_view in(std::vector{a}); + std::vector schema = {cudf::data_type{cudf::type_id::INT32}}; + + auto old_rows = cudf::convert_to_rows_fixed_width_optimized(in); + auto new_rows = cudf::convert_to_rows(in); + + EXPECT_EQ(old_rows.size(), new_rows.size()); + for (uint i = 0; i < old_rows.size(); ++i) { + auto old_tbl = + cudf::convert_from_rows_fixed_width_optimized(cudf::lists_column_view(*old_rows[i]), schema); + auto new_tbl = cudf::convert_from_rows(cudf::lists_column_view(*old_rows[i]), schema); + + CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*old_tbl, *new_tbl); + } +} + +TEST_F(ColumnToRowTests, SingleByteCol) +{ + cudf::test::fixed_width_column_wrapper a({-1, 0, 1}); + cudf::table_view in(std::vector{a}); + std::vector schema = {cudf::data_type{cudf::type_id::INT8}}; + + auto old_rows = cudf::convert_to_rows_fixed_width_optimized(in); + auto new_rows = cudf::convert_to_rows(in); + + EXPECT_EQ(old_rows.size(), new_rows.size()); + for (uint i = 0; i < old_rows.size(); ++i) { + auto old_tbl = + cudf::convert_from_rows_fixed_width_optimized(cudf::lists_column_view(*old_rows[i]), schema); + auto new_tbl = cudf::convert_from_rows(cudf::lists_column_view(*old_rows[i]), schema); + + CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*old_tbl, *new_tbl); + } +} + +TEST_F(ColumnToRowTests, Tall) +{ + auto r = + cudf::detail::make_counting_transform_iterator(0, [](auto i) -> int8_t { return rand(); }); + cudf::test::fixed_width_column_wrapper a(r, r + (size_t)4000000); + cudf::table_view in(std::vector{a}); + std::vector schema = {cudf::data_type{cudf::type_id::INT8}}; + + auto old_rows = cudf::convert_to_rows_fixed_width_optimized(in); + auto new_rows = cudf::convert_to_rows(in); + + EXPECT_EQ(old_rows.size(), new_rows.size()); + + for (uint i = 0; i < old_rows.size(); ++i) { + auto old_tbl = + cudf::convert_from_rows_fixed_width_optimized(cudf::lists_column_view(*old_rows[i]), schema); + auto new_tbl = cudf::convert_from_rows(cudf::lists_column_view(*old_rows[i]), schema); + + CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*old_tbl, *new_tbl); + } +} + +TEST_F(ColumnToRowTests, Wide) +{ + std::vector> cols; + std::vector views; + std::vector schema; + + for (int i = 0; i < 256; ++i) { + cols.push_back(cudf::test::fixed_width_column_wrapper({rand()})); + views.push_back(cols.back()); + schema.push_back(cudf::data_type{cudf::type_id::INT32}); + } + cudf::table_view in(views); + + auto old_rows = cudf::convert_to_rows_fixed_width_optimized(in); + auto new_rows = cudf::convert_to_rows(in); + + EXPECT_EQ(old_rows.size(), new_rows.size()); + for (uint i = 0; i < old_rows.size(); ++i) { + auto old_tbl = + cudf::convert_from_rows_fixed_width_optimized(cudf::lists_column_view(*old_rows[i]), schema); + auto new_tbl = cudf::convert_from_rows(cudf::lists_column_view(*old_rows[i]), schema); + + CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*old_tbl, *new_tbl); + } +} + +TEST_F(ColumnToRowTests, SingleByteWide) +{ + std::vector> cols; + std::vector views; + std::vector schema; + + for (int i = 0; i < 256; ++i) { + cols.push_back(cudf::test::fixed_width_column_wrapper({rand()})); + views.push_back(cols.back()); + + schema.push_back(cudf::data_type{cudf::type_id::INT8}); + } + cudf::table_view in(views); + + auto old_rows = cudf::convert_to_rows_fixed_width_optimized(in); + auto new_rows = cudf::convert_to_rows(in); + + EXPECT_EQ(old_rows.size(), new_rows.size()); + + for (uint i = 0; i < old_rows.size(); ++i) { + auto old_tbl = + cudf::convert_from_rows_fixed_width_optimized(cudf::lists_column_view(*old_rows[i]), schema); + auto new_tbl = cudf::convert_from_rows(cudf::lists_column_view(*old_rows[i]), schema); + + CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*old_tbl, *new_tbl); + } +} + +TEST_F(ColumnToRowTests, Non2Power) +{ + auto r = + cudf::detail::make_counting_transform_iterator(0, [](auto i) -> int32_t { return rand(); }); + std::vector> cols; + std::vector views; + std::vector schema; + + constexpr auto num_rows = 6 * 1024 + 557; + for (int i = 0; i < 131; ++i) { + cols.push_back(cudf::test::fixed_width_column_wrapper(r + num_rows * i, + r + num_rows * i + num_rows)); + views.push_back(cols.back()); + schema.push_back(cudf::data_type{cudf::type_id::INT32}); + } + cudf::table_view in(views); + + auto old_rows = cudf::convert_to_rows_fixed_width_optimized(in); + auto new_rows = cudf::convert_to_rows(in); + + EXPECT_EQ(old_rows.size(), new_rows.size()); + + for (uint i = 0; i < old_rows.size(); ++i) { + auto old_tbl = + cudf::convert_from_rows_fixed_width_optimized(cudf::lists_column_view(*old_rows[i]), schema); + auto new_tbl = cudf::convert_from_rows(cudf::lists_column_view(*old_rows[i]), schema); + + CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*old_tbl, *new_tbl); + } +} + +TEST_F(ColumnToRowTests, Big) +{ + auto r = + cudf::detail::make_counting_transform_iterator(0, [](auto i) -> int32_t { return rand(); }); + std::vector> cols; + std::vector views; + std::vector schema; + + // 28 columns of 1 million rows + constexpr auto num_rows = 1024 * 1024; + for (int i = 0; i < 28; ++i) { + cols.push_back(cudf::test::fixed_width_column_wrapper(r + num_rows * i, + r + num_rows * i + num_rows)); + views.push_back(cols.back()); + schema.push_back(cudf::data_type{cudf::type_id::INT32}); + } + cudf::table_view in(views); + + auto old_rows = cudf::convert_to_rows_fixed_width_optimized(in); + auto new_rows = cudf::convert_to_rows(in); + + EXPECT_EQ(old_rows.size(), new_rows.size()); + + for (uint i = 0; i < old_rows.size(); ++i) { + auto old_tbl = + cudf::convert_from_rows_fixed_width_optimized(cudf::lists_column_view(*old_rows[i]), schema); + auto new_tbl = cudf::convert_from_rows(cudf::lists_column_view(*old_rows[i]), schema); + + CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*old_tbl, *new_tbl); + } +} + +TEST_F(ColumnToRowTests, Bigger) +{ + auto r = + cudf::detail::make_counting_transform_iterator(0, [](auto i) -> int32_t { return rand(); }); + std::vector> cols; + std::vector views; + std::vector schema; + + // 128 columns of 1 million rows + constexpr auto num_rows = 1024 * 1024; + for (int i = 0; i < 128; ++i) { + cols.push_back(cudf::test::fixed_width_column_wrapper(r + num_rows * i, + r + num_rows * i + num_rows)); + views.push_back(cols.back()); + schema.push_back(cudf::data_type{cudf::type_id::INT32}); + } + cudf::table_view in(views); + + auto old_rows = cudf::convert_to_rows_fixed_width_optimized(in); + auto new_rows = cudf::convert_to_rows(in); + + EXPECT_EQ(old_rows.size(), new_rows.size()); + for (uint i = 0; i < old_rows.size(); ++i) { + auto old_tbl = + cudf::convert_from_rows_fixed_width_optimized(cudf::lists_column_view(*old_rows[i]), schema); + auto new_tbl = cudf::convert_from_rows(cudf::lists_column_view(*old_rows[i]), schema); + + CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*old_tbl, *new_tbl); + } +} + +TEST_F(ColumnToRowTests, Biggest) +{ + auto r = + cudf::detail::make_counting_transform_iterator(0, [](auto i) -> int32_t { return rand(); }); + std::vector> cols; + std::vector views; + std::vector schema; + + // 128 columns of 2 million rows + constexpr auto num_rows = 2 * 1024 * 1024; + for (int i = 0; i < 128; ++i) { + cols.push_back(cudf::test::fixed_width_column_wrapper(r + num_rows * i, + r + num_rows * i + num_rows)); + views.push_back(cols.back()); + schema.push_back(cudf::data_type{cudf::type_id::INT32}); + } + cudf::table_view in(views); + + auto old_rows = cudf::convert_to_rows_fixed_width_optimized(in); + auto new_rows = cudf::convert_to_rows(in); + + EXPECT_EQ(old_rows.size(), new_rows.size()); + + for (uint i = 0; i < old_rows.size(); ++i) { + auto old_tbl = + cudf::convert_from_rows_fixed_width_optimized(cudf::lists_column_view(*old_rows[i]), schema); + auto new_tbl = cudf::convert_from_rows(cudf::lists_column_view(*old_rows[i]), schema); + + CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*old_tbl, *new_tbl); + } +} + +TEST_F(RowToColumnTests, Single) +{ + cudf::test::fixed_width_column_wrapper a({-1}); + cudf::table_view in(std::vector{a}); + + auto old_rows = cudf::convert_to_rows(in); + std::vector schema{cudf::data_type{cudf::type_id::INT32}}; + for (uint i = 0; i < old_rows.size(); ++i) { + auto old_tbl = + cudf::convert_from_rows_fixed_width_optimized(cudf::lists_column_view(*old_rows[i]), schema); + auto new_tbl = cudf::convert_from_rows(cudf::lists_column_view(*old_rows[i]), schema); + + CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*old_tbl, *new_tbl); + } +} + +TEST_F(RowToColumnTests, Simple) +{ + cudf::test::fixed_width_column_wrapper a({-1, 0, 1}); + cudf::table_view in(std::vector{a}); + + auto old_rows = cudf::convert_to_rows_fixed_width_optimized(in); + std::vector schema{cudf::data_type{cudf::type_id::INT32}}; + for (uint i = 0; i < old_rows.size(); ++i) { + auto old_tbl = + cudf::convert_from_rows_fixed_width_optimized(cudf::lists_column_view(*old_rows[i]), schema); + auto new_tbl = cudf::convert_from_rows(cudf::lists_column_view(*old_rows[i]), schema); + + CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*old_tbl, *new_tbl); + } +} + +TEST_F(RowToColumnTests, Seeded) +{ + cudf::test::fixed_width_column_wrapper a( + {0x00, 0x28, 0xfc, 0x34, 0xce, 0x84, 0x9a, 0x30, 0x16, 0x4e, 0xd7, 0xc8}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + + cudf::table_view in(std::vector{a}); + std::vector schema = {cudf::data_type{cudf::type_id::INT8}}; + + auto old_rows = cudf::convert_to_rows_fixed_width_optimized(in); + auto new_rows = cudf::convert_to_rows(in); + for (uint i = 0; i < old_rows.size(); ++i) { + auto old_tbl = + cudf::convert_from_rows_fixed_width_optimized(cudf::lists_column_view(*old_rows[i]), schema); + auto new_tbl = cudf::convert_from_rows(cudf::lists_column_view(*new_rows[i]), schema); + + CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*old_tbl, *new_tbl); + } + + cudf::test::fixed_width_column_wrapper b( + {0xc9, 0x9a, 0xee, 0x71, 0x85, 0xa2, 0xa9, 0x4a, 0x95, 0x81, 0x00, 0x15}); + + cudf::table_view inb(std::vector{b}); + std::vector schemab = {cudf::data_type{cudf::type_id::INT8}}; + + auto old_rowsb = cudf::convert_to_rows_fixed_width_optimized(inb); + auto new_rowsb = cudf::convert_to_rows(inb); + for (uint i = 0; i < old_rowsb.size(); ++i) { + auto old_tblb = + cudf::convert_from_rows_fixed_width_optimized(cudf::lists_column_view(*old_rowsb[i]), schema); + auto new_tblb = cudf::convert_from_rows(cudf::lists_column_view(*new_rowsb[i]), schema); + + CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*old_tblb, *new_tblb); + } +} + +TEST_F(RowToColumnTests, SingleByteCol) +{ + cudf::test::fixed_width_column_wrapper a({-1, 0, 1}); + cudf::table_view in(std::vector{a}); + std::vector schema = {cudf::data_type{cudf::type_id::INT8}}; + + auto old_rows = cudf::convert_to_rows_fixed_width_optimized(in); + for (uint i = 0; i < old_rows.size(); ++i) { + auto old_tbl = + cudf::convert_from_rows_fixed_width_optimized(cudf::lists_column_view(*old_rows[i]), schema); + auto new_tbl = cudf::convert_from_rows(cudf::lists_column_view(*old_rows[i]), schema); + + CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*old_tbl, *new_tbl); + } +} + +TEST_F(RowToColumnTests, Tall) +{ + auto r = + cudf::detail::make_counting_transform_iterator(0, [](auto i) -> int8_t { return rand(); }); + cudf::test::fixed_width_column_wrapper a(r, r + (size_t)4000000); + cudf::table_view in(std::vector{a}); + + auto old_rows = cudf::convert_to_rows_fixed_width_optimized(in); + std::vector schema; + schema.reserve(in.num_columns()); + for (auto col = in.begin(); col < in.end(); ++col) { + schema.push_back(col->type()); + } + for (uint i = 0; i < old_rows.size(); ++i) { + auto old_tbl = + cudf::convert_from_rows_fixed_width_optimized(cudf::lists_column_view(*old_rows[i]), schema); + auto new_tbl = cudf::convert_from_rows(cudf::lists_column_view(*old_rows[i]), schema); + + CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*old_tbl, *new_tbl); + } +} + +TEST_F(RowToColumnTests, Wide) +{ + std::vector> cols; + std::vector views; + + for (int i = 0; i < 256; ++i) { + cols.push_back(cudf::test::fixed_width_column_wrapper({i})); // rand()})); + views.push_back(cols.back()); + } + cudf::table_view in(views); + + auto old_rows = cudf::convert_to_rows_fixed_width_optimized(in); + std::vector schema; + schema.reserve(in.num_columns()); + for (auto col = in.begin(); col < in.end(); ++col) { + schema.push_back(col->type()); + } + + for (uint i = 0; i < old_rows.size(); ++i) { + auto old_tbl = + cudf::convert_from_rows_fixed_width_optimized(cudf::lists_column_view(*old_rows[i]), schema); + auto new_tbl = cudf::convert_from_rows(cudf::lists_column_view(*old_rows[i]), schema); + + CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*old_tbl, *new_tbl); + } +} + +TEST_F(RowToColumnTests, SingleByteWide) +{ + std::vector> cols; + std::vector views; + + for (int i = 0; i < 256; ++i) { + cols.push_back(cudf::test::fixed_width_column_wrapper({rand()})); + views.push_back(cols.back()); + } + cudf::table_view in(views); + + auto old_rows = cudf::convert_to_rows_fixed_width_optimized(in); + std::vector schema; + schema.reserve(in.num_columns()); + for (auto col = in.begin(); col < in.end(); ++col) { + schema.push_back(col->type()); + } + for (uint i = 0; i < old_rows.size(); ++i) { + auto old_tbl = + cudf::convert_from_rows_fixed_width_optimized(cudf::lists_column_view(*old_rows[i]), schema); + auto new_tbl = cudf::convert_from_rows(cudf::lists_column_view(*old_rows[i]), schema); + + CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*old_tbl, *new_tbl); + } +} + +TEST_F(RowToColumnTests, AllTypes) +{ + std::vector> cols; + std::vector views; + std::vector schema{cudf::data_type{cudf::type_id::INT64}, + cudf::data_type{cudf::type_id::FLOAT64}, + cudf::data_type{cudf::type_id::INT8}, + cudf::data_type{cudf::type_id::BOOL8}, + cudf::data_type{cudf::type_id::FLOAT32}, + cudf::data_type{cudf::type_id::INT8}, + cudf::data_type{cudf::type_id::INT32}, + cudf::data_type{cudf::type_id::INT64}}; + + cudf::test::fixed_width_column_wrapper c0({3, 9, 4, 2, 20, 0}, {1, 1, 1, 1, 1, 0}); + cudf::test::fixed_width_column_wrapper c1({5.0, 9.5, 0.9, 7.23, 2.8, 0.0}, + {1, 1, 1, 1, 1, 0}); + cudf::test::fixed_width_column_wrapper c2({5, 1, 0, 2, 7, 0}, {1, 1, 1, 1, 1, 0}); + cudf::test::fixed_width_column_wrapper c3({true, false, false, true, false, false}, + {1, 1, 1, 1, 1, 0}); + cudf::test::fixed_width_column_wrapper c4({1.0f, 3.5f, 5.9f, 7.1f, 9.8f, 0.0f}, + {1, 1, 1, 1, 1, 0}); + cudf::test::fixed_width_column_wrapper c5({2, 3, 4, 5, 9, 0}, {1, 1, 1, 1, 1, 0}); + cudf::test::fixed_point_column_wrapper c6( + {-300, 500, 950, 90, 723, 0}, {1, 1, 1, 1, 1, 1, 1, 0}, numeric::scale_type{-2}); + cudf::test::fixed_point_column_wrapper c7( + {-80, 30, 90, 20, 200, 0}, {1, 1, 1, 1, 1, 1, 0}, numeric::scale_type{-1}); + + cudf::table_view in({c0, c1, c2, c3, c4, c5, c6, c7}); + + auto old_rows = cudf::convert_to_rows_fixed_width_optimized(in); + auto new_rows = cudf::convert_to_rows(in); + + for (uint i = 0; i < old_rows.size(); ++i) { + auto old_tbl = + cudf::convert_from_rows_fixed_width_optimized(cudf::lists_column_view(*old_rows[i]), schema); + auto new_tbl = cudf::convert_from_rows(cudf::lists_column_view(*new_rows[i]), schema); + + CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*old_tbl, *new_tbl); + } +} + +TEST_F(RowToColumnTests, AllTypesLarge) +{ + std::vector cols; + std::vector schema{}; + + // 10 columns of each type with 1024 entries + constexpr int num_rows{1024}; + + std::default_random_engine re; + std::uniform_real_distribution rand_double(std::numeric_limits::min(), + std::numeric_limits::max()); + std::uniform_int_distribution rand_int64(std::numeric_limits::min(), + std::numeric_limits::max()); + auto r = cudf::detail::make_counting_transform_iterator( + 0, [&](auto i) -> int64_t { return rand_int64(re); }); + auto d = cudf::detail::make_counting_transform_iterator( + 0, [&](auto i) -> double { return rand_double(re); }); + + auto all_valid = cudf::detail::make_counting_transform_iterator(0, [](auto i) { return 1; }); + auto none_valid = cudf::detail::make_counting_transform_iterator(0, [](auto i) { return 0; }); + auto most_valid = cudf::detail::make_counting_transform_iterator( + 0, [](auto i) { return rand() % 2 == 0 ? 0 : 1; }); + auto few_valid = cudf::detail::make_counting_transform_iterator( + 0, [](auto i) { return rand() % 13 == 0 ? 1 : 0; }); + + for (int i = 0; i < 10; ++i) { + cols.push_back(*cudf::test::fixed_width_column_wrapper(r, r + num_rows, all_valid) + .release() + .release()); + schema.push_back(cudf::data_type{cudf::type_id::INT8}); + } + + for (int i = 0; i < 10; ++i) { + cols.push_back(*cudf::test::fixed_width_column_wrapper(r, r + num_rows, few_valid) + .release() + .release()); + schema.push_back(cudf::data_type{cudf::type_id::INT16}); + } + + for (int i = 0; i < 10; ++i) { + if (i < 5) { + cols.push_back(*cudf::test::fixed_width_column_wrapper(r, r + num_rows, few_valid) + .release() + .release()); + } else { + cols.push_back(*cudf::test::fixed_width_column_wrapper(r, r + num_rows, none_valid) + .release() + .release()); + } + schema.push_back(cudf::data_type{cudf::type_id::INT32}); + } + + for (int i = 0; i < 10; ++i) { + cols.push_back(*cudf::test::fixed_width_column_wrapper(d, d + num_rows, most_valid) + .release() + .release()); + schema.push_back(cudf::data_type{cudf::type_id::FLOAT32}); + } + + for (int i = 0; i < 10; ++i) { + cols.push_back(*cudf::test::fixed_width_column_wrapper(d, d + num_rows, most_valid) + .release() + .release()); + schema.push_back(cudf::data_type{cudf::type_id::FLOAT64}); + } + + for (int i = 0; i < 10; ++i) { + cols.push_back(*cudf::test::fixed_width_column_wrapper(r, r + num_rows, few_valid) + .release() + .release()); + schema.push_back(cudf::data_type{cudf::type_id::BOOL8}); + } + + for (int i = 0; i < 10; ++i) { + cols.push_back( + *cudf::test::fixed_width_column_wrapper( + r, r + num_rows, all_valid) + .release() + .release()); + schema.push_back(cudf::data_type{cudf::type_id::TIMESTAMP_MILLISECONDS}); + } + + for (int i = 0; i < 10; ++i) { + cols.push_back( + *cudf::test::fixed_width_column_wrapper( + r, r + num_rows, most_valid) + .release() + .release()); + schema.push_back(cudf::data_type{cudf::type_id::TIMESTAMP_DAYS}); + } + + for (int i = 0; i < 10; ++i) { + cols.push_back(*cudf::test::fixed_point_column_wrapper( + r, r + num_rows, all_valid, numeric::scale_type{-2}) + .release() + .release()); + schema.push_back(cudf::data_type{cudf::type_id::DECIMAL32}); + } + + for (int i = 0; i < 10; ++i) { + cols.push_back(*cudf::test::fixed_point_column_wrapper( + r, r + num_rows, most_valid, numeric::scale_type{-1}) + .release() + .release()); + schema.push_back(cudf::data_type{cudf::type_id::DECIMAL64}); + } + + std::vector views(cols.begin(), cols.end()); + cudf::table_view in(views); + + auto old_rows = cudf::convert_to_rows_fixed_width_optimized(in); + auto new_rows = cudf::convert_to_rows(in); + + for (uint i = 0; i < old_rows.size(); ++i) { + auto old_tbl = + cudf::convert_from_rows_fixed_width_optimized(cudf::lists_column_view(*old_rows[i]), schema); + auto new_tbl = cudf::convert_from_rows(cudf::lists_column_view(*old_rows[i]), schema); + + CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*old_tbl, *new_tbl); + } +} + +TEST_F(RowToColumnTests, Non2Power) +{ + auto r = + cudf::detail::make_counting_transform_iterator(0, [](auto i) -> int32_t { return rand(); }); + std::vector> cols; + std::vector views; + std::vector schema; + + constexpr auto num_rows = 6 * 1024 + 557; + for (int i = 0; i < 131; ++i) { + cols.push_back(cudf::test::fixed_width_column_wrapper(r + num_rows * i, + r + num_rows * i + num_rows)); + views.push_back(cols.back()); + schema.push_back(cudf::data_type{cudf::type_id::INT32}); + } + cudf::table_view in(views); + + auto old_rows = cudf::convert_to_rows_fixed_width_optimized(in); + + for (uint i = 0; i < old_rows.size(); ++i) { + auto old_tbl = + cudf::convert_from_rows_fixed_width_optimized(cudf::lists_column_view(*old_rows[i]), schema); + auto new_tbl = cudf::convert_from_rows(cudf::lists_column_view(*old_rows[i]), schema); + + CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*old_tbl, *new_tbl); + } +} + +TEST_F(RowToColumnTests, Big) +{ + auto r = + cudf::detail::make_counting_transform_iterator(0, [](auto i) -> int32_t { return rand(); }); + std::vector> cols; + std::vector views; + std::vector schema; + + // 28 columns of 1 million rows + constexpr auto num_rows = 1024 * 1024; + for (int i = 0; i < 28; ++i) { + cols.push_back(cudf::test::fixed_width_column_wrapper(r + num_rows * i, + r + num_rows * i + num_rows)); + views.push_back(cols.back()); + schema.push_back(cudf::data_type{cudf::type_id::INT32}); + } + cudf::table_view in(views); + + auto old_rows = cudf::convert_to_rows_fixed_width_optimized(in); + + for (uint i = 0; i < old_rows.size(); ++i) { + auto old_tbl = + cudf::convert_from_rows_fixed_width_optimized(cudf::lists_column_view(*old_rows[i]), schema); + auto new_tbl = cudf::convert_from_rows(cudf::lists_column_view(*old_rows[i]), schema); + + CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*old_tbl, *new_tbl); + } +} + +TEST_F(RowToColumnTests, Bigger) +{ + auto r = + cudf::detail::make_counting_transform_iterator(0, [](auto i) -> int32_t { return rand(); }); + std::vector> cols; + std::vector views; + std::vector schema; + + // 28 columns of 1 million rows + constexpr auto num_rows = 1024 * 1024; + for (int i = 0; i < 128; ++i) { + cols.push_back(cudf::test::fixed_width_column_wrapper(r + num_rows * i, + r + num_rows * i + num_rows)); + views.push_back(cols.back()); + schema.push_back(cudf::data_type{cudf::type_id::INT32}); + } + cudf::table_view in(views); + + auto old_rows = cudf::convert_to_rows_fixed_width_optimized(in); + + for (uint i = 0; i < old_rows.size(); ++i) { + auto old_tbl = + cudf::convert_from_rows_fixed_width_optimized(cudf::lists_column_view(*old_rows[i]), schema); + auto new_tbl = cudf::convert_from_rows(cudf::lists_column_view(*old_rows[i]), schema); + + CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*old_tbl, *new_tbl); + } +} + +TEST_F(RowToColumnTests, Biggest) +{ + auto r = + cudf::detail::make_counting_transform_iterator(0, [](auto i) -> int32_t { return rand(); }); + std::vector> cols; + std::vector views; + std::vector schema; + + // 28 columns of 1 million rows + constexpr auto num_rows = 5 * 1024 * 1024; + for (int i = 0; i < 128; ++i) { + cols.push_back(cudf::test::fixed_width_column_wrapper(r + num_rows * i, + r + num_rows * i + num_rows)); + views.push_back(cols.back()); + schema.push_back(cudf::data_type{cudf::type_id::INT32}); + } + cudf::table_view in(views); + + auto old_rows = cudf::convert_to_rows_fixed_width_optimized(in); + + for (uint i = 0; i < old_rows.size(); ++i) { + auto old_tbl = + cudf::convert_from_rows_fixed_width_optimized(cudf::lists_column_view(*old_rows[i]), schema); + auto new_tbl = cudf::convert_from_rows(cudf::lists_column_view(*old_rows[i]), schema); + + CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*old_tbl, *new_tbl); + } +} + +#define XSTR(x) STR(x) +#define STR(x) #x + +#ifdef __CUDA_ARCH__ +#pragma message "__CUDA_ARCH__ defined as: " XSTR(__CUDA_ARCH__) +#else +#pragma message "__CUDA_ARCH__ undefined!" +#endif + +#if __CUDA_ARCH__ >= 800 +#define ASYNC_MEMCPY_SUPPORTED +#endif + +void func_with_device_specific_lambda() +{ + thrust::for_each(thrust::device, + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(1), + [] __device__(auto i) { +#ifdef ASYNC_MEMCPY_SUPPORTED + printf("Func:\t Ampere!\n"); +#else + printf("Func:\t Non-Ampere! (Turing!)\n"); +#endif + }); +} + +void func_with_outer_cuda_arch_check() +{ +#ifdef ASYNC_MEMCPY_SUPPORTED + thrust::for_each(thrust::device, + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(1), + [] __device__(auto i) { printf("OUTER:\t Ampere!\n"); }); +#else + thrust::for_each(thrust::device, + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(1), + [] __device__(auto i) { printf("OUTER:\t Non-Ampere! (Turing!)\n"); }); +#endif +} + +__device__ void device_print_arch() +{ +#ifdef ASYNC_MEMCPY_SUPPORTED + printf("Device:\t Ampere!\n"); +#else + printf("Device:\t Non-Ampere! (Turing!)\n"); +#endif +} + +void hop_to_device() +{ + thrust::for_each(thrust::device, + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(1), + [=] __device__(auto i) { device_print_arch(); }); +} + +__global__ void kernel() +{ +#ifdef ASYNC_MEMCPY_SUPPORTED + printf("Kernel:\t Ampere!\n"); +#else + printf("Kernel:\t Non-Ampere! (Turing!)\n"); +#endif +} + +void host() +{ +#ifdef ASYNC_MEMCPY_SUPPORTED + printf("HOST:\t Ampere!\n"); +#else + printf("HOST:\t Non-Ampere! (Turing!)\n"); +#endif +} + +TEST_F(RowToColumnTests, MythConditional) +{ + func_with_device_specific_lambda(); + kernel<<<1, 1>>>(); + hop_to_device(); + func_with_outer_cuda_arch_check(); + host(); + cudaDeviceSynchronize(); +} diff --git a/java/src/main/native/src/row_conversion.cu b/java/src/main/native/src/row_conversion.cu index 5a2aa44261d..c77594cfcb9 100644 --- a/java/src/main/native/src/row_conversion.cu +++ b/java/src/main/native/src/row_conversion.cu @@ -42,10 +42,14 @@ #include "row_conversion.hpp" -#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700 -#include +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 +#define ASYNC_MEMCPY_SUPPORTED #endif +#if !defined(__CUDA_ARCH__) || defined(ASYNC_MEMCPY_SUPPORTED) +#include +#endif // #if !defined(__CUDA_ARCH__) || defined(ASYNC_MEMCPY_SUPPORTED) + #include #include #include @@ -56,7 +60,6 @@ constexpr auto JCUDF_ROW_ALIGNMENT = 8; -#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700 constexpr auto NUM_TILES_PER_KERNEL_FROM_ROWS = 2; constexpr auto NUM_TILES_PER_KERNEL_TO_ROWS = 2; constexpr auto NUM_TILES_PER_KERNEL_LOADED = 2; @@ -65,19 +68,20 @@ constexpr auto NUM_VALIDITY_TILES_PER_KERNEL_LOADED = 2; constexpr auto MAX_BATCH_SIZE = std::numeric_limits::max(); -// needed to suppress warning about cuda::barrier -#pragma nv_diag_suppress static_var_with_dynamic_init -#endif - using namespace cudf; using detail::make_device_uvector_async; using rmm::device_uvector; + +#ifdef ASYNC_MEMCPY_SUPPORTED +using cuda::aligned_size_t; +#else +template using aligned_size_t = size_t; // Local stub for cuda::aligned_size_t. +#endif // ASYNC_MEMCPY_SUPPORTED + namespace cudf { namespace jni { namespace detail { -#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700 - /************************************************************************ * This module converts data from row-major to column-major and from column-major * to row-major. It is a transpose of the data of sorts, but there are a few @@ -274,8 +278,6 @@ struct fixed_width_row_offset_functor { size_type _fixed_width_only_row_size; }; -#endif // !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700 - /** * @brief Copies data from row-based JCUDF format to column-based cudf format. * @@ -536,7 +538,11 @@ __global__ void copy_to_rows_fixed_width_optimized( } } -#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700 +#ifdef ASYNC_MEMCPY_SUPPORTED +#define MEMCPY(dst, src, size, barrier) cuda::memcpy_async(dst, src, size, barrier) +#else +#define MEMCPY(dst, src, size, barrier) memcpy(dst, src, size) +#endif // ASYNC_MEMCPY_SUPPORTED /** * @brief copy data from cudf columns into JCUDF format, which is row-based @@ -574,14 +580,15 @@ __global__ void copy_to_rows(const size_type num_rows, const size_type num_colum extern __shared__ int8_t shared_data[]; int8_t *shared[stages_count] = {shared_data, shared_data + shmem_used_per_tile}; +#ifdef ASYNC_MEMCPY_SUPPORTED __shared__ cuda::barrier tile_barrier[NUM_TILES_PER_KERNEL_LOADED]; if (group.thread_rank() == 0) { for (int i = 0; i < NUM_TILES_PER_KERNEL_LOADED; ++i) { init(&tile_barrier[i], group.size()); } } - group.sync(); +#endif // ASYNC_MEMCPY_SUPPORTED auto const tiles_remaining = std::min(static_cast(tile_infos.size()) - blockIdx.x * NUM_TILES_PER_KERNEL_TO_ROWS, @@ -599,12 +606,18 @@ __global__ void copy_to_rows(const size_type num_rows, const size_type num_colum auto const num_elements_in_tile = num_fetch_cols * num_fetch_rows; auto const fetch_tile_row_size = fetch_tile.get_shared_row_size(col_offsets, col_sizes); auto const starting_column_offset = col_offsets[fetch_tile.start_col]; +#ifdef ASYNC_MEMCPY_SUPPORTED auto &fetch_barrier = tile_barrier[fetch_index % NUM_TILES_PER_KERNEL_LOADED]; - // wait for the last use of the memory to be completed if (fetch_index >= NUM_TILES_PER_KERNEL_LOADED) { fetch_barrier.arrive_and_wait(); } +#else + // wait for the last use of the memory to be completed + if (fetch_index >= NUM_TILES_PER_KERNEL_LOADED) { + group.sync(); + } +#endif // ASYNC_MEMCPY_SUPPORTED // to do the copy we need to do n column copies followed by m element copies OR // we have to do m element copies followed by r row copies. When going from column @@ -633,27 +646,30 @@ __global__ void copy_to_rows(const size_type num_rows, const size_type num_colum // copy the element from global memory switch (col_size) { case 2: - cuda::memcpy_async(&shared_buffer_base[shared_offset], input_src, - cuda::aligned_size_t<2>(col_size), fetch_barrier); + MEMCPY(&shared_buffer_base[shared_offset], input_src, aligned_size_t<2>(col_size), + fetch_barrier); break; case 4: - cuda::memcpy_async(&shared_buffer_base[shared_offset], input_src, - cuda::aligned_size_t<4>(col_size), fetch_barrier); + MEMCPY(&shared_buffer_base[shared_offset], input_src, aligned_size_t<4>(col_size), + fetch_barrier); break; case 8: - cuda::memcpy_async(&shared_buffer_base[shared_offset], input_src, - cuda::aligned_size_t<8>(col_size), fetch_barrier); + MEMCPY(&shared_buffer_base[shared_offset], input_src, aligned_size_t<8>(col_size), + fetch_barrier); break; default: - cuda::memcpy_async(&shared_buffer_base[shared_offset], input_src, col_size, - fetch_barrier); + MEMCPY(&shared_buffer_base[shared_offset], input_src, col_size, fetch_barrier); break; } } } +#ifdef ASYNC_MEMCPY_SUPPORTED auto &processing_barrier = tile_barrier[processing_index % NUM_TILES_PER_KERNEL_LOADED]; processing_barrier.arrive_and_wait(); +#else + group.sync(); +#endif // ASYNC_MEMCPY_SUPPORTED auto const tile = tile_infos[blockIdx.x * NUM_TILES_PER_KERNEL_TO_ROWS + processing_index]; auto const tile_row_size = tile.get_shared_row_size(col_offsets, col_sizes); @@ -677,16 +693,19 @@ __global__ void copy_to_rows(const size_type num_rows, const size_type num_colum auto const input_src = &shared[processing_index % stages_count] [tile_row_size * relative_row + relative_chunk_offset]; - cuda::memcpy_async(output_dest, input_src, - cuda::aligned_size_t(bytes_per_chunk), - processing_barrier); + MEMCPY(output_dest, input_src, aligned_size_t{bytes_per_chunk}, + processing_barrier); } } +#ifdef ASYNC_MEMCPY_SUPPORTED // wait on the last copies to complete for (uint i = 0; i < std::min(stages_count, tiles_remaining); ++i) { tile_barrier[i].arrive_and_wait(); } +#else + group.sync(); +#endif // ASYNC_MEMCPY_SUPPORTED } /** @@ -727,6 +746,8 @@ copy_validity_to_rows(const size_type num_rows, const size_type num_columns, std::min(static_cast(tile_infos.size()) - blockIdx.x * NUM_VALIDITY_TILES_PER_KERNEL, static_cast(NUM_VALIDITY_TILES_PER_KERNEL)); +#ifdef ASYNC_MEMCPY_SUPPORTED + // Initialize cuda barriers for each tile. __shared__ cuda::barrier shared_tile_barriers[NUM_VALIDITY_TILES_PER_KERNEL_LOADED]; if (group.thread_rank() == 0) { @@ -734,12 +755,16 @@ copy_validity_to_rows(const size_type num_rows, const size_type num_columns, init(&shared_tile_barriers[i], group.size()); } } - group.sync(); +#endif // ASYNC_MEMCPY_SUPPORTED for (int validity_tile = 0; validity_tile < tiles_remaining; ++validity_tile) { if (validity_tile >= NUM_VALIDITY_TILES_PER_KERNEL_LOADED) { +#ifdef ASYNC_MEMCPY_SUPPORTED shared_tile_barriers[validity_tile % NUM_VALIDITY_TILES_PER_KERNEL_LOADED].arrive_and_wait(); +#else + group.sync(); +#endif // ASYNC_MEMCPY_SUPPORTED } int8_t *this_shared_tile = shared_tiles[validity_tile % NUM_VALIDITY_TILES_PER_KERNEL_LOADED]; auto tile = tile_infos[blockIdx.x * NUM_VALIDITY_TILES_PER_KERNEL + validity_tile]; @@ -802,8 +827,10 @@ copy_validity_to_rows(const size_type num_rows, const size_type num_columns, auto const row_bytes = util::div_rounding_up_unsafe(num_tile_cols, CHAR_BIT); auto const chunks_per_row = util::div_rounding_up_unsafe(row_bytes, bytes_per_chunk); auto const total_chunks = chunks_per_row * tile.num_rows(); +#ifdef ASYNC_MEMCPY_SUPPORTED auto &processing_barrier = shared_tile_barriers[validity_tile % NUM_VALIDITY_TILES_PER_KERNEL_LOADED]; +#endif // ASYNC_MEMCPY_SUPPORTED auto const tail_bytes = row_bytes % bytes_per_chunk; auto const row_batch_start = tile.batch_number == 0 ? 0 : batch_row_boundaries[tile.batch_number]; @@ -820,19 +847,22 @@ copy_validity_to_rows(const size_type num_rows, const size_type num_columns, &this_shared_tile[validity_data_row_length * relative_row + relative_chunk_offset]; if (tail_bytes > 0 && col_chunk == chunks_per_row - 1) - cuda::memcpy_async(output_dest, input_src, tail_bytes, processing_barrier); + MEMCPY(output_dest, input_src, tail_bytes, processing_barrier); else - cuda::memcpy_async(output_dest, input_src, - cuda::aligned_size_t(bytes_per_chunk), - processing_barrier); + MEMCPY(output_dest, input_src, aligned_size_t(bytes_per_chunk), + processing_barrier); } } +#ifdef ASYNC_MEMCPY_SUPPORTED // wait for last tiles of data to arrive for (int validity_tile = 0; validity_tile < tiles_remaining % NUM_VALIDITY_TILES_PER_KERNEL_LOADED; ++validity_tile) { shared_tile_barriers[validity_tile].arrive_and_wait(); } +#else + group.sync(); +#endif // ASYNC_MEMCPY_SUPPORTED } /** @@ -873,14 +903,16 @@ __global__ void copy_from_rows(const size_type num_rows, const size_type num_col extern __shared__ int8_t shared_data[]; int8_t *shared[stages_count] = {shared_data, shared_data + shmem_used_per_tile}; +#ifdef ASYNC_MEMCPY_SUPPORTED + // Initialize cuda barriers for each tile. __shared__ cuda::barrier tile_barrier[NUM_TILES_PER_KERNEL_LOADED]; if (group.thread_rank() == 0) { for (int i = 0; i < NUM_TILES_PER_KERNEL_LOADED; ++i) { init(&tile_barrier[i], group.size()); } } - group.sync(); +#endif // ASYNC_MEMCPY_SUPPORTED auto tiles_remaining = std::min(static_cast(tile_infos.size()) - blockIdx.x * NUM_TILES_PER_KERNEL_FROM_ROWS, @@ -897,30 +929,38 @@ __global__ void copy_from_rows(const size_type num_rows, const size_type num_col auto const fetch_tile_start_row = fetch_tile.start_row; auto const starting_col_offset = col_offsets[fetch_tile.start_col]; auto const fetch_tile_row_size = fetch_tile.get_shared_row_size(col_offsets, col_sizes); - auto &fetch_barrier = tile_barrier[fetch_index % NUM_TILES_PER_KERNEL_LOADED]; auto const row_batch_start = fetch_tile.batch_number == 0 ? 0 : batch_row_boundaries[fetch_tile.batch_number]; - +#ifdef ASYNC_MEMCPY_SUPPORTED + auto &fetch_barrier = tile_barrier[fetch_index % NUM_TILES_PER_KERNEL_LOADED]; // if we have fetched all buffers, we need to wait for processing // to complete on them before we can use them again if (fetch_index > NUM_TILES_PER_KERNEL_LOADED) { fetch_barrier.arrive_and_wait(); } +#else + if (fetch_index >= NUM_TILES_PER_KERNEL_LOADED) { + group.sync(); + } +#endif // ASYNC_MEMCPY_SUPPORTED for (auto row = fetch_tile_start_row + static_cast(threadIdx.x); row <= fetch_tile.end_row; row += blockDim.x) { auto shared_offset = (row - fetch_tile_start_row) * fetch_tile_row_size; // copy the data - cuda::memcpy_async(&shared[fetch_index % stages_count][shared_offset], - &input_data[row_offsets(row, row_batch_start) + starting_col_offset], - fetch_tile_row_size, fetch_barrier); + MEMCPY(&shared[fetch_index % stages_count][shared_offset], + &input_data[row_offsets(row, row_batch_start) + starting_col_offset], + fetch_tile_row_size, fetch_barrier); } } +#ifdef ASYNC_MEMCPY_SUPPORTED auto &processing_barrier = tile_barrier[processing_index % NUM_TILES_PER_KERNEL_LOADED]; - // ensure our data is ready processing_barrier.arrive_and_wait(); +#else + group.sync(); +#endif // ASYNC_MEMCPY_SUPPORTED auto const tile = tile_infos[blockIdx.x * NUM_TILES_PER_KERNEL_FROM_ROWS + processing_index]; auto const rows_in_tile = tile.num_rows(); @@ -948,15 +988,19 @@ __global__ void copy_from_rows(const size_type num_rows, const size_type num_col int8_t *shmem_src = &shared[processing_index % stages_count][shared_memory_offset]; int8_t *dst = &output_data[absolute_col][absolute_row * column_size]; - cuda::memcpy_async(dst, shmem_src, column_size, processing_barrier); + MEMCPY(dst, shmem_src, column_size, processing_barrier); } group.sync(); } +#ifdef ASYNC_MEMCPY_SUPPORTED // wait on the last copies to complete for (uint i = 0; i < std::min(stages_count, tiles_remaining); ++i) { tile_barrier[i].arrive_and_wait(); } +#else + group.sync(); +#endif // ASYNC_MEMCPY_SUPPORTED } /** @@ -997,6 +1041,8 @@ copy_validity_from_rows(const size_type num_rows, const size_type num_columns, std::min(static_cast(tile_infos.size()) - blockIdx.x * NUM_VALIDITY_TILES_PER_KERNEL, static_cast(NUM_VALIDITY_TILES_PER_KERNEL)); +#ifdef ASYNC_MEMCPY_SUPPORTED + // Initialize cuda barriers for each tile. __shared__ cuda::barrier shared_tile_barriers[NUM_VALIDITY_TILES_PER_KERNEL_LOADED]; if (group.thread_rank() == 0) { @@ -1004,13 +1050,17 @@ copy_validity_from_rows(const size_type num_rows, const size_type num_columns, init(&shared_tile_barriers[i], group.size()); } } - group.sync(); +#endif // ASYNC_MEMCPY_SUPPORTED for (int validity_tile = 0; validity_tile < tiles_remaining; ++validity_tile) { if (validity_tile >= NUM_VALIDITY_TILES_PER_KERNEL_LOADED) { +#ifdef ASYNC_MEMCPY_SUPPORTED auto const validity_index = validity_tile % NUM_VALIDITY_TILES_PER_KERNEL_LOADED; shared_tile_barriers[validity_index].arrive_and_wait(); +#else + group.sync(); +#endif // ASYNC_MEMCPY_SUPPORTED } int8_t *this_shared_tile = shared_tiles[validity_tile % 2]; auto const tile = tile_infos[blockIdx.x * NUM_VALIDITY_TILES_PER_KERNEL + validity_tile]; @@ -1071,9 +1121,11 @@ copy_validity_from_rows(const size_type num_rows, const size_type num_columns, auto const col_bytes = util::div_rounding_up_unsafe(num_tile_rows, CHAR_BIT); auto const chunks_per_col = util::div_rounding_up_unsafe(col_bytes, bytes_per_chunk); auto const total_chunks = chunks_per_col * num_tile_cols; + auto const tail_bytes = col_bytes % bytes_per_chunk; +#ifdef ASYNC_MEMCPY_SUPPORTED auto &processing_barrier = shared_tile_barriers[validity_tile % NUM_VALIDITY_TILES_PER_KERNEL_LOADED]; - auto const tail_bytes = col_bytes % bytes_per_chunk; +#endif // ASYNC_MEMCPY_SUPPORTED for (auto i = threadIdx.x; i < total_chunks; i += blockDim.x) { // determine source address of my chunk @@ -1081,20 +1133,21 @@ copy_validity_from_rows(const size_type num_rows, const size_type num_columns, auto const row_chunk = i % chunks_per_col; auto const absolute_col = relative_col + tile_start_col; auto const relative_chunk_byte_offset = row_chunk * bytes_per_chunk; - auto const output_dest = output_nm[absolute_col] + word_index(tile_start_row) + row_chunk * 2; + auto output_dest = reinterpret_cast(output_nm[absolute_col] + + word_index(tile_start_row) + row_chunk * 2); auto const input_src = &this_shared_tile[validity_data_col_length * relative_col + relative_chunk_byte_offset]; if (tail_bytes > 0 && row_chunk == chunks_per_col - 1) { - cuda::memcpy_async(output_dest, input_src, tail_bytes, processing_barrier); + MEMCPY(output_dest, input_src, tail_bytes, processing_barrier); } else { - cuda::memcpy_async(output_dest, input_src, - cuda::aligned_size_t(bytes_per_chunk), - processing_barrier); + MEMCPY(output_dest, input_src, aligned_size_t(bytes_per_chunk), + processing_barrier); } } } +#ifdef ASYNC_MEMCPY_SUPPORTED // wait for last tiles of data to arrive auto const num_tiles_to_wait = tiles_remaining > NUM_VALIDITY_TILES_PER_KERNEL_LOADED ? NUM_VALIDITY_TILES_PER_KERNEL_LOADED : @@ -1102,10 +1155,11 @@ copy_validity_from_rows(const size_type num_rows, const size_type num_columns, for (int validity_tile = 0; validity_tile < num_tiles_to_wait; ++validity_tile) { shared_tile_barriers[validity_tile].arrive_and_wait(); } +#else + group.sync(); +#endif // ASYNC_MEMCPY_SUPPORTED } -#endif // !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700 - /** * @brief Calculate the dimensions of the kernel for fixed width only columns. * @@ -1238,8 +1292,6 @@ static inline int32_t compute_fixed_width_layout(std::vector const &s return util::round_up_unsafe(at_offset, JCUDF_ROW_ALIGNMENT); } -#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700 - /** * @brief Compute information about a table such as bytes per row and offsets. * @@ -1617,9 +1669,12 @@ convert_to_rows(table_view const &tbl, batch_data &batch_info, offsetFunctor off CUDA_TRY( cudaDeviceGetAttribute(&total_shmem_in_bytes, cudaDevAttrMaxSharedMemoryPerBlock, device_id)); +#ifndef __CUDA_ARCH__ // __host__ code. // Need to reduce total shmem available by the size of barriers in the kernel's shared memory total_shmem_in_bytes -= sizeof(cuda::barrier) * NUM_TILES_PER_KERNEL_LOADED; +#endif // __CUDA_ARCH__ + auto const shmem_limit_per_tile = total_shmem_in_bytes / NUM_TILES_PER_KERNEL_LOADED; auto const num_rows = tbl.num_rows(); @@ -1722,14 +1777,12 @@ convert_to_rows(table_view const &tbl, batch_data &batch_info, offsetFunctor off return ret; } -#endif // #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700 } // namespace detail std::vector> convert_to_rows(table_view const &tbl, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource *mr) { -#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700 auto const num_columns = tbl.num_columns(); auto const num_rows = tbl.num_rows(); @@ -1790,11 +1843,6 @@ std::vector> convert_to_rows(table_view const &tbl, return detail::convert_to_rows(tbl, batch_info, offset_functor, column_starts, column_sizes, fixed_width_size_per_row, stream, mr); } - -#else - CUDF_FAIL("Column to row conversion optimization requires volta or later hardware."); - return {}; -#endif // #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700 } std::vector> @@ -1862,7 +1910,6 @@ std::unique_ptr
convert_from_rows(lists_column_view const &input, std::vector const &schema, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource *mr) { -#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700 // verify that the types are what we expect column_view child = input.child(); auto const list_type = child.type().id(); @@ -1878,9 +1925,12 @@ std::unique_ptr
convert_from_rows(lists_column_view const &input, CUDA_TRY( cudaDeviceGetAttribute(&total_shmem_in_bytes, cudaDevAttrMaxSharedMemoryPerBlock, device_id)); +#ifndef __CUDA_ARCH__ // __host__ code. // Need to reduce total shmem available by the size of barriers in the kernel's shared memory total_shmem_in_bytes -= sizeof(cuda::barrier) * NUM_TILES_PER_KERNEL_LOADED; +#endif // __CUDA_ARCH__ + int shmem_limit_per_tile = total_shmem_in_bytes / NUM_TILES_PER_KERNEL_LOADED; std::vector column_starts; @@ -1977,10 +2027,6 @@ std::unique_ptr
convert_from_rows(lists_column_view const &input, dev_output_nm.data(), column_starts.back(), dev_validity_tile_infos, child.data()); return std::make_unique
(std::move(output_columns)); -#else - CUDF_FAIL("Row to column conversion optimization requires volta or later hardware."); - return {}; -#endif // #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700 } std::unique_ptr
convert_from_rows_fixed_width_optimized(