diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 48878aa9bd99..8820ca9942fc 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -21,6 +21,7 @@ use std::borrow::Borrow; use std::cmp::Ordering; use std::collections::HashSet; use std::convert::{Infallible, TryInto}; +use std::hash::Hash; use std::str::FromStr; use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; @@ -142,13 +143,13 @@ pub enum ScalarValue { /// Fixed size list scalar. /// /// The array must be a FixedSizeListArray with length 1. - FixedSizeList(ArrayRef), + FixedSizeList(Arc), /// Represents a single element of a [`ListArray`] as an [`ArrayRef`] /// /// The array must be a ListArray with length 1. - List(ArrayRef), + List(Arc), /// The array must be a LargeListArray with length 1. - LargeList(ArrayRef), + LargeList(Arc), /// Date stored as a signed 32bit int days since UNIX epoch 1970-01-01 Date32(Option), /// Date stored as a signed 64bit int milliseconds since UNIX epoch 1970-01-01 @@ -360,45 +361,13 @@ impl PartialOrd for ScalarValue { (FixedSizeBinary(_, _), _) => None, (LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2), (LargeBinary(_), _) => None, - (List(arr1), List(arr2)) - | (FixedSizeList(arr1), FixedSizeList(arr2)) - | (LargeList(arr1), LargeList(arr2)) => { - // ScalarValue::List / ScalarValue::FixedSizeList / ScalarValue::LargeList are ensure to have length 1 - assert_eq!(arr1.len(), 1); - assert_eq!(arr2.len(), 1); - - if arr1.data_type() != arr2.data_type() { - return None; - } - - fn first_array_for_list(arr: &ArrayRef) -> ArrayRef { - if let Some(arr) = arr.as_list_opt::() { - arr.value(0) - } else if let Some(arr) = arr.as_list_opt::() { - arr.value(0) - } else if let Some(arr) = arr.as_fixed_size_list_opt() { - arr.value(0) - } else { - unreachable!("Since only List / LargeList / FixedSizeList are supported, this should never happen") - } - } - - let arr1 = first_array_for_list(arr1); - let arr2 = first_array_for_list(arr2); - - let lt_res = arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?; - let eq_res = arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?; - - for j in 0..lt_res.len() { - if lt_res.is_valid(j) && lt_res.value(j) { - return Some(Ordering::Less); - } - if eq_res.is_valid(j) && !eq_res.value(j) { - return Some(Ordering::Greater); - } - } - - Some(Ordering::Equal) + // ScalarValue::List / ScalarValue::FixedSizeList / ScalarValue::LargeList are ensure to have length 1 + (List(arr1), List(arr2)) => partial_cmp_list(arr1.as_ref(), arr2.as_ref()), + (FixedSizeList(arr1), FixedSizeList(arr2)) => { + partial_cmp_list(arr1.as_ref(), arr2.as_ref()) + } + (LargeList(arr1), LargeList(arr2)) => { + partial_cmp_list(arr1.as_ref(), arr2.as_ref()) } (List(_), _) | (LargeList(_), _) | (FixedSizeList(_), _) => None, (Date32(v1), Date32(v2)) => v1.partial_cmp(v2), @@ -464,6 +433,44 @@ impl PartialOrd for ScalarValue { } } +/// List/LargeList/FixedSizeList scalars always have a single element +/// array. This function returns that array +fn first_array_for_list(arr: &dyn Array) -> ArrayRef { + assert_eq!(arr.len(), 1); + if let Some(arr) = arr.as_list_opt::() { + arr.value(0) + } else if let Some(arr) = arr.as_list_opt::() { + arr.value(0) + } else if let Some(arr) = arr.as_fixed_size_list_opt() { + arr.value(0) + } else { + unreachable!("Since only List / LargeList / FixedSizeList are supported, this should never happen") + } +} + +/// Compares two List/LargeList/FixedSizeList scalars +fn partial_cmp_list(arr1: &dyn Array, arr2: &dyn Array) -> Option { + if arr1.data_type() != arr2.data_type() { + return None; + } + let arr1 = first_array_for_list(arr1); + let arr2 = first_array_for_list(arr2); + + let lt_res = arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?; + let eq_res = arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?; + + for j in 0..lt_res.len() { + if lt_res.is_valid(j) && lt_res.value(j) { + return Some(Ordering::Less); + } + if eq_res.is_valid(j) && !eq_res.value(j) { + return Some(Ordering::Greater); + } + } + + Some(Ordering::Equal) +} + impl Eq for ScalarValue {} //Float wrapper over f32/f64. Just because we cannot build std::hash::Hash for floats directly we have to do it through type wrapper @@ -517,14 +524,14 @@ impl std::hash::Hash for ScalarValue { Binary(v) => v.hash(state), FixedSizeBinary(_, v) => v.hash(state), LargeBinary(v) => v.hash(state), - List(arr) | LargeList(arr) | FixedSizeList(arr) => { - let arrays = vec![arr.to_owned()]; - let hashes_buffer = &mut vec![0; arr.len()]; - let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0); - let hashes = - create_hashes(&arrays, &random_state, hashes_buffer).unwrap(); - // Hash back to std::hash::Hasher - hashes.hash(state); + List(arr) => { + hash_list(arr.to_owned() as ArrayRef, state); + } + LargeList(arr) => { + hash_list(arr.to_owned() as ArrayRef, state); + } + FixedSizeList(arr) => { + hash_list(arr.to_owned() as ArrayRef, state); } Date32(v) => v.hash(state), Date64(v) => v.hash(state), @@ -557,6 +564,15 @@ impl std::hash::Hash for ScalarValue { } } +fn hash_list(arr: ArrayRef, state: &mut H) { + let arrays = vec![arr.to_owned()]; + let hashes_buffer = &mut vec![0; arr.len()]; + let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0); + let hashes = create_hashes(&arrays, &random_state, hashes_buffer).unwrap(); + // Hash back to std::hash::Hasher + hashes.hash(state); +} + /// Return a reference to the values array and the index into it for a /// dictionary array /// @@ -942,9 +958,9 @@ impl ScalarValue { ScalarValue::Binary(_) => DataType::Binary, ScalarValue::FixedSizeBinary(sz, _) => DataType::FixedSizeBinary(*sz), ScalarValue::LargeBinary(_) => DataType::LargeBinary, - ScalarValue::List(arr) - | ScalarValue::LargeList(arr) - | ScalarValue::FixedSizeList(arr) => arr.data_type().to_owned(), + ScalarValue::List(arr) => arr.data_type().to_owned(), + ScalarValue::LargeList(arr) => arr.data_type().to_owned(), + ScalarValue::FixedSizeList(arr) => arr.data_type().to_owned(), ScalarValue::Date32(_) => DataType::Date32, ScalarValue::Date64(_) => DataType::Date64, ScalarValue::Time32Second(_) => DataType::Time32(TimeUnit::Second), @@ -1147,9 +1163,9 @@ impl ScalarValue { ScalarValue::LargeBinary(v) => v.is_none(), // arr.len() should be 1 for a list scalar, but we don't seem to // enforce that anywhere, so we still check against array length. - ScalarValue::List(arr) - | ScalarValue::LargeList(arr) - | ScalarValue::FixedSizeList(arr) => arr.len() == arr.null_count(), + ScalarValue::List(arr) => arr.len() == arr.null_count(), + ScalarValue::LargeList(arr) => arr.len() == arr.null_count(), + ScalarValue::FixedSizeList(arr) => arr.len() == arr.null_count(), ScalarValue::Date32(v) => v.is_none(), ScalarValue::Date64(v) => v.is_none(), ScalarValue::Time32Second(v) => v.is_none(), @@ -1695,17 +1711,16 @@ impl ScalarValue { /// ScalarValue::Int32(Some(2)) /// ]; /// - /// let array = ScalarValue::new_list(&scalars, &DataType::Int32); - /// let result = as_list_array(&array).unwrap(); + /// let result = ScalarValue::new_list(&scalars, &DataType::Int32); /// /// let expected = ListArray::from_iter_primitive::( /// vec![ /// Some(vec![Some(1), None, Some(2)]) /// ]); /// - /// assert_eq!(result, &expected); + /// assert_eq!(*result, expected); /// ``` - pub fn new_list(values: &[ScalarValue], data_type: &DataType) -> ArrayRef { + pub fn new_list(values: &[ScalarValue], data_type: &DataType) -> Arc { let values = if values.is_empty() { new_empty_array(data_type) } else { @@ -1730,17 +1745,19 @@ impl ScalarValue { /// ScalarValue::Int32(Some(2)) /// ]; /// - /// let array = ScalarValue::new_large_list(&scalars, &DataType::Int32); - /// let result = as_large_list_array(&array).unwrap(); + /// let result = ScalarValue::new_large_list(&scalars, &DataType::Int32); /// /// let expected = LargeListArray::from_iter_primitive::( /// vec![ /// Some(vec![Some(1), None, Some(2)]) /// ]); /// - /// assert_eq!(result, &expected); + /// assert_eq!(*result, expected); /// ``` - pub fn new_large_list(values: &[ScalarValue], data_type: &DataType) -> ArrayRef { + pub fn new_large_list( + values: &[ScalarValue], + data_type: &DataType, + ) -> Arc { let values = if values.is_empty() { new_empty_array(data_type) } else { @@ -1876,14 +1893,14 @@ impl ScalarValue { .collect::(), ), }, - ScalarValue::List(arr) - | ScalarValue::LargeList(arr) - | ScalarValue::FixedSizeList(arr) => { - let arrays = std::iter::repeat(arr.as_ref()) - .take(size) - .collect::>(); - arrow::compute::concat(arrays.as_slice()) - .map_err(|e| arrow_datafusion_err!(e))? + ScalarValue::List(arr) => { + Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)? + } + ScalarValue::LargeList(arr) => { + Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)? + } + ScalarValue::FixedSizeList(arr) => { + Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)? } ScalarValue::Date32(e) => { build_array_from_option!(Date32, Date32Array, e, size) @@ -2040,6 +2057,11 @@ impl ScalarValue { } } + fn list_to_array_of_size(arr: &dyn Array, size: usize) -> Result { + let arrays = std::iter::repeat(arr).take(size).collect::>(); + arrow::compute::concat(arrays.as_slice()).map_err(|e| arrow_datafusion_err!(e)) + } + /// Retrieve ScalarValue for each row in `array` /// /// Example @@ -2433,11 +2455,14 @@ impl ScalarValue { ScalarValue::LargeBinary(val) => { eq_array_primitive!(array, index, LargeBinaryArray, val)? } - ScalarValue::List(arr) - | ScalarValue::LargeList(arr) - | ScalarValue::FixedSizeList(arr) => { - let right = array.slice(index, 1); - arr == &right + ScalarValue::List(arr) => { + Self::eq_array_list(&(arr.to_owned() as ArrayRef), array, index) + } + ScalarValue::LargeList(arr) => { + Self::eq_array_list(&(arr.to_owned() as ArrayRef), array, index) + } + ScalarValue::FixedSizeList(arr) => { + Self::eq_array_list(&(arr.to_owned() as ArrayRef), array, index) } ScalarValue::Date32(val) => { eq_array_primitive!(array, index, Date32Array, val)? @@ -2515,6 +2540,11 @@ impl ScalarValue { }) } + fn eq_array_list(arr1: &ArrayRef, arr2: &ArrayRef, index: usize) -> bool { + let right = arr2.slice(index, 1); + arr1 == &right + } + /// Estimate size if bytes including `Self`. For values with internal containers such as `String` /// includes the allocated size (`capacity`) rather than the current length (`len`) pub fn size(&self) -> usize { @@ -2561,9 +2591,9 @@ impl ScalarValue { | ScalarValue::LargeBinary(b) => { b.as_ref().map(|b| b.capacity()).unwrap_or_default() } - ScalarValue::List(arr) - | ScalarValue::LargeList(arr) - | ScalarValue::FixedSizeList(arr) => arr.get_array_memory_size(), + ScalarValue::List(arr) => arr.get_array_memory_size(), + ScalarValue::LargeList(arr) => arr.get_array_memory_size(), + ScalarValue::FixedSizeList(arr) => arr.get_array_memory_size(), ScalarValue::Struct(vals, fields) => { vals.as_ref() .map(|vals| { @@ -2865,14 +2895,19 @@ impl TryFrom<&DataType> for ScalarValue { Box::new(value_type.as_ref().try_into()?), ), // `ScalaValue::List` contains single element `ListArray`. - DataType::List(field) => ScalarValue::List(new_null_array( - &DataType::List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - ))), - 1, - )), + DataType::List(field) => ScalarValue::List( + new_null_array( + &DataType::List(Arc::new(Field::new( + "item", + field.data_type().clone(), + true, + ))), + 1, + ) + .as_list::() + .to_owned() + .into(), + ), DataType::Struct(fields) => ScalarValue::Struct(None, fields.clone()), DataType::Null => ScalarValue::Null, _ => { @@ -2937,16 +2972,9 @@ impl fmt::Display for ScalarValue { )?, None => write!(f, "NULL")?, }, - ScalarValue::List(arr) - | ScalarValue::LargeList(arr) - | ScalarValue::FixedSizeList(arr) => { - // ScalarValue List should always have a single element - assert_eq!(arr.len(), 1); - let options = FormatOptions::default().with_display_error(true); - let formatter = ArrayFormatter::try_new(arr, &options).unwrap(); - let value_formatter = formatter.value(0); - write!(f, "{value_formatter}")? - } + ScalarValue::List(arr) => fmt_list(arr.to_owned() as ArrayRef, f)?, + ScalarValue::LargeList(arr) => fmt_list(arr.to_owned() as ArrayRef, f)?, + ScalarValue::FixedSizeList(arr) => fmt_list(arr.to_owned() as ArrayRef, f)?, ScalarValue::Date32(e) => format_option!(f, e)?, ScalarValue::Date64(e) => format_option!(f, e)?, ScalarValue::Time32Second(e) => format_option!(f, e)?, @@ -2979,6 +3007,16 @@ impl fmt::Display for ScalarValue { } } +fn fmt_list(arr: ArrayRef, f: &mut fmt::Formatter) -> fmt::Result { + // ScalarValue List, LargeList, FixedSizeList should always have a single element + assert_eq!(arr.len(), 1); + let options = FormatOptions::default().with_display_error(true); + let formatter = + ArrayFormatter::try_new(arr.as_ref() as &dyn Array, &options).unwrap(); + let value_formatter = formatter.value(0); + write!(f, "{value_formatter}") +} + impl fmt::Debug for ScalarValue { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -3182,15 +3220,14 @@ mod tests { ScalarValue::from("data-fusion"), ]; - let array = ScalarValue::new_list(scalars.as_slice(), &DataType::Utf8); + let result = ScalarValue::new_list(scalars.as_slice(), &DataType::Utf8); let expected = array_into_list_array(Arc::new(StringArray::from(vec![ "rust", "arrow", "data-fusion", ]))); - let result = as_list_array(&array); - assert_eq!(result, &expected); + assert_eq!(*result, expected); } fn build_list( @@ -3226,9 +3263,9 @@ mod tests { }; if O::IS_LARGE { - ScalarValue::LargeList(arr) + ScalarValue::LargeList(arr.as_list::().to_owned().into()) } else { - ScalarValue::List(arr) + ScalarValue::List(arr.as_list::().to_owned().into()) } }) .collect() @@ -3311,18 +3348,16 @@ mod tests { ])); let fsl_array: ArrayRef = - Arc::new(FixedSizeListArray::from_iter_primitive::( - vec![ - Some(vec![Some(0), Some(1), Some(2)]), - None, - Some(vec![Some(3), None, Some(5)]), - ], - 3, - )); + Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(5)]), + ])); for arr in [list_array, fsl_array] { for i in 0..arr.len() { - let scalar = ScalarValue::List(arr.slice(i, 1)); + let scalar = + ScalarValue::List(arr.slice(i, 1).as_list::().to_owned().into()); assert!(scalar.eq_array(&arr, i).unwrap()); } } @@ -3676,8 +3711,7 @@ mod tests { #[test] fn scalar_list_null_to_array() { - let list_array_ref = ScalarValue::new_list(&[], &DataType::UInt64); - let list_array = as_list_array(&list_array_ref); + let list_array = ScalarValue::new_list(&[], &DataType::UInt64); assert_eq!(list_array.len(), 1); assert_eq!(list_array.values().len(), 0); @@ -3685,8 +3719,7 @@ mod tests { #[test] fn scalar_large_list_null_to_array() { - let list_array_ref = ScalarValue::new_large_list(&[], &DataType::UInt64); - let list_array = as_large_list_array(&list_array_ref); + let list_array = ScalarValue::new_large_list(&[], &DataType::UInt64); assert_eq!(list_array.len(), 1); assert_eq!(list_array.values().len(), 0); @@ -3699,8 +3732,7 @@ mod tests { ScalarValue::UInt64(None), ScalarValue::UInt64(Some(101)), ]; - let list_array_ref = ScalarValue::new_list(&values, &DataType::UInt64); - let list_array = as_list_array(&list_array_ref); + let list_array = ScalarValue::new_list(&values, &DataType::UInt64); assert_eq!(list_array.len(), 1); assert_eq!(list_array.values().len(), 3); @@ -3720,8 +3752,7 @@ mod tests { ScalarValue::UInt64(None), ScalarValue::UInt64(Some(101)), ]; - let list_array_ref = ScalarValue::new_large_list(&values, &DataType::UInt64); - let list_array = as_large_list_array(&list_array_ref); + let list_array = ScalarValue::new_large_list(&values, &DataType::UInt64); assert_eq!(list_array.len(), 1); assert_eq!(list_array.values().len(), 3); @@ -3959,10 +3990,15 @@ mod tests { let data_type = &data_type; let scalar: ScalarValue = data_type.try_into().unwrap(); - let expected = ScalarValue::List(new_null_array( - &DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), - 1, - )); + let expected = ScalarValue::List( + new_null_array( + &DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + 1, + ) + .as_list::() + .to_owned() + .into(), + ); assert_eq!(expected, scalar) } @@ -3977,14 +4013,19 @@ mod tests { let data_type = &data_type; let scalar: ScalarValue = data_type.try_into().unwrap(); - let expected = ScalarValue::List(new_null_array( - &DataType::List(Arc::new(Field::new( - "item", - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), - true, - ))), - 1, - )); + let expected = ScalarValue::List( + new_null_array( + &DataType::List(Arc::new(Field::new( + "item", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + ))), + 1, + ) + .as_list::() + .to_owned() + .into(), + ); assert_eq!(expected, scalar) } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 7d09aec7e748..dbdfb856a71c 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -27,7 +27,7 @@ use crate::simplify_expressions::regex::simplify_regex_expr; use crate::simplify_expressions::SimplifyInfo; use arrow::{ - array::new_null_array, + array::{new_null_array, AsArray}, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; @@ -396,7 +396,7 @@ impl<'a> ConstEvaluator<'a> { a.len() ) } else if as_list_array(&a).is_ok() || as_large_list_array(&a).is_ok() { - Ok(ScalarValue::List(a)) + Ok(ScalarValue::List(a.as_list().to_owned().into())) } else { // Non-ListArray ScalarValue::try_from_array(&a, 0) diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index 1efae424cc69..2d263a42e0ff 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -186,7 +186,6 @@ mod tests { use arrow::array::{ArrayRef, Int32Array}; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; - use arrow_array::cast::as_list_array; use arrow_array::types::Int32Type; use arrow_array::{Array, ListArray}; use arrow_buffer::OffsetBuffer; @@ -196,10 +195,7 @@ mod tests { // arrow::compute::sort cann't sort ListArray directly, so we need to sort the inner primitive array and wrap it back into ListArray. fn sort_list_inner(arr: ScalarValue) -> ScalarValue { let arr = match arr { - ScalarValue::List(arr) => { - let list_arr = as_list_array(&arr); - list_arr.value(0) - } + ScalarValue::List(arr) => arr.value(0), _ => { panic!("Expected ScalarValue::List, got {:?}", arr) } diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index f7c13948b2dc..021c33fb94a7 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -292,7 +292,7 @@ where let arr = Arc::new(PrimitiveArray::::from_iter_values( self.values.iter().cloned(), )) as ArrayRef; - let list = Arc::new(array_into_list_array(arr)) as ArrayRef; + let list = Arc::new(array_into_list_array(arr)); Ok(vec![ScalarValue::List(list)]) } @@ -378,7 +378,7 @@ where let arr = Arc::new(PrimitiveArray::::from_iter_values( self.values.iter().map(|v| v.0), )) as ArrayRef; - let list = Arc::new(array_into_list_array(arr)) as ArrayRef; + let list = Arc::new(array_into_list_array(arr)); Ok(vec![ScalarValue::List(list)]) } diff --git a/datafusion/physical-expr/src/aggregate/tdigest.rs b/datafusion/physical-expr/src/aggregate/tdigest.rs index 90f5244f477d..78708df94c25 100644 --- a/datafusion/physical-expr/src/aggregate/tdigest.rs +++ b/datafusion/physical-expr/src/aggregate/tdigest.rs @@ -28,7 +28,6 @@ //! [Facebook's Folly TDigest]: https://github.com/facebook/folly/blob/main/folly/stats/TDigest.h use arrow::datatypes::DataType; -use arrow_array::cast::as_list_array; use arrow_array::types::Float64Type; use datafusion_common::cast::as_primitive_array; use datafusion_common::Result; @@ -606,11 +605,10 @@ impl TDigest { let centroids: Vec<_> = match &state[5] { ScalarValue::List(arr) => { - let list_array = as_list_array(arr); - let arr = list_array.values(); + let array = arr.values(); let f64arr = - as_primitive_array::(arr).expect("expected f64 array"); + as_primitive_array::(array).expect("expected f64 array"); f64arr .values() .chunks(2) diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 36c5b44f00b9..3f48be0c4d1f 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -27,6 +27,7 @@ use crate::protobuf::{ OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, }; use arrow::{ + array::AsArray, buffer::Buffer, datatypes::{ i256, DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, TimeUnit, @@ -722,9 +723,15 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { .map_err(|e| e.context("Decoding ScalarValue::List Value"))?; let arr = record_batch.column(0); match value { - Value::ListValue(_) => Self::List(arr.to_owned()), - Value::LargeListValue(_) => Self::LargeList(arr.to_owned()), - Value::FixedSizeListValue(_) => Self::FixedSizeList(arr.to_owned()), + Value::ListValue(_) => { + Self::List(arr.as_list::().to_owned().into()) + } + Value::LargeListValue(_) => { + Self::LargeList(arr.as_list::().to_owned().into()) + } + Value::FixedSizeListValue(_) => { + Self::FixedSizeList(arr.as_fixed_size_list().to_owned().into()) + } _ => unreachable!(), } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index a162b2389cd1..b7d1ef225251 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -32,6 +32,7 @@ use crate::protobuf::{ OptimizedLogicalPlanType, OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, }; use arrow::{ + array::ArrayRef, datatypes::{ DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionMode, @@ -1159,54 +1160,15 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { } // ScalarValue::List and ScalarValue::FixedSizeList are serialized using // Arrow IPC messages as a single column RecordBatch - ScalarValue::List(arr) - | ScalarValue::LargeList(arr) - | ScalarValue::FixedSizeList(arr) => { + ScalarValue::List(arr) => { + encode_scalar_list_value(arr.to_owned() as ArrayRef, val) + } + ScalarValue::LargeList(arr) => { // Wrap in a "field_name" column - let batch = RecordBatch::try_from_iter(vec![( - "field_name", - arr.to_owned(), - )]) - .map_err(|e| { - Error::General( format!("Error creating temporary batch while encoding ScalarValue::List: {e}")) - })?; - - let gen = IpcDataGenerator {}; - let mut dict_tracker = DictionaryTracker::new(false); - let (_, encoded_message) = gen - .encoded_batch(&batch, &mut dict_tracker, &Default::default()) - .map_err(|e| { - Error::General(format!( - "Error encoding ScalarValue::List as IPC: {e}" - )) - })?; - - let schema: protobuf::Schema = batch.schema().try_into()?; - - let scalar_list_value = protobuf::ScalarListValue { - ipc_message: encoded_message.ipc_message, - arrow_data: encoded_message.arrow_data, - schema: Some(schema), - }; - - match val { - ScalarValue::List(_) => Ok(protobuf::ScalarValue { - value: Some(protobuf::scalar_value::Value::ListValue( - scalar_list_value, - )), - }), - ScalarValue::LargeList(_) => Ok(protobuf::ScalarValue { - value: Some(protobuf::scalar_value::Value::LargeListValue( - scalar_list_value, - )), - }), - ScalarValue::FixedSizeList(_) => Ok(protobuf::ScalarValue { - value: Some(protobuf::scalar_value::Value::FixedSizeListValue( - scalar_list_value, - )), - }), - _ => unreachable!(), - } + encode_scalar_list_value(arr.to_owned() as ArrayRef, val) + } + ScalarValue::FixedSizeList(arr) => { + encode_scalar_list_value(arr.to_owned() as ArrayRef, val) } ScalarValue::Date32(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::Date32Value(*s)) @@ -1723,3 +1685,47 @@ fn create_proto_scalar protobuf::scalar_value::Value>( Ok(protobuf::ScalarValue { value: Some(value) }) } + +fn encode_scalar_list_value( + arr: ArrayRef, + val: &ScalarValue, +) -> Result { + let batch = RecordBatch::try_from_iter(vec![("field_name", arr)]).map_err(|e| { + Error::General(format!( + "Error creating temporary batch while encoding ScalarValue::List: {e}" + )) + })?; + + let gen = IpcDataGenerator {}; + let mut dict_tracker = DictionaryTracker::new(false); + let (_, encoded_message) = gen + .encoded_batch(&batch, &mut dict_tracker, &Default::default()) + .map_err(|e| { + Error::General(format!("Error encoding ScalarValue::List as IPC: {e}")) + })?; + + let schema: protobuf::Schema = batch.schema().try_into()?; + + let scalar_list_value = protobuf::ScalarListValue { + ipc_message: encoded_message.ipc_message, + arrow_data: encoded_message.arrow_data, + schema: Some(schema), + }; + + match val { + ScalarValue::List(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::ListValue(scalar_list_value)), + }), + ScalarValue::LargeList(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::LargeListValue( + scalar_list_value, + )), + }), + ScalarValue::FixedSizeList(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::FixedSizeListValue( + scalar_list_value, + )), + }), + _ => unreachable!(), + } +}