Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-10561: [Rust] Simplified Buffer's write and write_bytes and fixed undefined behavior #8645

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions rust/arrow/src/array/array_primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use std::any::Any;
use std::borrow::Borrow;
use std::convert::From;
use std::fmt;
use std::io::Write;
use std::iter::{FromIterator, IntoIterator};
use std::mem;
use std::sync::Arc;
Expand Down Expand Up @@ -309,9 +308,9 @@ impl<T: ArrowPrimitiveType, Ptr: Borrow<Option<<T as ArrowPrimitiveType>::Native
iter.enumerate().for_each(|(i, item)| {
if let Some(a) = item.borrow() {
bit_util::set_bit(null_slice, i);
val_buf.write_all(a.to_byte_slice()).unwrap();
val_buf.extend_from_slice(a.to_byte_slice());
} else {
val_buf.write_all(&null).unwrap();
val_buf.extend_from_slice(&null);
}
});

Expand Down Expand Up @@ -406,11 +405,9 @@ impl<T: ArrowTimestampType> PrimitiveArray<T> {
for (i, v) in data.iter().enumerate() {
if let Some(n) = v {
bit_util::set_bit(null_slice, i);
// unwrap() in the following should be safe here since we've
// made sure enough space is allocated for the values.
val_buf.write_all(&n.to_byte_slice()).unwrap();
val_buf.extend_from_slice(&n.to_byte_slice());
} else {
val_buf.write_all(&null).unwrap();
val_buf.extend_from_slice(&null);
}
}
}
Expand Down
44 changes: 13 additions & 31 deletions rust/arrow/src/array/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
use std::any::Any;
use std::collections::HashMap;
use std::fmt;
use std::io::Write;
use std::marker::PhantomData;
use std::mem;
use std::{convert::TryInto, sync::Arc};
Expand Down Expand Up @@ -325,7 +324,7 @@ impl<T: ArrowPrimitiveType> BufferBuilderTrait<T> for BufferBuilder<T> {
}
self.len += 1;
} else {
self.write_bytes(v.to_byte_slice(), 1)?;
self.write_bytes(v.to_byte_slice(), 1);
}
Ok(())
}
Expand All @@ -335,18 +334,18 @@ impl<T: ArrowPrimitiveType> BufferBuilderTrait<T> for BufferBuilder<T> {
self.reserve(n);
if T::DATA_TYPE == DataType::Boolean {
if n != 0 && v != T::default_value() {
unsafe {
bit_util::set_bits_raw(
let data = unsafe {
std::slice::from_raw_parts_mut(
self.buffer.raw_data_mut(),
self.len,
self.len + n,
self.buffer.capacity(),
)
}
};
(self.len..self.len + n).for_each(|i| bit_util::set_bit(data, i))
}
self.len += n;
} else {
for _ in 0..n {
self.write_bytes(v.to_byte_slice(), 1)?;
self.write_bytes(v.to_byte_slice(), 1);
}
}
Ok(())
Expand All @@ -371,7 +370,7 @@ impl<T: ArrowPrimitiveType> BufferBuilderTrait<T> for BufferBuilder<T> {
}
Ok(())
} else {
self.write_bytes(slice.to_byte_slice(), array_slots)
Ok(self.write_bytes(slice.to_byte_slice(), array_slots))
}
}

Expand All @@ -397,18 +396,9 @@ impl<T: ArrowPrimitiveType> BufferBuilder<T> {
/// Writes a byte slice to the underlying buffer and updates the `len`, i.e. the
/// number array elements in the builder. Also, converts the `io::Result`
/// required by the `Write` trait to the Arrow `Result` type.
fn write_bytes(&mut self, bytes: &[u8], len_added: usize) -> Result<()> {
let write_result = self.buffer.write(bytes);
// `io::Result` has many options one of which we use, so pattern matching is
// overkill here
if write_result.is_err() {
Err(ArrowError::MemoryError(
"Could not write to Buffer, not big enough".to_string(),
))
} else {
self.len += len_added;
Ok(())
}
fn write_bytes(&mut self, bytes: &[u8], len_added: usize) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jorgecarleitao My guess would be that the issue is related to this len_added parameter. In the filter kernel this was used for additional padding, most other users probably interpreted this as the length of the bytes array. I would suggest removing this parameter, since you already implemented a workaround in the filter kernel.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @jhorstmann here that we should remove len_added (maybe as another PR) -- the length that is actually added is the bytes.len() rather than len_added so having the caller have to provide both leaves the opportunity for additional latent bugs

self.buffer.extend_from_slice(bytes);
self.len += len_added;
}
}

Expand Down Expand Up @@ -525,7 +515,7 @@ impl<T: ArrowPrimitiveType> ArrayBuilder for PrimitiveBuilder<T> {
let sliced = array.buffers()[0].data();
// slice into data by factoring (offset and length) * byte width
self.values_builder
.write_bytes(&sliced[(offset * mul)..((len + offset) * mul)], len)?;
.write_bytes(&sliced[(offset * mul)..((len + offset) * mul)], len);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't reproduce the issue yet, but this line looks a bit suspicious. The first parameter has a larger len (in bytes) than the second len parameter indicates (number of T elements).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice 👁️ -- I agree that removing the len parameter entirely would be the best course of action here

}

for i in 0..len {
Expand Down Expand Up @@ -2600,21 +2590,13 @@ mod tests {
fn test_write_bytes_i32() {
let mut b = Int32BufferBuilder::new(4);
let bytes = [8, 16, 32, 64].to_byte_slice();
b.write_bytes(bytes, 4).unwrap();
b.write_bytes(bytes, 4);
assert_eq!(4, b.len());
assert_eq!(16, b.capacity());
let buffer = b.finish();
assert_eq!(16, buffer.len());
}

#[test]
#[should_panic(expected = "Could not write to Buffer, not big enough")]
fn test_write_too_many_bytes() {
let mut b = Int32BufferBuilder::new(0);
let bytes = [8, 16, 32, 64].to_byte_slice();
b.write_bytes(bytes, 4).unwrap();
}

#[test]
fn test_boolean_array_builder_append_slice() {
let arr1 =
Expand Down
2 changes: 2 additions & 0 deletions rust/arrow/src/array/equal_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,8 @@ mod tests {
"#,
)
.unwrap();
println!("{:?}", arrow_array);
println!("{:?}", json_array);
assert!(arrow_array.eq(&json_array));
assert!(json_array.eq(&arrow_array));

Expand Down
80 changes: 20 additions & 60 deletions rust/arrow/src/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use packed_simd::u8x64;
use std::cmp;
use std::convert::AsRef;
use std::fmt::{Debug, Formatter};
use std::io::{Error as IoError, ErrorKind, Result as IoResult, Write};
use std::mem;
use std::ops::{BitAnd, BitOr, Not};
use std::ptr::NonNull;
Expand Down Expand Up @@ -416,9 +415,7 @@ where
let rem = op(left_chunks.remainder_bits(), right_chunks.remainder_bits());
// we are counting its starting from the least significant bit, to to_le_bytes should be correct
let rem = &rem.to_le_bytes()[0..remainder_bytes];
result
.write_all(rem)
.expect("not enough capacity in buffer");
result.extend_from_slice(rem);

result.freeze()
}
Expand Down Expand Up @@ -451,9 +448,7 @@ where
let rem = op(left_chunks.remainder_bits());
// we are counting its starting from the least significant bit, to to_le_bytes should be correct
let rem = &rem.to_le_bytes()[0..remainder_bytes];
result
.write_all(rem)
.expect("not enough capacity in buffer");
result.extend_from_slice(rem);

result.freeze()
}
Expand Down Expand Up @@ -773,21 +768,16 @@ impl MutableBuffer {
}
}

/// Writes a byte slice to the underlying buffer and updates the `len`, i.e. the
/// number array elements in the buffer. Also, converts the `io::Result`
/// required by the `Write` trait to the Arrow `Result` type.
pub fn write_bytes(&mut self, bytes: &[u8], len_added: usize) -> Result<()> {
let write_result = self.write(bytes);
// `io::Result` has many options one of which we use, so pattern matching is
// overkill here
if write_result.is_err() {
Err(ArrowError::IoError(
"Could not write to Buffer, not big enough".to_string(),
))
} else {
self.len += len_added;
Ok(())
/// Extends the buffer from a byte slice, incrementing its capacity if needed.
pub fn extend_from_slice(&mut self, bytes: &[u8]) {
let remaining_capacity = self.capacity - self.len;
if bytes.len() > remaining_capacity {
self.reserve(self.len + bytes.len());
}
unsafe {
memory::memcpy(self.data.add(self.len), bytes.as_ptr(), bytes.len());
}
self.len += bytes.len();
}
}

Expand All @@ -811,24 +801,6 @@ impl PartialEq for MutableBuffer {
}
}

impl Write for MutableBuffer {
fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
let remaining_capacity = self.capacity - self.len;
if buf.len() > remaining_capacity {
return Err(IoError::new(ErrorKind::Other, "Buffer not big enough"));
}
unsafe {
memory::memcpy(self.data.add(self.len), buf.as_ptr(), buf.len());
self.len += buf.len();
Ok(buf.len())
}
}

fn flush(&mut self) -> IoResult<()> {
Ok(())
}
}

unsafe impl Sync for MutableBuffer {}
unsafe impl Send for MutableBuffer {}

Expand All @@ -855,8 +827,7 @@ mod tests {

// Different capacities should still preserve equality
let mut buf2 = MutableBuffer::new(65);
buf2.write_all(&[0, 1, 2, 3, 4])
.expect("write should be OK");
buf2.extend_from_slice(&[0, 1, 2, 3, 4]);

let buf2 = buf2.freeze();
assert_eq!(buf1, buf2);
Expand Down Expand Up @@ -994,33 +965,23 @@ mod tests {
}

#[test]
fn test_mutable_write() {
fn test_mutable_extend_from_slice() {
let mut buf = MutableBuffer::new(100);
buf.write_all(b"hello").expect("Ok");
buf.extend_from_slice(b"hello");
assert_eq!(5, buf.len());
assert_eq!(b"hello", buf.data());

buf.write_all(b" world").expect("Ok");
buf.extend_from_slice(b" world");
assert_eq!(11, buf.len());
assert_eq!(b"hello world", buf.data());

buf.clear();
assert_eq!(0, buf.len());
buf.write_all(b"hello arrow").expect("Ok");
buf.extend_from_slice(b"hello arrow");
assert_eq!(11, buf.len());
assert_eq!(b"hello arrow", buf.data());
}

#[test]
#[should_panic(expected = "Buffer not big enough")]
fn test_mutable_write_overflow() {
let mut buf = MutableBuffer::new(1);
assert_eq!(64, buf.capacity());
for _ in 0..10 {
buf.write_all(&[0, 0, 0, 0, 0, 0, 0, 0]).unwrap();
}
}

#[test]
fn test_mutable_reserve() {
let mut buf = MutableBuffer::new(1);
Expand Down Expand Up @@ -1066,8 +1027,7 @@ mod tests {
#[test]
fn test_mutable_freeze() {
let mut buf = MutableBuffer::new(1);
buf.write_all(b"aaaa bbbb cccc dddd")
.expect("write should be OK");
buf.extend_from_slice(b"aaaa bbbb cccc dddd");
assert_eq!(19, buf.len());
assert_eq!(64, buf.capacity());
assert_eq!(b"aaaa bbbb cccc dddd", buf.data());
Expand All @@ -1083,11 +1043,11 @@ mod tests {
let mut buf = MutableBuffer::new(1);
let mut buf2 = MutableBuffer::new(1);

buf.write_all(&[0xaa])?;
buf2.write_all(&[0xaa, 0xbb])?;
buf.extend_from_slice(&[0xaa]);
buf2.extend_from_slice(&[0xaa, 0xbb]);
assert!(buf != buf2);

buf.write_all(&[0xbb])?;
buf.extend_from_slice(&[0xbb]);
assert_eq!(buf, buf2);

buf2.reserve(65);
Expand Down
10 changes: 4 additions & 6 deletions rust/arrow/src/compute/kernels/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,6 @@ where
T: ArrowNumericType,
F: Fn(T::Simd, T::Simd) -> T::SimdMask,
{
use std::io::Write;
use std::mem;

let len = left.len();
Expand All @@ -283,7 +282,7 @@ where
let simd_right = T::load(right.value_slice(i, lanes));
let simd_result = op(simd_left, simd_right);
T::bitmask(&simd_result, |b| {
result.write(b).unwrap();
result.extend_from_slice(b);
});
}

Expand All @@ -293,7 +292,7 @@ where
let simd_result = op(simd_left, simd_right);
let rem_buffer_size = (rem as f32 / 8f32).ceil() as usize;
T::bitmask(&simd_result, |b| {
result.write(&b[0..rem_buffer_size]).unwrap();
result.extend_from_slice(&b[0..rem_buffer_size]);
});
}

Expand Down Expand Up @@ -321,7 +320,6 @@ where
T: ArrowNumericType,
F: Fn(T::Simd, T::Simd) -> T::SimdMask,
{
use std::io::Write;
use std::mem;

let len = left.len();
Expand All @@ -336,7 +334,7 @@ where
let simd_left = T::load(left.value_slice(i, lanes));
let simd_result = op(simd_left, simd_right);
T::bitmask(&simd_result, |b| {
result.write(b).unwrap();
result.extend_from_slice(b);
});
}

Expand All @@ -345,7 +343,7 @@ where
let simd_result = op(simd_left, simd_right);
let rem_buffer_size = (rem as f32 / 8f32).ceil() as usize;
T::bitmask(&simd_result, |b| {
result.write(&b[0..rem_buffer_size]).unwrap();
result.extend_from_slice(&b[0..rem_buffer_size]);
});
}

Expand Down
3 changes: 2 additions & 1 deletion rust/arrow/src/compute/kernels/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,8 @@ impl FilterContext {
let mut u64_buffer = MutableBuffer::new(filter_bytes.len());
// add to the resulting len so is is a multiple of the size of u64
let pad_addional_len = (8 - filter_bytes.len() % 8) % 8;
u64_buffer.write_bytes(filter_bytes, pad_addional_len)?;
u64_buffer.extend_from_slice(filter_bytes);
u64_buffer.extend_from_slice(&vec![0; pad_addional_len]);
let mut filter_u64 = u64_buffer.typed_data_mut::<u64>().to_owned();

// mask of any bits outside of the given len
Expand Down
Loading