diff --git a/arrow-buffer/src/buffer/immutable.rs b/arrow-buffer/src/buffer/immutable.rs index 8d1a46583fca..7254313b7da5 100644 --- a/arrow-buffer/src/buffer/immutable.rs +++ b/arrow-buffer/src/buffer/immutable.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use crate::alloc::{Allocation, Deallocation, ALIGNMENT}; use crate::util::bit_chunk_iterator::{BitChunks, UnalignedBitChunk}; use crate::BufferBuilder; -use crate::{bytes::Bytes, native::ArrowNativeType}; +use crate::{bit_util, bytes::Bytes, native::ArrowNativeType}; use super::ops::bitwise_unary_op_helper; use super::{MutableBuffer, ScalarBuffer}; @@ -265,7 +265,7 @@ impl Buffer { /// otherwise a new buffer is allocated and filled with a copy of the bits in the range. pub fn bit_slice(&self, offset: usize, len: usize) -> Self { if offset % 8 == 0 { - return self.slice(offset / 8); + return self.slice_with_length(offset / 8, bit_util::ceil(len, 8)); } bitwise_unary_op_helper(self, offset, len, |a| a) @@ -860,4 +860,37 @@ mod tests { let iter_len = usize::MAX / std::mem::size_of::() + 1; let _ = Buffer::from_iter(std::iter::repeat(0_u64).take(iter_len)); } + + #[test] + fn bit_slice_length_preserved() { + // Create a boring buffer + let buf = Buffer::from_iter(std::iter::repeat(true).take(64)); + + let assert_preserved = |offset: usize, len: usize| { + let new_buf = buf.bit_slice(offset, len); + assert_eq!(new_buf.len(), bit_util::ceil(len, 8)); + + // if the offset is not byte-aligned, we have to create a deep copy to a new buffer + // (since the `offset` value inside a Buffer is byte-granular, not bit-granular), so + // checking the offset should always return 0 if so. If the offset IS byte-aligned, we + // want to make sure it doesn't unnecessarily create a deep copy. + if offset % 8 == 0 { + assert_eq!(new_buf.ptr_offset(), offset / 8); + } else { + assert_eq!(new_buf.ptr_offset(), 0); + } + }; + + // go through every available value for offset + for o in 0..=64 { + // and go through every length that could accompany that offset - we can't have a + // situation where offset + len > 64, because that would go past the end of the buffer, + // so we use the map to ensure it's in range. + for l in (o..=64).map(|l| l - o) { + // and we just want to make sure every one of these keeps its offset and length + // when neeeded + assert_preserved(o, l); + } + } + } } diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs index ae3475c7c7d7..e65295e8750f 100644 --- a/arrow-flight/src/encode.rs +++ b/arrow-flight/src/encode.rs @@ -1741,7 +1741,7 @@ mod tests { let batch = RecordBatch::try_from_iter(vec![("a1", Arc::new(array) as _)]).unwrap(); - verify_encoded_split(batch, 160).await; + verify_encoded_split(batch, 48).await; } #[tokio::test]