Skip to content

Commit

Permalink
Use MutableArrayData
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Nov 17, 2023
1 parent ad87f8e commit a245f07
Showing 1 changed file with 54 additions and 141 deletions.
195 changes: 54 additions & 141 deletions arrow-cast/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ use crate::parse::{
string_to_datetime, Parser,
};
use arrow_array::{builder::*, cast::*, temporal_conversions::*, timezone::Tz, types::*, *};
use arrow_buffer::{i256, ArrowNativeType, Buffer, NullBuffer, NullBufferBuilder, OffsetBuffer};
use arrow_buffer::{i256, ArrowNativeType, Buffer, OffsetBuffer};
use arrow_data::transform::MutableArrayData;
use arrow_data::ArrayData;
use arrow_schema::*;
use arrow_select::take::take;
Expand Down Expand Up @@ -3229,155 +3230,62 @@ fn cast_list_to_fixed_size_list<OffsetSize>(
where
OffsetSize: OffsetSizeTrait,
{
let can_zero_copy_values = match validate_consistent_lengths::<OffsetSize>(array, size) {
Ok(nulls_have_correct_length) => nulls_have_correct_length,
Err(_) if cast_options.safe => false,
Err(err) => return Err(err),
};
let cap = array.len() * size as usize;

let data = array.to_data();
let underlying_array = make_array(data.child_data()[0].clone());
let mut cast_array =
cast_with_options(underlying_array.as_ref(), field.data_type(), cast_options)?;

// Some of the lists aren't the correct size, so we need to call take on
// the values array to make them the correct size.
if !can_zero_copy_values {
let (take_indices, null_buffer) = build_take_indices(array, size);
cast_array = take(cast_array.as_ref(), take_indices.as_ref(), None)?;
if field.is_nullable() {
cast_array = make_array(
cast_array
.to_data()
.into_builder()
.nulls(null_buffer)
.build()?,
);
let mut nulls = (cast_options.safe || array.null_count() != 0).then(|| {
let mut buffer = BooleanBufferBuilder::new(array.len());
match array.nulls() {
Some(n) => buffer.append_buffer(n.inner()),
None => buffer.append_n(array.len(), true),
}
}

let dest_type = DataType::FixedSizeList(field.clone(), size);
let array_data = data
.into_builder()
.buffers(vec![])
.data_type(dest_type)
.child_data(vec![cast_array.into_data()])
.build()?;

let list = FixedSizeListArray::from(array_data);
Ok(Arc::new(list) as ArrayRef)
}
buffer
});

/// Validate that the values of the list array all have the given length.
///
/// Returns an error if the lengths are inconsistent.
///
/// This function also checks if the null slots have the correct length, and
/// returns whether this is true. This can be used later for zero-copy
/// optimizations.
fn validate_consistent_lengths<OffsetSize>(
array: &GenericListArray<OffsetSize>,
size: i32,
) -> Result<bool, ArrowError>
where
OffsetSize: OffsetSizeTrait,
{
let mut nulls_have_correct_length = true;

let size = OffsetSize::from_usize(size as usize).unwrap();
let mut offsets_iter = array.offsets().iter();
let mut expected_offset = *offsets_iter.next().expect("empty offsets buffer") + size;

// Iterator over (index, ending offset)
let offsets_iter = offsets_iter.enumerate();

if array.null_count() > 0 {
for (i, offset) in offsets_iter {
if *offset != expected_offset {
nulls_have_correct_length = false;
if array.is_valid(i) {
return Err(ArrowError::InvalidArgumentError(format!(
"Cannot cast to FixedSizeList({:?}): value at index {} has length {:?}",
size,
i,
*offset - expected_offset + size
)));
// Nulls in FixedSizeListArray take up space and so we must pad the values
let values = array.values().to_data();
let mut mutable = MutableArrayData::new(vec![&values], cast_options.safe, cap);
let mut last_pos = 0;
for (idx, w) in array.offsets().windows(2).enumerate() {
let start_pos = w[0].as_usize();
let end_pos = w[1].as_usize();
let len = end_pos - start_pos;

if len != size as usize {
if cast_options.safe || array.is_null(idx) {
// Pad with nulls
if last_pos != start_pos {
mutable.extend(0, last_pos, start_pos);
}
expected_offset = *offset + size;
mutable.extend_nulls(size as _);
nulls.as_mut().unwrap().set_bit(idx, false);
last_pos = end_pos
} else {
expected_offset += size;
}
}
} else {
for (i, offset) in offsets_iter {
if *offset != expected_offset {
return Err(ArrowError::InvalidArgumentError(format!(
"Cannot cast to FixedSizeList({:?}): value at index {} has length {:?}",
size,
i,
*offset - expected_offset + size
return Err(ArrowError::CastError(format!(
"Cannot cast to FixedSizeList({size}): value at index {idx} has length {len}",
)));
}
expected_offset += size;
}
}
Ok(nulls_have_correct_length)
}

/// Build take indices and null buffer for values array.
///
/// The take indices are such that when applied to the values array, it will
/// provide a new values array where all values have the same length.
///
/// If there are null slots in the list array, then the values take indices is
/// filled with 0, so that they will be populated with the first value. This
/// means that if only the null slots are improperly sized, then the null buffer
/// returned will be None.
///
/// If there are valid slots in the list array that are too long, they will be
/// truncated. If there are valid slots that are too short, then they will have
/// their corresponding slots in the null buffer set to null. Thus, if there are
/// any valid slots in the list array that are too short, then the null buffer
/// returned will be Some.
fn build_take_indices<OffsetSize>(
array: &GenericListArray<OffsetSize>,
list_size: i32,
) -> (ArrayRef, Option<NullBuffer>)
where
OffsetSize: OffsetSizeTrait,
{
let list_size = list_size as usize;
let mut indices = Vec::with_capacity(array.len() * list_size);
let mut null_buffer = NullBufferBuilder::new(array.len() * list_size);

for i in 0..array.len() {
if array.is_null(i) {
indices.extend(std::iter::repeat(0_i64).take(list_size));
null_buffer.append_n_non_nulls(list_size);
} else {
let start = array.value_offsets()[i].as_usize() as i64;
let end = array.value_offsets()[i + 1].as_usize() as i64;
let actual_size = (end - start) as usize;
match actual_size.cmp(&(list_size)) {
Ordering::Equal => {
indices.extend(start..end);
null_buffer.append_n_non_nulls(list_size);
}
Ordering::Greater => {
indices.extend(start..start + list_size as i64);
null_buffer.append_n_non_nulls(list_size);
}
Ordering::Less => {
indices.extend(start..end);
indices.extend(std::iter::repeat(0_i64).take(list_size - actual_size));
null_buffer.append_n_non_nulls(actual_size);
null_buffer.append_n_nulls(list_size - actual_size);
}
let values = match last_pos {
0 => array.values().slice(0, cap), // All slices were the correct length
_ => {
if mutable.len() != cap {
// Remaining slices were all correct length
let remaining = cap - mutable.len();
mutable.extend(0, last_pos, last_pos + remaining)
}
make_array(mutable.freeze())
}
}
};

// Cast the inner values if necessary
let values = cast_with_options(values.as_ref(), field.data_type(), cast_options)?;

(Arc::new(Int64Array::from(indices)), null_buffer.finish())
// Construct the FixedSizeListArray
let nulls = nulls.map(|mut x| x.finish().into());
let array = FixedSizeListArray::new(field.clone(), size, values, nulls);
Ok(Arc::new(array))
}

/// Cast the container type of List/Largelist array but not the inner types.
Expand Down Expand Up @@ -7591,6 +7499,11 @@ mod tests {
vec![1, 2, 3, 0, 0, 0, 4, 5, 6, 0, 0, 0],
vec![3, 3, 3, 3],
),
(
// Mixed nulls
vec![1, 2, 3, 4, 5, 6, 0, 0, 0],
vec![3, 0, 3, 3],
),
];
let null_buffer = NullBuffer::from(vec![true, false, true, false]);

Expand All @@ -7604,10 +7517,10 @@ mod tests {
3,
)) as ArrayRef;

for (values, offsets) in cases.iter() {
for (values, lengths) in cases.iter() {
let array = Arc::new(ListArray::new(
field.clone(),
OffsetBuffer::from_lengths(offsets.clone()),
OffsetBuffer::from_lengths(lengths.clone()),
Arc::new(Int32Array::from(values.clone())),
Some(null_buffer.clone()),
)) as ArrayRef;
Expand Down Expand Up @@ -7649,8 +7562,8 @@ mod tests {
let expected = Arc::new(FixedSizeListArray::from_iter_primitive::<Int32Type, _, _>(
vec![
Some(vec![Some(1), Some(2), Some(3)]),
Some(vec![Some(4), Some(5), None]), // Too short -> filled with null
Some(vec![Some(6), Some(7), Some(8)]), // Too long -> Truncated
None, // Too short -> filled with null
None, // Too long -> Truncated
],
3,
)) as ArrayRef;
Expand Down

0 comments on commit a245f07

Please sign in to comment.