Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix memcheck read error in libcudf contiguous_split #9067

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions cpp/src/copying/contiguous_split.cu
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ struct dst_buf_info {
*/
template <int block_size>
__device__ void copy_buffer(uint8_t* __restrict__ dst,
uint8_t* __restrict__ src,
uint8_t const* __restrict__ src,
int t,
std::size_t num_elements,
std::size_t element_size,
Expand Down Expand Up @@ -193,11 +193,12 @@ __device__ void copy_buffer(uint8_t* __restrict__ dst,
// and will never both be true at the same time.
if (value_shift || bit_shift) {
std::size_t idx = (num_bytes - remainder) / 4;
uint32_t v = remainder > 0 ? (reinterpret_cast<uint32_t*>(src)[idx] - value_shift) : 0;
uint32_t v = remainder > 0 ? (reinterpret_cast<uint32_t const*>(src)[idx] - value_shift) : 0;
while (remainder) {
uint32_t const next =
remainder > 0 ? (reinterpret_cast<uint32_t*>(src)[idx + 1] - value_shift) : 0;
uint32_t const val = (v >> bit_shift) | (next << (32 - bit_shift));
uint32_t const next = bit_shift > 0 || remainder > 4
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
? (reinterpret_cast<uint32_t const*>(src)[idx + 1] - value_shift)
: 0;
uint32_t const val = (v >> bit_shift) | (next << (32 - bit_shift));
if (valid_count) { thread_valid_count += __popc(val); }
reinterpret_cast<uint32_t*>(dst)[idx] = val;
v = next;
Expand All @@ -207,7 +208,7 @@ __device__ void copy_buffer(uint8_t* __restrict__ dst,
} else {
while (remainder) {
std::size_t const idx = num_bytes - remainder--;
uint32_t const val = reinterpret_cast<uint8_t*>(src)[idx];
uint32_t const val = reinterpret_cast<uint8_t const*>(src)[idx];
if (valid_count) { thread_valid_count += __popc(val); }
reinterpret_cast<uint8_t*>(dst)[idx] = val;
}
Expand Down Expand Up @@ -255,7 +256,7 @@ __device__ void copy_buffer(uint8_t* __restrict__ dst,
*/
template <int block_size>
__global__ void copy_partition(int num_src_bufs,
uint8_t** src_bufs,
uint8_t const** src_bufs,
uint8_t** dst_bufs,
dst_buf_info* buf_info)
{
Expand Down Expand Up @@ -349,13 +350,13 @@ OutputIter setup_src_buf_data(InputIter begin, InputIter end, OutputIter out_buf
{
std::for_each(begin, end, [&out_buf](column_view const& col) {
if (col.nullable()) {
*out_buf = reinterpret_cast<uint8_t*>(const_cast<bitmask_type*>(col.null_mask()));
*out_buf = reinterpret_cast<uint8_t const*>(col.null_mask());
out_buf++;
}
// NOTE: we're always returning the base pointer here. column-level offset is accounted
// for later. Also, for some column types (string, list, struct) this pointer will be null
// because there is no associated data with the root column.
*out_buf = const_cast<uint8_t*>(col.head<uint8_t>());
*out_buf = col.head<uint8_t>();
out_buf++;

out_buf = setup_src_buf_data(col.child_begin(), col.child_end(), out_buf);
Expand Down Expand Up @@ -1020,14 +1021,14 @@ std::vector<packed_table> contiguous_split(cudf::table_view const& input,
cudf::util::round_up_safe(num_partitions * sizeof(uint8_t*), split_align);
// host-side
std::vector<uint8_t> h_src_and_dst_buffers(src_bufs_size + dst_bufs_size);
uint8_t** h_src_bufs = reinterpret_cast<uint8_t**>(h_src_and_dst_buffers.data());
uint8_t const** h_src_bufs = reinterpret_cast<uint8_t const**>(h_src_and_dst_buffers.data());
uint8_t** h_dst_bufs = reinterpret_cast<uint8_t**>(h_src_and_dst_buffers.data() + src_bufs_size);
// device-side
rmm::device_buffer d_src_and_dst_buffers(src_bufs_size + dst_bufs_size + offset_stack_size,
stream,
rmm::mr::get_current_device_resource());
uint8_t** d_src_bufs = reinterpret_cast<uint8_t**>(d_src_and_dst_buffers.data());
uint8_t** d_dst_bufs = reinterpret_cast<uint8_t**>(
uint8_t const** d_src_bufs = reinterpret_cast<uint8_t const**>(d_src_and_dst_buffers.data());
uint8_t** d_dst_bufs = reinterpret_cast<uint8_t**>(
reinterpret_cast<uint8_t*>(d_src_and_dst_buffers.data()) + src_bufs_size);

// setup src buffers
Expand Down