diff --git a/cpp/src/copying/contiguous_split.cu b/cpp/src/copying/contiguous_split.cu index e082624b04d..8fc41fc5a27 100644 --- a/cpp/src/copying/contiguous_split.cu +++ b/cpp/src/copying/contiguous_split.cu @@ -99,8 +99,6 @@ struct dst_buf_info { size_type valid_count; }; -constexpr size_type copy_block_size = 512; - /** * @brief Copy a single buffer of column data, shifting values (for offset columns), * and validity (for validity buffers) as necessary. @@ -130,6 +128,7 @@ constexpr size_type copy_block_size = 512; * @param num_rows Number of rows being copied * @param valid_count Optional pointer to a value to store count of set bits */ +template __device__ void copy_buffer(uint8_t* __restrict__ dst, uint8_t* __restrict__ src, int t, @@ -217,7 +216,7 @@ __device__ void copy_buffer(uint8_t* __restrict__ dst, if (num_bytes == 0) { if (!t) { *valid_count = 0; } } else { - using BlockReduce = cub::BlockReduce; + using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; size_type block_valid_count{BlockReduce(temp_storage).Sum(thread_valid_count)}; if (!t) { @@ -253,6 +252,7 @@ __device__ void copy_buffer(uint8_t* __restrict__ dst, * @param dst_bufs Desination buffers (N*M) * @param buf_info Information on the range of values to be copied for each destination buffer. */ +template __global__ void copy_partition(int num_src_bufs, int num_partitions, uint8_t** src_bufs, @@ -264,17 +264,18 @@ __global__ void copy_partition(int num_src_bufs, size_t const buf_index = (partition_index * num_src_bufs) + src_buf_index; // copy, shifting offsets and validity bits as needed - copy_buffer(dst_bufs[partition_index] + buf_info[buf_index].dst_offset, - src_bufs[src_buf_index], - threadIdx.x, - buf_info[buf_index].num_elements, - buf_info[buf_index].element_size, - buf_info[buf_index].src_row_index, - blockDim.x, - buf_info[buf_index].value_shift, - buf_info[buf_index].bit_shift, - buf_info[buf_index].num_rows, - buf_info[buf_index].valid_count > 0 ? &buf_info[buf_index].valid_count : nullptr); + copy_buffer( + dst_bufs[partition_index] + buf_info[buf_index].dst_offset, + src_bufs[src_buf_index], + threadIdx.x, + buf_info[buf_index].num_elements, + buf_info[buf_index].element_size, + buf_info[buf_index].src_row_index, + blockDim.x, + buf_info[buf_index].value_shift, + buf_info[buf_index].bit_shift, + buf_info[buf_index].num_rows, + buf_info[buf_index].valid_count > 0 ? &buf_info[buf_index].valid_count : nullptr); } // The block of functions below are all related: @@ -1019,7 +1020,8 @@ std::vector contiguous_split(cudf::table_view const& input, // copy. 1 block per buffer { - copy_partition<<>>( + constexpr size_type block_size = 512; + copy_partition<<>>( num_src_bufs, num_partitions, d_src_bufs, d_dst_bufs, d_dst_buf_info); }