From dc75a280b46149140eca8dd5e18d31cbadf04716 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Fri, 17 Nov 2023 10:09:34 -0800 Subject: [PATCH] feat: cast (Large)List to FixedSizeList (#5081) * feat: cast (Large)List to FixedSizeList * fix: support 'safe' casting of list to FSL * fix: if target is non-null, use non-null sentinel value * Use MutableArrayData * Docs --------- Co-authored-by: Raphael Taylor-Davies --- arrow-cast/src/cast.rs | 264 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 263 insertions(+), 1 deletion(-) diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs index e44133f81b4a..dd3e271afb0d 100644 --- a/arrow-cast/src/cast.rs +++ b/arrow-cast/src/cast.rs @@ -48,6 +48,7 @@ use crate::parse::{ }; use arrow_array::{builder::*, cast::*, temporal_conversions::*, timezone::Tz, types::*, *}; use arrow_buffer::{i256, ArrowNativeType, Buffer, OffsetBuffer}; +use arrow_data::transform::MutableArrayData; use arrow_data::ArrayData; use arrow_schema::*; use arrow_select::take::take; @@ -138,6 +139,9 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (List(list_from) | LargeList(list_from), Utf8 | LargeUtf8) => { can_cast_types(list_from.data_type(), to_type) } + (List(list_from) | LargeList(list_from), FixedSizeList(list_to, _)) => { + can_cast_types(list_from.data_type(), list_to.data_type()) + } (List(_), _) => false, (FixedSizeList(list_from,_), List(list_to)) => { list_from.data_type() == list_to.data_type() @@ -279,6 +283,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { /// in integer casts return null /// * Numeric to boolean: 0 returns `false`, any other value returns `true` /// * List to List: the underlying data type is cast +/// * List to FixedSizeList: the underlying data type is cast. If safe is true and a list element +/// has the wrong length it will be replaced with NULL, otherwise an error will be returned /// * Primitive to List: a list array with 1 value per slot is created /// * Date32 and Date64: precision lost when going to higher interval /// * Time32 and Time64: precision lost when going to higher interval @@ -799,6 +805,14 @@ pub fn cast_with_options( cast_list_container::(array, cast_options) } } + (List(_), FixedSizeList(field, size)) => { + let array = array.as_list::(); + cast_list_to_fixed_size_list::(array, field, *size, cast_options) + } + (LargeList(_), FixedSizeList(field, size)) => { + let array = array.as_list::(); + cast_list_to_fixed_size_list::(array, field, *size, cast_options) + } (List(_) | LargeList(_), _) => match to_type { Utf8 => value_to_string::(array, cast_options), LargeUtf8 => value_to_string::(array, cast_options), @@ -824,7 +838,6 @@ pub fn cast_with_options( cast_fixed_size_list_to_list::(array) } } - (_, List(ref to)) => cast_values_to_list::(array, to, cast_options), (_, LargeList(ref to)) => cast_values_to_list::(array, to, cast_options), (Decimal128(_, s1), Decimal128(p2, s2)) => { @@ -3206,6 +3219,76 @@ where Ok(Arc::new(list)) } +fn cast_list_to_fixed_size_list( + array: &GenericListArray, + field: &Arc, + size: i32, + cast_options: &CastOptions, +) -> Result +where + OffsetSize: OffsetSizeTrait, +{ + let cap = array.len() * size as usize; + + let mut nulls = (cast_options.safe || array.null_count() != 0).then(|| { + let mut buffer = BooleanBufferBuilder::new(array.len()); + match array.nulls() { + Some(n) => buffer.append_buffer(n.inner()), + None => buffer.append_n(array.len(), true), + } + buffer + }); + + // Nulls in FixedSizeListArray take up space and so we must pad the values + let values = array.values().to_data(); + let mut mutable = MutableArrayData::new(vec![&values], cast_options.safe, cap); + // The end position in values of the last incorrectly-sized list slice + let mut last_pos = 0; + for (idx, w) in array.offsets().windows(2).enumerate() { + let start_pos = w[0].as_usize(); + let end_pos = w[1].as_usize(); + let len = end_pos - start_pos; + + if len != size as usize { + if cast_options.safe || array.is_null(idx) { + if last_pos != start_pos { + // Extend with valid slices + mutable.extend(0, last_pos, start_pos); + } + // Pad this slice with nulls + mutable.extend_nulls(size as _); + nulls.as_mut().unwrap().set_bit(idx, false); + // Set last_pos to the end of this slice's values + last_pos = end_pos + } else { + return Err(ArrowError::CastError(format!( + "Cannot cast to FixedSizeList({size}): value at index {idx} has length {len}", + ))); + } + } + } + + let values = match last_pos { + 0 => array.values().slice(0, cap), // All slices were the correct length + _ => { + if mutable.len() != cap { + // Remaining slices were all correct length + let remaining = cap - mutable.len(); + mutable.extend(0, last_pos, last_pos + remaining) + } + make_array(mutable.freeze()) + } + }; + + // Cast the inner values if necessary + let values = cast_with_options(values.as_ref(), field.data_type(), cast_options)?; + + // Construct the FixedSizeListArray + let nulls = nulls.map(|mut x| x.finish().into()); + let array = FixedSizeListArray::new(field.clone(), size, values, nulls); + Ok(Arc::new(array)) +} + /// Cast the container type of List/Largelist array but not the inner types. /// This function can leave the value data intact and only has to cast the offset dtypes. fn cast_list_container( @@ -3274,6 +3357,8 @@ where #[cfg(test)] mod tests { + use arrow_buffer::NullBuffer; + use super::*; macro_rules! generate_cast_test_case { @@ -7374,6 +7459,183 @@ mod tests { assert_eq!(&expected.value(2), &actual.value(2)); } + #[test] + fn test_cast_list_to_fsl() { + // There four noteworthy cases we should handle: + // 1. No nulls + // 2. Nulls that are always empty + // 3. Nulls that have varying lengths + // 4. Nulls that are correctly sized (same as target list size) + + // Non-null case + let field = Arc::new(Field::new("item", DataType::Int32, true)); + let values = vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(4), Some(5), Some(6)]), + ]; + let array = Arc::new(ListArray::from_iter_primitive::( + values.clone(), + )) as ArrayRef; + let expected = Arc::new(FixedSizeListArray::from_iter_primitive::( + values, 3, + )) as ArrayRef; + let actual = cast(array.as_ref(), &DataType::FixedSizeList(field.clone(), 3)).unwrap(); + assert_eq!(expected.as_ref(), actual.as_ref()); + + // Null cases + // Array is [[1, 2, 3], null, [4, 5, 6], null] + let cases = [ + ( + // Zero-length nulls + vec![1, 2, 3, 4, 5, 6], + vec![3, 0, 3, 0], + ), + ( + // Varying-length nulls + vec![1, 2, 3, 0, 0, 4, 5, 6, 0], + vec![3, 2, 3, 1], + ), + ( + // Correctly-sized nulls + vec![1, 2, 3, 0, 0, 0, 4, 5, 6, 0, 0, 0], + vec![3, 3, 3, 3], + ), + ( + // Mixed nulls + vec![1, 2, 3, 4, 5, 6, 0, 0, 0], + vec![3, 0, 3, 3], + ), + ]; + let null_buffer = NullBuffer::from(vec![true, false, true, false]); + + let expected = Arc::new(FixedSizeListArray::from_iter_primitive::( + vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), Some(5), Some(6)]), + None, + ], + 3, + )) as ArrayRef; + + for (values, lengths) in cases.iter() { + let array = Arc::new(ListArray::new( + field.clone(), + OffsetBuffer::from_lengths(lengths.clone()), + Arc::new(Int32Array::from(values.clone())), + Some(null_buffer.clone()), + )) as ArrayRef; + let actual = cast(array.as_ref(), &DataType::FixedSizeList(field.clone(), 3)).unwrap(); + assert_eq!(expected.as_ref(), actual.as_ref()); + } + } + + #[test] + fn test_cast_list_to_fsl_safety() { + let values = vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(4), Some(5)]), + Some(vec![Some(6), Some(7), Some(8), Some(9)]), + Some(vec![Some(3), Some(4), Some(5)]), + ]; + let array = Arc::new(ListArray::from_iter_primitive::( + values.clone(), + )) as ArrayRef; + + let res = cast_with_options( + array.as_ref(), + &DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 3), + &CastOptions { + safe: false, + ..Default::default() + }, + ); + assert!(res.is_err()); + assert!(format!("{:?}", res) + .contains("Cannot cast to FixedSizeList(3): value at index 1 has length 2")); + + // When safe=true (default), the cast will fill nulls for lists that are + // too short and truncate lists that are too long. + let res = cast( + array.as_ref(), + &DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 3), + ) + .unwrap(); + let expected = Arc::new(FixedSizeListArray::from_iter_primitive::( + vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, // Too short -> replaced with null + None, // Too long -> replaced with null + Some(vec![Some(3), Some(4), Some(5)]), + ], + 3, + )) as ArrayRef; + assert_eq!(expected.as_ref(), res.as_ref()); + } + + #[test] + fn test_cast_large_list_to_fsl() { + let values = vec![Some(vec![Some(1), Some(2)]), Some(vec![Some(3), Some(4)])]; + let array = Arc::new(LargeListArray::from_iter_primitive::( + values.clone(), + )) as ArrayRef; + let expected = Arc::new(FixedSizeListArray::from_iter_primitive::( + values, 2, + )) as ArrayRef; + let actual = cast( + array.as_ref(), + &DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, true)), 2), + ) + .unwrap(); + assert_eq!(expected.as_ref(), actual.as_ref()); + } + + #[test] + fn test_cast_list_to_fsl_subcast() { + let array = Arc::new(LargeListArray::from_iter_primitive::( + vec![ + Some(vec![Some(1), Some(2)]), + Some(vec![Some(3), Some(i32::MAX)]), + ], + )) as ArrayRef; + let expected = Arc::new(FixedSizeListArray::from_iter_primitive::( + vec![ + Some(vec![Some(1), Some(2)]), + Some(vec![Some(3), Some(i32::MAX as i64)]), + ], + 2, + )) as ArrayRef; + let actual = cast( + array.as_ref(), + &DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int64, true)), 2), + ) + .unwrap(); + assert_eq!(expected.as_ref(), actual.as_ref()); + + let res = cast_with_options( + array.as_ref(), + &DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int16, true)), 2), + &CastOptions { + safe: false, + ..Default::default() + }, + ); + assert!(res.is_err()); + assert!(format!("{:?}", res).contains("Can't cast value 2147483647 to type Int16")); + } + + #[test] + fn test_cast_list_to_fsl_empty() { + let field = Arc::new(Field::new("item", DataType::Int32, true)); + let array = new_empty_array(&DataType::List(field.clone())); + + let target_type = DataType::FixedSizeList(field.clone(), 3); + let expected = new_empty_array(&target_type); + + let actual = cast(array.as_ref(), &target_type).unwrap(); + assert_eq!(expected.as_ref(), actual.as_ref()); + } + fn make_list_array() -> ListArray { // Construct a value array let value_data = ArrayData::builder(DataType::Int32)