Skip to content

Commit

Permalink
fix: bool IPC
Browse files Browse the repository at this point in the history
  • Loading branch information
crepererum committed Jan 9, 2023
1 parent eae993f commit d94018d
Showing 1 changed file with 95 additions and 0 deletions.
95 changes: 95 additions & 0 deletions arrow-ipc/src/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1231,6 +1231,38 @@ fn write_array_data(
compression_codec,
)?;
}
} else if matches!(data_type, DataType::Boolean) {
// Bools are special because the payload (= 1 bit) is smaller than the physical container elements (= bytes).
// The array data may not start at the physical boundary of the underlying buffer, so we need to shift bits around.
assert!(array_data.buffers().len() == 1);
let buffer = &array_data.buffers()[0];

let byte_offset = array_data.offset() / 8;
let byte_len = array_data.len() / 8 + usize::from(array_data.len() % 8 != 0);
let bslice = &buffer.as_slice()[byte_offset..(byte_offset + byte_len)];
let bit_offset = array_data.offset() % 8;
let bslice_shifted = if bit_offset != 0 {
let mut bslice_shifted = Vec::<u8>::with_capacity(bslice.len());
let mut prev = 0u8;
for (i, b) in bslice.iter().enumerate() {
// first byte is special: we throw away `bit_offset` bits
if i > 0 {
// use `bit_offset` bits from the previous element and `8 - bit_offset` bits from the current one
bslice_shifted.push((b << (8 - bit_offset)) | prev);
}

// throw `bit_offset` lower bits and keep the rest
prev = b >> bit_offset;
}
bslice_shifted.push(prev);
Some(bslice_shifted)
} else {
// no need to copy the data
None
};
let bslice = bslice_shifted.as_deref().unwrap_or(bslice);

offset = write_buffer(bslice, buffers, arrow_data, offset, compression_codec)?;
} else {
for buffer in array_data.buffers() {
offset =
Expand Down Expand Up @@ -1312,6 +1344,7 @@ fn pad_to_8(len: u32) -> usize {
mod tests {
use super::*;

use std::io::Cursor;
use std::io::Seek;
use std::sync::Arc;

Expand Down Expand Up @@ -1926,4 +1959,66 @@ mod tests {
read_array.iter().collect::<Vec<_>>()
);
}

#[test]
fn encode_bools_slice() {
// Test case for https://github.com/apache/arrow-rs/issues/3496
assert_bool_roundtrip([true, false], 1, 1);

// slice somewhere in the middle
assert_bool_roundtrip(
[
true, false, true, true, false, false, true, true, true, false, false,
false, true, true, true, true, false, false, false, false, true, true,
true, true, true, false, false, false, false, false,
],
13,
17,
);

// start at byte boundary, end in the middle
assert_bool_roundtrip(
[
true, false, true, true, false, false, true, true, true, false, false,
false,
],
8,
2,
);

// start and stop and byte boundary
assert_bool_roundtrip(
[
true, false, true, true, false, false, true, true, true, false, false,
false, true, true, true, true, true, false, false, false, false, false,
],
8,
8,
);
}

fn assert_bool_roundtrip<const N: usize>(
bools: [bool; N],
offset: usize,
length: usize,
) {
let val_bool_field = Field::new("val", DataType::Boolean, false);

let schema = Arc::new(Schema::new(vec![val_bool_field]));

let bools = BooleanArray::from(bools.to_vec());

let batch =
RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(bools)]).unwrap();
let batch = batch.slice(offset, length);

let mut writer = StreamWriter::try_new(Vec::<u8>::new(), &schema).unwrap();
writer.write(&batch).unwrap();
writer.finish().unwrap();
let data = writer.into_inner().unwrap();

let mut reader = StreamReader::try_new(Cursor::new(data), None).unwrap();
let batch2 = reader.next().unwrap().unwrap();
assert_eq!(batch, batch2);
}
}

0 comments on commit d94018d

Please sign in to comment.