Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix BitReader::get_batch zero extension (#1708) #1722

Merged
merged 2 commits into from
May 23, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 43 additions & 33 deletions parquet/src/util/bit_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -568,40 +568,35 @@ impl BitReader {
}
}

unsafe {
let in_buf = &self.buffer.data()[self.byte_offset..];
let mut in_ptr = in_buf as *const [u8] as *const u8 as *const u32;
if size_of::<T>() == 4 {
while values_to_read - i >= 32 {
let out_ptr = &mut batch[i..] as *mut [T] as *mut T as *mut u32;
in_ptr = unpack32(in_ptr, out_ptr, num_bits);
self.byte_offset += 4 * num_bits;
i += 32;
}
} else {
let mut out_buf = [0u32; 32];
let out_ptr = &mut out_buf as &mut [u32] as *mut [u32] as *mut u32;
while values_to_read - i >= 32 {
in_ptr = unpack32(in_ptr, out_ptr, num_bits);
self.byte_offset += 4 * num_bits;
for n in 0..32 {
// We need to copy from smaller size to bigger size to avoid
// overwriting other memory regions.
if size_of::<T>() > size_of::<u32>() {
std::ptr::copy_nonoverlapping(
out_buf[n..].as_ptr() as *const u32,
&mut batch[i] as *mut T as *mut u32,
1,
);
} else {
std::ptr::copy_nonoverlapping(
out_buf[n..].as_ptr() as *const T,
&mut batch[i] as *mut T,
1,
);
}
i += 1;
let in_buf = &self.buffer.data()[self.byte_offset..];
let mut in_ptr = in_buf as *const [u8] as *const u8 as *const u32;
if size_of::<T>() == 4 {
while values_to_read - i >= 32 {
let out_ptr = &mut batch[i..] as *mut [T] as *mut T as *mut u32;
in_ptr = unsafe { unpack32(in_ptr, out_ptr, num_bits) };
self.byte_offset += 4 * num_bits;
i += 32;
}
} else {
let mut out_buf = [0u32; 32];
let out_ptr = &mut out_buf as &mut [u32] as *mut [u32] as *mut u32;
while values_to_read - i >= 32 {
in_ptr = unsafe { unpack32(in_ptr, out_ptr, num_bits) };
self.byte_offset += 4 * num_bits;

for out in out_buf {
// Zero-allocate buffer
let mut out_bytes = T::Buffer::default();
let in_bytes = out.to_le_bytes();

{
let out_bytes = out_bytes.as_mut();
let len = out_bytes.len().min(in_bytes.len());
(&mut out_bytes[..len]).copy_from_slice(&in_bytes[..len]);
}

batch[i] = T::from_le_bytes(out_bytes);
i += 1;
}
}
}
Expand Down Expand Up @@ -1193,4 +1188,19 @@ mod tests {
);
});
}

#[test]
fn test_get_batch_zero_extend() {
let to_read = vec![0xFF; 4];
let mut reader = BitReader::new(ByteBufferPtr::new(to_read));

// Create a non-zeroed output buffer
let mut output = [u64::MAX; 32];
reader.get_batch(&mut output, 1);

for v in output {
// Values should be read correctly
assert_eq!(v, 1);
}
}
}