Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix an out-of-bounds read in validity copying in contiguous_split. #9842

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions cpp/src/copying/contiguous_split.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ struct src_buf_info {
int _offset_stack_pos,
int _parent_offsets_index,
bool _is_validity,
int _column_offset)
size_type _column_offset)
robertmaynard marked this conversation as resolved.
Show resolved Hide resolved
: type(_type),
offsets(_offsets),
offset_stack_pos(_offset_stack_pos),
Expand Down Expand Up @@ -194,11 +194,16 @@ __device__ void copy_buffer(uint8_t* __restrict__ dst,
if (value_shift || bit_shift) {
std::size_t idx = (num_bytes - remainder) / 4;
uint32_t v = remainder > 0 ? (reinterpret_cast<uint32_t const*>(src)[idx] - value_shift) : 0;
auto const have_trailing_bits = ((num_elements * 32) - num_rows) < bit_shift;
while (remainder) {
uint32_t const next = bit_shift > 0 || remainder > 4
? (reinterpret_cast<uint32_t const*>(src)[idx + 1] - value_shift)
: 0;
uint32_t const val = (v >> bit_shift) | (next << (32 - bit_shift));
// if we're at the very last word of a validity copy, we do not always need to read the next
// word to get the final trailing bits.
auto const read_trailing_bits = bit_shift > 0 && remainder == 4 && have_trailing_bits;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the 4 here sizeof(uint32_t)? Not requesting a change, just trying to understand the logic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. We're reading (up to) the trailing 15 bytes of the buffer to be copied. In the case where the buffer happens to be validity, the elements are all bitmask_type words.

The fundamental issue is that if we're copying from some arbitrary row, we have to shift the bits of any validity around. So imagine we're reading 1 bit starting at row 31. That's all within the "final" word - we just need to shift that bit up by 31 for the output. But let's say we want to read 2 bits starting at row 31. The first bit comes from the word we just read, but the 2nd bit comes from the next word (idx + 1) - so we have to read it.

The initial bug here was that we were doing this when we shouldn't have been.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies in advance for the drive-by review comment. Is there any reason not to define this variable? I see multiple magic numbers throughout this code (lots of 4s and 32s) and it seems like a constexpr auto uint32_size = sizeof(uint32_t); would help avoid questions like this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drive-by review comment

😂

I think it would be good, if not a bit out of scope of the PR. Up to the author IMO.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think something like this definitely adds clarity.

constexpr size_type rows_per_element = 32;
auto const have_trailing_bits = ((num_elements * rows_per_element) - num_rows) < bit_shift;

But I also think this makes more sense as is, since this reads as pretty standard bit-shifting stuff.
uint32_t const val = (v >> bit_shift) | (next << (32 - bit_shift));

Changing it to this would obfuscate I think.
uint32_t const val = (v >> bit_shift) | (next << (rows_per_bitmask - bit_shift));

uint32_t const next = (read_trailing_bits || remainder > 4)
? (reinterpret_cast<uint32_t const*>(src)[idx + 1] - value_shift)
: 0;

uint32_t const val = (v >> bit_shift) | (next << (32 - bit_shift));
if (valid_count) { thread_valid_count += __popc(val); }
reinterpret_cast<uint32_t*>(dst)[idx] = val;
v = next;
Expand Down
15 changes: 15 additions & 0 deletions cpp/tests/copying/split_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1315,6 +1315,21 @@ TEST_F(ContiguousSplitUntypedTest, ProgressiveSizes)
}
}

TEST_F(ContiguousSplitUntypedTest, ValidityEdgeCase)
{
// tests an edge case where the splits cause the final validity data to be copied
// to be < 32 full bits, making sure we don't unintentionally read past the end of the input
auto col = cudf::make_numeric_column(
cudf::data_type{cudf::type_id::INT32}, 512, cudf::mask_state::ALL_VALID);
auto result = cudf::contiguous_split(cudf::table_view{{*col}}, {510});
auto expected = cudf::split(cudf::table_view{{*col}}, {510});

EXPECT_EQ(expected.size(), result.size());
for (unsigned long index = 0; index < result.size(); index++) {
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected[index].column(0), result[index].table.column(0));
}
}

// contiguous split with strings
struct ContiguousSplitStringTableTest : public SplitTest<std::string> {
};
Expand Down