Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix an out-of-bounds read in validity copying in contiguous_split. #9842

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions cpp/src/copying/contiguous_split.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ struct src_buf_info {
int _offset_stack_pos,
int _parent_offsets_index,
bool _is_validity,
int _column_offset)
size_type _column_offset)
robertmaynard marked this conversation as resolved.
Show resolved Hide resolved
: type(_type),
offsets(_offsets),
offset_stack_pos(_offset_stack_pos),
Expand Down Expand Up @@ -195,10 +195,15 @@ __device__ void copy_buffer(uint8_t* __restrict__ dst,
std::size_t idx = (num_bytes - remainder) / 4;
uint32_t v = remainder > 0 ? (reinterpret_cast<uint32_t const*>(src)[idx] - value_shift) : 0;
while (remainder) {
uint32_t const next = bit_shift > 0 || remainder > 4
// if we're doing a validity copy, do we need to read an extra bitmask word to OR it's
// relevant bits in?
nvdbaranec marked this conversation as resolved.
Show resolved Hide resolved
auto const have_extra_rows =
bit_shift > 0 && remainder == 4 ? (num_elements * 32) - num_rows < bit_shift : false;
nvdbaranec marked this conversation as resolved.
Show resolved Hide resolved
uint32_t const next = (have_extra_rows || remainder > 4)
? (reinterpret_cast<uint32_t const*>(src)[idx + 1] - value_shift)
: 0;
uint32_t const val = (v >> bit_shift) | (next << (32 - bit_shift));

uint32_t const val = (v >> bit_shift) | (next << (32 - bit_shift));
if (valid_count) { thread_valid_count += __popc(val); }
reinterpret_cast<uint32_t*>(dst)[idx] = val;
v = next;
Expand Down
15 changes: 15 additions & 0 deletions cpp/tests/copying/split_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1315,6 +1315,21 @@ TEST_F(ContiguousSplitUntypedTest, ProgressiveSizes)
}
}

TEST_F(ContiguousSplitUntypedTest, ValidityEdgeCase)
{
// tests an edge case where the splits cause the final validity data to be copied
// to be < 32 full bits, making sure we don't unintentionally read past the end of the input
auto col = cudf::make_numeric_column(
cudf::data_type{cudf::type_id::INT32}, 512, cudf::mask_state::ALL_VALID);
auto result = cudf::contiguous_split(cudf::table_view{{*col}}, {510});
auto expected = cudf::split(cudf::table_view{{*col}}, {510});

EXPECT_EQ(expected.size(), result.size());
for (unsigned long index = 0; index < result.size(); index++) {
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected[index].column(0), result[index].table.column(0));
}
}

// contiguous split with strings
struct ContiguousSplitStringTableTest : public SplitTest<std::string> {
};
Expand Down