From cada9ba33803a48a3145ab333fe1cf6410999d89 Mon Sep 17 00:00:00 2001 From: Marco Neumann Date: Tue, 10 Jan 2023 13:55:07 +0100 Subject: [PATCH] Fix IPCWriter for Sliced BooleanArray (#3498) * fix: bool IPC Fixes #3496. * refactor: simplify code * refactor: `assert!` -> `assert_eq!` --- arrow-ipc/src/writer.rs | 73 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 72 insertions(+), 1 deletion(-) diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index ed5e53a959c0..d7cc83aabddb 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -1202,7 +1202,7 @@ fn write_array_data( ) { // Truncate values - assert!(array_data.buffers().len() == 1); + assert_eq!(array_data.buffers().len(), 1); let buffer = &array_data.buffers()[0]; let layout = layout(data_type); @@ -1231,6 +1231,14 @@ fn write_array_data( compression_codec, )?; } + } else if matches!(data_type, DataType::Boolean) { + // Bools are special because the payload (= 1 bit) is smaller than the physical container elements (= bytes). + // The array data may not start at the physical boundary of the underlying buffer, so we need to shift bits around. + assert_eq!(array_data.buffers().len(), 1); + + let buffer = &array_data.buffers()[0]; + let buffer = buffer.bit_slice(array_data.offset(), array_data.len()); + offset = write_buffer(&buffer, buffers, arrow_data, offset, compression_codec)?; } else { for buffer in array_data.buffers() { offset = @@ -1312,6 +1320,7 @@ fn pad_to_8(len: u32) -> usize { mod tests { use super::*; + use std::io::Cursor; use std::io::Seek; use std::sync::Arc; @@ -1926,4 +1935,66 @@ mod tests { read_array.iter().collect::>() ); } + + #[test] + fn encode_bools_slice() { + // Test case for https://github.com/apache/arrow-rs/issues/3496 + assert_bool_roundtrip([true, false], 1, 1); + + // slice somewhere in the middle + assert_bool_roundtrip( + [ + true, false, true, true, false, false, true, true, true, false, false, + false, true, true, true, true, false, false, false, false, true, true, + true, true, true, false, false, false, false, false, + ], + 13, + 17, + ); + + // start at byte boundary, end in the middle + assert_bool_roundtrip( + [ + true, false, true, true, false, false, true, true, true, false, false, + false, + ], + 8, + 2, + ); + + // start and stop and byte boundary + assert_bool_roundtrip( + [ + true, false, true, true, false, false, true, true, true, false, false, + false, true, true, true, true, true, false, false, false, false, false, + ], + 8, + 8, + ); + } + + fn assert_bool_roundtrip( + bools: [bool; N], + offset: usize, + length: usize, + ) { + let val_bool_field = Field::new("val", DataType::Boolean, false); + + let schema = Arc::new(Schema::new(vec![val_bool_field])); + + let bools = BooleanArray::from(bools.to_vec()); + + let batch = + RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(bools)]).unwrap(); + let batch = batch.slice(offset, length); + + let mut writer = StreamWriter::try_new(Vec::::new(), &schema).unwrap(); + writer.write(&batch).unwrap(); + writer.finish().unwrap(); + let data = writer.into_inner().unwrap(); + + let mut reader = StreamReader::try_new(Cursor::new(data), None).unwrap(); + let batch2 = reader.next().unwrap().unwrap(); + assert_eq!(batch, batch2); + } }