From 8d166a14467ac8e59a47174de676971f9f896e78 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Tue, 25 Apr 2023 14:04:27 -0400 Subject: [PATCH] Add StructArray Constructors (#3879) (#4064) * Add StructArray Constructors (#3879) * Fix doc * Add try_new * Update other constructors --- arrow-array/src/array/struct_array.rs | 302 ++++++++++++++------------ arrow-array/src/record_batch.rs | 11 +- arrow-buffer/src/buffer/null.rs | 7 + arrow-json/src/reader/struct_array.rs | 16 +- 4 files changed, 173 insertions(+), 163 deletions(-) diff --git a/arrow-array/src/array/struct_array.rs b/arrow-array/src/array/struct_array.rs index fa43062b77bf..a18f38c082c9 100644 --- a/arrow-array/src/array/struct_array.rs +++ b/arrow-array/src/array/struct_array.rs @@ -15,8 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::{make_array, Array, ArrayRef, RecordBatch}; -use arrow_buffer::{buffer_bin_or, Buffer, NullBuffer}; +use crate::{make_array, new_null_array, Array, ArrayRef, RecordBatch}; +use arrow_buffer::{BooleanBuffer, Buffer, NullBuffer}; use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{ArrowError, DataType, Field, Fields, SchemaBuilder}; use std::sync::Arc; @@ -77,10 +77,136 @@ pub struct StructArray { len: usize, data_type: DataType, nulls: Option, - pub(crate) fields: Vec, + fields: Vec, } impl StructArray { + /// Create a new [`StructArray`] from the provided parts, panicking on failure + /// + /// # Panics + /// + /// Panics if [`Self::try_new`] returns an error + pub fn new(fields: Fields, arrays: Vec, nulls: Option) -> Self { + Self::try_new(fields, arrays, nulls).unwrap() + } + + /// Create a new [`StructArray`] from the provided parts, returning an error on failure + /// + /// # Errors + /// + /// Errors if + /// + /// * `fields.len() != arrays.len()` + /// * `fields[i].data_type() != arrays[i].data_type()` + /// * `arrays[i].len() != arrays[j].len()` + /// * `arrays[i].len() != nulls.len()` + /// * `!fields[i].is_nullable() && !nulls.contains(arrays[i].nulls())` + pub fn try_new( + fields: Fields, + arrays: Vec, + nulls: Option, + ) -> Result { + if fields.len() != arrays.len() { + return Err(ArrowError::InvalidArgumentError(format!( + "Incorrect number of arrays for StructArray fields, expected {} got {}", + fields.len(), + arrays.len() + ))); + } + let len = arrays.first().map(|x| x.len()).unwrap_or_default(); + + if let Some(n) = nulls.as_ref() { + if n.len() != len { + return Err(ArrowError::InvalidArgumentError(format!( + "Incorrect number of nulls for StructArray, expected {len} got {}", + n.len(), + ))); + } + } + + for (f, a) in fields.iter().zip(&arrays) { + if f.data_type() != a.data_type() { + return Err(ArrowError::InvalidArgumentError(format!( + "Incorrect datatype for StructArray field {:?}, expected {} got {}", + f.name(), + f.data_type(), + a.data_type() + ))); + } + + if a.len() != len { + return Err(ArrowError::InvalidArgumentError(format!( + "Incorrect array length for StructArray field {:?}, expected {} got {}", + f.name(), + len, + a.len() + ))); + } + + if let Some(a) = a.nulls() { + let nulls_valid = f.is_nullable() + || nulls.as_ref().map(|n| n.contains(a)).unwrap_or_default(); + + if !nulls_valid { + return Err(ArrowError::InvalidArgumentError(format!( + "Found unmasked nulls for non-nullable StructArray field {:?}", + f.name() + ))); + } + } + } + + Ok(Self { + len, + data_type: DataType::Struct(fields), + nulls: nulls.filter(|n| n.null_count() > 0), + fields: arrays, + }) + } + + /// Create a new [`StructArray`] of length `len` where all values are null + pub fn new_null(fields: Fields, len: usize) -> Self { + let arrays = fields + .iter() + .map(|f| new_null_array(f.data_type(), len)) + .collect(); + + Self { + len, + data_type: DataType::Struct(fields), + nulls: Some(NullBuffer::new_null(len)), + fields: arrays, + } + } + + /// Create a new [`StructArray`] from the provided parts without validation + /// + /// # Safety + /// + /// Safe if [`Self::new`] would not panic with the given arguments + pub unsafe fn new_unchecked( + fields: Fields, + arrays: Vec, + nulls: Option, + ) -> Self { + let len = arrays.first().map(|x| x.len()).unwrap_or_default(); + Self { + len, + data_type: DataType::Struct(fields), + nulls, + fields: arrays, + } + } + + /// Deconstruct this array into its constituent parts + pub fn into_parts(self) -> (Fields, Vec, Option) { + let f = match self.data_type { + DataType::Struct(f) => f, + _ => unreachable!(), + }; + (f, self.fields, self.nulls) + } + /// Returns the field at `pos`. pub fn column(&self, pos: usize) -> &ArrayRef { &self.fields[pos] @@ -183,66 +309,18 @@ impl TryFrom> for StructArray { type Error = ArrowError; /// builds a StructArray from a vector of names and arrays. - /// This errors if the values have a different length. - /// An entry is set to Null when all values are null. fn try_from(values: Vec<(&str, ArrayRef)>) -> Result { - let values_len = values.len(); - - // these will be populated - let mut fields = Vec::with_capacity(values_len); - let mut child_data = Vec::with_capacity(values_len); - - // len: the size of the arrays. - let mut len: Option = None; - // null: the null mask of the arrays. - let mut null: Option = None; - for (field_name, array) in values { - let child_datum = array.to_data(); - let child_datum_len = child_datum.len(); - if let Some(len) = len { - if len != child_datum_len { - return Err(ArrowError::InvalidArgumentError( - format!("Array of field \"{field_name}\" has length {child_datum_len}, but previous elements have length {len}. - All arrays in every entry in a struct array must have the same length.") - )); - } - } else { - len = Some(child_datum_len) - } - fields.push(Arc::new(Field::new( - field_name, - array.data_type().clone(), - child_datum.nulls().is_some(), - ))); - - if let Some(child_nulls) = child_datum.nulls() { - null = Some(if let Some(null_buffer) = &null { - buffer_bin_or( - null_buffer, - 0, - child_nulls.buffer(), - child_nulls.offset(), - child_datum_len, - ) - } else { - child_nulls.inner().sliced() - }); - } else if null.is_some() { - // when one of the fields has no nulls, then there is no null in the array - null = None; - } - child_data.push(child_datum); - } - let len = len.unwrap(); - - let builder = ArrayData::builder(DataType::Struct(fields.into())) - .len(len) - .null_bit_buffer(null) - .child_data(child_data); - - let array_data = unsafe { builder.build_unchecked() }; - - Ok(StructArray::from(array_data)) + let (schema, arrays): (SchemaBuilder, _) = values + .into_iter() + .map(|(name, array)| { + ( + Field::new(name, array.data_type().clone(), array.nulls().is_some()), + array, + ) + }) + .unzip(); + + StructArray::try_new(schema.finish().fields, arrays, None) } } @@ -303,38 +381,8 @@ impl Array for StructArray { impl From> for StructArray { fn from(v: Vec<(Field, ArrayRef)>) -> Self { - let iter = v.into_iter(); - let capacity = iter.size_hint().0; - - let mut len = None; - let mut schema = SchemaBuilder::with_capacity(capacity); - let mut child_data = Vec::with_capacity(capacity); - for (field, array) in iter { - // Check the length of the child arrays - assert_eq!( - *len.get_or_insert(array.len()), - array.len(), - "all child arrays of a StructArray must have the same length" - ); - // Check data types of child arrays - assert_eq!( - field.data_type(), - array.data_type(), - "the field data types must match the array data in a StructArray" - ); - schema.push(field); - child_data.push(array.to_data()); - } - let field_types = schema.finish().fields; - let array_data = ArrayData::builder(DataType::Struct(field_types)) - .child_data(child_data) - .len(len.unwrap_or_default()); - let array_data = unsafe { array_data.build_unchecked() }; - - // We must validate nullability - array_data.validate_nulls().unwrap(); - - Self::from(array_data) + let (schema, arrays): (SchemaBuilder, _) = v.into_iter().unzip(); + StructArray::new(schema.finish().fields, arrays, None) } } @@ -359,37 +407,10 @@ impl std::fmt::Debug for StructArray { impl From<(Vec<(Field, ArrayRef)>, Buffer)> for StructArray { fn from(pair: (Vec<(Field, ArrayRef)>, Buffer)) -> Self { - let capacity = pair.0.len(); - let mut len = None; - let mut schema = SchemaBuilder::with_capacity(capacity); - let mut child_data = Vec::with_capacity(capacity); - for (field, array) in pair.0 { - // Check the length of the child arrays - assert_eq!( - *len.get_or_insert(array.len()), - array.len(), - "all child arrays of a StructArray must have the same length" - ); - // Check data types of child arrays - assert_eq!( - field.data_type(), - array.data_type(), - "the field data types must match the array data in a StructArray" - ); - schema.push(field); - child_data.push(array.to_data()); - } - let field_types = schema.finish().fields; - let array_data = ArrayData::builder(DataType::Struct(field_types)) - .null_bit_buffer(Some(pair.1)) - .child_data(child_data) - .len(len.unwrap_or_default()); - let array_data = unsafe { array_data.build_unchecked() }; - - // We must validate nullability - array_data.validate_nulls().unwrap(); - - Self::from(array_data) + let len = pair.0.first().map(|x| x.1.len()).unwrap_or_default(); + let (fields, arrays): (SchemaBuilder, Vec<_>) = pair.0.into_iter().unzip(); + let nulls = NullBuffer::new(BooleanBuffer::new(pair.1, 0, len)); + Self::new(fields.finish().fields, arrays, Some(nulls)) } } @@ -512,12 +533,7 @@ mod tests { let struct_data = arr.into_data(); assert_eq!(4, struct_data.len()); - assert_eq!(1, struct_data.null_count()); - assert_eq!( - // 00001011 - &[11_u8], - struct_data.nulls().unwrap().validity() - ); + assert_eq!(0, struct_data.null_count()); let expected_string_data = ArrayData::builder(DataType::Utf8) .len(4) @@ -549,20 +565,20 @@ mod tests { let ints: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), Some(2), None, Some(4)])); - let arr = - StructArray::try_from(vec![("f1", strings.clone()), ("f2", ints.clone())]); + let err = + StructArray::try_from(vec![("f1", strings.clone()), ("f2", ints.clone())]) + .unwrap_err() + .to_string(); - match arr { - Err(ArrowError::InvalidArgumentError(e)) => { - assert!(e.starts_with("Array of field \"f2\" has length 4, but previous elements have length 3.")); - } - _ => panic!("This test got an unexpected error type"), - }; + assert_eq!( + err, + "Invalid argument error: Incorrect array length for StructArray field \"f2\", expected 3 got 4" + ) } #[test] #[should_panic( - expected = "the field data types must match the array data in a StructArray" + expected = "Incorrect datatype for StructArray field \\\"b\\\", expected Int16 got Boolean" )] fn test_struct_array_from_mismatched_types_single() { drop(StructArray::from(vec![( @@ -574,7 +590,7 @@ mod tests { #[test] #[should_panic( - expected = "the field data types must match the array data in a StructArray" + expected = "Incorrect datatype for StructArray field \\\"b\\\", expected Int16 got Boolean" )] fn test_struct_array_from_mismatched_types_multiple() { drop(StructArray::from(vec![ @@ -679,7 +695,7 @@ mod tests { #[test] #[should_panic( - expected = "all child arrays of a StructArray must have the same length" + expected = "Incorrect array length for StructArray field \\\"c\\\", expected 1 got 2" )] fn test_invalid_struct_child_array_lengths() { drop(StructArray::from(vec![ @@ -702,7 +718,7 @@ mod tests { #[test] #[should_panic( - expected = "non-nullable child of type Int32 contains nulls not present in parent Struct" + expected = "Found unmasked nulls for non-nullable StructArray field \\\"c\\\"" )] fn test_struct_array_from_mismatched_nullability() { drop(StructArray::from(vec![( diff --git a/arrow-array/src/record_batch.rs b/arrow-array/src/record_batch.rs index ee61d2da6597..8fb08111c846 100644 --- a/arrow-array/src/record_batch.rs +++ b/arrow-array/src/record_batch.rs @@ -467,17 +467,12 @@ impl Default for RecordBatchOptions { } impl From for RecordBatch { fn from(value: StructArray) -> Self { - assert_eq!( - value.null_count(), - 0, - "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation" - ); let row_count = value.len(); - let schema = Arc::new(Schema::new(value.fields().clone())); - let columns = value.fields; + let (fields, columns, nulls) = value.into_parts(); + assert_eq!(nulls.map(|n| n.null_count()).unwrap_or_default(), 0, "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation"); RecordBatch { - schema, + schema: Arc::new(Schema::new(fields)), row_count, columns, } diff --git a/arrow-buffer/src/buffer/null.rs b/arrow-buffer/src/buffer/null.rs index f088e7fa62e9..cdb0c2aeb824 100644 --- a/arrow-buffer/src/buffer/null.rs +++ b/arrow-buffer/src/buffer/null.rs @@ -67,6 +67,13 @@ impl NullBuffer { } } + /// Returns true if all nulls in `other` also exist in self + pub fn contains(&self, other: &NullBuffer) -> bool { + let lhs = self.inner().bit_chunks().iter_padded(); + let rhs = other.inner().bit_chunks().iter_padded(); + lhs.zip(rhs).all(|(l, r)| (l & !r) == 0) + } + /// Returns the length of this [`NullBuffer`] #[inline] pub fn len(&self) -> usize { diff --git a/arrow-json/src/reader/struct_array.rs b/arrow-json/src/reader/struct_array.rs index 6c6f1457bfc2..707b56d50eef 100644 --- a/arrow-json/src/reader/struct_array.rs +++ b/arrow-json/src/reader/struct_array.rs @@ -120,19 +120,11 @@ impl ArrayDecoder for StructArrayDecoder { for (c, f) in child_data.iter().zip(fields) { // Sanity check assert_eq!(c.len(), pos.len()); + if let Some(a) = c.nulls() { + let nulls_valid = f.is_nullable() + || nulls.as_ref().map(|n| n.contains(a)).unwrap_or_default(); - if !f.is_nullable() && c.null_count() != 0 { - // Need to verify nulls - let valid = match nulls.as_ref() { - Some(nulls) => { - let lhs = nulls.inner().bit_chunks().iter_padded(); - let rhs = c.nulls().unwrap().inner().bit_chunks().iter_padded(); - lhs.zip(rhs).all(|(l, r)| (l & !r) == 0) - } - None => false, - }; - - if !valid { + if !nulls_valid { return Err(ArrowError::JsonError(format!("Encountered unmasked nulls in non-nullable StructArray child: {f}"))); } }