Skip to content

Commit

Permalink
Add StructArray Constructors (#3879) (#4064)
Browse files Browse the repository at this point in the history
* Add StructArray Constructors (#3879)

* Fix doc

* Add try_new

* Update other constructors
  • Loading branch information
tustvold authored Apr 25, 2023
1 parent be33ec5 commit 8d166a1
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 163 deletions.
302 changes: 159 additions & 143 deletions arrow-array/src/array/struct_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
// specific language governing permissions and limitations
// under the License.

use crate::{make_array, Array, ArrayRef, RecordBatch};
use arrow_buffer::{buffer_bin_or, Buffer, NullBuffer};
use crate::{make_array, new_null_array, Array, ArrayRef, RecordBatch};
use arrow_buffer::{BooleanBuffer, Buffer, NullBuffer};
use arrow_data::{ArrayData, ArrayDataBuilder};
use arrow_schema::{ArrowError, DataType, Field, Fields, SchemaBuilder};
use std::sync::Arc;
Expand Down Expand Up @@ -77,10 +77,136 @@ pub struct StructArray {
len: usize,
data_type: DataType,
nulls: Option<NullBuffer>,
pub(crate) fields: Vec<ArrayRef>,
fields: Vec<ArrayRef>,
}

impl StructArray {
/// Create a new [`StructArray`] from the provided parts, panicking on failure
///
/// # Panics
///
/// Panics if [`Self::try_new`] returns an error
pub fn new(fields: Fields, arrays: Vec<ArrayRef>, nulls: Option<NullBuffer>) -> Self {
Self::try_new(fields, arrays, nulls).unwrap()
}

/// Create a new [`StructArray`] from the provided parts, returning an error on failure
///
/// # Errors
///
/// Errors if
///
/// * `fields.len() != arrays.len()`
/// * `fields[i].data_type() != arrays[i].data_type()`
/// * `arrays[i].len() != arrays[j].len()`
/// * `arrays[i].len() != nulls.len()`
/// * `!fields[i].is_nullable() && !nulls.contains(arrays[i].nulls())`
pub fn try_new(
fields: Fields,
arrays: Vec<ArrayRef>,
nulls: Option<NullBuffer>,
) -> Result<Self, ArrowError> {
if fields.len() != arrays.len() {
return Err(ArrowError::InvalidArgumentError(format!(
"Incorrect number of arrays for StructArray fields, expected {} got {}",
fields.len(),
arrays.len()
)));
}
let len = arrays.first().map(|x| x.len()).unwrap_or_default();

if let Some(n) = nulls.as_ref() {
if n.len() != len {
return Err(ArrowError::InvalidArgumentError(format!(
"Incorrect number of nulls for StructArray, expected {len} got {}",
n.len(),
)));
}
}

for (f, a) in fields.iter().zip(&arrays) {
if f.data_type() != a.data_type() {
return Err(ArrowError::InvalidArgumentError(format!(
"Incorrect datatype for StructArray field {:?}, expected {} got {}",
f.name(),
f.data_type(),
a.data_type()
)));
}

if a.len() != len {
return Err(ArrowError::InvalidArgumentError(format!(
"Incorrect array length for StructArray field {:?}, expected {} got {}",
f.name(),
len,
a.len()
)));
}

if let Some(a) = a.nulls() {
let nulls_valid = f.is_nullable()
|| nulls.as_ref().map(|n| n.contains(a)).unwrap_or_default();

if !nulls_valid {
return Err(ArrowError::InvalidArgumentError(format!(
"Found unmasked nulls for non-nullable StructArray field {:?}",
f.name()
)));
}
}
}

Ok(Self {
len,
data_type: DataType::Struct(fields),
nulls: nulls.filter(|n| n.null_count() > 0),
fields: arrays,
})
}

/// Create a new [`StructArray`] of length `len` where all values are null
pub fn new_null(fields: Fields, len: usize) -> Self {
let arrays = fields
.iter()
.map(|f| new_null_array(f.data_type(), len))
.collect();

Self {
len,
data_type: DataType::Struct(fields),
nulls: Some(NullBuffer::new_null(len)),
fields: arrays,
}
}

/// Create a new [`StructArray`] from the provided parts without validation
///
/// # Safety
///
/// Safe if [`Self::new`] would not panic with the given arguments
pub unsafe fn new_unchecked(
fields: Fields,
arrays: Vec<ArrayRef>,
nulls: Option<NullBuffer>,
) -> Self {
let len = arrays.first().map(|x| x.len()).unwrap_or_default();
Self {
len,
data_type: DataType::Struct(fields),
nulls,
fields: arrays,
}
}

/// Deconstruct this array into its constituent parts
pub fn into_parts(self) -> (Fields, Vec<ArrayRef>, Option<NullBuffer>) {
let f = match self.data_type {
DataType::Struct(f) => f,
_ => unreachable!(),
};
(f, self.fields, self.nulls)
}

/// Returns the field at `pos`.
pub fn column(&self, pos: usize) -> &ArrayRef {
&self.fields[pos]
Expand Down Expand Up @@ -183,66 +309,18 @@ impl TryFrom<Vec<(&str, ArrayRef)>> for StructArray {
type Error = ArrowError;

/// builds a StructArray from a vector of names and arrays.
/// This errors if the values have a different length.
/// An entry is set to Null when all values are null.
fn try_from(values: Vec<(&str, ArrayRef)>) -> Result<Self, ArrowError> {
let values_len = values.len();

// these will be populated
let mut fields = Vec::with_capacity(values_len);
let mut child_data = Vec::with_capacity(values_len);

// len: the size of the arrays.
let mut len: Option<usize> = None;
// null: the null mask of the arrays.
let mut null: Option<Buffer> = None;
for (field_name, array) in values {
let child_datum = array.to_data();
let child_datum_len = child_datum.len();
if let Some(len) = len {
if len != child_datum_len {
return Err(ArrowError::InvalidArgumentError(
format!("Array of field \"{field_name}\" has length {child_datum_len}, but previous elements have length {len}.
All arrays in every entry in a struct array must have the same length.")
));
}
} else {
len = Some(child_datum_len)
}
fields.push(Arc::new(Field::new(
field_name,
array.data_type().clone(),
child_datum.nulls().is_some(),
)));

if let Some(child_nulls) = child_datum.nulls() {
null = Some(if let Some(null_buffer) = &null {
buffer_bin_or(
null_buffer,
0,
child_nulls.buffer(),
child_nulls.offset(),
child_datum_len,
)
} else {
child_nulls.inner().sliced()
});
} else if null.is_some() {
// when one of the fields has no nulls, then there is no null in the array
null = None;
}
child_data.push(child_datum);
}
let len = len.unwrap();

let builder = ArrayData::builder(DataType::Struct(fields.into()))
.len(len)
.null_bit_buffer(null)
.child_data(child_data);

let array_data = unsafe { builder.build_unchecked() };

Ok(StructArray::from(array_data))
let (schema, arrays): (SchemaBuilder, _) = values
.into_iter()
.map(|(name, array)| {
(
Field::new(name, array.data_type().clone(), array.nulls().is_some()),
array,
)
})
.unzip();

StructArray::try_new(schema.finish().fields, arrays, None)
}
}

Expand Down Expand Up @@ -303,38 +381,8 @@ impl Array for StructArray {

impl From<Vec<(Field, ArrayRef)>> for StructArray {
fn from(v: Vec<(Field, ArrayRef)>) -> Self {
let iter = v.into_iter();
let capacity = iter.size_hint().0;

let mut len = None;
let mut schema = SchemaBuilder::with_capacity(capacity);
let mut child_data = Vec::with_capacity(capacity);
for (field, array) in iter {
// Check the length of the child arrays
assert_eq!(
*len.get_or_insert(array.len()),
array.len(),
"all child arrays of a StructArray must have the same length"
);
// Check data types of child arrays
assert_eq!(
field.data_type(),
array.data_type(),
"the field data types must match the array data in a StructArray"
);
schema.push(field);
child_data.push(array.to_data());
}
let field_types = schema.finish().fields;
let array_data = ArrayData::builder(DataType::Struct(field_types))
.child_data(child_data)
.len(len.unwrap_or_default());
let array_data = unsafe { array_data.build_unchecked() };

// We must validate nullability
array_data.validate_nulls().unwrap();

Self::from(array_data)
let (schema, arrays): (SchemaBuilder, _) = v.into_iter().unzip();
StructArray::new(schema.finish().fields, arrays, None)
}
}

Expand All @@ -359,37 +407,10 @@ impl std::fmt::Debug for StructArray {

impl From<(Vec<(Field, ArrayRef)>, Buffer)> for StructArray {
fn from(pair: (Vec<(Field, ArrayRef)>, Buffer)) -> Self {
let capacity = pair.0.len();
let mut len = None;
let mut schema = SchemaBuilder::with_capacity(capacity);
let mut child_data = Vec::with_capacity(capacity);
for (field, array) in pair.0 {
// Check the length of the child arrays
assert_eq!(
*len.get_or_insert(array.len()),
array.len(),
"all child arrays of a StructArray must have the same length"
);
// Check data types of child arrays
assert_eq!(
field.data_type(),
array.data_type(),
"the field data types must match the array data in a StructArray"
);
schema.push(field);
child_data.push(array.to_data());
}
let field_types = schema.finish().fields;
let array_data = ArrayData::builder(DataType::Struct(field_types))
.null_bit_buffer(Some(pair.1))
.child_data(child_data)
.len(len.unwrap_or_default());
let array_data = unsafe { array_data.build_unchecked() };

// We must validate nullability
array_data.validate_nulls().unwrap();

Self::from(array_data)
let len = pair.0.first().map(|x| x.1.len()).unwrap_or_default();
let (fields, arrays): (SchemaBuilder, Vec<_>) = pair.0.into_iter().unzip();
let nulls = NullBuffer::new(BooleanBuffer::new(pair.1, 0, len));
Self::new(fields.finish().fields, arrays, Some(nulls))
}
}

Expand Down Expand Up @@ -512,12 +533,7 @@ mod tests {

let struct_data = arr.into_data();
assert_eq!(4, struct_data.len());
assert_eq!(1, struct_data.null_count());
assert_eq!(
// 00001011
&[11_u8],
struct_data.nulls().unwrap().validity()
);
assert_eq!(0, struct_data.null_count());

let expected_string_data = ArrayData::builder(DataType::Utf8)
.len(4)
Expand Down Expand Up @@ -549,20 +565,20 @@ mod tests {
let ints: ArrayRef =
Arc::new(Int32Array::from(vec![Some(1), Some(2), None, Some(4)]));

let arr =
StructArray::try_from(vec![("f1", strings.clone()), ("f2", ints.clone())]);
let err =
StructArray::try_from(vec![("f1", strings.clone()), ("f2", ints.clone())])
.unwrap_err()
.to_string();

match arr {
Err(ArrowError::InvalidArgumentError(e)) => {
assert!(e.starts_with("Array of field \"f2\" has length 4, but previous elements have length 3."));
}
_ => panic!("This test got an unexpected error type"),
};
assert_eq!(
err,
"Invalid argument error: Incorrect array length for StructArray field \"f2\", expected 3 got 4"
)
}

#[test]
#[should_panic(
expected = "the field data types must match the array data in a StructArray"
expected = "Incorrect datatype for StructArray field \\\"b\\\", expected Int16 got Boolean"
)]
fn test_struct_array_from_mismatched_types_single() {
drop(StructArray::from(vec![(
Expand All @@ -574,7 +590,7 @@ mod tests {

#[test]
#[should_panic(
expected = "the field data types must match the array data in a StructArray"
expected = "Incorrect datatype for StructArray field \\\"b\\\", expected Int16 got Boolean"
)]
fn test_struct_array_from_mismatched_types_multiple() {
drop(StructArray::from(vec![
Expand Down Expand Up @@ -679,7 +695,7 @@ mod tests {

#[test]
#[should_panic(
expected = "all child arrays of a StructArray must have the same length"
expected = "Incorrect array length for StructArray field \\\"c\\\", expected 1 got 2"
)]
fn test_invalid_struct_child_array_lengths() {
drop(StructArray::from(vec![
Expand All @@ -702,7 +718,7 @@ mod tests {

#[test]
#[should_panic(
expected = "non-nullable child of type Int32 contains nulls not present in parent Struct"
expected = "Found unmasked nulls for non-nullable StructArray field \\\"c\\\""
)]
fn test_struct_array_from_mismatched_nullability() {
drop(StructArray::from(vec![(
Expand Down
Loading

0 comments on commit 8d166a1

Please sign in to comment.