diff --git a/arrow-array/src/array/struct_array.rs b/arrow-array/src/array/struct_array.rs index fa43062b77bf..457cbdb6ae87 100644 --- a/arrow-array/src/array/struct_array.rs +++ b/arrow-array/src/array/struct_array.rs @@ -15,8 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::{make_array, Array, ArrayRef, RecordBatch}; -use arrow_buffer::{buffer_bin_or, Buffer, NullBuffer}; +use crate::{make_array, new_null_array, Array, ArrayRef, RecordBatch}; +use arrow_buffer::{buffer_bin_or, BooleanBuffer, Buffer, NullBuffer}; use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{ArrowError, DataType, Field, Fields, SchemaBuilder}; use std::sync::Arc; @@ -77,10 +77,92 @@ pub struct StructArray { len: usize, data_type: DataType, nulls: Option, - pub(crate) fields: Vec, + fields: Vec, } impl StructArray { + /// Create a new [`StructArray`] from the provided parts + /// + /// # Panics + /// + /// Panics if + /// + /// * fields.len() != arrays.len() + /// * fields[i].data_type() != arrays[i].data_type() + /// * arrays[i].len() != arrays[j].len() + /// * arrays[i].len() != nulls.len() + /// * !fields[i].is_nullable() && !nulls.contains(arrays[i].nulls()) + /// + pub fn new(fields: Fields, arrays: Vec, nulls: Option) -> Self { + assert_eq!(fields.len(), arrays.len()); + let len = arrays.first().map(|x| x.len()).unwrap_or_default(); + + if let Some(n) = nulls.as_ref() { + assert_eq!(n.len(), len); + } + + for (f, a) in fields.iter().zip(&arrays) { + assert_eq!(f.data_type(), a.data_type(), "{f}"); + assert_eq!(a.len(), len, "{f}"); + + if let Some(a) = a.nulls() { + let nulls_valid = f.is_nullable() + || nulls.as_ref().map(|n| n.contains(a)).unwrap_or_default(); + assert!(nulls_valid, "{f}"); + } + } + + Self { + len, + data_type: DataType::Struct(fields), + nulls: nulls.filter(|n| n.null_count() > 0), + fields: arrays, + } + } + + /// Create a new [`StructArray`] of length `len` where all values are null + pub fn new_null(fields: Fields, len: usize) -> Self { + let arrays = fields + .iter() + .map(|f| new_null_array(f.data_type(), len)) + .collect(); + + Self { + len, + data_type: DataType::Struct(fields), + nulls: Some(NullBuffer::new_null(len)), + fields: arrays, + } + } + + /// Create a new [`StructArray`] from the provided parts without validation + /// + /// # Safety + /// + /// Safe if [`Self::new`] would not panic with the given arguments + pub unsafe fn new_unchecked( + fields: Fields, + arrays: Vec, + nulls: Option, + ) -> Self { + let len = arrays.first().map(|x| x.len()).unwrap_or_default(); + Self { + len, + data_type: DataType::Struct(fields), + nulls, + fields: arrays, + } + } + + /// Deconstruct this array into its constituent parts + pub fn into_parts(self) -> (Fields, Vec, Option) { + let f = match self.data_type { + DataType::Struct(f) => f, + _ => unreachable!(), + }; + (f, self.fields, self.nulls) + } + /// Returns the field at `pos`. pub fn column(&self, pos: usize) -> &ArrayRef { &self.fields[pos] @@ -359,37 +441,10 @@ impl std::fmt::Debug for StructArray { impl From<(Vec<(Field, ArrayRef)>, Buffer)> for StructArray { fn from(pair: (Vec<(Field, ArrayRef)>, Buffer)) -> Self { - let capacity = pair.0.len(); - let mut len = None; - let mut schema = SchemaBuilder::with_capacity(capacity); - let mut child_data = Vec::with_capacity(capacity); - for (field, array) in pair.0 { - // Check the length of the child arrays - assert_eq!( - *len.get_or_insert(array.len()), - array.len(), - "all child arrays of a StructArray must have the same length" - ); - // Check data types of child arrays - assert_eq!( - field.data_type(), - array.data_type(), - "the field data types must match the array data in a StructArray" - ); - schema.push(field); - child_data.push(array.to_data()); - } - let field_types = schema.finish().fields; - let array_data = ArrayData::builder(DataType::Struct(field_types)) - .null_bit_buffer(Some(pair.1)) - .child_data(child_data) - .len(len.unwrap_or_default()); - let array_data = unsafe { array_data.build_unchecked() }; - - // We must validate nullability - array_data.validate_nulls().unwrap(); - - Self::from(array_data) + let len = pair.0.first().map(|x| x.1.len()).unwrap_or_default(); + let (fields, arrays): (SchemaBuilder, Vec<_>) = pair.0.into_iter().unzip(); + let nulls = NullBuffer::new(BooleanBuffer::new(pair.1, 0, len)); + Self::new(fields.finish().fields, arrays, Some(nulls)) } } diff --git a/arrow-array/src/record_batch.rs b/arrow-array/src/record_batch.rs index ee61d2da6597..8fb08111c846 100644 --- a/arrow-array/src/record_batch.rs +++ b/arrow-array/src/record_batch.rs @@ -467,17 +467,12 @@ impl Default for RecordBatchOptions { } impl From for RecordBatch { fn from(value: StructArray) -> Self { - assert_eq!( - value.null_count(), - 0, - "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation" - ); let row_count = value.len(); - let schema = Arc::new(Schema::new(value.fields().clone())); - let columns = value.fields; + let (fields, columns, nulls) = value.into_parts(); + assert_eq!(nulls.map(|n| n.null_count()).unwrap_or_default(), 0, "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation"); RecordBatch { - schema, + schema: Arc::new(Schema::new(fields)), row_count, columns, } diff --git a/arrow-buffer/src/buffer/null.rs b/arrow-buffer/src/buffer/null.rs index f088e7fa62e9..cdb0c2aeb824 100644 --- a/arrow-buffer/src/buffer/null.rs +++ b/arrow-buffer/src/buffer/null.rs @@ -67,6 +67,13 @@ impl NullBuffer { } } + /// Returns true if all nulls in `other` also exist in self + pub fn contains(&self, other: &NullBuffer) -> bool { + let lhs = self.inner().bit_chunks().iter_padded(); + let rhs = other.inner().bit_chunks().iter_padded(); + lhs.zip(rhs).all(|(l, r)| (l & !r) == 0) + } + /// Returns the length of this [`NullBuffer`] #[inline] pub fn len(&self) -> usize { diff --git a/arrow-json/src/reader/struct_array.rs b/arrow-json/src/reader/struct_array.rs index 013f862c51ad..464893b2bfe6 100644 --- a/arrow-json/src/reader/struct_array.rs +++ b/arrow-json/src/reader/struct_array.rs @@ -113,19 +113,11 @@ impl ArrayDecoder for StructArrayDecoder { for (c, f) in child_data.iter().zip(fields) { // Sanity check assert_eq!(c.len(), pos.len()); + if let Some(a) = c.nulls() { + let nulls_valid = f.is_nullable() + || nulls.as_ref().map(|n| n.contains(a)).unwrap_or_default(); - if !f.is_nullable() && c.null_count() != 0 { - // Need to verify nulls - let valid = match nulls.as_ref() { - Some(nulls) => { - let lhs = nulls.inner().bit_chunks().iter_padded(); - let rhs = c.nulls().unwrap().inner().bit_chunks().iter_padded(); - lhs.zip(rhs).all(|(l, r)| (l & !r) == 0) - } - None => false, - }; - - if !valid { + if !nulls_valid { return Err(ArrowError::JsonError(format!("Encountered unmasked nulls in non-nullable StructArray child: {f}"))); } }