Skip to content

Commit

Permalink
Add StructArray Constructors (apache#3879)
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Apr 12, 2023
1 parent dd7dc10 commit a327788
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 54 deletions.
123 changes: 89 additions & 34 deletions arrow-array/src/array/struct_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
// specific language governing permissions and limitations
// under the License.

use crate::{make_array, Array, ArrayRef, RecordBatch};
use arrow_buffer::{buffer_bin_or, Buffer, NullBuffer};
use crate::{make_array, new_null_array, Array, ArrayRef, RecordBatch};
use arrow_buffer::{buffer_bin_or, BooleanBuffer, Buffer, NullBuffer};
use arrow_data::{ArrayData, ArrayDataBuilder};
use arrow_schema::{ArrowError, DataType, Field, Fields, SchemaBuilder};
use std::sync::Arc;
Expand Down Expand Up @@ -77,10 +77,92 @@ pub struct StructArray {
len: usize,
data_type: DataType,
nulls: Option<NullBuffer>,
pub(crate) fields: Vec<ArrayRef>,
fields: Vec<ArrayRef>,
}

impl StructArray {
/// Create a new [`StructArray`] from the provided parts
///
/// # Panics
///
/// Panics if
///
/// * fields.len() != arrays.len()
/// * fields[i].data_type() != arrays[i].data_type()
/// * arrays[i].len() != arrays[j].len()
/// * arrays[i].len() != nulls.len()
/// * !fields[i].is_nullable() && !nulls.contains(arrays[i].nulls())
///
pub fn new(fields: Fields, arrays: Vec<ArrayRef>, nulls: Option<NullBuffer>) -> Self {
assert_eq!(fields.len(), arrays.len());
let len = arrays.first().map(|x| x.len()).unwrap_or_default();

if let Some(n) = nulls.as_ref() {
assert_eq!(n.len(), len);
}

for (f, a) in fields.iter().zip(&arrays) {
assert_eq!(f.data_type(), a.data_type(), "{f}");
assert_eq!(a.len(), len, "{f}");

if let Some(a) = a.nulls() {
let nulls_valid = f.is_nullable()
|| nulls.as_ref().map(|n| n.contains(a)).unwrap_or_default();
assert!(nulls_valid, "{f}");
}
}

Self {
len,
data_type: DataType::Struct(fields),
nulls: nulls.filter(|n| n.null_count() > 0),
fields: arrays,
}
}

/// Create a new [`StructArray`] of length `len` where all values are null
pub fn new_null(fields: Fields, len: usize) -> Self {
let arrays = fields
.iter()
.map(|f| new_null_array(f.data_type(), len))
.collect();

Self {
len,
data_type: DataType::Struct(fields),
nulls: Some(NullBuffer::new_null(len)),
fields: arrays,
}
}

/// Create a new [`StructArray`] from the provided parts without validation
///
/// # Safety
///
/// Safe if [`Self::new`] would not panic with the given arguments
pub unsafe fn new_unchecked(
fields: Fields,
arrays: Vec<ArrayRef>,
nulls: Option<NullBuffer>,
) -> Self {
let len = arrays.first().map(|x| x.len()).unwrap_or_default();
Self {
len,
data_type: DataType::Struct(fields),
nulls,
fields: arrays,
}
}

/// Deconstruct this array into its constituent parts
pub fn into_parts(self) -> (Fields, Vec<ArrayRef>, Option<NullBuffer>) {
let f = match self.data_type {
DataType::Struct(f) => f,
_ => unreachable!(),
};
(f, self.fields, self.nulls)
}

/// Returns the field at `pos`.
pub fn column(&self, pos: usize) -> &ArrayRef {
&self.fields[pos]
Expand Down Expand Up @@ -359,37 +441,10 @@ impl std::fmt::Debug for StructArray {

impl From<(Vec<(Field, ArrayRef)>, Buffer)> for StructArray {
fn from(pair: (Vec<(Field, ArrayRef)>, Buffer)) -> Self {
let capacity = pair.0.len();
let mut len = None;
let mut schema = SchemaBuilder::with_capacity(capacity);
let mut child_data = Vec::with_capacity(capacity);
for (field, array) in pair.0 {
// Check the length of the child arrays
assert_eq!(
*len.get_or_insert(array.len()),
array.len(),
"all child arrays of a StructArray must have the same length"
);
// Check data types of child arrays
assert_eq!(
field.data_type(),
array.data_type(),
"the field data types must match the array data in a StructArray"
);
schema.push(field);
child_data.push(array.to_data());
}
let field_types = schema.finish().fields;
let array_data = ArrayData::builder(DataType::Struct(field_types))
.null_bit_buffer(Some(pair.1))
.child_data(child_data)
.len(len.unwrap_or_default());
let array_data = unsafe { array_data.build_unchecked() };

// We must validate nullability
array_data.validate_nulls().unwrap();

Self::from(array_data)
let len = pair.0.first().map(|x| x.1.len()).unwrap_or_default();
let (fields, arrays): (SchemaBuilder, Vec<_>) = pair.0.into_iter().unzip();
let nulls = NullBuffer::new(BooleanBuffer::new(pair.1, 0, len));
Self::new(fields.finish().fields, arrays, Some(nulls))
}
}

Expand Down
11 changes: 3 additions & 8 deletions arrow-array/src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -467,17 +467,12 @@ impl Default for RecordBatchOptions {
}
impl From<StructArray> for RecordBatch {
fn from(value: StructArray) -> Self {
assert_eq!(
value.null_count(),
0,
"Cannot convert nullable StructArray to RecordBatch, see StructArray documentation"
);
let row_count = value.len();
let schema = Arc::new(Schema::new(value.fields().clone()));
let columns = value.fields;
let (fields, columns, nulls) = value.into_parts();
assert_eq!(nulls.map(|n| n.null_count()).unwrap_or_default(), 0, "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation");

RecordBatch {
schema,
schema: Arc::new(Schema::new(fields)),
row_count,
columns,
}
Expand Down
7 changes: 7 additions & 0 deletions arrow-buffer/src/buffer/null.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ impl NullBuffer {
}
}

/// Returns true if all nulls in `other` also exist in self
pub fn contains(&self, other: &NullBuffer) -> bool {
let lhs = self.inner().bit_chunks().iter_padded();
let rhs = other.inner().bit_chunks().iter_padded();
lhs.zip(rhs).all(|(l, r)| (l & !r) == 0)
}

/// Returns the length of this [`NullBuffer`]
#[inline]
pub fn len(&self) -> usize {
Expand Down
16 changes: 4 additions & 12 deletions arrow-json/src/reader/struct_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,19 +113,11 @@ impl ArrayDecoder for StructArrayDecoder {
for (c, f) in child_data.iter().zip(fields) {
// Sanity check
assert_eq!(c.len(), pos.len());
if let Some(a) = c.nulls() {
let nulls_valid = f.is_nullable()
|| nulls.as_ref().map(|n| n.contains(a)).unwrap_or_default();

if !f.is_nullable() && c.null_count() != 0 {
// Need to verify nulls
let valid = match nulls.as_ref() {
Some(nulls) => {
let lhs = nulls.inner().bit_chunks().iter_padded();
let rhs = c.nulls().unwrap().inner().bit_chunks().iter_padded();
lhs.zip(rhs).all(|(l, r)| (l & !r) == 0)
}
None => false,
};

if !valid {
if !nulls_valid {
return Err(ArrowError::JsonError(format!("Encountered unmasked nulls in non-nullable StructArray child: {f}")));
}
}
Expand Down

0 comments on commit a327788

Please sign in to comment.