From 39b8600444d421444458e0fdfd91dbad7eb4c62f Mon Sep 17 00:00:00 2001 From: Dave Baranec Date: Mon, 22 Nov 2021 16:41:04 -0600 Subject: [PATCH 1/8] Add benchmarks for poorly balanced splits --- .../copying/contiguous_split_benchmark.cu | 60 +++++++++++-------- 1 file changed, 36 insertions(+), 24 deletions(-) diff --git a/cpp/benchmarks/copying/contiguous_split_benchmark.cu b/cpp/benchmarks/copying/contiguous_split_benchmark.cu index 506d676d196..bd135d31a59 100644 --- a/cpp/benchmarks/copying/contiguous_split_benchmark.cu +++ b/cpp/benchmarks/copying/contiguous_split_benchmark.cu @@ -33,11 +33,13 @@ void BM_contiguous_split_common(benchmark::State& state, int64_t num_splits, int64_t bytes_total) { - // generate splits - cudf::size_type split_stride = num_rows / num_splits; - std::vector splits; - for (int idx = 0; idx < num_rows; idx += split_stride) { - splits.push_back(std::min(idx + split_stride, static_cast(num_rows))); + // generate splits + std::vector splits; + if(num_splits > 0){ + cudf::size_type split_stride = num_rows / num_splits; + for (int idx = 0; idx < num_rows; idx += split_stride) { + splits.push_back(std::min(idx + split_stride, static_cast(num_rows))); + } } std::vector> columns(src_cols.size()); @@ -53,7 +55,8 @@ void BM_contiguous_split_common(benchmark::State& state, auto result = cudf::contiguous_split(src_table, splits); } - state.SetBytesProcessed(static_cast(state.iterations()) * bytes_total); + // it's 2x bytes_total because we're both reading and writing. + state.SetBytesProcessed(static_cast(state.iterations()) * bytes_total * 2); } class ContiguousSplit : public cudf::benchmark { @@ -61,13 +64,13 @@ class ContiguousSplit : public cudf::benchmark { void BM_contiguous_split(benchmark::State& state) { - int64_t total_desired_bytes = state.range(0); - cudf::size_type num_cols = state.range(1); - cudf::size_type num_splits = state.range(2); - bool include_validity = state.range(3) == 0 ? false : true; + int64_t const total_desired_bytes = state.range(0); + cudf::size_type const num_cols = state.range(1); + cudf::size_type const num_splits = state.range(2); + bool const include_validity = state.range(3) == 0 ? false : true; cudf::size_type el_size = 4; // ints and floats - int64_t num_rows = total_desired_bytes / (num_cols * el_size); + int64_t const num_rows = total_desired_bytes / (num_cols * el_size); // generate input table srand(31337); @@ -84,9 +87,9 @@ void BM_contiguous_split(benchmark::State& state) cudf::test::fixed_width_column_wrapper(rand_elements, rand_elements + num_rows); } } - - size_t total_bytes = total_desired_bytes; - if (include_validity) { total_bytes += num_rows / (sizeof(cudf::bitmask_type) * 8); } + + int64_t const total_bytes = total_desired_bytes + + (include_validity ? min(int64_t{1}, (num_rows / 32)) * 4 * num_cols : 0); BM_contiguous_split_common(state, src_cols, num_rows, num_splits, total_bytes); } @@ -102,17 +105,17 @@ int rand_range(int r) void BM_contiguous_split_strings(benchmark::State& state) { - int64_t total_desired_bytes = state.range(0); - cudf::size_type num_cols = state.range(1); - cudf::size_type num_splits = state.range(2); - bool include_validity = state.range(3) == 0 ? false : true; + int64_t const total_desired_bytes = state.range(0); + cudf::size_type const num_cols = state.range(1); + cudf::size_type const num_splits = state.range(2); + bool const include_validity = state.range(3) == 0 ? false : true; - const int64_t string_len = 8; + int64_t const string_len = 8; std::vector h_strings{ "aaaaaaaa", "bbbbbbbb", "cccccccc", "dddddddd", "eeeeeeee", "ffffffff", "gggggggg", "hhhhhhhh"}; - int64_t col_len_bytes = total_desired_bytes / num_cols; - int64_t num_rows = col_len_bytes / string_len; + int64_t const col_len_bytes = total_desired_bytes / num_cols; + int64_t const num_rows = col_len_bytes / string_len; // generate input table srand(31337); @@ -133,8 +136,9 @@ void BM_contiguous_split_strings(benchmark::State& state) } } - size_t total_bytes = total_desired_bytes + (num_rows * sizeof(cudf::size_type)); - if (include_validity) { total_bytes += num_rows / (sizeof(cudf::bitmask_type) * 8); } + int64_t const total_bytes = total_desired_bytes + + ((num_rows + 1) * sizeof(cudf::size_type)) + + (include_validity ? min(int64_t{1}, (num_rows / 32)) * 4 * num_cols : 0); BM_contiguous_split_common(state, src_cols, num_rows, num_splits, total_bytes); } @@ -157,12 +161,16 @@ CSBM_BENCHMARK_DEFINE(6Gb10ColsValidity, (int64_t)6 * 1024 * 1024 * 1024, 10, 25 CSBM_BENCHMARK_DEFINE(4Gb512ColsNoValidity, (int64_t)4 * 1024 * 1024 * 1024, 512, 256, 0); CSBM_BENCHMARK_DEFINE(4Gb512ColsValidity, (int64_t)4 * 1024 * 1024 * 1024, 512, 256, 1); CSBM_BENCHMARK_DEFINE(4Gb10ColsNoValidity, (int64_t)4 * 1024 * 1024 * 1024, 10, 256, 0); -CSBM_BENCHMARK_DEFINE(46b10ColsValidity, (int64_t)4 * 1024 * 1024 * 1024, 10, 256, 1); +CSBM_BENCHMARK_DEFINE(4Gb10ColsValidity, (int64_t)4 * 1024 * 1024 * 1024, 10, 256, 1); +CSBM_BENCHMARK_DEFINE(4Gb4ColsNoSplits, (int64_t)1 * 1024 * 1024 * 1024, 4, 0, 1); +CSBM_BENCHMARK_DEFINE(4Gb4ColsValidityNoSplits, (int64_t)1 * 1024 * 1024 * 1024, 4, 0, 1); CSBM_BENCHMARK_DEFINE(1Gb512ColsNoValidity, (int64_t)1 * 1024 * 1024 * 1024, 512, 256, 0); CSBM_BENCHMARK_DEFINE(1Gb512ColsValidity, (int64_t)1 * 1024 * 1024 * 1024, 512, 256, 1); CSBM_BENCHMARK_DEFINE(1Gb10ColsNoValidity, (int64_t)1 * 1024 * 1024 * 1024, 10, 256, 0); CSBM_BENCHMARK_DEFINE(1Gb10ColsValidity, (int64_t)1 * 1024 * 1024 * 1024, 10, 256, 1); +CSBM_BENCHMARK_DEFINE(1Gb1ColNoSplits, (int64_t)1 * 1024 * 1024 * 1024, 1, 0, 1); +CSBM_BENCHMARK_DEFINE(1Gb1ColValidityNoSplits, (int64_t)1 * 1024 * 1024 * 1024, 1, 0, 1); #define CSBM_STRINGS_BENCHMARK_DEFINE(name, size, num_columns, num_splits, validity) \ BENCHMARK_DEFINE_F(ContiguousSplitStrings, name)(::benchmark::State & state) \ @@ -179,8 +187,12 @@ CSBM_STRINGS_BENCHMARK_DEFINE(4Gb512ColsNoValidity, (int64_t)4 * 1024 * 1024 * 1 CSBM_STRINGS_BENCHMARK_DEFINE(4Gb512ColsValidity, (int64_t)4 * 1024 * 1024 * 1024, 512, 256, 1); CSBM_STRINGS_BENCHMARK_DEFINE(4Gb10ColsNoValidity, (int64_t)4 * 1024 * 1024 * 1024, 10, 256, 0); CSBM_STRINGS_BENCHMARK_DEFINE(4Gb10ColsValidity, (int64_t)4 * 1024 * 1024 * 1024, 10, 256, 1); +CSBM_STRINGS_BENCHMARK_DEFINE(4Gb4ColsNoSplits, (int64_t)1 * 1024 * 1024 * 1024, 4, 0, 0); +CSBM_STRINGS_BENCHMARK_DEFINE(4Gb4ColsValidityNoSplits, (int64_t)1 * 1024 * 1024 * 1024, 4, 0, 1); CSBM_STRINGS_BENCHMARK_DEFINE(1Gb512ColsNoValidity, (int64_t)1 * 1024 * 1024 * 1024, 512, 256, 0); CSBM_STRINGS_BENCHMARK_DEFINE(1Gb512ColsValidity, (int64_t)1 * 1024 * 1024 * 1024, 512, 256, 1); CSBM_STRINGS_BENCHMARK_DEFINE(1Gb10ColsNoValidity, (int64_t)1 * 1024 * 1024 * 1024, 10, 256, 0); CSBM_STRINGS_BENCHMARK_DEFINE(1Gb10ColsValidity, (int64_t)1 * 1024 * 1024 * 1024, 10, 256, 1); +CSBM_STRINGS_BENCHMARK_DEFINE(1Gb1ColNoSplits, (int64_t)1 * 1024 * 1024 * 1024, 1, 0, 0); +CSBM_STRINGS_BENCHMARK_DEFINE(1Gb1ColValidityNoSplits, (int64_t)1 * 1024 * 1024 * 1024, 1, 0, 1); From c42dac4e5ba71312bddfeeaccf883a4e9323b56f Mon Sep 17 00:00:00 2001 From: Dave Baranec Date: Mon, 22 Nov 2021 17:01:47 -0600 Subject: [PATCH 2/8] Add repartitioning code that balances the amount of the bytes to be copied over all the SMs as best it can. Previously, it was vulnerable to low numbers of splits/columns or large discrepancies between column sizes (eg, large string columns combined with columns of shorts for example). --- .../copying/contiguous_split_benchmark.cu | 25 +- cpp/src/copying/contiguous_split.cu | 249 +++++++++++++++--- 2 files changed, 225 insertions(+), 49 deletions(-) diff --git a/cpp/benchmarks/copying/contiguous_split_benchmark.cu b/cpp/benchmarks/copying/contiguous_split_benchmark.cu index bd135d31a59..16f34c0f047 100644 --- a/cpp/benchmarks/copying/contiguous_split_benchmark.cu +++ b/cpp/benchmarks/copying/contiguous_split_benchmark.cu @@ -33,9 +33,9 @@ void BM_contiguous_split_common(benchmark::State& state, int64_t num_splits, int64_t bytes_total) { - // generate splits - std::vector splits; - if(num_splits > 0){ + // generate splits + std::vector splits; + if (num_splits > 0) { cudf::size_type split_stride = num_rows / num_splits; for (int idx = 0; idx < num_rows; idx += split_stride) { splits.push_back(std::min(idx + split_stride, static_cast(num_rows))); @@ -55,7 +55,7 @@ void BM_contiguous_split_common(benchmark::State& state, auto result = cudf::contiguous_split(src_table, splits); } - // it's 2x bytes_total because we're both reading and writing. + // it's 2x bytes_total because we're both reading and writing. state.SetBytesProcessed(static_cast(state.iterations()) * bytes_total * 2); } @@ -70,7 +70,7 @@ void BM_contiguous_split(benchmark::State& state) bool const include_validity = state.range(3) == 0 ? false : true; cudf::size_type el_size = 4; // ints and floats - int64_t const num_rows = total_desired_bytes / (num_cols * el_size); + int64_t const num_rows = total_desired_bytes / (num_cols * el_size); // generate input table srand(31337); @@ -87,9 +87,11 @@ void BM_contiguous_split(benchmark::State& state) cudf::test::fixed_width_column_wrapper(rand_elements, rand_elements + num_rows); } } - - int64_t const total_bytes = total_desired_bytes + - (include_validity ? min(int64_t{1}, (num_rows / 32)) * 4 * num_cols : 0); + + int64_t const total_bytes = + total_desired_bytes + + (include_validity ? (max(int64_t{1}, (num_rows / 32)) * sizeof(cudf::bitmask_type) * num_cols) + : 0); BM_contiguous_split_common(state, src_cols, num_rows, num_splits, total_bytes); } @@ -136,9 +138,10 @@ void BM_contiguous_split_strings(benchmark::State& state) } } - int64_t const total_bytes = total_desired_bytes + - ((num_rows + 1) * sizeof(cudf::size_type)) + - (include_validity ? min(int64_t{1}, (num_rows / 32)) * 4 * num_cols : 0); + int64_t const total_bytes = + total_desired_bytes + ((num_rows + 1) * sizeof(cudf::offset_type)) + + (include_validity ? (max(int64_t{1}, (num_rows / 32)) * sizeof(cudf::bitmask_type) * num_cols) + : 0); BM_contiguous_split_common(state, src_cols, num_rows, num_splits, total_bytes); } diff --git a/cpp/src/copying/contiguous_split.cu b/cpp/src/copying/contiguous_split.cu index a9194ceea93..893f4d18e6a 100644 --- a/cpp/src/copying/contiguous_split.cu +++ b/cpp/src/copying/contiguous_split.cu @@ -30,6 +30,7 @@ #include #include +#include #include #include @@ -89,16 +90,21 @@ struct src_buf_info { * M partitions, then we have N*M destination buffers. */ struct dst_buf_info { + // constant across all copy commands for this buffer std::size_t buf_size; // total size of buffer, including padding int num_elements; // # of elements to be copied int element_size; // size of each element in bytes - int num_rows; // # of rows (which may be different from num_elements in the case of validity or - // offset buffers) - int src_row_index; // row index to start reading from from my associated source buffer + int num_rows; // # of rows to be copied(which may be different from num_elements in the case of + // validity or offset buffers) + + int src_element_index; // element index to start reading from from my associated source buffer std::size_t dst_offset; // my offset into the per-partition allocation int value_shift; // amount to shift values down by (for offset buffers) int bit_shift; // # of bits to shift right by (for validity buffers) - size_type valid_count; + size_type valid_count; // validity count for this block of work + + int src_buf_index; // source buffer index + int dst_buf_index; // destination buffer index }; /** @@ -123,7 +129,7 @@ struct dst_buf_info { * @param t Thread index * @param num_elements Number of elements to copy * @param element_size Size of each element in bytes - * @param src_row_index Row index to start copying at + * @param src_element_index Element index to start copying at * @param stride Size of the kernel block * @param value_shift Shift incoming 4-byte offset values down by this amount * @param bit_shift Shift incoming data right by this many bits @@ -136,14 +142,14 @@ __device__ void copy_buffer(uint8_t* __restrict__ dst, int t, std::size_t num_elements, std::size_t element_size, - std::size_t src_row_index, + std::size_t src_element_index, uint32_t stride, int value_shift, int bit_shift, std::size_t num_rows, size_type* valid_count) { - src += (src_row_index * element_size); + src += (src_element_index * element_size); size_type thread_valid_count = 0; @@ -240,38 +246,36 @@ __device__ void copy_buffer(uint8_t* __restrict__ dst, } /** - * @brief Kernel which copies a single buffer from a set of partitioned - * column buffers. + * @brief Kernel which copies data from multiple source buffers to multiple + * destination buffers. * * When doing a contiguous_split on X columns comprising N total internal buffers * with M splits, we end up having to copy N*M source/destination buffer pairs. + * These logical copies are further subdivided to distribute the amount of work + * to be done as evenly as possible across the multiprocessors on the device. * This kernel is arranged such that each block copies 1 source/destination pair. - * This function retrieves the relevant buffers and then calls copy_buffer to perform - * the actual copy. * - * @param num_src_bufs Total number of source buffers (N) - * @param src_bufs Input source buffers (N) - * @param dst_bufs Destination buffers (N*M) + * @param src_bufs Input source buffers + * @param dst_bufs Destination buffers * @param buf_info Information on the range of values to be copied for each destination buffer. */ template -__global__ void copy_partition(int num_src_bufs, - uint8_t const** src_bufs, - uint8_t** dst_bufs, - dst_buf_info* buf_info) +__global__ void copy_partitions(uint8_t const** src_bufs, + uint8_t** dst_bufs, + dst_buf_info* buf_info) { - int const partition_index = blockIdx.x / num_src_bufs; - int const src_buf_index = blockIdx.x % num_src_bufs; - std::size_t const buf_index = (partition_index * num_src_bufs) + src_buf_index; + auto const buf_index = blockIdx.x; + auto const src_buf_index = buf_info[buf_index].src_buf_index; + auto const dst_buf_index = buf_info[buf_index].dst_buf_index; // copy, shifting offsets and validity bits as needed copy_buffer( - dst_bufs[partition_index] + buf_info[buf_index].dst_offset, + dst_bufs[dst_buf_index] + buf_info[buf_index].dst_offset, src_bufs[src_buf_index], threadIdx.x, buf_info[buf_index].num_elements, buf_info[buf_index].element_size, - buf_info[buf_index].src_row_index, + buf_info[buf_index].src_element_index, blockDim.x, buf_info[buf_index].value_shift, buf_info[buf_index].bit_shift, @@ -742,6 +746,32 @@ struct dst_offset_output_iterator { reference __device__ dereference(dst_buf_info* c) { return c->dst_offset; } }; +/** + * @brief Output iterator for writing values to the valid_count field of the + * dst_buf_info struct + */ +struct dst_valid_count_output_iterator { + dst_buf_info* c; + using value_type = size_type; + using difference_type = size_type; + using pointer = size_type*; + using reference = size_type&; + using iterator_category = thrust::output_device_iterator_tag; + + dst_valid_count_output_iterator operator+ __host__ __device__(int i) + { + return dst_valid_count_output_iterator{c + i}; + } + + void operator++ __host__ __device__() { c++; } + + reference operator[] __device__(int i) { return dereference(c + i); } + reference operator* __device__() { return dereference(c); } + + private: + reference __device__ dereference(dst_buf_info* c) { return c->valid_count; } +}; + /** * @brief Functor for computing size of data elements for a given cudf type. * @@ -762,6 +792,150 @@ struct size_of_helper { } }; +/** + * @brief Functor for returning the number of chunks an input buffer is being + * subdivided into during the repartitioning step. + * + * Note: columns types which themselves inherently have no data (strings, lists, + * structs) return 0. + */ +struct num_chunks_func { + thrust::pair const* chunks; + __device__ size_t operator()(size_type i) const { return thrust::get<0>(chunks[i]); } +}; + +void copy_data(size_t total_bytes, + int num_bufs, + int num_src_bufs, + uint8_t const** d_src_bufs, + uint8_t** d_dst_bufs, + dst_buf_info* _d_dst_buf_info, + rmm::cuda_stream_view stream) +{ + // ideally we'd like to give each SM a similar amount of work to do so that a.) we keep all + // of them saturated and b.) we don't have any long-running outliers. + // our incoming dst_buf_info data is the exact description of what the output should look like. + // let's do some examination of what's being copied and potentially break things up into + // more pieces to parallelize better. + int device; + cudaGetDevice(&device); + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, device); + + // distribute the # of chunks to be copied roughly evenly among the SMs we have + rmm::device_uvector> chunks(num_bufs, stream); + + auto const num_sms = prop.multiProcessorCount; + thrust::transform( + rmm::exec_policy(stream), + _d_dst_buf_info, + _d_dst_buf_info + num_bufs, + chunks.begin(), + [total_bytes_f = static_cast(total_bytes), num_sms] __device__(dst_buf_info const& buf) { + // how many chunks do we want to subdivide this buffer into + size_t const bytes = buf.num_elements * buf.element_size; + // can happen for things like lists and strings (the root columns store no data) + if (bytes == 0) { return thrust::pair{1, 0}; } + float const fraction = static_cast(bytes) / total_bytes_f; + size_t const ideal_num_chunks = max(size_t{1}, static_cast(fraction * num_sms)); + + // make sure chunks are padded/aligned to 64 bytes. + size_t const chunk_size = + max(size_t{split_align}, _round_up_safe(bytes / ideal_num_chunks, split_align)); + size_t const num_chunks = _round_up_safe(bytes, chunk_size) / chunk_size; + return thrust::pair{num_chunks, chunk_size}; + }); + + rmm::device_uvector chunk_offsets(num_bufs + 1, stream); + auto buf_count_iter = cudf::detail::make_counting_transform_iterator( + 0, [num_bufs, num_chunks = num_chunks_func{chunks.begin()}] __device__(size_type i) { + return i == num_bufs ? 0 : num_chunks(i); + }); + thrust::exclusive_scan(rmm::exec_policy(stream), + buf_count_iter, + buf_count_iter + num_bufs + 1, + chunk_offsets.begin(), + 0); + + auto out_to_in_index = [chunk_offsets = chunk_offsets.begin(), num_bufs] __device__(size_type i) { + return static_cast( + thrust::upper_bound(thrust::seq, chunk_offsets, chunk_offsets + num_bufs + 1, i) - + chunk_offsets) - + 1; + }; + + // apply the chunking. + auto num_chunks = + cudf::detail::make_counting_transform_iterator(0, num_chunks_func{chunks.begin()}); + size_type new_buf_count = + thrust::reduce(rmm::exec_policy(stream), num_chunks, num_chunks + chunks.size()); + rmm::device_uvector d_dst_buf_info(new_buf_count, stream); + auto iter = thrust::make_counting_iterator(0); + thrust::for_each(rmm::exec_policy(stream), + iter, + iter + new_buf_count, + [_d_dst_buf_info, + d_dst_buf_info = d_dst_buf_info.begin(), + chunks = chunks.begin(), + chunk_offsets = chunk_offsets.begin(), + num_bufs, + num_src_bufs, + out_to_in_index] __device__(size_type i) { + size_type const in_buf_index = out_to_in_index(i); + size_type const chunk_index = i - chunk_offsets[in_buf_index]; + auto const chunk_size = thrust::get<1>(chunks[in_buf_index]); + dst_buf_info const& in = _d_dst_buf_info[in_buf_index]; + + // adjust info + dst_buf_info& out = d_dst_buf_info[i]; + out.element_size = in.element_size; + out.value_shift = in.value_shift; + out.bit_shift = in.bit_shift; + out.valid_count = + in.valid_count; // valid count will be set to 1 if this is a validity buffer + out.src_buf_index = in.src_buf_index; + out.dst_buf_index = in.dst_buf_index; + + size_type const elements_per_chunk = + out.element_size == 0 ? 0 : chunk_size / out.element_size; + out.num_elements = ((chunk_index + 1) * elements_per_chunk) > in.num_elements + ? in.num_elements - (chunk_index * elements_per_chunk) + : elements_per_chunk; + + size_type const rows_per_chunk = + out.valid_count > 0 ? elements_per_chunk * 32 : elements_per_chunk; + out.num_rows = ((chunk_index + 1) * rows_per_chunk) > in.num_rows + ? in.num_rows - (chunk_index * rows_per_chunk) + : rows_per_chunk; + + out.src_element_index = + in.src_element_index + (chunk_index * elements_per_chunk); + out.dst_offset = in.dst_offset + (chunk_index * chunk_size); + + // out.bytes and out.buf_size are unneeded here because they are only used to + // calculate real output buffer sizes. the data we are generating here is + // purely intermediate for the purposes of doing more uniform copying of data + // underneath the final structure of the output + }); + + // perform the copy + constexpr size_type block_size = 256; + copy_partitions<<>>( + d_src_bufs, d_dst_bufs, d_dst_buf_info.data()); + + // postprocess valid_counts + auto keys = cudf::detail::make_counting_transform_iterator( + 0, [out_to_in_index] __device__(size_type i) { return out_to_in_index(i); }); + auto values = thrust::make_transform_iterator( + d_dst_buf_info.begin(), [] __device__(dst_buf_info const& info) { return info.valid_count; }); + thrust::reduce_by_key(rmm::exec_policy(stream), + keys, + keys + new_buf_count, + values, + thrust::make_discard_iterator(), + dst_valid_count_output_iterator{_d_dst_buf_info}); +} + }; // anonymous namespace namespace detail { @@ -933,9 +1107,9 @@ std::vector contiguous_split(cudf::table_view const& input, } } - // final row indices and row count - int const out_row_index = src_info.is_validity ? row_start / 32 : row_start; - int const num_rows = row_end - row_start; + // final element indices and row count + int const out_element_index = src_info.is_validity ? row_start / 32 : row_start; + int const num_rows = row_end - row_start; // if I am an offsets column, all my values need to be shifted int const value_shift = src_info.offsets == nullptr ? 0 : src_info.offsets[row_start]; // if I am a validity column, we may need to shift bits @@ -953,15 +1127,17 @@ std::vector contiguous_split(cudf::table_view const& input, std::size_t const bytes = static_cast(num_elements) * static_cast(element_size); - return dst_buf_info{_round_up_safe(bytes, 64), + return dst_buf_info{_round_up_safe(bytes, split_align), num_elements, element_size, num_rows, - out_row_index, + out_element_index, 0, value_shift, bit_shift, - src_info.is_validity ? 1 : 0}; + src_info.is_validity ? 1 : 0, + src_buf_index, + split_index}; }); // compute total size of each partition @@ -1006,10 +1182,12 @@ std::vector contiguous_split(cudf::table_view const& input, // allocate output partition buffers std::vector out_buffers; out_buffers.reserve(num_partitions); + size_t total_bytes = 0; std::transform(h_buf_sizes, h_buf_sizes + num_partitions, std::back_inserter(out_buffers), - [stream, mr](std::size_t bytes) { + [stream, mr, &total_bytes](std::size_t bytes) { + total_bytes += bytes; return rmm::device_buffer{bytes, stream, mr}; }); @@ -1043,12 +1221,8 @@ std::vector contiguous_split(cudf::table_view const& input, CUDA_TRY(cudaMemcpyAsync( d_src_bufs, h_src_bufs, src_bufs_size + dst_bufs_size, cudaMemcpyHostToDevice, stream.value())); - // copy. 1 block per buffer - { - constexpr size_type block_size = 256; - copy_partition<<>>( - num_src_bufs, d_src_bufs, d_dst_bufs, d_dst_buf_info); - } + // perform the copy. + copy_data(total_bytes, num_bufs, num_src_bufs, d_src_bufs, d_dst_bufs, d_dst_buf_info, stream); // DtoH dst info (to retrieve null counts) CUDA_TRY(cudaMemcpyAsync( @@ -1078,7 +1252,6 @@ std::vector contiguous_split(cudf::table_view const& input, cols.clear(); } - return result; } @@ -1092,4 +1265,4 @@ std::vector contiguous_split(cudf::table_view const& input, return cudf::detail::contiguous_split(input, splits, rmm::cuda_stream_default, mr); } -}; // namespace cudf +}; // namespace cudf \ No newline at end of file From 4fba4cb44cda708ee8a8608b4c93548747ba6f0a Mon Sep 17 00:00:00 2001 From: Dave Baranec Date: Mon, 29 Nov 2021 14:08:50 -0600 Subject: [PATCH 3/8] Add a test that forces repartitioning of validity buffer copies. --- cpp/tests/copying/split_tests.cpp | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/cpp/tests/copying/split_tests.cpp b/cpp/tests/copying/split_tests.cpp index f7714ce9ac7..ea98200b309 100644 --- a/cpp/tests/copying/split_tests.cpp +++ b/cpp/tests/copying/split_tests.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -1315,6 +1316,32 @@ TEST_F(ContiguousSplitUntypedTest, ProgressiveSizes) } } +TEST_F(ContiguousSplitUntypedTest, ValidityRepartition) +{ + // it is tricky to actually get the internal repartitioning/load-balancing code to add new splits + // inside a validity buffer. Under almost all situations, the fraction of bytes that validity + // represents is so small compared to the bytes for all other data, that those buffers end up not + // getting subdivided. this test forces it happen by using a small, single column of int8's, which + // keeps the overall fraction that validity takes up large enough to cause a repartition. + srand(0); + auto rvalids = cudf::detail::make_counting_transform_iterator(0, [](auto i) { + return static_cast(rand()) / static_cast(RAND_MAX) < 0.5f ? 0 : 1; + }); + cudf::size_type const num_rows = 2000000; + auto col = cudf::sequence(num_rows, cudf::numeric_scalar{0}); + col->set_null_mask(cudf::test::detail::make_null_mask(rvalids, rvalids + num_rows)); + + cudf::table_view t({*col}); + auto result = cudf::contiguous_split(t, {num_rows / 2}); + auto expected = cudf::split(t, {num_rows / 2}); + CUDF_EXPECTS(result.size() == expected.size(), + "Mismatch in split results in ValidityRepartition test"); + + for (size_t idx = 0; idx < result.size(); idx++) { + CUDF_TEST_EXPECT_TABLES_EQUAL(result[idx].table, expected[idx]); + } +} + // contiguous split with strings struct ContiguousSplitStringTableTest : public SplitTest { }; From a09fa9a92f96c6299c72e7371610d09b80d742b8 Mon Sep 17 00:00:00 2001 From: Dave Baranec Date: Mon, 10 Jan 2022 11:26:20 -0600 Subject: [PATCH 4/8] Change how chunk subpartitioning works to be simpler and more uniform. Simply subdivide each buffer to be copied into 1 MB chunks. --- cpp/src/copying/contiguous_split.cu | 70 +++++++++++++---------------- 1 file changed, 30 insertions(+), 40 deletions(-) diff --git a/cpp/src/copying/contiguous_split.cu b/cpp/src/copying/contiguous_split.cu index c1ed913c6f3..5dafb666623 100644 --- a/cpp/src/copying/contiguous_split.cu +++ b/cpp/src/copying/contiguous_split.cu @@ -804,47 +804,39 @@ struct num_chunks_func { __device__ size_t operator()(size_type i) const { return thrust::get<0>(chunks[i]); } }; -void copy_data(size_t total_bytes, - int num_bufs, +void copy_data(int num_bufs, int num_src_bufs, uint8_t const** d_src_bufs, uint8_t** d_dst_bufs, dst_buf_info* _d_dst_buf_info, rmm::cuda_stream_view stream) { - // ideally we'd like to give each SM a similar amount of work to do so that a.) we keep all - // of them saturated and b.) we don't have any long-running outliers. - // our incoming dst_buf_info data is the exact description of what the output should look like. - // let's do some examination of what's being copied and potentially break things up into - // more pieces to parallelize better. - int device; - cudaGetDevice(&device); - cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, device); - - // distribute the # of chunks to be copied roughly evenly among the SMs we have + // since we parallelize at one block per copy, we are vulnerable to situations where we + // have small numbers of copies to do (a combination of small numbers of splits and/or columns). + // so we will take the actual set of outgoing source/destination buffers and further partition + // them into much smaller chunks in order to drive up the number of blocks and overall occupancy. + auto const desired_chunk_size = size_t{1 * 1024 * 1024}; rmm::device_uvector> chunks(num_bufs, stream); - - auto const num_sms = prop.multiProcessorCount; - thrust::transform( - rmm::exec_policy(stream), - _d_dst_buf_info, - _d_dst_buf_info + num_bufs, - chunks.begin(), - [total_bytes_f = static_cast(total_bytes), num_sms] __device__(dst_buf_info const& buf) { - // how many chunks do we want to subdivide this buffer into - size_t const bytes = buf.num_elements * buf.element_size; - // can happen for things like lists and strings (the root columns store no data) - if (bytes == 0) { return thrust::pair{1, 0}; } - float const fraction = static_cast(bytes) / total_bytes_f; - size_t const ideal_num_chunks = max(size_t{1}, static_cast(fraction * num_sms)); - - // make sure chunks are padded/aligned to 64 bytes. - size_t const chunk_size = - max(size_t{split_align}, util::round_up_unsafe(bytes / ideal_num_chunks, split_align)); - size_t const num_chunks = util::round_up_unsafe(bytes, chunk_size) / chunk_size; - return thrust::pair{num_chunks, chunk_size}; - }); + thrust::transform(rmm::exec_policy(stream), + _d_dst_buf_info, + _d_dst_buf_info + num_bufs, + chunks.begin(), + [desired_chunk_size] __device__(dst_buf_info const& buf) { + // how many chunks do we want to subdivide this buffer into + size_t const bytes = buf.num_elements * buf.element_size; + + // can happen for things like lists and strings (the root columns store no + // data) + if (bytes == 0) { return thrust::pair{1, 0}; } + size_t const num_chunks = + max(size_t{1}, + util::round_up_unsafe(bytes, desired_chunk_size) / desired_chunk_size); + + // NOTE: leaving chunk size as a separate parameter for future tuning + // possibilities, even though in the current implemenetation it will be a + // constant. + return thrust::pair{num_chunks, desired_chunk_size}; + }); rmm::device_uvector chunk_offsets(num_bufs + 1, stream); auto buf_count_iter = cudf::detail::make_counting_transform_iterator( @@ -865,9 +857,9 @@ void copy_data(size_t total_bytes, }; // apply the chunking. - auto num_chunks = + auto const num_chunks = cudf::detail::make_counting_transform_iterator(0, num_chunks_func{chunks.begin()}); - size_type new_buf_count = + size_type const new_buf_count = thrust::reduce(rmm::exec_policy(stream), num_chunks, num_chunks + chunks.size()); rmm::device_uvector d_dst_buf_info(new_buf_count, stream); auto iter = thrust::make_counting_iterator(0); @@ -1182,12 +1174,10 @@ std::vector contiguous_split(cudf::table_view const& input, // allocate output partition buffers std::vector out_buffers; out_buffers.reserve(num_partitions); - size_t total_bytes = 0; std::transform(h_buf_sizes, h_buf_sizes + num_partitions, std::back_inserter(out_buffers), - [stream, mr, &total_bytes](std::size_t bytes) { - total_bytes += bytes; + [stream, mr](std::size_t bytes) { return rmm::device_buffer{bytes, stream, mr}; }); @@ -1222,7 +1212,7 @@ std::vector contiguous_split(cudf::table_view const& input, d_src_bufs, h_src_bufs, src_bufs_size + dst_bufs_size, cudaMemcpyHostToDevice, stream.value())); // perform the copy. - copy_data(total_bytes, num_bufs, num_src_bufs, d_src_bufs, d_dst_bufs, d_dst_buf_info, stream); + copy_data(num_bufs, num_src_bufs, d_src_bufs, d_dst_bufs, d_dst_buf_info, stream); // DtoH dst info (to retrieve null counts) CUDA_TRY(cudaMemcpyAsync( From f453196bcbea4723babf710e58e81bb5e81a38da Mon Sep 17 00:00:00 2001 From: Dave Baranec Date: Mon, 10 Jan 2022 13:54:12 -0600 Subject: [PATCH 5/8] PR review changes. --- .../copying/contiguous_split_benchmark.cu | 13 ++++-- cpp/src/copying/contiguous_split.cu | 45 +++++++++---------- 2 files changed, 30 insertions(+), 28 deletions(-) diff --git a/cpp/benchmarks/copying/contiguous_split_benchmark.cu b/cpp/benchmarks/copying/contiguous_split_benchmark.cu index 16f34c0f047..e2c81fca859 100644 --- a/cpp/benchmarks/copying/contiguous_split_benchmark.cu +++ b/cpp/benchmarks/copying/contiguous_split_benchmark.cu @@ -36,10 +36,15 @@ void BM_contiguous_split_common(benchmark::State& state, // generate splits std::vector splits; if (num_splits > 0) { - cudf::size_type split_stride = num_rows / num_splits; - for (int idx = 0; idx < num_rows; idx += split_stride) { - splits.push_back(std::min(idx + split_stride, static_cast(num_rows))); - } + cudf::size_type const split_stride = num_rows / num_splits; + auto iter = thrust::make_counting_iterator(1); + splits.reserve(num_splits); + std::transform(iter, + iter + num_splits, + std::back_inserter(splits), + [split_stride, num_rows](cudf::size_type i) { + return std::min(i * split_stride, static_cast(num_rows)); + }); } std::vector> columns(src_cols.size()); diff --git a/cpp/src/copying/contiguous_split.cu b/cpp/src/copying/contiguous_split.cu index 5dafb666623..d059ce44fb3 100644 --- a/cpp/src/copying/contiguous_split.cu +++ b/cpp/src/copying/contiguous_split.cu @@ -732,10 +732,7 @@ struct dst_offset_output_iterator { using reference = std::size_t&; using iterator_category = thrust::output_device_iterator_tag; - dst_offset_output_iterator operator+ __host__ __device__(int i) - { - return dst_offset_output_iterator{c + i}; - } + dst_offset_output_iterator operator+ __host__ __device__(int i) { return {c + i}; } void operator++ __host__ __device__() { c++; } @@ -817,26 +814,26 @@ void copy_data(int num_bufs, // them into much smaller chunks in order to drive up the number of blocks and overall occupancy. auto const desired_chunk_size = size_t{1 * 1024 * 1024}; rmm::device_uvector> chunks(num_bufs, stream); - thrust::transform(rmm::exec_policy(stream), - _d_dst_buf_info, - _d_dst_buf_info + num_bufs, - chunks.begin(), - [desired_chunk_size] __device__(dst_buf_info const& buf) { - // how many chunks do we want to subdivide this buffer into - size_t const bytes = buf.num_elements * buf.element_size; - - // can happen for things like lists and strings (the root columns store no - // data) - if (bytes == 0) { return thrust::pair{1, 0}; } - size_t const num_chunks = - max(size_t{1}, - util::round_up_unsafe(bytes, desired_chunk_size) / desired_chunk_size); - - // NOTE: leaving chunk size as a separate parameter for future tuning - // possibilities, even though in the current implemenetation it will be a - // constant. - return thrust::pair{num_chunks, desired_chunk_size}; - }); + thrust::transform( + rmm::exec_policy(stream), + _d_dst_buf_info, + _d_dst_buf_info + num_bufs, + chunks.begin(), + [desired_chunk_size] __device__(dst_buf_info const& buf) -> thrust::pair { + // how many chunks do we want to subdivide this buffer into + size_t const bytes = buf.num_elements * buf.element_size; + + // can happen for things like lists and strings (the root columns store no + // data) + if (bytes == 0) { return {1, 0}; } + size_t const num_chunks = + max(size_t{1}, util::round_up_unsafe(bytes, desired_chunk_size) / desired_chunk_size); + + // NOTE: leaving chunk size as a separate parameter for future tuning + // possibilities, even though in the current implemenetation it will be a + // constant. + return {num_chunks, desired_chunk_size}; + }); rmm::device_uvector chunk_offsets(num_bufs + 1, stream); auto buf_count_iter = cudf::detail::make_counting_transform_iterator( From 11c4113d9881e17d4a34353bdc9b9cbb0378b2b5 Mon Sep 17 00:00:00 2001 From: Dave Baranec Date: Tue, 11 Jan 2022 10:55:47 -0600 Subject: [PATCH 6/8] PR review changes. --- cpp/benchmarks/copying/contiguous_split_benchmark.cu | 5 +++-- cpp/src/copying/contiguous_split.cu | 2 ++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/cpp/benchmarks/copying/contiguous_split_benchmark.cu b/cpp/benchmarks/copying/contiguous_split_benchmark.cu index e2c81fca859..065af52b3a9 100644 --- a/cpp/benchmarks/copying/contiguous_split_benchmark.cu +++ b/cpp/benchmarks/copying/contiguous_split_benchmark.cu @@ -37,7 +37,8 @@ void BM_contiguous_split_common(benchmark::State& state, std::vector splits; if (num_splits > 0) { cudf::size_type const split_stride = num_rows / num_splits; - auto iter = thrust::make_counting_iterator(1); + // start after the first element. + auto iter = thrust::make_counting_iterator(1); splits.reserve(num_splits); std::transform(iter, iter + num_splits, @@ -117,7 +118,7 @@ void BM_contiguous_split_strings(benchmark::State& state) cudf::size_type const num_splits = state.range(2); bool const include_validity = state.range(3) == 0 ? false : true; - int64_t const string_len = 8; + constexpr int64_t string_len = 8; std::vector h_strings{ "aaaaaaaa", "bbbbbbbb", "cccccccc", "dddddddd", "eeeeeeee", "ffffffff", "gggggggg", "hhhhhhhh"}; diff --git a/cpp/src/copying/contiguous_split.cu b/cpp/src/copying/contiguous_split.cu index d059ce44fb3..6f8368dba69 100644 --- a/cpp/src/copying/contiguous_split.cu +++ b/cpp/src/copying/contiguous_split.cu @@ -892,6 +892,8 @@ void copy_data(int num_bufs, : elements_per_chunk; size_type const rows_per_chunk = + // if this is a validity buffer, each element is a bitmask_type, which + // corresponds to 32 rows. out.valid_count > 0 ? elements_per_chunk * 32 : elements_per_chunk; out.num_rows = ((chunk_index + 1) * rows_per_chunk) > in.num_rows ? in.num_rows - (chunk_index * rows_per_chunk) From 283fd54e7b9f8a8be574cf3c67928495a5d408d4 Mon Sep 17 00:00:00 2001 From: Dave Baranec Date: Tue, 11 Jan 2022 13:28:41 -0600 Subject: [PATCH 7/8] Comments and cleanup from PR reviews. --- cpp/src/copying/contiguous_split.cu | 112 ++++++++++++++-------------- 1 file changed, 58 insertions(+), 54 deletions(-) diff --git a/cpp/src/copying/contiguous_split.cu b/cpp/src/copying/contiguous_split.cu index 6f8368dba69..f4b3950d5fa 100644 --- a/cpp/src/copying/contiguous_split.cu +++ b/cpp/src/copying/contiguous_split.cu @@ -808,8 +808,8 @@ void copy_data(int num_bufs, dst_buf_info* _d_dst_buf_info, rmm::cuda_stream_view stream) { - // since we parallelize at one block per copy, we are vulnerable to situations where we - // have small numbers of copies to do (a combination of small numbers of splits and/or columns). + // Since we parallelize at one block per copy, we are vulnerable to situations where we + // have small numbers of copies to do (a combination of small numbers of splits and/or columns), // so we will take the actual set of outgoing source/destination buffers and further partition // them into much smaller chunks in order to drive up the number of blocks and overall occupancy. auto const desired_chunk_size = size_t{1 * 1024 * 1024}; @@ -820,17 +820,19 @@ void copy_data(int num_bufs, _d_dst_buf_info + num_bufs, chunks.begin(), [desired_chunk_size] __device__(dst_buf_info const& buf) -> thrust::pair { - // how many chunks do we want to subdivide this buffer into + // Total bytes for this incoming partition size_t const bytes = buf.num_elements * buf.element_size; - // can happen for things like lists and strings (the root columns store no - // data) + // This clause handles nested data types (e.g. list or string) that store no data in the roow + // columns, only in their children. if (bytes == 0) { return {1, 0}; } + + // The number of chunks we want to subdivide this buffer into size_t const num_chunks = max(size_t{1}, util::round_up_unsafe(bytes, desired_chunk_size) / desired_chunk_size); // NOTE: leaving chunk size as a separate parameter for future tuning - // possibilities, even though in the current implemenetation it will be a + // possibilities, even though in the current implementation it will be a // constant. return {num_chunks, desired_chunk_size}; }); @@ -860,54 +862,56 @@ void copy_data(int num_bufs, thrust::reduce(rmm::exec_policy(stream), num_chunks, num_chunks + chunks.size()); rmm::device_uvector d_dst_buf_info(new_buf_count, stream); auto iter = thrust::make_counting_iterator(0); - thrust::for_each(rmm::exec_policy(stream), - iter, - iter + new_buf_count, - [_d_dst_buf_info, - d_dst_buf_info = d_dst_buf_info.begin(), - chunks = chunks.begin(), - chunk_offsets = chunk_offsets.begin(), - num_bufs, - num_src_bufs, - out_to_in_index] __device__(size_type i) { - size_type const in_buf_index = out_to_in_index(i); - size_type const chunk_index = i - chunk_offsets[in_buf_index]; - auto const chunk_size = thrust::get<1>(chunks[in_buf_index]); - dst_buf_info const& in = _d_dst_buf_info[in_buf_index]; - - // adjust info - dst_buf_info& out = d_dst_buf_info[i]; - out.element_size = in.element_size; - out.value_shift = in.value_shift; - out.bit_shift = in.bit_shift; - out.valid_count = - in.valid_count; // valid count will be set to 1 if this is a validity buffer - out.src_buf_index = in.src_buf_index; - out.dst_buf_index = in.dst_buf_index; - - size_type const elements_per_chunk = - out.element_size == 0 ? 0 : chunk_size / out.element_size; - out.num_elements = ((chunk_index + 1) * elements_per_chunk) > in.num_elements - ? in.num_elements - (chunk_index * elements_per_chunk) - : elements_per_chunk; - - size_type const rows_per_chunk = - // if this is a validity buffer, each element is a bitmask_type, which - // corresponds to 32 rows. - out.valid_count > 0 ? elements_per_chunk * 32 : elements_per_chunk; - out.num_rows = ((chunk_index + 1) * rows_per_chunk) > in.num_rows - ? in.num_rows - (chunk_index * rows_per_chunk) - : rows_per_chunk; - - out.src_element_index = - in.src_element_index + (chunk_index * elements_per_chunk); - out.dst_offset = in.dst_offset + (chunk_index * chunk_size); - - // out.bytes and out.buf_size are unneeded here because they are only used to - // calculate real output buffer sizes. the data we are generating here is - // purely intermediate for the purposes of doing more uniform copying of data - // underneath the final structure of the output - }); + thrust::for_each( + rmm::exec_policy(stream), + iter, + iter + new_buf_count, + [_d_dst_buf_info, + d_dst_buf_info = d_dst_buf_info.begin(), + chunks = chunks.begin(), + chunk_offsets = chunk_offsets.begin(), + num_bufs, + num_src_bufs, + out_to_in_index] __device__(size_type i) { + size_type const in_buf_index = out_to_in_index(i); + size_type const chunk_index = i - chunk_offsets[in_buf_index]; + auto const chunk_size = thrust::get<1>(chunks[in_buf_index]); + dst_buf_info const& in = _d_dst_buf_info[in_buf_index]; + + // adjust info + dst_buf_info& out = d_dst_buf_info[i]; + out.element_size = in.element_size; + out.value_shift = in.value_shift; + out.bit_shift = in.bit_shift; + out.valid_count = + in.valid_count; // valid count will be set to 1 if this is a validity buffer + out.src_buf_index = in.src_buf_index; + out.dst_buf_index = in.dst_buf_index; + + size_type const elements_per_chunk = + out.element_size == 0 ? 0 : chunk_size / out.element_size; + out.num_elements = ((chunk_index + 1) * elements_per_chunk) > in.num_elements + ? in.num_elements - (chunk_index * elements_per_chunk) + : elements_per_chunk; + + size_type const rows_per_chunk = + // if this is a validity buffer, each element is a bitmask_type, which + // corresponds to 32 rows. + out.valid_count > 0 + ? elements_per_chunk * static_cast(detail::size_in_bits()) + : elements_per_chunk; + out.num_rows = ((chunk_index + 1) * rows_per_chunk) > in.num_rows + ? in.num_rows - (chunk_index * rows_per_chunk) + : rows_per_chunk; + + out.src_element_index = in.src_element_index + (chunk_index * elements_per_chunk); + out.dst_offset = in.dst_offset + (chunk_index * chunk_size); + + // out.bytes and out.buf_size are unneeded here because they are only used to + // calculate real output buffer sizes. the data we are generating here is + // purely intermediate for the purposes of doing more uniform copying of data + // underneath the final structure of the output + }); // perform the copy constexpr size_type block_size = 256; From 717ab924e60aa28f25fc1859ce56f646cd893f17 Mon Sep 17 00:00:00 2001 From: Dave Baranec Date: Tue, 11 Jan 2022 14:46:24 -0600 Subject: [PATCH 8/8] Copyright date updates. --- cpp/benchmarks/copying/contiguous_split_benchmark.cu | 2 +- cpp/src/copying/contiguous_split.cu | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/benchmarks/copying/contiguous_split_benchmark.cu b/cpp/benchmarks/copying/contiguous_split_benchmark.cu index 065af52b3a9..55e1360efc8 100644 --- a/cpp/benchmarks/copying/contiguous_split_benchmark.cu +++ b/cpp/benchmarks/copying/contiguous_split_benchmark.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/src/copying/contiguous_split.cu b/cpp/src/copying/contiguous_split.cu index f4b3950d5fa..f8c0006ed45 100644 --- a/cpp/src/copying/contiguous_split.cu +++ b/cpp/src/copying/contiguous_split.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License.