Skip to content

Commit

Permalink
Fix memcheck read error in libcudf contiguous_split (#9067)
Browse files Browse the repository at this point in the history
Reference #8883 

The `cudf::contiguous_split` was failing on memcheck using the `compute-sanitizer` with a 4-byte out-of-bounds read. This was traced to the `copy_buffer` device function that was reading 1 past the end of the input buffer when performing a value-shift. The ternary check was incorrectly protecting the out-of-bounds read. The logic is corrected by this PR.

Also, I fixed some `const` removal casts from the same source file by adding appropriate `const` qualifiers to the input data variables.

Authors:
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - Robert Maynard (https://github.com/robertmaynard)
  - Nghia Truong (https://github.com/ttnghia)

URL: #9067
  • Loading branch information
davidwendt authored Aug 23, 2021
1 parent 6cd0167 commit e42464c
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions cpp/src/copying/contiguous_split.cu
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ struct dst_buf_info {
*/
template <int block_size>
__device__ void copy_buffer(uint8_t* __restrict__ dst,
uint8_t* __restrict__ src,
uint8_t const* __restrict__ src,
int t,
std::size_t num_elements,
std::size_t element_size,
Expand Down Expand Up @@ -193,11 +193,12 @@ __device__ void copy_buffer(uint8_t* __restrict__ dst,
// and will never both be true at the same time.
if (value_shift || bit_shift) {
std::size_t idx = (num_bytes - remainder) / 4;
uint32_t v = remainder > 0 ? (reinterpret_cast<uint32_t*>(src)[idx] - value_shift) : 0;
uint32_t v = remainder > 0 ? (reinterpret_cast<uint32_t const*>(src)[idx] - value_shift) : 0;
while (remainder) {
uint32_t const next =
remainder > 0 ? (reinterpret_cast<uint32_t*>(src)[idx + 1] - value_shift) : 0;
uint32_t const val = (v >> bit_shift) | (next << (32 - bit_shift));
uint32_t const next = bit_shift > 0 || remainder > 4
? (reinterpret_cast<uint32_t const*>(src)[idx + 1] - value_shift)
: 0;
uint32_t const val = (v >> bit_shift) | (next << (32 - bit_shift));
if (valid_count) { thread_valid_count += __popc(val); }
reinterpret_cast<uint32_t*>(dst)[idx] = val;
v = next;
Expand All @@ -207,7 +208,7 @@ __device__ void copy_buffer(uint8_t* __restrict__ dst,
} else {
while (remainder) {
std::size_t const idx = num_bytes - remainder--;
uint32_t const val = reinterpret_cast<uint8_t*>(src)[idx];
uint32_t const val = reinterpret_cast<uint8_t const*>(src)[idx];
if (valid_count) { thread_valid_count += __popc(val); }
reinterpret_cast<uint8_t*>(dst)[idx] = val;
}
Expand Down Expand Up @@ -255,7 +256,7 @@ __device__ void copy_buffer(uint8_t* __restrict__ dst,
*/
template <int block_size>
__global__ void copy_partition(int num_src_bufs,
uint8_t** src_bufs,
uint8_t const** src_bufs,
uint8_t** dst_bufs,
dst_buf_info* buf_info)
{
Expand Down Expand Up @@ -349,13 +350,13 @@ OutputIter setup_src_buf_data(InputIter begin, InputIter end, OutputIter out_buf
{
std::for_each(begin, end, [&out_buf](column_view const& col) {
if (col.nullable()) {
*out_buf = reinterpret_cast<uint8_t*>(const_cast<bitmask_type*>(col.null_mask()));
*out_buf = reinterpret_cast<uint8_t const*>(col.null_mask());
out_buf++;
}
// NOTE: we're always returning the base pointer here. column-level offset is accounted
// for later. Also, for some column types (string, list, struct) this pointer will be null
// because there is no associated data with the root column.
*out_buf = const_cast<uint8_t*>(col.head<uint8_t>());
*out_buf = col.head<uint8_t>();
out_buf++;

out_buf = setup_src_buf_data(col.child_begin(), col.child_end(), out_buf);
Expand Down Expand Up @@ -1020,14 +1021,14 @@ std::vector<packed_table> contiguous_split(cudf::table_view const& input,
cudf::util::round_up_safe(num_partitions * sizeof(uint8_t*), split_align);
// host-side
std::vector<uint8_t> h_src_and_dst_buffers(src_bufs_size + dst_bufs_size);
uint8_t** h_src_bufs = reinterpret_cast<uint8_t**>(h_src_and_dst_buffers.data());
uint8_t const** h_src_bufs = reinterpret_cast<uint8_t const**>(h_src_and_dst_buffers.data());
uint8_t** h_dst_bufs = reinterpret_cast<uint8_t**>(h_src_and_dst_buffers.data() + src_bufs_size);
// device-side
rmm::device_buffer d_src_and_dst_buffers(src_bufs_size + dst_bufs_size + offset_stack_size,
stream,
rmm::mr::get_current_device_resource());
uint8_t** d_src_bufs = reinterpret_cast<uint8_t**>(d_src_and_dst_buffers.data());
uint8_t** d_dst_bufs = reinterpret_cast<uint8_t**>(
uint8_t const** d_src_bufs = reinterpret_cast<uint8_t const**>(d_src_and_dst_buffers.data());
uint8_t** d_dst_bufs = reinterpret_cast<uint8_t**>(
reinterpret_cast<uint8_t*>(d_src_and_dst_buffers.data()) + src_bufs_size);

// setup src buffers
Expand Down

0 comments on commit e42464c

Please sign in to comment.