From d94018d963596ef1d12d63b84cf7d9f18a4dcb30 Mon Sep 17 00:00:00 2001 From: Marco Neumann Date: Mon, 9 Jan 2023 13:01:32 +0100 Subject: [PATCH] fix: bool IPC Fixes #3496. --- arrow-ipc/src/writer.rs | 95 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index ed5e53a959c0..74c578e468d2 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -1231,6 +1231,38 @@ fn write_array_data( compression_codec, )?; } + } else if matches!(data_type, DataType::Boolean) { + // Bools are special because the payload (= 1 bit) is smaller than the physical container elements (= bytes). + // The array data may not start at the physical boundary of the underlying buffer, so we need to shift bits around. + assert!(array_data.buffers().len() == 1); + let buffer = &array_data.buffers()[0]; + + let byte_offset = array_data.offset() / 8; + let byte_len = array_data.len() / 8 + usize::from(array_data.len() % 8 != 0); + let bslice = &buffer.as_slice()[byte_offset..(byte_offset + byte_len)]; + let bit_offset = array_data.offset() % 8; + let bslice_shifted = if bit_offset != 0 { + let mut bslice_shifted = Vec::::with_capacity(bslice.len()); + let mut prev = 0u8; + for (i, b) in bslice.iter().enumerate() { + // first byte is special: we throw away `bit_offset` bits + if i > 0 { + // use `bit_offset` bits from the previous element and `8 - bit_offset` bits from the current one + bslice_shifted.push((b << (8 - bit_offset)) | prev); + } + + // throw `bit_offset` lower bits and keep the rest + prev = b >> bit_offset; + } + bslice_shifted.push(prev); + Some(bslice_shifted) + } else { + // no need to copy the data + None + }; + let bslice = bslice_shifted.as_deref().unwrap_or(bslice); + + offset = write_buffer(bslice, buffers, arrow_data, offset, compression_codec)?; } else { for buffer in array_data.buffers() { offset = @@ -1312,6 +1344,7 @@ fn pad_to_8(len: u32) -> usize { mod tests { use super::*; + use std::io::Cursor; use std::io::Seek; use std::sync::Arc; @@ -1926,4 +1959,66 @@ mod tests { read_array.iter().collect::>() ); } + + #[test] + fn encode_bools_slice() { + // Test case for https://github.com/apache/arrow-rs/issues/3496 + assert_bool_roundtrip([true, false], 1, 1); + + // slice somewhere in the middle + assert_bool_roundtrip( + [ + true, false, true, true, false, false, true, true, true, false, false, + false, true, true, true, true, false, false, false, false, true, true, + true, true, true, false, false, false, false, false, + ], + 13, + 17, + ); + + // start at byte boundary, end in the middle + assert_bool_roundtrip( + [ + true, false, true, true, false, false, true, true, true, false, false, + false, + ], + 8, + 2, + ); + + // start and stop and byte boundary + assert_bool_roundtrip( + [ + true, false, true, true, false, false, true, true, true, false, false, + false, true, true, true, true, true, false, false, false, false, false, + ], + 8, + 8, + ); + } + + fn assert_bool_roundtrip( + bools: [bool; N], + offset: usize, + length: usize, + ) { + let val_bool_field = Field::new("val", DataType::Boolean, false); + + let schema = Arc::new(Schema::new(vec![val_bool_field])); + + let bools = BooleanArray::from(bools.to_vec()); + + let batch = + RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(bools)]).unwrap(); + let batch = batch.slice(offset, length); + + let mut writer = StreamWriter::try_new(Vec::::new(), &schema).unwrap(); + writer.write(&batch).unwrap(); + writer.finish().unwrap(); + let data = writer.into_inner().unwrap(); + + let mut reader = StreamReader::try_new(Cursor::new(data), None).unwrap(); + let batch2 = reader.next().unwrap().unwrap(); + assert_eq!(batch, batch2); + } }