Skip to content

Commit

Permalink
Fix IPCWriter for Sliced BooleanArray (#3498)
Browse files Browse the repository at this point in the history
* fix: bool IPC

Fixes #3496.

* refactor: simplify code

* refactor: `assert!` -> `assert_eq!`
  • Loading branch information
crepererum authored Jan 10, 2023
1 parent fb36dd9 commit cada9ba
Showing 1 changed file with 72 additions and 1 deletion.
73 changes: 72 additions & 1 deletion arrow-ipc/src/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1202,7 +1202,7 @@ fn write_array_data(
)
{
// Truncate values
assert!(array_data.buffers().len() == 1);
assert_eq!(array_data.buffers().len(), 1);

let buffer = &array_data.buffers()[0];
let layout = layout(data_type);
Expand Down Expand Up @@ -1231,6 +1231,14 @@ fn write_array_data(
compression_codec,
)?;
}
} else if matches!(data_type, DataType::Boolean) {
// Bools are special because the payload (= 1 bit) is smaller than the physical container elements (= bytes).
// The array data may not start at the physical boundary of the underlying buffer, so we need to shift bits around.
assert_eq!(array_data.buffers().len(), 1);

let buffer = &array_data.buffers()[0];
let buffer = buffer.bit_slice(array_data.offset(), array_data.len());
offset = write_buffer(&buffer, buffers, arrow_data, offset, compression_codec)?;
} else {
for buffer in array_data.buffers() {
offset =
Expand Down Expand Up @@ -1312,6 +1320,7 @@ fn pad_to_8(len: u32) -> usize {
mod tests {
use super::*;

use std::io::Cursor;
use std::io::Seek;
use std::sync::Arc;

Expand Down Expand Up @@ -1926,4 +1935,66 @@ mod tests {
read_array.iter().collect::<Vec<_>>()
);
}

#[test]
fn encode_bools_slice() {
// Test case for https://github.com/apache/arrow-rs/issues/3496
assert_bool_roundtrip([true, false], 1, 1);

// slice somewhere in the middle
assert_bool_roundtrip(
[
true, false, true, true, false, false, true, true, true, false, false,
false, true, true, true, true, false, false, false, false, true, true,
true, true, true, false, false, false, false, false,
],
13,
17,
);

// start at byte boundary, end in the middle
assert_bool_roundtrip(
[
true, false, true, true, false, false, true, true, true, false, false,
false,
],
8,
2,
);

// start and stop and byte boundary
assert_bool_roundtrip(
[
true, false, true, true, false, false, true, true, true, false, false,
false, true, true, true, true, true, false, false, false, false, false,
],
8,
8,
);
}

fn assert_bool_roundtrip<const N: usize>(
bools: [bool; N],
offset: usize,
length: usize,
) {
let val_bool_field = Field::new("val", DataType::Boolean, false);

let schema = Arc::new(Schema::new(vec![val_bool_field]));

let bools = BooleanArray::from(bools.to_vec());

let batch =
RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(bools)]).unwrap();
let batch = batch.slice(offset, length);

let mut writer = StreamWriter::try_new(Vec::<u8>::new(), &schema).unwrap();
writer.write(&batch).unwrap();
writer.finish().unwrap();
let data = writer.into_inner().unwrap();

let mut reader = StreamReader::try_new(Cursor::new(data), None).unwrap();
let batch2 = reader.next().unwrap().unwrap();
assert_eq!(batch, batch2);
}
}

0 comments on commit cada9ba

Please sign in to comment.