diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs index 5cf208c8eb5c..53224fd28998 100644 --- a/arrow-ipc/src/reader.rs +++ b/arrow-ipc/src/reader.rs @@ -32,13 +32,33 @@ use std::sync::Arc; use arrow_array::*; use arrow_buffer::{ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, ScalarBuffer}; -use arrow_data::ArrayData; +use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::*; use crate::compression::CompressionCodec; use crate::{Block, FieldNode, Message, MetadataVersion, CONTINUATION_MARKER}; use DataType::*; +/// Build an array from the builder, optionally skipping validations. +/// +/// # Safety +/// If `skip_validations` is true, the function will build an `ArrayData` without performing the +/// usual validations. This can lead to undefined behavior if the data is not correctly formatted. +/// +/// Set `skip_validations` to true only if you are certain. +unsafe fn build_array_internal( + builder: ArrayDataBuilder, + require_alignment: bool, + skip_validation: bool, +) -> Result { + unsafe { + builder + .align_buffers(require_alignment) + .skip_validation(skip_validation) + .build() + } +} + /// Read a buffer based on offset and length /// From /// Each constituent buffer is first compressed with the indicated @@ -62,6 +82,345 @@ fn read_buffer( } } +struct ArrayCreator { + require_alignment: bool, + skip_validation: bool, +} + +impl ArrayCreator { + fn new() -> Self { + Self { + require_alignment: true, + skip_validation: false, + } + } + + fn require_alignment(mut self, require_alignment: bool) -> Self { + self.require_alignment = require_alignment; + self + } + + unsafe fn skip_validation(mut self, skip_validation: bool) -> Self { + self.skip_validation = skip_validation; + self + } + + fn create( + &self, + reader: &mut ArrayReader, + field: &Field, + variadic_counts: &mut VecDeque, + ) -> Result { + self.create_array(reader, field, variadic_counts) + } +} + +impl ArrayCreator { + /// Coordinates reading arrays based on data types. + /// + /// `variadic_counts` encodes the number of buffers to read for variadic types (e.g., Utf8View, BinaryView) + /// When encounter such types, we pop from the front of the queue to get the number of buffers to read. + /// + /// Notes: + /// * In the IPC format, null buffers are always set, but may be empty. We discard them if an array has 0 nulls + /// * Numeric values inside list arrays are often stored as 64-bit values regardless of their data type size. + /// We thus: + /// - check if the bit width of non-64-bit numbers is 64, and + /// - read the buffer as 64-bit (signed integer or float), and + /// - cast the 64-bit array to the appropriate data type + fn create_array( + &self, + reader: &mut ArrayReader, + field: &Field, + variadic_counts: &mut VecDeque, + ) -> Result { + let data_type = field.data_type(); + + match data_type { + Utf8 | Binary | LargeBinary | LargeUtf8 => self.create_primitive_array( + reader.next_node(field)?, + data_type, + &[ + reader.next_buffer()?, + reader.next_buffer()?, + reader.next_buffer()?, + ], + ), + BinaryView | Utf8View => { + let count = variadic_counts + .pop_front() + .ok_or(ArrowError::IpcError(format!( + "Missing variadic count for {data_type} column" + )))?; + let count = count + 2; // view and null buffer. + let buffers = (0..count) + .map(|_| reader.next_buffer()) + .collect::, _>>()?; + self.create_primitive_array(reader.next_node(field)?, data_type, &buffers) + } + FixedSizeBinary(_) => self.create_primitive_array( + reader.next_node(field)?, + data_type, + &[reader.next_buffer()?, reader.next_buffer()?], + ), + List(ref list_field) | LargeList(ref list_field) | Map(ref list_field, _) => { + let list_node = reader.next_node(field)?; + let list_buffers = [reader.next_buffer()?, reader.next_buffer()?]; + let values = self.create_array(reader, list_field, variadic_counts)?; + self.create_list_array(list_node, data_type, &list_buffers, values) + } + FixedSizeList(ref list_field, _) => { + let list_node = reader.next_node(field)?; + let list_buffers = [reader.next_buffer()?]; + let values = self.create_array(reader, list_field, variadic_counts)?; + self.create_list_array(list_node, data_type, &list_buffers, values) + } + Struct(struct_fields) => { + let struct_node = reader.next_node(field)?; + let null_buffer = reader.next_buffer()?; + + // read the arrays for each field + let mut struct_arrays = vec![]; + // TODO investigate whether just knowing the number of buffers could + // still work + for struct_field in struct_fields { + let child = self.create_array(reader, struct_field, variadic_counts)?; + struct_arrays.push(child); + } + let null_count = struct_node.null_count() as usize; + let struct_array = if struct_arrays.is_empty() { + // `StructArray::from` can't infer the correct row count + // if we have zero fields + let len = struct_node.length() as usize; + StructArray::new_empty_fields( + len, + (null_count > 0).then(|| BooleanBuffer::new(null_buffer, 0, len).into()), + ) + } else if null_count > 0 { + // create struct array from fields, arrays and null data + let len = struct_node.length() as usize; + let nulls = BooleanBuffer::new(null_buffer, 0, len).into(); + StructArray::try_new(struct_fields.clone(), struct_arrays, Some(nulls))? + } else { + StructArray::try_new(struct_fields.clone(), struct_arrays, None)? + }; + Ok(Arc::new(struct_array)) + } + RunEndEncoded(run_ends_field, values_field) => { + let run_node = reader.next_node(field)?; + let run_ends = self.create_array(reader, run_ends_field, variadic_counts)?; + let values = self.create_array(reader, values_field, variadic_counts)?; + + let run_array_length = run_node.length() as usize; + let builder = ArrayData::builder(data_type.clone()) + .len(run_array_length) + .offset(0) + .add_child_data(run_ends.into_data()) + .add_child_data(values.into_data()); + + let array_data = unsafe { + build_array_internal(builder, self.require_alignment, self.skip_validation)? + }; + + Ok(make_array(array_data)) + } + // Create dictionary array from RecordBatch + Dictionary(_, _) => { + let index_node = reader.next_node(field)?; + let index_buffers = [reader.next_buffer()?, reader.next_buffer()?]; + + #[allow(deprecated)] + let dict_id = field.dict_id().ok_or_else(|| { + ArrowError::ParseError(format!("Field {field} does not have dict id")) + })?; + + let value_array = reader.dictionaries_by_id.get(&dict_id).ok_or_else(|| { + ArrowError::ParseError(format!( + "Cannot find a dictionary batch with dict id: {dict_id}" + )) + })?; + + self.create_dictionary_array( + index_node, + data_type, + &index_buffers, + value_array.clone(), + ) + } + Union(fields, mode) => { + let union_node = reader.next_node(field)?; + let len = union_node.length() as usize; + + // In V4, union types has validity bitmap + // In V5 and later, union types have no validity bitmap + if reader.version < MetadataVersion::V5 { + reader.next_buffer()?; + } + + let type_ids: ScalarBuffer = + reader.next_buffer()?.slice_with_length(0, len).into(); + + let value_offsets = match mode { + UnionMode::Dense => { + let offsets: ScalarBuffer = + reader.next_buffer()?.slice_with_length(0, len * 4).into(); + Some(offsets) + } + UnionMode::Sparse => None, + }; + + let mut children = Vec::with_capacity(fields.len()); + + for (_id, field) in fields.iter() { + let child = self.create_array(reader, field, variadic_counts)?; + children.push(child); + } + + let array = UnionArray::try_new(fields.clone(), type_ids, value_offsets, children)?; + Ok(Arc::new(array)) + } + Null => { + let node = reader.next_node(field)?; + let length = node.length(); + let null_count = node.null_count(); + + if length != null_count { + return Err(ArrowError::SchemaError(format!( + "Field {field} of NullArray has unequal null_count {null_count} and len {length}" + ))); + } + + let builder = ArrayData::builder(data_type.clone()) + .len(length as usize) + .offset(0); + + let array_data = unsafe { + build_array_internal(builder, self.require_alignment, self.skip_validation)? + }; + + // no buffer increases + Ok(Arc::new(NullArray::from(array_data))) + } + _ => self.create_primitive_array( + reader.next_node(field)?, + data_type, + &[reader.next_buffer()?, reader.next_buffer()?], + ), + } + } + + /// `skip_validations` allows the creation of an `ArrayData` without performing the usual + /// validations. This can lead to undefined behavior if the data is not correctly formatted. + fn create_primitive_array( + &self, + field_node: &FieldNode, + data_type: &DataType, + buffers: &[Buffer], + ) -> Result { + let length = field_node.length() as usize; + let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone()); + let builder = match data_type { + Utf8 | Binary | LargeBinary | LargeUtf8 => { + // read 3 buffers: null buffer (optional), offsets buffer and data buffer + ArrayData::builder(data_type.clone()) + .len(length) + .buffers(buffers[1..3].to_vec()) + .null_bit_buffer(null_buffer) + } + BinaryView | Utf8View => ArrayData::builder(data_type.clone()) + .len(length) + .buffers(buffers[1..].to_vec()) + .null_bit_buffer(null_buffer), + _ if data_type.is_primitive() || matches!(data_type, Boolean | FixedSizeBinary(_)) => { + // read 2 buffers: null buffer (optional) and data buffer + ArrayData::builder(data_type.clone()) + .len(length) + .add_buffer(buffers[1].clone()) + .null_bit_buffer(null_buffer) + } + t => unreachable!("Data type {:?} either unsupported or not primitive", t), + }; + + let array_data = + unsafe { build_array_internal(builder, self.require_alignment, self.skip_validation)? }; + + Ok(make_array(array_data)) + } + + /// Reads the correct number of buffers based on list type and null_count, and creates a + /// list array ref + pub fn create_list_array( + &self, + field_node: &FieldNode, + data_type: &DataType, + buffers: &[Buffer], + child_array: ArrayRef, + ) -> Result { + let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone()); + let length = field_node.length() as usize; + let child_data = child_array.into_data(); + let builder = match data_type { + List(_) | LargeList(_) | Map(_, _) => ArrayData::builder(data_type.clone()) + .len(length) + .add_buffer(buffers[1].clone()) + .add_child_data(child_data) + .null_bit_buffer(null_buffer), + + FixedSizeList(_, _) => ArrayData::builder(data_type.clone()) + .len(length) + .add_child_data(child_data) + .null_bit_buffer(null_buffer), + + _ => unreachable!("Cannot create list or map array from {:?}", data_type), + }; + + let array_data = + unsafe { build_array_internal(builder, self.require_alignment, self.skip_validation)? }; + + Ok(make_array(array_data)) + } + + /// Reads the correct number of buffers based on list type and null_count, and creates a + /// dictionary array ref + /// + /// Safety: + /// `skip_validations` allows the creation of an `ArrayData` without performing the + /// usual validations. This can lead to undefined behavior if the data is not + /// correctly formatted. Set `skip_validations` to true only if you are certain. + /// + /// Notes: + /// * If `skip_validations` is true, `require_alignment` is ignored. + fn create_dictionary_array( + &self, + field_node: &FieldNode, + data_type: &DataType, + buffers: &[Buffer], + value_array: ArrayRef, + ) -> Result { + if let Dictionary(_, _) = *data_type { + let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone()); + let builder = ArrayData::builder(data_type.clone()) + .len(field_node.length() as usize) + .add_buffer(buffers[1].clone()) + .add_child_data(value_array.into_data()) + .null_bit_buffer(null_buffer); + + let array_data = unsafe { + build_array_internal(builder, self.require_alignment, self.skip_validation)? + }; + Ok(make_array(array_data)) + } else { + unreachable!("Cannot create dictionary array from {:?}", data_type) + } + } +} + +impl Default for ArrayCreator { + fn default() -> Self { + Self::new() + } +} + /// Coordinates reading arrays based on data types. /// /// `variadic_counts` encodes the number of buffers to read for variadic types (e.g., Utf8View, BinaryView) @@ -81,6 +440,7 @@ fn create_array( require_alignment: bool, ) -> Result { let data_type = field.data_type(); + match data_type { Utf8 | Binary | LargeBinary | LargeUtf8 => create_primitive_array( reader.next_node(field)?, @@ -177,13 +537,13 @@ fn create_array( let values = create_array(reader, values_field, variadic_counts, require_alignment)?; let run_array_length = run_node.length() as usize; - let array_data = ArrayData::builder(data_type.clone()) + let builder = ArrayData::builder(data_type.clone()) .len(run_array_length) .offset(0) .add_child_data(run_ends.into_data()) - .add_child_data(values.into_data()) - .align_buffers(!require_alignment) - .build()?; + .add_child_data(values.into_data()); + + let array_data = unsafe { build_array_internal(builder, require_alignment, false)? }; Ok(make_array(array_data)) } @@ -253,11 +613,11 @@ fn create_array( ))); } - let array_data = ArrayData::builder(data_type.clone()) + let builder = ArrayData::builder(data_type.clone()) .len(length as usize) - .offset(0) - .align_buffers(!require_alignment) - .build()?; + .offset(0); + + let array_data = unsafe { build_array_internal(builder, require_alignment, false)? }; // no buffer increases Ok(Arc::new(NullArray::from(array_data))) @@ -303,7 +663,7 @@ fn create_primitive_array( t => unreachable!("Data type {:?} either unsupported or not primitive", t), }; - let array_data = builder.align_buffers(!require_alignment).build()?; + let array_data = unsafe { build_array_internal(builder, require_alignment, false)? }; Ok(make_array(array_data)) } @@ -335,13 +695,21 @@ fn create_list_array( _ => unreachable!("Cannot create list or map array from {:?}", data_type), }; - let array_data = builder.align_buffers(!require_alignment).build()?; + let array_data = unsafe { build_array_internal(builder, require_alignment, false)? }; Ok(make_array(array_data)) } /// Reads the correct number of buffers based on list type and null_count, and creates a -/// list array ref +/// dictionary array ref +/// +/// Safety: +/// `skip_validations` allows the creation of an `ArrayData` without performing the +/// usual validations. This can lead to undefined behavior if the data is not +/// correctly formatted. Set `skip_validations` to true only if you are certain. +/// +/// Notes: +/// * If `skip_validations` is true, `require_alignment` is ignored. fn create_dictionary_array( field_node: &FieldNode, data_type: &DataType, @@ -351,14 +719,13 @@ fn create_dictionary_array( ) -> Result { if let Dictionary(_, _) = *data_type { let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone()); - let array_data = ArrayData::builder(data_type.clone()) + let builder = ArrayData::builder(data_type.clone()) .len(field_node.length() as usize) .add_buffer(buffers[1].clone()) .add_child_data(value_array.into_data()) - .null_bit_buffer(null_buffer) - .align_buffers(!require_alignment) - .build()?; + .null_bit_buffer(null_buffer); + let array_data = unsafe { build_array_internal(builder, require_alignment, false)? }; Ok(make_array(array_data)) } else { unreachable!("Cannot create dictionary array from {:?}", data_type) @@ -492,6 +859,29 @@ pub fn read_record_batch( dictionaries_by_id: &HashMap, projection: Option<&[usize]>, metadata: &MetadataVersion, +) -> Result { + unsafe { + read_record_batch_impl( + buf, + batch, + schema, + dictionaries_by_id, + projection, + metadata, + false, + false, + ) + } +} + +/// Same as `read_record_batch`, but skips validations. +pub unsafe fn read_record_batch_unchecked( + buf: &Buffer, + batch: crate::RecordBatch, + schema: SchemaRef, + dictionaries_by_id: &HashMap, + projection: Option<&[usize]>, + metadata: &MetadataVersion, ) -> Result { read_record_batch_impl( buf, @@ -501,6 +891,7 @@ pub fn read_record_batch( projection, metadata, false, + true, ) } @@ -513,10 +904,39 @@ pub fn read_dictionary( dictionaries_by_id: &mut HashMap, metadata: &MetadataVersion, ) -> Result<(), ArrowError> { - read_dictionary_impl(buf, batch, schema, dictionaries_by_id, metadata, false) + unsafe { + read_dictionary_impl( + buf, + batch, + schema, + dictionaries_by_id, + metadata, + false, + false, + ) + } +} + +/// Same as `read_dictionary`, but skips validations. +pub unsafe fn read_dictionary_unchecked( + buf: &Buffer, + batch: crate::DictionaryBatch, + schema: &Schema, + dictionaries_by_id: &mut HashMap, + metadata: &MetadataVersion, +) -> Result<(), ArrowError> { + read_dictionary_impl( + buf, + batch, + schema, + dictionaries_by_id, + metadata, + false, + true, + ) } -fn read_record_batch_impl( +unsafe fn read_record_batch_impl( buf: &Buffer, batch: crate::RecordBatch, schema: SchemaRef, @@ -524,6 +944,7 @@ fn read_record_batch_impl( projection: Option<&[usize]>, metadata: &MetadataVersion, require_alignment: bool, + skip_validation: bool, ) -> Result { let buffers = batch.buffers().ok_or_else(|| { ArrowError::IpcError("Unable to get buffers from IPC RecordBatch".to_string()) @@ -557,8 +978,10 @@ fn read_record_batch_impl( for (idx, field) in schema.fields().iter().enumerate() { // Create array for projected field if let Some(proj_idx) = projection.iter().position(|p| p == &idx) { - let child = - create_array(&mut reader, field, &mut variadic_counts, require_alignment)?; + let child = ArrayCreator::new() + .require_alignment(require_alignment) + .skip_validation(skip_validation) + .create(&mut reader, field, &mut variadic_counts)?; arrays.push((proj_idx, child)); } else { reader.skip_field(field, &mut variadic_counts)?; @@ -575,7 +998,10 @@ fn read_record_batch_impl( let mut children = vec![]; // keep track of index as lists require more than one node for field in schema.fields() { - let child = create_array(&mut reader, field, &mut variadic_counts, require_alignment)?; + let child = ArrayCreator::new() + .require_alignment(require_alignment) + .skip_validation(skip_validation) + .create(&mut reader, field, &mut variadic_counts)?; children.push(child); } assert!(variadic_counts.is_empty()); @@ -583,13 +1009,14 @@ fn read_record_batch_impl( } } -fn read_dictionary_impl( +unsafe fn read_dictionary_impl( buf: &Buffer, batch: crate::DictionaryBatch, schema: &Schema, dictionaries_by_id: &mut HashMap, metadata: &MetadataVersion, require_alignment: bool, + skip_validations: bool, ) -> Result<(), ArrowError> { if batch.isDelta() { return Err(ArrowError::InvalidArgumentError( @@ -621,6 +1048,7 @@ fn read_dictionary_impl( None, metadata, require_alignment, + skip_validations, )?; Some(record_batch.column(0).clone()) } @@ -795,8 +1223,12 @@ impl FileDecoder { Ok(message) } - /// Read the dictionary with the given block and data buffer - pub fn read_dictionary(&mut self, block: &Block, buf: &Buffer) -> Result<(), ArrowError> { + unsafe fn read_dictionary_internal( + &mut self, + block: &Block, + buf: &Buffer, + skip_validations: bool, + ) -> Result<(), ArrowError> { let message = self.read_message(buf)?; match message.header_type() { crate::MessageHeader::DictionaryBatch => { @@ -808,6 +1240,7 @@ impl FileDecoder { &mut self.dictionaries, &message.version(), self.require_alignment, + skip_validations, ) } t => Err(ArrowError::ParseError(format!( @@ -816,11 +1249,30 @@ impl FileDecoder { } } - /// Read the RecordBatch with the given block and data buffer - pub fn read_record_batch( + /// Read the dictionary with the given block and data buffer + pub fn read_dictionary(&mut self, block: &Block, buf: &Buffer) -> Result<(), ArrowError> { + unsafe { self.read_dictionary_internal(block, buf, false) } + } + + /// Read the dictionary with the given block and data buffer + /// + /// # Safety: + /// Skip validations to create an `ArrayData` without performing the usual validations. + /// This can lead to undefined behavior if the data is not correctly formatted. + /// Use this function only if you are certain and trust the data source. + pub unsafe fn read_dictionary_unchecked( + &mut self, + block: &Block, + buf: &Buffer, + ) -> Result<(), ArrowError> { + self.read_dictionary_internal(block, buf, true) + } + + unsafe fn read_record_batch_internal( &self, block: &Block, buf: &Buffer, + skip_validations: bool, ) -> Result, ArrowError> { let message = self.read_message(buf)?; match message.header_type() { @@ -840,6 +1292,7 @@ impl FileDecoder { self.projection.as_deref(), &message.version(), self.require_alignment, + skip_validations, ) .map(Some) } @@ -849,6 +1302,29 @@ impl FileDecoder { ))), } } + + /// Read the RecordBatch with the given block and data buffer + pub fn read_record_batch( + &self, + block: &Block, + buf: &Buffer, + ) -> Result, ArrowError> { + unsafe { self.read_record_batch_internal(block, buf, false) } + } + + /// Same as `read_record_batch`, but skips validations. + /// + /// # Safety: + /// Skip validations to create an `ArrayData` without performing the usual validations. + /// This can lead to undefined behavior if the data is not correctly formatted. + /// Use this function only if you are certain and trust the data source. + pub unsafe fn read_record_batch_unchecked( + &self, + block: &Block, + buf: &Buffer, + ) -> Result, ArrowError> { + self.read_record_batch_internal(block, buf, true) + } } /// Build an Arrow [`FileReader`] with custom options. @@ -921,8 +1397,11 @@ impl FileReaderBuilder { self } - /// Build [`FileReader`] with given reader. - pub fn build(self, mut reader: R) -> Result, ArrowError> { + unsafe fn build_internal( + self, + mut reader: R, + skip_validations: bool, + ) -> Result, ArrowError> { // Space for ARROW_MAGIC (6 bytes) and length (4 bytes) let mut buffer = [0; 10]; reader.seek(SeekFrom::End(-10))?; @@ -978,7 +1457,11 @@ impl FileReaderBuilder { if let Some(dictionaries) = footer.dictionaries() { for block in dictionaries { let buf = read_block(&mut reader, block)?; - decoder.read_dictionary(block, &buf)?; + if skip_validations { + decoder.read_dictionary_unchecked(block, &buf)?; + } else { + decoder.read_dictionary(block, &buf)?; + } } } @@ -991,6 +1474,20 @@ impl FileReaderBuilder { custom_metadata, }) } + + /// Build [`FileReader`] with given reader. + pub fn build(self, reader: R) -> Result, ArrowError> { + unsafe { self.build_internal(reader, false) } + } + + /// Build [`FileReader`] with given reader without validations at the build up stage. + /// Upon reading the data, validations will be performed if not specified otherwise. + pub unsafe fn build_unvalidated( + self, + reader: R, + ) -> Result, ArrowError> { + self.build_internal(reader, true) + } } /// Arrow File reader @@ -1055,6 +1552,22 @@ impl FileReader { builder.build(reader) } + /// Try to create a new file reader without validations. + /// + /// This is useful when the file is known to be valid and the user wants to skip validations. + /// This might be useful when the content is trusted and the user wants to avoid the overhead of + /// validating the content. + pub fn try_new_unvalidated( + reader: R, + projection: Option>, + ) -> Result { + let builder = FileReaderBuilder { + projection, + ..Default::default() + }; + builder.build(reader) + } + /// Return user defined customized metadata pub fn custom_metadata(&self) -> &HashMap { &self.custom_metadata @@ -1094,6 +1607,20 @@ impl FileReader { self.decoder.read_record_batch(block, &buffer) } + unsafe fn maybe_next_unvalidated(&mut self) -> Result, ArrowError> { + let block = &self.blocks[self.current_block]; + self.current_block += 1; + + // read length + let buffer = read_block(&mut self.reader, block)?; + self.decoder.read_record_batch_unchecked(block, &buffer) + } + + /// Returns an iterator that uses the unsafe version of maybe_next_unvalidated. + pub unsafe fn into_unvalidated_iterator(self) -> UnvalidatedFileReader { + UnvalidatedFileReader { reader: self } + } + /// Gets a reference to the underlying reader. /// /// It is inadvisable to directly read from the underlying reader. @@ -1128,6 +1655,30 @@ impl RecordBatchReader for FileReader { } } +/// An iterator over the record batches (without validation) in an Arrow file +pub struct UnvalidatedFileReader { + reader: FileReader, +} + +impl Iterator for UnvalidatedFileReader { + type Item = Result; + + fn next(&mut self) -> Option { + if self.reader.current_block < self.reader.total_blocks { + // Use the unsafe `maybe_next_unvalidated` function + unsafe { + match self.reader.maybe_next_unvalidated() { + Ok(Some(batch)) => Some(Ok(batch)), + Ok(None) => None, // End of the file + Err(e) => Some(Err(e)), + } + } + } else { + None + } + } +} + /// Arrow Stream reader pub struct StreamReader { /// Stream reader @@ -1249,7 +1800,10 @@ impl StreamReader { self.finished } - fn maybe_next(&mut self) -> Result, ArrowError> { + unsafe fn maybe_next_internal( + &mut self, + skip_validations: bool, + ) -> Result, ArrowError> { if self.finished { return Ok(None); } @@ -1314,6 +1868,7 @@ impl StreamReader { self.projection.as_ref().map(|x| x.0.as_ref()), &message.version(), false, + skip_validations, ) .map(Some) } @@ -1334,6 +1889,7 @@ impl StreamReader { &mut self.dictionaries_by_id, &message.version(), false, + skip_validations, )?; // read the next message until we encounter a RecordBatch @@ -1346,6 +1902,14 @@ impl StreamReader { } } + fn maybe_next(&mut self) -> Result, ArrowError> { + unsafe { self.maybe_next_internal(false) } + } + + unsafe fn maybe_next_unvalidated(&mut self) -> Result, ArrowError> { + self.maybe_next_internal(true) + } + /// Gets a reference to the underlying reader. /// /// It is inadvisable to directly read from the underlying reader. @@ -1359,6 +1923,11 @@ impl StreamReader { pub fn get_mut(&mut self) -> &mut R { &mut self.reader } + + /// Returns an iterator that uses the unsafe version of maybe_next_unvalidated. + pub unsafe fn into_unvalidated_iterator(self) -> UnvalidatedStreamReader { + UnvalidatedStreamReader { reader: self } + } } impl Iterator for StreamReader { @@ -1375,6 +1944,25 @@ impl RecordBatchReader for StreamReader { } } +/// An iterator over the record batches (without validation) in an Arrow stream +pub struct UnvalidatedStreamReader { + reader: StreamReader, +} + +impl Iterator for UnvalidatedStreamReader { + type Item = Result; + + fn next(&mut self) -> Option { + unsafe { + match self.reader.maybe_next_unvalidated() { + Ok(Some(batch)) => Some(Ok(batch)), + Ok(None) => None, // End of the file + Err(e) => Some(Err(e)), + } + } + } +} + #[cfg(test)] mod tests { use crate::writer::{unslice_run_array, DictionaryTracker, IpcDataGenerator, IpcWriteOptions}; @@ -2156,16 +2744,19 @@ mod tests { assert_ne!(b.as_ptr().align_offset(8), 0); let ipc_batch = message.header_as_record_batch().unwrap(); - let roundtrip = read_record_batch_impl( - &b, - ipc_batch, - batch.schema(), - &Default::default(), - None, - &message.version(), - false, - ) - .unwrap(); + let roundtrip = unsafe { + read_record_batch_impl( + &b, + ipc_batch, + batch.schema(), + &Default::default(), + None, + &message.version(), + false, + false, + ) + .unwrap() + }; assert_eq!(batch, roundtrip); } @@ -2194,15 +2785,18 @@ mod tests { assert_ne!(b.as_ptr().align_offset(8), 0); let ipc_batch = message.header_as_record_batch().unwrap(); - let result = read_record_batch_impl( - &b, - ipc_batch, - batch.schema(), - &Default::default(), - None, - &message.version(), - true, - ); + let result = unsafe { + read_record_batch_impl( + &b, + ipc_batch, + batch.schema(), + &Default::default(), + None, + &message.version(), + true, + false, + ) + }; let error = result.unwrap_err(); assert_eq!( diff --git a/arrow-ipc/src/reader/stream.rs b/arrow-ipc/src/reader/stream.rs index 9b0eea9b6198..6ee86b720880 100644 --- a/arrow-ipc/src/reader/stream.rs +++ b/arrow-ipc/src/reader/stream.rs @@ -102,33 +102,11 @@ impl StreamDecoder { self } - /// Try to read the next [`RecordBatch`] from the provided [`Buffer`] - /// - /// [`Buffer::advance`] will be called on `buffer` for any consumed bytes. - /// - /// The push-based interface facilitates integration with sources that yield arbitrarily - /// delimited bytes ranges, such as a chunked byte stream received from object storage - /// - /// ``` - /// # use arrow_array::RecordBatch; - /// # use arrow_buffer::Buffer; - /// # use arrow_ipc::reader::StreamDecoder; - /// # use arrow_schema::ArrowError; - /// # - /// fn print_stream(src: impl Iterator) -> Result<(), ArrowError> { - /// let mut decoder = StreamDecoder::new(); - /// for mut x in src { - /// while !x.is_empty() { - /// if let Some(x) = decoder.decode(&mut x)? { - /// println!("{x:?}"); - /// } - /// } - /// } - /// decoder.finish().unwrap(); - /// Ok(()) - /// } - /// ``` - pub fn decode(&mut self, buffer: &mut Buffer) -> Result, ArrowError> { + unsafe fn decode_internal( + &mut self, + buffer: &mut Buffer, + skip_validations: bool, + ) -> Result, ArrowError> { while !buffer.is_empty() { match &mut self.state { DecoderState::Header { @@ -211,15 +189,18 @@ impl StreamDecoder { let schema = self.schema.clone().ok_or_else(|| { ArrowError::IpcError("Missing schema".to_string()) })?; - let batch = read_record_batch_impl( - &body, - batch, - schema, - &self.dictionaries, - None, - &version, - self.require_alignment, - )?; + let batch = unsafe { + read_record_batch_impl( + &body, + batch, + schema, + &self.dictionaries, + None, + &version, + self.require_alignment, + skip_validations, + )? + }; self.state = DecoderState::default(); return Ok(Some(batch)); } @@ -228,14 +209,17 @@ impl StreamDecoder { let schema = self.schema.as_deref().ok_or_else(|| { ArrowError::IpcError("Missing schema".to_string()) })?; - read_dictionary_impl( - &body, - dictionary, - schema, - &mut self.dictionaries, - &version, - self.require_alignment, - )?; + unsafe { + read_dictionary_impl( + &body, + dictionary, + schema, + &mut self.dictionaries, + &version, + self.require_alignment, + skip_validations, + )?; + } self.state = DecoderState::default(); } MessageHeader::NONE => { @@ -256,6 +240,48 @@ impl StreamDecoder { Ok(None) } + /// Try to read the next [`RecordBatch`] from the provided [`Buffer`] + /// + /// [`Buffer::advance`] will be called on `buffer` for any consumed bytes. + /// + /// The push-based interface facilitates integration with sources that yield arbitrarily + /// delimited bytes ranges, such as a chunked byte stream received from object storage + /// + /// ``` + /// # use arrow_array::RecordBatch; + /// # use arrow_buffer::Buffer; + /// # use arrow_ipc::reader::StreamDecoder; + /// # use arrow_schema::ArrowError; + /// # + /// fn print_stream(src: impl Iterator) -> Result<(), ArrowError> { + /// let mut decoder = StreamDecoder::new(); + /// for mut x in src { + /// while !x.is_empty() { + /// if let Some(x) = decoder.decode(&mut x)? { + /// println!("{x:?}"); + /// } + /// } + /// } + /// decoder.finish().unwrap(); + /// Ok(()) + /// } + /// ``` + pub fn decode(&mut self, buffer: &mut Buffer) -> Result, ArrowError> { + unsafe { self.decode_internal(buffer, false) } + } + + /// Try to read the next [`RecordBatch`] from the provided [`Buffer`] without validating the data + /// This is useful when the data is known to be valid and the validation can be skipped + /// + /// # Safety: + /// This method is unsafe because it does not validate the data + pub unsafe fn decode_unvalidated( + &mut self, + buffer: &mut Buffer, + ) -> Result, ArrowError> { + unsafe { self.decode_internal(buffer, true) } + } + /// Signal the end of stream /// /// Returns an error if any partial data remains in the stream diff --git a/parquet-testing b/parquet-testing index f4d7ed772a62..4439a223a315 160000 --- a/parquet-testing +++ b/parquet-testing @@ -1 +1 @@ -Subproject commit f4d7ed772a62a95111db50fbcad2460833e8c882 +Subproject commit 4439a223a315cf874746d3b5da25e6a6b2a2b16e