diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 242d784edc9d..2d47b3e31472 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -26,12 +26,13 @@ use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; use crate::cast::{ as_decimal128_array, as_decimal256_array, as_dictionary_array, - as_fixed_size_binary_array, as_fixed_size_list_array, as_list_array, as_struct_array, + as_fixed_size_binary_array, as_fixed_size_list_array, as_struct_array, }; use crate::error::{DataFusionError, Result, _internal_err, _not_impl_err}; -use arrow::buffer::NullBuffer; +use crate::hash_utils::create_hashes; +use crate::utils::wrap_into_list_array; +use arrow::buffer::{NullBuffer, OffsetBuffer}; use arrow::compute::kernels::numeric::*; -use arrow::compute::nullif; use arrow::datatypes::{i256, FieldRef, Fields, SchemaBuilder}; use arrow::{ array::*, @@ -45,6 +46,7 @@ use arrow::{ DECIMAL128_MAX_PRECISION, }, }; +use arrow_array::cast::as_list_array; use arrow_array::{ArrowNativeTypeOp, Scalar}; /// Represents a dynamically typed, nullable single value. @@ -95,8 +97,8 @@ pub enum ScalarValue { LargeBinary(Option>), /// Fixed size list of nested ScalarValue Fixedsizelist(Option>, FieldRef, i32), - /// List of nested ScalarValue - List(Option>, FieldRef), + /// Represents a single element of a [`ListArray`] as an [`ArrayRef`] + List(ArrayRef), /// Date stored as a signed 32bit int days since UNIX epoch 1970-01-01 Date32(Option), /// Date stored as a signed 64bit int milliseconds since UNIX epoch 1970-01-01 @@ -198,8 +200,8 @@ impl PartialEq for ScalarValue { v1.eq(v2) && t1.eq(t2) && l1.eq(l2) } (Fixedsizelist(_, _, _), _) => false, - (List(v1, t1), List(v2, t2)) => v1.eq(v2) && t1.eq(t2), - (List(_, _), _) => false, + (List(v1), List(v2)) => v1.eq(v2), + (List(_), _) => false, (Date32(v1), Date32(v2)) => v1.eq(v2), (Date32(_), _) => false, (Date64(v1), Date64(v2)) => v1.eq(v2), @@ -316,14 +318,37 @@ impl PartialOrd for ScalarValue { } } (Fixedsizelist(_, _, _), _) => None, - (List(v1, t1), List(v2, t2)) => { - if t1.eq(t2) { - v1.partial_cmp(v2) + (List(arr1), List(arr2)) => { + if arr1.data_type() == arr2.data_type() { + let list_arr1 = as_list_array(arr1); + let list_arr2 = as_list_array(arr2); + if list_arr1.len() != list_arr2.len() { + return None; + } + for i in 0..list_arr1.len() { + let arr1 = list_arr1.value(i); + let arr2 = list_arr2.value(i); + + let lt_res = + arrow::compute::kernels::cmp::lt(&arr1, &arr2).unwrap(); + let eq_res = + arrow::compute::kernels::cmp::eq(&arr1, &arr2).unwrap(); + + for j in 0..lt_res.len() { + if lt_res.is_valid(j) && lt_res.value(j) { + return Some(Ordering::Less); + } + if eq_res.is_valid(j) && !eq_res.value(j) { + return Some(Ordering::Greater); + } + } + } + Some(Ordering::Equal) } else { None } } - (List(_, _), _) => None, + (List(_), _) => None, (Date32(v1), Date32(v2)) => v1.partial_cmp(v2), (Date32(_), _) => None, (Date64(v1), Date64(v2)) => v1.partial_cmp(v2), @@ -441,9 +466,14 @@ impl std::hash::Hash for ScalarValue { t.hash(state); l.hash(state); } - List(v, t) => { - v.hash(state); - t.hash(state); + List(arr) => { + let arrays = vec![arr.to_owned()]; + let hashes_buffer = &mut vec![0; arr.len()]; + let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0); + let hashes = + create_hashes(&arrays, &random_state, hashes_buffer).unwrap(); + // Hash back to std::hash::Hasher + hashes.hash(state); } Date32(v) => v.hash(state), Date64(v) => v.hash(state), @@ -570,28 +600,6 @@ macro_rules! typed_cast { }}; } -// keep until https://github.com/apache/arrow-rs/issues/2054 is finished -macro_rules! build_list { - ($VALUE_BUILDER_TY:ident, $SCALAR_TY:ident, $VALUES:expr, $SIZE:expr) => {{ - match $VALUES { - // the return on the macro is necessary, to short-circuit and return ArrayRef - None => { - return new_null_array( - &DataType::List(Arc::new(Field::new( - "item", - DataType::$SCALAR_TY, - true, - ))), - $SIZE, - ) - } - Some(values) => { - build_values_list!($VALUE_BUILDER_TY, $SCALAR_TY, values, $SIZE) - } - } - }}; -} - macro_rules! build_timestamp_list { ($TIME_UNIT:expr, $TIME_ZONE:expr, $VALUES:expr, $SIZE:expr) => {{ match $VALUES { @@ -792,11 +800,6 @@ impl ScalarValue { ScalarValue::IntervalMonthDayNano(Some(val)) } - /// Create a new nullable ScalarValue::List with the specified child_type - pub fn new_list(scalars: Option>, child_type: DataType) -> Self { - Self::List(scalars, Arc::new(Field::new("item", child_type, true))) - } - /// Create a zero value in the given type. pub fn new_zero(datatype: &DataType) -> Result { assert!(datatype.is_primitive()); @@ -953,11 +956,7 @@ impl ScalarValue { Arc::new(Field::new("item", field.data_type().clone(), true)), *length, ), - ScalarValue::List(_, field) => DataType::List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - ))), + ScalarValue::List(arr) => arr.data_type().to_owned(), ScalarValue::Date32(_) => DataType::Date32, ScalarValue::Date64(_) => DataType::Date64, ScalarValue::Time32Second(_) => DataType::Time32(TimeUnit::Second), @@ -1105,7 +1104,7 @@ impl ScalarValue { ScalarValue::FixedSizeBinary(_, v) => v.is_none(), ScalarValue::LargeBinary(v) => v.is_none(), ScalarValue::Fixedsizelist(v, ..) => v.is_none(), - ScalarValue::List(v, _) => v.is_none(), + ScalarValue::List(arr) => arr.len() == arr.null_count(), ScalarValue::Date32(v) => v.is_none(), ScalarValue::Date64(v) => v.is_none(), ScalarValue::Time32Second(v) => v.is_none(), @@ -1307,17 +1306,18 @@ impl ScalarValue { ($ARRAY_TY:ident, $SCALAR_TY:ident, $NATIVE_TYPE:ident) => {{ Arc::new(ListArray::from_iter_primitive::<$ARRAY_TY, _, _>( scalars.into_iter().map(|x| match x { - ScalarValue::List(xs, _) => xs.map(|x| { - x.iter().map(|x| match x { - ScalarValue::$SCALAR_TY(i) => *i, - sv => panic!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected {:?}, got {:?}", - data_type, sv - ), - }) - .collect::>>() - }), + ScalarValue::List(arr) => { + if arr.as_any().downcast_ref::().is_some() { + None + } else { + let list_arr = as_list_array(&arr); + let primitive_arr = + list_arr.values().as_primitive::<$ARRAY_TY>(); + Some( + primitive_arr.into_iter().collect::>>(), + ) + } + } sv => panic!( "Inconsistent types in ScalarValue::iter_to_array. \ Expected {:?}, got {:?}", @@ -1329,33 +1329,28 @@ impl ScalarValue { } macro_rules! build_array_list_string { - ($BUILDER:ident, $SCALAR_TY:ident) => {{ + ($BUILDER:ident, $STRING_ARRAY:ident) => {{ let mut builder = ListBuilder::new($BUILDER::new()); for scalar in scalars.into_iter() { match scalar { - ScalarValue::List(Some(xs), _) => { - for s in xs { - match s { - ScalarValue::$SCALAR_TY(Some(val)) => { - builder.values().append_value(val); - } - ScalarValue::$SCALAR_TY(None) => { - builder.values().append_null(); - } - sv => { - return _internal_err!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected Utf8, got {:?}", - sv - ) - } + ScalarValue::List(arr) => { + if arr.as_any().downcast_ref::().is_some() { + builder.append(false); + continue; + } + + let list_arr = as_list_array(&arr); + let string_arr = $STRING_ARRAY(list_arr.values()); + + for v in string_arr.iter() { + if let Some(v) = v { + builder.values().append_value(v); + } else { + builder.values().append_null(); } } builder.append(true); } - ScalarValue::List(None, _) => { - builder.append(false); - } sv => { return _internal_err!( "Inconsistent types in ScalarValue::iter_to_array. \ @@ -1474,14 +1469,14 @@ impl ScalarValue { build_array_list_primitive!(Float64Type, Float64, f64) } DataType::List(fields) if fields.data_type() == &DataType::Utf8 => { - build_array_list_string!(StringBuilder, Utf8) + build_array_list_string!(StringBuilder, as_string_array) } DataType::List(fields) if fields.data_type() == &DataType::LargeUtf8 => { - build_array_list_string!(LargeStringBuilder, LargeUtf8) + build_array_list_string!(LargeStringBuilder, as_largestring_array) } DataType::List(_) => { // Fallback case handling homogeneous lists with any ScalarValue element type - let list_array = ScalarValue::iter_to_array_list(scalars, &data_type)?; + let list_array = ScalarValue::iter_to_array_list(scalars)?; Arc::new(list_array) } DataType::Struct(fields) => { @@ -1529,7 +1524,7 @@ impl ScalarValue { .collect::>>()?; let array = StructArray::from(field_values); - nullif(&array, &null_mask_builder.finish())? + arrow::compute::nullif(&array, &null_mask_builder.finish())? } DataType::Dictionary(key_type, value_type) => { // create the values array @@ -1654,42 +1649,66 @@ impl ScalarValue { Ok(array) } - fn iter_to_array_list( - scalars: impl IntoIterator, + /// This function does not contains nulls but empty array instead. + fn iter_to_array_list_without_nulls( + values: &[ScalarValue], data_type: &DataType, ) -> Result> { - let mut offsets = Int32Array::builder(0); - offsets.append_value(0); + let mut elements: Vec = vec![]; + let mut offsets = vec![]; - let mut elements: Vec = Vec::new(); - let mut valid = BooleanBufferBuilder::new(0); - let mut flat_len = 0i32; - for scalar in scalars { - if let ScalarValue::List(values, field) = scalar { - match values { - Some(values) => { - let element_array = if !values.is_empty() { - ScalarValue::iter_to_array(values)? - } else { - arrow::array::new_empty_array(field.data_type()) - }; + if values.is_empty() { + offsets.push(0); + } else { + let arr = ScalarValue::iter_to_array(values.to_vec())?; + offsets.push(arr.len()); + elements.push(arr); + } - // Add new offset index - flat_len += element_array.len() as i32; - offsets.append_value(flat_len); + // Concatenate element arrays to create single flat array + let flat_array = if elements.is_empty() { + new_empty_array(data_type) + } else { + let element_arrays: Vec<&dyn Array> = + elements.iter().map(|a| a.as_ref()).collect(); + arrow::compute::concat(&element_arrays)? + }; - elements.push(element_array); + let list_array = ListArray::new( + Arc::new(Field::new("item", flat_array.data_type().to_owned(), true)), + OffsetBuffer::::from_lengths(offsets), + flat_array, + None, + ); - // Element is valid - valid.append(true); - } - None => { - // Repeat previous offset index - offsets.append_value(flat_len); + Ok(list_array) + } - // Element is null - valid.append(false); - } + /// This function build with nulls with nulls buffer. + fn iter_to_array_list( + scalars: impl IntoIterator, + ) -> Result> { + let mut elements: Vec = vec![]; + let mut valid = BooleanBufferBuilder::new(0); + let mut offsets = vec![]; + + for scalar in scalars { + if let ScalarValue::List(arr) = scalar { + // i.e. NullArray(1) + if arr.as_any().downcast_ref::().is_some() { + // Repeat previous offset index + offsets.push(0); + + // Element is null + valid.append(false); + } else { + let list_arr = as_list_array(&arr); + let arr = list_arr.values().to_owned(); + offsets.push(arr.len()); + elements.push(arr); + + // Element is valid + valid.append(true); } } else { return _internal_err!( @@ -1701,20 +1720,21 @@ impl ScalarValue { // Concatenate element arrays to create single flat array let element_arrays: Vec<&dyn Array> = elements.iter().map(|a| a.as_ref()).collect(); + let flat_array = match arrow::compute::concat(&element_arrays) { Ok(flat_array) => flat_array, Err(err) => return Err(DataFusionError::ArrowError(err)), }; - // Build ListArray using ArrayData so we can specify a flat inner array, and offset indices - let offsets_array = offsets.finish(); - let array_data = ArrayDataBuilder::new(data_type.clone()) - .len(offsets_array.len() - 1) - .nulls(Some(NullBuffer::new(valid.finish()))) - .add_buffer(offsets_array.values().inner().clone()) - .add_child_data(flat_array.to_data()); + let buffer = valid.finish(); + + let list_array = ListArray::new( + Arc::new(Field::new("item", flat_array.data_type().clone(), true)), + OffsetBuffer::::from_lengths(offsets), + flat_array, + Some(NullBuffer::new(buffer)), + ); - let list_array = ListArray::from(array_data.build()?); Ok(list_array) } @@ -1751,6 +1771,80 @@ impl ScalarValue { .unwrap() } + /// Converts `Vec` to ListArray, simplified version of ScalarValue::to_array + /// + /// Example + /// ``` + /// use datafusion_common::ScalarValue; + /// use arrow::array::{ListArray, Int32Array}; + /// use arrow::datatypes::{DataType, Int32Type}; + /// use datafusion_common::cast::as_list_array; + /// + /// let scalars = vec![ + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(None), + /// ScalarValue::Int32(Some(2)) + /// ]; + /// + /// let array = ScalarValue::new_list(&scalars, &DataType::Int32); + /// let result = as_list_array(&array).unwrap(); + /// + /// let expected = ListArray::from_iter_primitive::( + /// vec![ + /// Some(vec![Some(1), None, Some(2)]) + /// ]); + /// + /// assert_eq!(result, &expected); + /// ``` + pub fn new_list(values: &[ScalarValue], data_type: &DataType) -> ArrayRef { + Arc::new(match data_type { + DataType::Boolean => build_values_list!(BooleanBuilder, Boolean, values, 1), + DataType::Int8 => build_values_list!(Int8Builder, Int8, values, 1), + DataType::Int16 => build_values_list!(Int16Builder, Int16, values, 1), + DataType::Int32 => build_values_list!(Int32Builder, Int32, values, 1), + DataType::Int64 => build_values_list!(Int64Builder, Int64, values, 1), + DataType::UInt8 => build_values_list!(UInt8Builder, UInt8, values, 1), + DataType::UInt16 => build_values_list!(UInt16Builder, UInt16, values, 1), + DataType::UInt32 => build_values_list!(UInt32Builder, UInt32, values, 1), + DataType::UInt64 => build_values_list!(UInt64Builder, UInt64, values, 1), + DataType::Utf8 => build_values_list!(StringBuilder, Utf8, values, 1), + DataType::LargeUtf8 => { + build_values_list!(LargeStringBuilder, LargeUtf8, values, 1) + } + DataType::Float32 => build_values_list!(Float32Builder, Float32, values, 1), + DataType::Float64 => build_values_list!(Float64Builder, Float64, values, 1), + DataType::Timestamp(unit, tz) => { + let values = Some(values); + build_timestamp_list!(unit.clone(), tz.clone(), values, 1) + } + DataType::List(_) | DataType::Struct(_) => { + ScalarValue::iter_to_array_list_without_nulls(values, data_type).unwrap() + } + DataType::Decimal128(precision, scale) => { + let mut vals = vec![]; + for value in values.iter() { + if let ScalarValue::Decimal128(v, _, _) = value { + vals.push(v.to_owned()) + } + } + + let arr = Decimal128Array::from(vals) + .with_precision_and_scale(*precision, *scale) + .unwrap(); + wrap_into_list_array(Arc::new(arr)) + } + + DataType::Null => { + let arr = new_null_array(&DataType::Null, values.len()); + wrap_into_list_array(arr) + } + _ => panic!( + "Unsupported data type {:?} for ScalarValue::list_to_array", + data_type + ), + }) + } + /// Converts a scalar value into an array of `size` rows. pub fn to_array_of_size(&self, size: usize) -> ArrayRef { match self { @@ -1873,35 +1967,12 @@ impl ScalarValue { ScalarValue::Fixedsizelist(..) => { unimplemented!("FixedSizeList is not supported yet") } - ScalarValue::List(values, field) => Arc::new(match field.data_type() { - DataType::Boolean => build_list!(BooleanBuilder, Boolean, values, size), - DataType::Int8 => build_list!(Int8Builder, Int8, values, size), - DataType::Int16 => build_list!(Int16Builder, Int16, values, size), - DataType::Int32 => build_list!(Int32Builder, Int32, values, size), - DataType::Int64 => build_list!(Int64Builder, Int64, values, size), - DataType::UInt8 => build_list!(UInt8Builder, UInt8, values, size), - DataType::UInt16 => build_list!(UInt16Builder, UInt16, values, size), - DataType::UInt32 => build_list!(UInt32Builder, UInt32, values, size), - DataType::UInt64 => build_list!(UInt64Builder, UInt64, values, size), - DataType::Utf8 => build_list!(StringBuilder, Utf8, values, size), - DataType::Float32 => build_list!(Float32Builder, Float32, values, size), - DataType::Float64 => build_list!(Float64Builder, Float64, values, size), - DataType::Timestamp(unit, tz) => { - build_timestamp_list!(unit.clone(), tz.clone(), values, size) - } - &DataType::LargeUtf8 => { - build_list!(LargeStringBuilder, LargeUtf8, values, size) - } - _ => ScalarValue::iter_to_array_list( - repeat(self.clone()).take(size), - &DataType::List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - ))), - ) - .unwrap(), - }), + ScalarValue::List(arr) => { + let arrays = std::iter::repeat(arr.as_ref()) + .take(size) + .collect::>(); + arrow::compute::concat(arrays.as_slice()).unwrap() + } ScalarValue::Date32(e) => { build_array_from_option!(Date32, Date32Array, e, size) } @@ -2057,6 +2128,71 @@ impl ScalarValue { } } + /// Retrieve ScalarValue for each row in `array` + /// + /// Example + /// ``` + /// use datafusion_common::ScalarValue; + /// use arrow::array::ListArray; + /// use arrow::datatypes::{DataType, Int32Type}; + /// + /// let list_arr = ListArray::from_iter_primitive::(vec![ + /// Some(vec![Some(1), Some(2), Some(3)]), + /// None, + /// Some(vec![Some(4), Some(5)]) + /// ]); + /// + /// let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&list_arr).unwrap(); + /// + /// let expected = vec![ + /// vec![ + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(Some(2)), + /// ScalarValue::Int32(Some(3)), + /// ], + /// vec![], + /// vec![ScalarValue::Int32(Some(4)), ScalarValue::Int32(Some(5))] + /// ]; + /// + /// assert_eq!(scalar_vec, expected); + /// ``` + pub fn convert_array_to_scalar_vec(array: &dyn Array) -> Result>> { + let mut scalars = Vec::with_capacity(array.len()); + + for index in 0..array.len() { + let scalar_values = match array.data_type() { + DataType::List(_) => { + let list_array = as_list_array(array); + match list_array.is_null(index) { + true => Vec::new(), + false => { + let nested_array = list_array.value(index); + ScalarValue::convert_array_to_scalar_vec(&nested_array)? + .into_iter() + .flatten() + .collect() + } + } + } + _ => { + let scalar = ScalarValue::try_from_array(array, index)?; + vec![scalar] + } + }; + scalars.push(scalar_values); + } + Ok(scalars) + } + + // TODO: Support more types after other ScalarValue is wrapped with ArrayRef + /// Get raw data (inner array) inside ScalarValue + pub fn raw_data(&self) -> Result { + match self { + ScalarValue::List(arr) => Ok(arr.to_owned()), + _ => _internal_err!("ScalarValue is not a list"), + } + } + /// Converts a value in `array` at `index` into a ScalarValue pub fn try_from_array(array: &dyn Array, index: usize) -> Result { // handle NULL value @@ -2094,18 +2230,29 @@ impl ScalarValue { DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8), DataType::LargeUtf8 => typed_cast!(array, index, LargeStringArray, LargeUtf8), DataType::List(nested_type) => { - let list_array = as_list_array(array)?; - let value = match list_array.is_null(index) { - true => None, + let list_array = as_list_array(array); + let arr = match list_array.is_null(index) { + true => new_null_array(nested_type.data_type(), 0), + false => { + let nested_array = list_array.value(index); + Arc::new(wrap_into_list_array(nested_array)) + } + }; + + ScalarValue::List(arr) + } + // TODO: There is no test for FixedSizeList now, add it later + DataType::FixedSizeList(nested_type, _len) => { + let list_array = as_fixed_size_list_array(array)?; + let arr = match list_array.is_null(index) { + true => new_null_array(nested_type.data_type(), 0), false => { let nested_array = list_array.value(index); - let scalar_vec = (0..nested_array.len()) - .map(|i| ScalarValue::try_from_array(&nested_array, i)) - .collect::>>()?; - Some(scalar_vec) + Arc::new(wrap_into_list_array(nested_array)) } }; - ScalarValue::new_list(value, nested_type.data_type().clone()) + + ScalarValue::List(arr) } DataType::Date32 => { typed_cast!(array, index, Date32Array, Date32) @@ -2194,20 +2341,6 @@ impl ScalarValue { } Self::Struct(Some(field_values), fields.clone()) } - DataType::FixedSizeList(nested_type, _len) => { - let list_array = as_fixed_size_list_array(array)?; - let value = match list_array.is_null(index) { - true => None, - false => { - let nested_array = list_array.value(index); - let scalar_vec = (0..nested_array.len()) - .map(|i| ScalarValue::try_from_array(&nested_array, i)) - .collect::>>()?; - Some(scalar_vec) - } - }; - ScalarValue::new_list(value, nested_type.data_type().clone()) - } DataType::FixedSizeBinary(_) => { let array = as_fixed_size_binary_array(array)?; let size = match array.data_type() { @@ -2383,7 +2516,7 @@ impl ScalarValue { eq_array_primitive!(array, index, LargeBinaryArray, val) } ScalarValue::Fixedsizelist(..) => unimplemented!(), - ScalarValue::List(_, _) => unimplemented!(), + ScalarValue::List(_) => unimplemented!("ListArr"), ScalarValue::Date32(val) => { eq_array_primitive!(array, index, Date32Array, val) } @@ -2504,14 +2637,14 @@ impl ScalarValue { | ScalarValue::LargeBinary(b) => { b.as_ref().map(|b| b.capacity()).unwrap_or_default() } - ScalarValue::Fixedsizelist(vals, field, _) - | ScalarValue::List(vals, field) => { + ScalarValue::Fixedsizelist(vals, field, _) => { vals.as_ref() .map(|vals| Self::size_of_vec(vals) - std::mem::size_of_val(vals)) .unwrap_or_default() // `field` is boxed, so it is NOT already included in `self` + field.size() } + ScalarValue::List(arr) => arr.get_array_memory_size(), ScalarValue::Struct(vals, fields) => { vals.as_ref() .map(|vals| { @@ -2735,8 +2868,8 @@ impl TryFrom<&DataType> for ScalarValue { type Error = DataFusionError; /// Create a Null instance of ScalarValue for this datatype - fn try_from(datatype: &DataType) -> Result { - Ok(match datatype { + fn try_from(data_type: &DataType) -> Result { + Ok(match data_type { DataType::Boolean => ScalarValue::Boolean(None), DataType::Float64 => ScalarValue::Float64(None), DataType::Float32 => ScalarValue::Float32(None), @@ -2806,14 +2939,13 @@ impl TryFrom<&DataType> for ScalarValue { index_type.clone(), Box::new(value_type.as_ref().try_into()?), ), - DataType::List(ref nested_type) => { - ScalarValue::new_list(None, nested_type.data_type().clone()) - } + DataType::List(_) => ScalarValue::List(new_null_array(&DataType::Null, 0)), + DataType::Struct(fields) => ScalarValue::Struct(None, fields.clone()), DataType::Null => ScalarValue::Null, _ => { return _not_impl_err!( - "Can't create a scalar from data_type \"{datatype:?}\"" + "Can't create a scalar from data_type \"{data_type:?}\"" ); } }) @@ -2868,7 +3000,7 @@ impl fmt::Display for ScalarValue { )?, None => write!(f, "NULL")?, }, - ScalarValue::Fixedsizelist(e, ..) | ScalarValue::List(e, _) => match e { + ScalarValue::Fixedsizelist(e, ..) => match e { Some(l) => write!( f, "{}", @@ -2879,6 +3011,12 @@ impl fmt::Display for ScalarValue { )?, None => write!(f, "NULL")?, }, + ScalarValue::List(arr) => write!( + f, + "{}", + arrow::util::pretty::pretty_format_columns("col", &[arr.to_owned()]) + .unwrap() + )?, ScalarValue::Date32(e) => format_option!(f, e)?, ScalarValue::Date64(e) => format_option!(f, e)?, ScalarValue::Time32Second(e) => format_option!(f, e)?, @@ -2954,7 +3092,7 @@ impl fmt::Debug for ScalarValue { ScalarValue::LargeBinary(None) => write!(f, "LargeBinary({self})"), ScalarValue::LargeBinary(Some(_)) => write!(f, "LargeBinary(\"{self}\")"), ScalarValue::Fixedsizelist(..) => write!(f, "FixedSizeList([{self}])"), - ScalarValue::List(_, _) => write!(f, "List([{self}])"), + ScalarValue::List(arr) => write!(f, "List([{arr:?}])"), ScalarValue::Date32(_) => write!(f, "Date32(\"{self}\")"), ScalarValue::Date64(_) => write!(f, "Date64(\"{self}\")"), ScalarValue::Time32Second(_) => write!(f, "Time32Second(\"{self}\")"), @@ -3060,6 +3198,104 @@ mod tests { use super::*; + #[test] + fn test_to_array_of_size_for_list() { + let arr = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + None, + Some(2), + ])]); + + let sv = ScalarValue::List(Arc::new(arr)); + let actual_arr = sv.to_array_of_size(2); + let actual_list_arr = as_list_array(&actual_arr); + + let arr = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), None, Some(2)]), + Some(vec![Some(1), None, Some(2)]), + ]); + + assert_eq!(&arr, actual_list_arr); + } + + #[test] + fn test_list_to_array_string() { + let scalars = vec![ + ScalarValue::Utf8(Some(String::from("rust"))), + ScalarValue::Utf8(Some(String::from("arrow"))), + ScalarValue::Utf8(Some(String::from("data-fusion"))), + ]; + + let array = ScalarValue::new_list(scalars.as_slice(), &DataType::Utf8); + + let expected = wrap_into_list_array(Arc::new(StringArray::from(vec![ + "rust", + "arrow", + "data-fusion", + ]))); + let result = as_list_array(&array); + assert_eq!(result, &expected); + } + + #[test] + fn iter_to_array_primitive_test() { + let scalars = vec![ + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]), + )), + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(4), + Some(5), + ])]), + )), + ]; + + let array = ScalarValue::iter_to_array(scalars).unwrap(); + let list_array = as_list_array(&array); + let expected = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(4), Some(5)]), + ]); + assert_eq!(list_array, &expected); + } + + #[test] + fn iter_to_array_string_test() { + let arr1 = + wrap_into_list_array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"]))); + let arr2 = + wrap_into_list_array(Arc::new(StringArray::from(vec!["rust", "world"]))); + + let scalars = vec![ + ScalarValue::List(Arc::new(arr1)), + ScalarValue::List(Arc::new(arr2)), + ]; + + let array = ScalarValue::iter_to_array(scalars).unwrap(); + let result = as_list_array(&array); + + // build expected array + let string_builder = StringBuilder::with_capacity(5, 25); + let mut list_of_string_builder = ListBuilder::new(string_builder); + + list_of_string_builder.values().append_value("foo"); + list_of_string_builder.values().append_value("bar"); + list_of_string_builder.values().append_value("baz"); + list_of_string_builder.append(true); + + list_of_string_builder.values().append_value("rust"); + list_of_string_builder.values().append_value("world"); + list_of_string_builder.append(true); + let expected = list_of_string_builder.finish(); + + assert_eq!(result, &expected); + } + #[test] fn scalar_add_trait_test() -> Result<()> { let float_value = ScalarValue::Float64(Some(123.)); @@ -3304,6 +3540,81 @@ mod tests { Ok(()) } + #[test] + fn test_list_partial_cmp() { + let a = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]), + )); + let b = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]), + )); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Equal)); + + let a = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(10), + Some(2), + Some(3), + ])]), + )); + let b = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(30), + ])]), + )); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Greater)); + + let a = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(10), + Some(2), + Some(3), + ])]), + )); + let b = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(10), + Some(2), + Some(30), + ])]), + )); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Less)); + + let a = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![ + Some(vec![Some(10), Some(2), Some(3)]), + None, + Some(vec![Some(10), Some(2), Some(3)]), + ]), + )); + let b = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![ + Some(vec![Some(10), Some(2), Some(3)]), + None, + Some(vec![Some(10), Some(2), Some(3)]), + ]), + )); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Equal)); + } + #[test] fn scalar_value_to_array_u64() -> Result<()> { let value = ScalarValue::UInt64(Some(13u64)); @@ -3340,31 +3651,22 @@ mod tests { #[test] fn scalar_list_null_to_array() { - let list_array_ref = ScalarValue::List( - None, - Arc::new(Field::new("item", DataType::UInt64, false)), - ) - .to_array(); - let list_array = as_list_array(&list_array_ref).unwrap(); + let list_array_ref = ScalarValue::new_list(&[], &DataType::UInt64); + let list_array = as_list_array(&list_array_ref); - assert!(list_array.is_null(0)); assert_eq!(list_array.len(), 1); assert_eq!(list_array.values().len(), 0); } #[test] fn scalar_list_to_array() -> Result<()> { - let list_array_ref = ScalarValue::List( - Some(vec![ - ScalarValue::UInt64(Some(100)), - ScalarValue::UInt64(None), - ScalarValue::UInt64(Some(101)), - ]), - Arc::new(Field::new("item", DataType::UInt64, false)), - ) - .to_array(); - - let list_array = as_list_array(&list_array_ref)?; + let values = vec![ + ScalarValue::UInt64(Some(100)), + ScalarValue::UInt64(None), + ScalarValue::UInt64(Some(101)), + ]; + let list_array_ref = ScalarValue::new_list(&values, &DataType::UInt64); + let list_array = as_list_array(&list_array_ref); assert_eq!(list_array.len(), 1); assert_eq!(list_array.values().len(), 3); @@ -3864,55 +4166,6 @@ mod tests { assert_eq!(Int64(Some(33)).partial_cmp(&Int32(Some(33))), None); assert_eq!(Int32(Some(33)).partial_cmp(&Int64(Some(33))), None); - assert_eq!( - List( - Some(vec![Int32(Some(1)), Int32(Some(5))]), - Arc::new(Field::new("item", DataType::Int32, false)), - ) - .partial_cmp(&List( - Some(vec![Int32(Some(1)), Int32(Some(5))]), - Arc::new(Field::new("item", DataType::Int32, false)), - )), - Some(Ordering::Equal) - ); - - assert_eq!( - List( - Some(vec![Int32(Some(10)), Int32(Some(5))]), - Arc::new(Field::new("item", DataType::Int32, false)), - ) - .partial_cmp(&List( - Some(vec![Int32(Some(1)), Int32(Some(5))]), - Arc::new(Field::new("item", DataType::Int32, false)), - )), - Some(Ordering::Greater) - ); - - assert_eq!( - List( - Some(vec![Int32(Some(1)), Int32(Some(5))]), - Arc::new(Field::new("item", DataType::Int32, false)), - ) - .partial_cmp(&List( - Some(vec![Int32(Some(10)), Int32(Some(5))]), - Arc::new(Field::new("item", DataType::Int32, false)), - )), - Some(Ordering::Less) - ); - - // For different data type, `partial_cmp` returns None. - assert_eq!( - List( - Some(vec![Int64(Some(1)), Int64(Some(5))]), - Arc::new(Field::new("item", DataType::Int64, false)), - ) - .partial_cmp(&List( - Some(vec![Int32(Some(1)), Int32(Some(5))]), - Arc::new(Field::new("item", DataType::Int32, false)), - )), - None - ); - assert_eq!( ScalarValue::from(vec![ ("A", ScalarValue::from(1.0)), @@ -4125,24 +4378,26 @@ mod tests { )); // Define primitive list scalars - let l0 = ScalarValue::List( - Some(vec![ - ScalarValue::from(1i32), - ScalarValue::from(2i32), - ScalarValue::from(3i32), - ]), - Arc::new(Field::new("item", DataType::Int32, false)), - ); - - let l1 = ScalarValue::List( - Some(vec![ScalarValue::from(4i32), ScalarValue::from(5i32)]), - Arc::new(Field::new("item", DataType::Int32, false)), - ); - - let l2 = ScalarValue::List( - Some(vec![ScalarValue::from(6i32)]), - Arc::new(Field::new("item", DataType::Int32, false)), - ); + let l0 = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]), + )); + let l1 = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(4), + Some(5), + ])]), + )); + let l2 = ScalarValue::List(Arc::new(ListArray::from_iter_primitive::< + Int32Type, + _, + _, + >(vec![Some(vec![Some(6)])]))); // Define struct scalars let s0 = ScalarValue::from(vec![ @@ -4182,15 +4437,19 @@ mod tests { assert_eq!(array, &expected); // Define list-of-structs scalars - let nl0 = - ScalarValue::new_list(Some(vec![s0.clone(), s1.clone()]), s0.data_type()); - let nl1 = ScalarValue::new_list(Some(vec![s2]), s0.data_type()); + let nl0_array = ScalarValue::iter_to_array(vec![s0.clone(), s1.clone()]).unwrap(); + let nl0 = ScalarValue::List(Arc::new(wrap_into_list_array(nl0_array))); + + let nl1_array = ScalarValue::iter_to_array(vec![s2.clone()]).unwrap(); + let nl1 = ScalarValue::List(Arc::new(wrap_into_list_array(nl1_array))); + + let nl2_array = ScalarValue::iter_to_array(vec![s1.clone()]).unwrap(); + let nl2 = ScalarValue::List(Arc::new(wrap_into_list_array(nl2_array))); - let nl2 = ScalarValue::new_list(Some(vec![s1]), s0.data_type()); // iter_to_array for list-of-struct let array = ScalarValue::iter_to_array(vec![nl0, nl1, nl2]).unwrap(); - let array = as_list_array(&array).unwrap(); + let array = as_list_array(&array); // Construct expected array with array builders let field_a_builder = StringBuilder::with_capacity(4, 1024); @@ -4313,48 +4572,63 @@ mod tests { #[test] fn test_nested_lists() { // Define inner list scalars - let l1 = ScalarValue::new_list( - Some(vec![ - ScalarValue::new_list( - Some(vec![ - ScalarValue::from(1i32), - ScalarValue::from(2i32), - ScalarValue::from(3i32), - ]), - DataType::Int32, - ), - ScalarValue::new_list( - Some(vec![ScalarValue::from(4i32), ScalarValue::from(5i32)]), - DataType::Int32, - ), - ]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + let a1 = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]); + let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(4), + Some(5), + ])]); + let l1 = ListArray::new( + Arc::new(Field::new( + "item", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + )), + OffsetBuffer::::from_lengths([1, 1]), + arrow::compute::concat(&[&a1, &a2]).unwrap(), + None, ); - let l2 = ScalarValue::new_list( - Some(vec![ - ScalarValue::new_list( - Some(vec![ScalarValue::from(6i32)]), - DataType::Int32, - ), - ScalarValue::new_list( - Some(vec![ScalarValue::from(7i32), ScalarValue::from(8i32)]), - DataType::Int32, - ), - ]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + let a1 = + ListArray::from_iter_primitive::(vec![Some(vec![Some(6)])]); + let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(7), + Some(8), + ])]); + let l2 = ListArray::new( + Arc::new(Field::new( + "item", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + )), + OffsetBuffer::::from_lengths([1, 1]), + arrow::compute::concat(&[&a1, &a2]).unwrap(), + None, ); - let l3 = ScalarValue::new_list( - Some(vec![ScalarValue::new_list( - Some(vec![ScalarValue::from(9i32)]), - DataType::Int32, - )]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + let a1 = + ListArray::from_iter_primitive::(vec![Some(vec![Some(9)])]); + let l3 = ListArray::new( + Arc::new(Field::new( + "item", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + )), + OffsetBuffer::::from_lengths([1]), + arrow::compute::concat(&[&a1]).unwrap(), + None, ); - let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap(); - let array = as_list_array(&array).unwrap(); + let array = ScalarValue::iter_to_array(vec![ + ScalarValue::List(Arc::new(l1)), + ScalarValue::List(Arc::new(l2)), + ScalarValue::List(Arc::new(l3)), + ]) + .unwrap(); + let array = as_list_array(&array); // Construct expected array with array builders let inner_builder = Int32Array::builder(8); @@ -4904,12 +5178,11 @@ mod tests { #[test] fn test_build_timestamp_millisecond_list() { let values = vec![ScalarValue::TimestampMillisecond(Some(1), None)]; - let ts_list = ScalarValue::new_list( - Some(values), - DataType::Timestamp(TimeUnit::Millisecond, None), + let arr = ScalarValue::new_list( + &values, + &DataType::Timestamp(TimeUnit::Millisecond, None), ); - let list = ts_list.to_array_of_size(1); - assert_eq!(1, list.len()); + assert_eq!(1, arr.len()); } fn get_random_timestamps(sample_size: u64) -> Vec { diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index b7c80aa9ac44..b2f71e86f21e 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -19,10 +19,12 @@ use crate::{DataFusionError, Result, ScalarValue}; use arrow::array::{ArrayRef, PrimitiveArray}; +use arrow::buffer::OffsetBuffer; use arrow::compute; use arrow::compute::{partition, SortColumn, SortOptions}; -use arrow::datatypes::{SchemaRef, UInt32Type}; +use arrow::datatypes::{Field, SchemaRef, UInt32Type}; use arrow::record_batch::RecordBatch; +use arrow_array::ListArray; use sqlparser::ast::Ident; use sqlparser::dialect::GenericDialect; use sqlparser::parser::Parser; @@ -334,6 +336,18 @@ pub fn longest_consecutive_prefix>( count } +/// Wrap an array into a single element `ListArray`. +/// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` +pub fn wrap_into_list_array(arr: ArrayRef) -> ListArray { + let offsets = OffsetBuffer::from_lengths([arr.len()]); + ListArray::new( + Arc::new(Field::new("item", arr.data_type().to_owned(), true)), + offsets, + arr, + None, + ) +} + /// An extension trait for smart pointers. Provides an interface to get a /// raw pointer to the data (with metadata stripped away). /// diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index 63d5e58090eb..03864e9efef8 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -47,25 +47,23 @@ async fn csv_query_array_agg_distinct() -> Result<()> { let column = actual[0].column(0); assert_eq!(column.len(), 1); - if let ScalarValue::List(Some(mut v), _) = ScalarValue::try_from_array(column, 0)? { - // workaround lack of Ord of ScalarValue - let cmp = |a: &ScalarValue, b: &ScalarValue| { - a.partial_cmp(b).expect("Can compare ScalarValues") - }; - v.sort_by(cmp); - assert_eq!( - *v, - vec![ - ScalarValue::UInt32(Some(1)), - ScalarValue::UInt32(Some(2)), - ScalarValue::UInt32(Some(3)), - ScalarValue::UInt32(Some(4)), - ScalarValue::UInt32(Some(5)) - ] - ); - } else { - unreachable!(); - } + let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&column)?; + let mut scalars = scalar_vec[0].clone(); + // workaround lack of Ord of ScalarValue + let cmp = |a: &ScalarValue, b: &ScalarValue| { + a.partial_cmp(b).expect("Can compare ScalarValues") + }; + scalars.sort_by(cmp); + assert_eq!( + scalars, + vec![ + ScalarValue::UInt32(Some(1)), + ScalarValue::UInt32(Some(2)), + ScalarValue::UInt32(Some(3)), + ScalarValue::UInt32(Some(4)), + ScalarValue::UInt32(Some(5)) + ] + ); Ok(()) } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index f5a6860299ab..cb3f13a51ec4 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -30,7 +30,10 @@ use arrow::{ error::ArrowError, record_batch::RecordBatch, }; -use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}; +use datafusion_common::{ + cast::{as_large_list_array, as_list_array}, + tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}, +}; use datafusion_common::{ exec_err, internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, }; @@ -392,8 +395,11 @@ impl<'a> ConstEvaluator<'a> { "Could not evaluate the expression, found a result of length {}", a.len() ) + } else if as_list_array(&a).is_ok() || as_large_list_array(&a).is_ok() { + Ok(ScalarValue::List(a)) } else { - Ok(ScalarValue::try_from_array(&a, 0)?) + // Non-ListArray + ScalarValue::try_from_array(&a, 0) } } ColumnarValue::Scalar(s) => Ok(s), diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index 0cf39888f133..834925b8d554 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -22,8 +22,11 @@ use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field}; +use arrow_array::Array; +use datafusion_common::cast::as_list_array; +use datafusion_common::utils::wrap_into_list_array; +use datafusion_common::Result; use datafusion_common::ScalarValue; -use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_expr::Accumulator; use std::any::Any; use std::sync::Arc; @@ -102,7 +105,7 @@ impl PartialEq for ArrayAgg { #[derive(Debug)] pub(crate) struct ArrayAggAccumulator { - values: Vec, + values: Vec, datatype: DataType, } @@ -117,34 +120,29 @@ impl ArrayAggAccumulator { } impl Accumulator for ArrayAggAccumulator { + // Append value like Int64Array(1,2,3) fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { if values.is_empty() { return Ok(()); } assert!(values.len() == 1, "array_agg can only take 1 param!"); - let arr = &values[0]; - (0..arr.len()).try_for_each(|index| { - let scalar = ScalarValue::try_from_array(arr, index)?; - self.values.push(scalar); - Ok(()) - }) + let val = values[0].clone(); + self.values.push(val); + Ok(()) } + // Append value like ListArray(Int64Array(1,2,3), Int64Array(4,5,6)) fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { if states.is_empty() { return Ok(()); } assert!(states.len() == 1, "array_agg states must be singleton!"); - let arr = &states[0]; - (0..arr.len()).try_for_each(|index| { - let scalar = ScalarValue::try_from_array(arr, index)?; - if let ScalarValue::List(Some(values), _) = scalar { - self.values.extend(values); - Ok(()) - } else { - internal_err!("array_agg state must be list!") - } - }) + + let list_arr = as_list_array(&states[0])?; + for arr in list_arr.iter().flatten() { + self.values.push(arr); + } + Ok(()) } fn state(&self) -> Result> { @@ -152,15 +150,30 @@ impl Accumulator for ArrayAggAccumulator { } fn evaluate(&self) -> Result { - Ok(ScalarValue::new_list( - Some(self.values.clone()), - self.datatype.clone(), - )) + // Transform Vec to ListArr + + let element_arrays: Vec<&dyn Array> = + self.values.iter().map(|a| a.as_ref()).collect(); + + if element_arrays.is_empty() { + let arr = ScalarValue::new_list(&[], &self.datatype); + return Ok(ScalarValue::List(arr)); + } + + let concated_array = arrow::compute::concat(&element_arrays)?; + let list_array = wrap_into_list_array(concated_array); + + Ok(ScalarValue::List(Arc::new(list_array))) } fn size(&self) -> usize { - std::mem::size_of_val(self) + ScalarValue::size_of_vec(&self.values) - - std::mem::size_of_val(&self.values) + std::mem::size_of_val(self) + + (std::mem::size_of::() * self.values.capacity()) + + self + .values + .iter() + .map(|arr| arr.get_array_memory_size()) + .sum::() + self.datatype.size() - std::mem::size_of_val(&self.datatype) } @@ -176,72 +189,78 @@ mod tests { use arrow::array::Int32Array; use arrow::datatypes::*; use arrow::record_batch::RecordBatch; + use arrow_array::Array; + use arrow_array::ListArray; + use arrow_buffer::OffsetBuffer; + use datafusion_common::DataFusionError; use datafusion_common::Result; #[test] fn array_agg_i32() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - let list = ScalarValue::new_list( - Some(vec![ - ScalarValue::Int32(Some(1)), - ScalarValue::Int32(Some(2)), - ScalarValue::Int32(Some(3)), - ScalarValue::Int32(Some(4)), - ScalarValue::Int32(Some(5)), - ]), - DataType::Int32, - ); + let list = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + ])]); + let list = ScalarValue::List(Arc::new(list)); generic_test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32) } #[test] fn array_agg_nested() -> Result<()> { - let l1 = ScalarValue::new_list( - Some(vec![ - ScalarValue::new_list( - Some(vec![ - ScalarValue::from(1i32), - ScalarValue::from(2i32), - ScalarValue::from(3i32), - ]), - DataType::Int32, - ), - ScalarValue::new_list( - Some(vec![ScalarValue::from(4i32), ScalarValue::from(5i32)]), - DataType::Int32, - ), - ]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + let a1 = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]); + let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(4), + Some(5), + ])]); + let l1 = ListArray::new( + Arc::new(Field::new("item", a1.data_type().to_owned(), true)), + OffsetBuffer::from_lengths([a1.len() + a2.len()]), + arrow::compute::concat(&[&a1, &a2])?, + None, ); - let l2 = ScalarValue::new_list( - Some(vec![ - ScalarValue::new_list( - Some(vec![ScalarValue::from(6i32)]), - DataType::Int32, - ), - ScalarValue::new_list( - Some(vec![ScalarValue::from(7i32), ScalarValue::from(8i32)]), - DataType::Int32, - ), - ]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + let a1 = + ListArray::from_iter_primitive::(vec![Some(vec![Some(6)])]); + let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(7), + Some(8), + ])]); + let l2 = ListArray::new( + Arc::new(Field::new("item", a1.data_type().to_owned(), true)), + OffsetBuffer::from_lengths([a1.len() + a2.len()]), + arrow::compute::concat(&[&a1, &a2])?, + None, ); - let l3 = ScalarValue::new_list( - Some(vec![ScalarValue::new_list( - Some(vec![ScalarValue::from(9i32)]), - DataType::Int32, - )]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + let a1 = + ListArray::from_iter_primitive::(vec![Some(vec![Some(9)])]); + let l3 = ListArray::new( + Arc::new(Field::new("item", a1.data_type().to_owned(), true)), + OffsetBuffer::from_lengths([a1.len()]), + arrow::compute::concat(&[&a1])?, + None, ); - let list = ScalarValue::new_list( - Some(vec![l1.clone(), l2.clone(), l3.clone()]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + let list = ListArray::new( + Arc::new(Field::new("item", l1.data_type().to_owned(), true)), + OffsetBuffer::from_lengths([l1.len() + l2.len() + l3.len()]), + arrow::compute::concat(&[&l1, &l2, &l3])?, + None, ); + let list = ScalarValue::List(Arc::new(list)); + let l1 = ScalarValue::List(Arc::new(l1)); + let l2 = ScalarValue::List(Arc::new(l2)); + let l3 = ScalarValue::List(Arc::new(l3)); let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap(); diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index 422eecd20155..21143ce54a20 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -22,13 +22,13 @@ use std::any::Any; use std::fmt::Debug; use std::sync::Arc; -use arrow::array::{Array, ArrayRef}; +use arrow::array::ArrayRef; use std::collections::HashSet; use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; -use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Accumulator; /// Expression for a ARRAY_AGG(DISTINCT) aggregation. @@ -125,22 +125,18 @@ impl DistinctArrayAggAccumulator { impl Accumulator for DistinctArrayAggAccumulator { fn state(&self) -> Result> { - Ok(vec![ScalarValue::new_list( - Some(self.values.clone().into_iter().collect()), - self.datatype.clone(), - )]) + Ok(vec![self.evaluate()?]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { assert_eq!(values.len(), 1, "batch input should only include 1 column!"); let array = &values[0]; - (0..array.len()).try_for_each(|i| { - if !array.is_null(i) { - self.values.insert(ScalarValue::try_from_array(array, i)?); - } - Ok(()) - }) + let scalars = ScalarValue::convert_array_to_scalar_vec(array)?; + for scalar in scalars { + self.values.extend(scalar) + } + Ok(()) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { @@ -154,25 +150,18 @@ impl Accumulator for DistinctArrayAggAccumulator { "array_agg_distinct states must contain single array" ); - let array = &states[0]; - (0..array.len()).try_for_each(|i| { - let scalar = ScalarValue::try_from_array(array, i)?; - if let ScalarValue::List(Some(values), _) = scalar { - self.values.extend(values); - Ok(()) - } else { - internal_err!("array_agg_distinct state must be list") - } - })?; + let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&states[0])?; + for scalars in scalar_vec { + self.values.extend(scalars) + } Ok(()) } fn evaluate(&self) -> Result { - Ok(ScalarValue::new_list( - Some(self.values.clone().into_iter().collect()), - self.datatype.clone(), - )) + let values: Vec = self.values.iter().cloned().collect(); + let arr = ScalarValue::new_list(&values, &self.datatype); + Ok(ScalarValue::List(arr)) } fn size(&self) -> usize { @@ -185,34 +174,56 @@ impl Accumulator for DistinctArrayAggAccumulator { #[cfg(test)] mod tests { + use super::*; - use crate::aggregate::utils::get_accum_scalar_values_as_arrays; use crate::expressions::col; use crate::expressions::tests::aggregate; use arrow::array::{ArrayRef, Int32Array}; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; + use arrow_array::cast::as_list_array; + use arrow_array::types::Int32Type; + use arrow_array::{Array, ListArray}; + use arrow_buffer::OffsetBuffer; + use datafusion_common::utils::wrap_into_list_array; + use datafusion_common::{internal_err, DataFusionError}; + + // arrow::compute::sort cann't sort ListArray directly, so we need to sort the inner primitive array and wrap it back into ListArray. + fn sort_list_inner(arr: ScalarValue) -> ScalarValue { + let arr = match arr { + ScalarValue::List(arr) => { + let list_arr = as_list_array(&arr); + list_arr.value(0) + } + _ => { + panic!("Expected ScalarValue::List, got {:?}", arr) + } + }; + + let arr = arrow::compute::sort(&arr, None).unwrap(); + let list_arr = wrap_into_list_array(arr); + ScalarValue::List(Arc::new(list_arr)) + } fn compare_list_contents(expected: ScalarValue, actual: ScalarValue) -> Result<()> { - match (expected, actual) { - (ScalarValue::List(Some(mut e), _), ScalarValue::List(Some(mut a), _)) => { - // workaround lack of Ord of ScalarValue - let cmp = |a: &ScalarValue, b: &ScalarValue| { - a.partial_cmp(b).expect("Can compare ScalarValues") - }; - - e.sort_by(cmp); - a.sort_by(cmp); - // Check that the inputs are the same - assert_eq!(e, a); + let actual = sort_list_inner(actual); + + match (&expected, &actual) { + (ScalarValue::List(arr1), ScalarValue::List(arr2)) => { + if arr1.eq(arr2) { + Ok(()) + } else { + internal_err!( + "Actual value {:?} not found in expected values {:?}", + actual, + expected + ) + } } _ => { - return Err(DataFusionError::Internal( - "Expected scalar lists as inputs".to_string(), - )); + internal_err!("Expected scalar lists as inputs") } } - Ok(()) } fn check_distinct_array_agg( @@ -252,8 +263,8 @@ mod tests { accum1.update_batch(&[input1])?; accum2.update_batch(&[input2])?; - let state = get_accum_scalar_values_as_arrays(accum2.as_ref())?; - accum1.merge_batch(&state)?; + let array = accum2.state()?[0].raw_data()?; + accum1.merge_batch(&[array])?; let actual = accum1.evaluate()?; @@ -263,19 +274,18 @@ mod tests { #[test] fn distinct_array_agg_i32() -> Result<()> { let col: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5, 2])); - - let out = ScalarValue::new_list( - Some(vec![ - ScalarValue::Int32(Some(1)), - ScalarValue::Int32(Some(2)), - ScalarValue::Int32(Some(7)), - ScalarValue::Int32(Some(4)), - ScalarValue::Int32(Some(5)), - ]), - DataType::Int32, - ); - - check_distinct_array_agg(col, out, DataType::Int32) + let expected = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(4), + Some(5), + Some(7), + ])]), + )); + + check_distinct_array_agg(col, expected, DataType::Int32) } #[test] @@ -283,78 +293,90 @@ mod tests { let col1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5, 2])); let col2: ArrayRef = Arc::new(Int32Array::from(vec![1, 3, 7, 8, 4])); - let out = ScalarValue::new_list( - Some(vec![ - ScalarValue::Int32(Some(1)), - ScalarValue::Int32(Some(2)), - ScalarValue::Int32(Some(3)), - ScalarValue::Int32(Some(4)), - ScalarValue::Int32(Some(5)), - ScalarValue::Int32(Some(7)), - ScalarValue::Int32(Some(8)), - ]), - DataType::Int32, - ); - - check_merge_distinct_array_agg(col1, col2, out, DataType::Int32) + let expected = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + Some(7), + Some(8), + ])]), + )); + + check_merge_distinct_array_agg(col1, col2, expected, DataType::Int32) } #[test] fn distinct_array_agg_nested() -> Result<()> { // [[1, 2, 3], [4, 5]] - let l1 = ScalarValue::new_list( - Some(vec![ - ScalarValue::new_list( - Some(vec![ - ScalarValue::from(1i32), - ScalarValue::from(2i32), - ScalarValue::from(3i32), - ]), - DataType::Int32, - ), - ScalarValue::new_list( - Some(vec![ScalarValue::from(4i32), ScalarValue::from(5i32)]), - DataType::Int32, - ), - ]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + let a1 = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]); + let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(4), + Some(5), + ])]); + let l1 = ListArray::new( + Arc::new(Field::new("item", a1.data_type().to_owned(), true)), + OffsetBuffer::from_lengths([2]), + arrow::compute::concat(&[&a1, &a2]).unwrap(), + None, ); // [[6], [7, 8]] - let l2 = ScalarValue::new_list( - Some(vec![ - ScalarValue::new_list( - Some(vec![ScalarValue::from(6i32)]), - DataType::Int32, - ), - ScalarValue::new_list( - Some(vec![ScalarValue::from(7i32), ScalarValue::from(8i32)]), - DataType::Int32, - ), - ]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + let a1 = + ListArray::from_iter_primitive::(vec![Some(vec![Some(6)])]); + let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(7), + Some(8), + ])]); + let l2 = ListArray::new( + Arc::new(Field::new("item", a1.data_type().to_owned(), true)), + OffsetBuffer::from_lengths([2]), + arrow::compute::concat(&[&a1, &a2]).unwrap(), + None, ); // [[9]] - let l3 = ScalarValue::new_list( - Some(vec![ScalarValue::new_list( - Some(vec![ScalarValue::from(9i32)]), - DataType::Int32, - )]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + let a1 = + ListArray::from_iter_primitive::(vec![Some(vec![Some(9)])]); + let l3 = ListArray::new( + Arc::new(Field::new("item", a1.data_type().to_owned(), true)), + OffsetBuffer::from_lengths([1]), + Arc::new(a1), + None, ); - let list = ScalarValue::new_list( - Some(vec![l1.clone(), l2.clone(), l3.clone()]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), - ); + let l1 = ScalarValue::List(Arc::new(l1)); + let l2 = ScalarValue::List(Arc::new(l2)); + let l3 = ScalarValue::List(Arc::new(l3)); // Duplicate l1 in the input array and check that it is deduped in the output. let array = ScalarValue::iter_to_array(vec![l1.clone(), l2, l3, l1]).unwrap(); + let expected = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + Some(6), + Some(7), + Some(8), + Some(9), + ])]), + )); + check_distinct_array_agg( array, - list, + expected, DataType::List(Arc::new(Field::new_list( "item", Field::new("item", DataType::Int32, true), @@ -366,62 +388,66 @@ mod tests { #[test] fn merge_distinct_array_agg_nested() -> Result<()> { // [[1, 2], [3, 4]] - let l1 = ScalarValue::new_list( - Some(vec![ - ScalarValue::new_list( - Some(vec![ScalarValue::from(1i32), ScalarValue::from(2i32)]), - DataType::Int32, - ), - ScalarValue::new_list( - Some(vec![ScalarValue::from(3i32), ScalarValue::from(4i32)]), - DataType::Int32, - ), - ]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + let a1 = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + ])]); + let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(3), + Some(4), + ])]); + let l1 = ListArray::new( + Arc::new(Field::new("item", a1.data_type().to_owned(), true)), + OffsetBuffer::from_lengths([2]), + arrow::compute::concat(&[&a1, &a2]).unwrap(), + None, ); - // [[5]] - let l2 = ScalarValue::new_list( - Some(vec![ScalarValue::new_list( - Some(vec![ScalarValue::from(5i32)]), - DataType::Int32, - )]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + let a1 = + ListArray::from_iter_primitive::(vec![Some(vec![Some(5)])]); + let l2 = ListArray::new( + Arc::new(Field::new("item", a1.data_type().to_owned(), true)), + OffsetBuffer::from_lengths([1]), + Arc::new(a1), + None, ); // [[6, 7], [8]] - let l3 = ScalarValue::new_list( - Some(vec![ - ScalarValue::new_list( - Some(vec![ScalarValue::from(6i32), ScalarValue::from(7i32)]), - DataType::Int32, - ), - ScalarValue::new_list( - Some(vec![ScalarValue::from(8i32)]), - DataType::Int32, - ), - ]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + let a1 = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(6), + Some(7), + ])]); + let a2 = + ListArray::from_iter_primitive::(vec![Some(vec![Some(8)])]); + let l3 = ListArray::new( + Arc::new(Field::new("item", a1.data_type().to_owned(), true)), + OffsetBuffer::from_lengths([2]), + arrow::compute::concat(&[&a1, &a2]).unwrap(), + None, ); - let expected = ScalarValue::new_list( - Some(vec![l1.clone(), l2.clone(), l3.clone()]), - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), - ); + let l1 = ScalarValue::List(Arc::new(l1)); + let l2 = ScalarValue::List(Arc::new(l2)); + let l3 = ScalarValue::List(Arc::new(l3)); // Duplicate l1 in the input array and check that it is deduped in the output. let input1 = ScalarValue::iter_to_array(vec![l1.clone(), l2]).unwrap(); let input2 = ScalarValue::iter_to_array(vec![l1, l3]).unwrap(); - check_merge_distinct_array_agg( - input1, - input2, - expected, - DataType::List(Arc::new(Field::new_list( - "item", - Field::new("item", DataType::Int32, true), - true, - ))), - ) + let expected = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + Some(6), + Some(7), + Some(8), + ])]), + )); + + check_merge_distinct_array_agg(input1, input2, expected, DataType::Int32) } } diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs index bf5dbfb4fda9..a53d53107add 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs @@ -30,10 +30,11 @@ use crate::{AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr}; use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field}; -use arrow_array::{Array, ListArray}; +use arrow_array::Array; use arrow_schema::{Fields, SortOptions}; +use datafusion_common::cast::as_list_array; use datafusion_common::utils::{compare_rows, get_row_at_idx}; -use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue}; +use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::Accumulator; use itertools::izip; @@ -181,12 +182,14 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { if values.is_empty() { return Ok(()); } + let n_row = values[0].len(); for index in 0..n_row { let row = get_row_at_idx(values, index)?; self.values.push(row[0].clone()); self.ordering_values.push(row[1..].to_vec()); } + Ok(()) } @@ -197,10 +200,11 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { // First entry in the state is the aggregation result. let array_agg_values = &states[0]; // 2nd entry stores values received for ordering requirement columns, for each aggregation value inside ARRAY_AGG list. - // For each `ScalarValue` inside ARRAY_AGG list, we will receive a `Vec` that stores + // For each `StructArray` inside ARRAY_AGG list, we will receive an `Array` that stores // values received from its ordering requirement expression. (This information is necessary for during merging). let agg_orderings = &states[1]; - if agg_orderings.as_any().is::() { + + if as_list_array(agg_orderings).is_ok() { // Stores ARRAY_AGG results coming from each partition let mut partition_values = vec![]; // Stores ordering requirement expression results coming from each partition @@ -209,20 +213,21 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { // Existing values should be merged also. partition_values.push(self.values.clone()); partition_ordering_values.push(self.ordering_values.clone()); - for index in 0..agg_orderings.len() { - let ordering = ScalarValue::try_from_array(agg_orderings, index)?; - // Ordering requirement expression values for each entry in the ARRAY_AGG list - let other_ordering_values = - self.convert_array_agg_to_orderings(ordering)?; - // ARRAY_AGG result. (It is a `ScalarValue::List` under the hood, it stores `Vec`) - let array_agg_res = ScalarValue::try_from_array(array_agg_values, index)?; - if let ScalarValue::List(Some(other_values), _) = array_agg_res { - partition_values.push(other_values); - partition_ordering_values.push(other_ordering_values); - } else { - return internal_err!("ARRAY_AGG state must be list!"); - } + + let array_agg_res = + ScalarValue::convert_array_to_scalar_vec(array_agg_values)?; + + for v in array_agg_res.into_iter() { + partition_values.push(v); } + + let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?; + // Ordering requirement expression values for each entry in the ARRAY_AGG list + let other_ordering_values = self.convert_array_agg_to_orderings(orderings)?; + for v in other_ordering_values.into_iter() { + partition_ordering_values.push(v); + } + let sort_options = self .ordering_req .iter() @@ -248,10 +253,8 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { } fn evaluate(&self) -> Result { - Ok(ScalarValue::new_list( - Some(self.values.clone()), - self.datatypes[0].clone(), - )) + let arr = ScalarValue::new_list(&self.values, &self.datatypes[0]); + Ok(ScalarValue::List(arr)) } fn size(&self) -> usize { @@ -280,33 +283,34 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { } impl OrderSensitiveArrayAggAccumulator { + /// Inner Vec\ in the ordering_values can be thought as ordering information for the each ScalarValue in the values array. + /// See [`merge_ordered_arrays`] for more information. fn convert_array_agg_to_orderings( &self, - in_data: ScalarValue, - ) -> Result>> { - if let ScalarValue::List(Some(list_vals), _field_ref) = in_data { - list_vals.into_iter().map(|struct_vals| { - if let ScalarValue::Struct(Some(orderings), _fields) = struct_vals { - Ok(orderings) - } else { - exec_err!( - "Expects to receive ScalarValue::Struct(Some(..), _) but got:{:?}", - struct_vals.data_type() - ) - } - }).collect::>>() - } else { - exec_err!( - "Expects to receive ScalarValue::List(Some(..), _) but got:{:?}", - in_data.data_type() - ) + array_agg: Vec>, + ) -> Result>>> { + let mut orderings = vec![]; + // in_data is Vec where ScalarValue does not include ScalarValue::List + for in_data in array_agg.into_iter() { + let ordering = in_data.into_iter().map(|struct_vals| { + if let ScalarValue::Struct(Some(orderings), _) = struct_vals { + Ok(orderings) + } else { + exec_err!( + "Expects to receive ScalarValue::Struct(Some(..), _) but got:{:?}", + struct_vals.data_type() + ) + } + }).collect::>>()?; + orderings.push(ordering); } + Ok(orderings) } fn evaluate_orderings(&self) -> Result { let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]); let struct_field = Fields::from(fields.clone()); - let orderings = self + let orderings: Vec = self .ordering_values .iter() .map(|ordering| { @@ -314,7 +318,9 @@ impl OrderSensitiveArrayAggAccumulator { }) .collect(); let struct_type = DataType::Struct(Fields::from(fields)); - Ok(ScalarValue::new_list(Some(orderings), struct_type)) + + let arr = ScalarValue::new_list(&orderings, &struct_type); + Ok(ScalarValue::List(arr)) } } diff --git a/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs b/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs index 93b911c939d6..d7934e79c366 100644 --- a/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs +++ b/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs @@ -18,6 +18,7 @@ //! Defines BitAnd, BitOr, and BitXor Aggregate accumulators use ahash::RandomState; +use datafusion_common::cast::as_list_array; use std::any::Any; use std::sync::Arc; @@ -637,13 +638,14 @@ where // 1. Stores aggregate state in `ScalarValue::List` // 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set let state_out = { - let values = self + let values: Vec = self .values .iter() .map(|x| ScalarValue::new_primitive::(Some(*x), &T::DATA_TYPE)) .collect(); - vec![ScalarValue::new_list(Some(values), T::DATA_TYPE)] + let arr = ScalarValue::new_list(&values, &T::DATA_TYPE); + vec![ScalarValue::List(arr)] }; Ok(state_out) } @@ -668,12 +670,11 @@ where } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - - for x in states[0].as_list::().iter().flatten() { - self.update_batch(&[x])? + if let Some(state) = states.first() { + let list_arr = as_list_array(state)?; + for arr in list_arr.iter().flatten() { + self.update_batch(&[arr])?; + } } Ok(()) } diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index 05be8cbccb5f..f5242d983d4c 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -16,6 +16,7 @@ // under the License. use arrow::datatypes::{DataType, Field}; + use std::any::Any; use std::fmt::Debug; use std::sync::Arc; @@ -27,8 +28,8 @@ use std::collections::HashSet; use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; +use datafusion_common::Result; use datafusion_common::ScalarValue; -use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_expr::Accumulator; type DistinctScalarValues = ScalarValue; @@ -142,18 +143,11 @@ impl DistinctCountAccumulator { impl Accumulator for DistinctCountAccumulator { fn state(&self) -> Result> { - let mut cols_out = - ScalarValue::new_list(Some(Vec::new()), self.state_data_type.clone()); - self.values - .iter() - .enumerate() - .for_each(|(_, distinct_values)| { - if let ScalarValue::List(Some(ref mut v), _) = cols_out { - v.push(distinct_values.clone()); - } - }); - Ok(vec![cols_out]) + let scalars = self.values.iter().cloned().collect::>(); + let arr = ScalarValue::new_list(scalars.as_slice(), &self.state_data_type); + Ok(vec![ScalarValue::List(arr)]) } + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { if values.is_empty() { return Ok(()); @@ -167,25 +161,17 @@ impl Accumulator for DistinctCountAccumulator { Ok(()) }) } + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { if states.is_empty() { return Ok(()); } - let arr = &states[0]; - (0..arr.len()).try_for_each(|index| { - let scalar = ScalarValue::try_from_array(arr, index)?; - - if let ScalarValue::List(Some(scalar), _) = scalar { - scalar.iter().for_each(|scalar| { - if !ScalarValue::is_null(scalar) { - self.values.insert(scalar.clone()); - } - }); - } else { - return internal_err!("Unexpected accumulator state"); - } - Ok(()) - }) + assert_eq!(states.len(), 1, "array_agg states must be singleton!"); + let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&states[0])?; + for scalars in scalar_vec.into_iter() { + self.values.extend(scalars) + } + Ok(()) } fn evaluate(&self) -> Result { @@ -211,33 +197,21 @@ mod tests { Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }; use arrow::datatypes::DataType; + use arrow::datatypes::{ + Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, + UInt32Type, UInt64Type, UInt8Type, + }; + use datafusion_common::cast::{as_boolean_array, as_list_array, as_primitive_array}; use datafusion_common::internal_err; - - macro_rules! state_to_vec { - ($LIST:expr, $DATA_TYPE:ident, $PRIM_TY:ty) => {{ - match $LIST { - ScalarValue::List(_, field) => match field.data_type() { - &DataType::$DATA_TYPE => (), - _ => panic!("Unexpected DataType for list"), - }, - _ => panic!("Expected a ScalarValue::List"), - } - - match $LIST { - ScalarValue::List(None, _) => None, - ScalarValue::List(Some(scalar_values), _) => { - let vec = scalar_values - .iter() - .map(|scalar_value| match scalar_value { - ScalarValue::$DATA_TYPE(value) => *value, - _ => panic!("Unexpected ScalarValue variant"), - }) - .collect::>>(); - - Some(vec) - } - _ => unreachable!(), - } + use datafusion_common::DataFusionError; + + macro_rules! state_to_vec_primitive { + ($LIST:expr, $DATA_TYPE:ident) => {{ + let arr = ScalarValue::raw_data($LIST).unwrap(); + let list_arr = as_list_array(&arr).unwrap(); + let arr = list_arr.values(); + let arr = as_primitive_array::<$DATA_TYPE>(arr)?; + arr.values().iter().cloned().collect::>() }}; } @@ -259,18 +233,25 @@ mod tests { let (states, result) = run_update_batch(&arrays)?; - let mut state_vec = - state_to_vec!(&states[0], $DATA_TYPE, $PRIM_TYPE).unwrap(); + let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE); state_vec.sort(); assert_eq!(states.len(), 1); - assert_eq!(state_vec, vec![Some(1), Some(2), Some(3)]); + assert_eq!(state_vec, vec![1, 2, 3]); assert_eq!(result, ScalarValue::Int64(Some(3))); Ok(()) }}; } + fn state_to_vec_bool(sv: &ScalarValue) -> Result> { + let arr = ScalarValue::raw_data(sv)?; + let list_arr = as_list_array(&arr)?; + let arr = list_arr.values(); + let bool_arr = as_boolean_array(arr)?; + Ok(bool_arr.iter().flatten().collect()) + } + fn run_update_batch(arrays: &[ArrayRef]) -> Result<(Vec, ScalarValue)> { let agg = DistinctCount::new( arrays[0].data_type().clone(), @@ -353,13 +334,11 @@ mod tests { let (states, result) = run_update_batch(&arrays)?; - let mut state_vec = - state_to_vec!(&states[0], $DATA_TYPE, $PRIM_TYPE).unwrap(); + let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE); dbg!(&state_vec); state_vec.sort_by(|a, b| match (a, b) { - (Some(lhs), Some(rhs)) => lhs.total_cmp(rhs), - _ => a.partial_cmp(b).unwrap(), + (lhs, rhs) => lhs.total_cmp(rhs), }); let nan_idx = state_vec.len() - 1; @@ -367,16 +346,16 @@ mod tests { assert_eq!( &state_vec[..nan_idx], vec![ - Some(<$PRIM_TYPE>::NEG_INFINITY), - Some(-4.5), - Some(<$PRIM_TYPE as SubNormal>::SUBNORMAL), - Some(1.0), - Some(2.0), - Some(3.0), - Some(<$PRIM_TYPE>::INFINITY) + <$PRIM_TYPE>::NEG_INFINITY, + -4.5, + <$PRIM_TYPE as SubNormal>::SUBNORMAL, + 1.0, + 2.0, + 3.0, + <$PRIM_TYPE>::INFINITY ] ); - assert!(state_vec[nan_idx].unwrap_or_default().is_nan()); + assert!(state_vec[nan_idx].is_nan()); assert_eq!(result, ScalarValue::Int64(Some(8))); Ok(()) @@ -385,61 +364,62 @@ mod tests { #[test] fn count_distinct_update_batch_i8() -> Result<()> { - test_count_distinct_update_batch_numeric!(Int8Array, Int8, i8) + test_count_distinct_update_batch_numeric!(Int8Array, Int8Type, i8) } #[test] fn count_distinct_update_batch_i16() -> Result<()> { - test_count_distinct_update_batch_numeric!(Int16Array, Int16, i16) + test_count_distinct_update_batch_numeric!(Int16Array, Int16Type, i16) } #[test] fn count_distinct_update_batch_i32() -> Result<()> { - test_count_distinct_update_batch_numeric!(Int32Array, Int32, i32) + test_count_distinct_update_batch_numeric!(Int32Array, Int32Type, i32) } #[test] fn count_distinct_update_batch_i64() -> Result<()> { - test_count_distinct_update_batch_numeric!(Int64Array, Int64, i64) + test_count_distinct_update_batch_numeric!(Int64Array, Int64Type, i64) } #[test] fn count_distinct_update_batch_u8() -> Result<()> { - test_count_distinct_update_batch_numeric!(UInt8Array, UInt8, u8) + test_count_distinct_update_batch_numeric!(UInt8Array, UInt8Type, u8) } #[test] fn count_distinct_update_batch_u16() -> Result<()> { - test_count_distinct_update_batch_numeric!(UInt16Array, UInt16, u16) + test_count_distinct_update_batch_numeric!(UInt16Array, UInt16Type, u16) } #[test] fn count_distinct_update_batch_u32() -> Result<()> { - test_count_distinct_update_batch_numeric!(UInt32Array, UInt32, u32) + test_count_distinct_update_batch_numeric!(UInt32Array, UInt32Type, u32) } #[test] fn count_distinct_update_batch_u64() -> Result<()> { - test_count_distinct_update_batch_numeric!(UInt64Array, UInt64, u64) + test_count_distinct_update_batch_numeric!(UInt64Array, UInt64Type, u64) } #[test] fn count_distinct_update_batch_f32() -> Result<()> { - test_count_distinct_update_batch_floating_point!(Float32Array, Float32, f32) + test_count_distinct_update_batch_floating_point!(Float32Array, Float32Type, f32) } #[test] fn count_distinct_update_batch_f64() -> Result<()> { - test_count_distinct_update_batch_floating_point!(Float64Array, Float64, f64) + test_count_distinct_update_batch_floating_point!(Float64Array, Float64Type, f64) } #[test] fn count_distinct_update_batch_boolean() -> Result<()> { - let get_count = |data: BooleanArray| -> Result<(Vec>, i64)> { + let get_count = |data: BooleanArray| -> Result<(Vec, i64)> { let arrays = vec![Arc::new(data) as ArrayRef]; let (states, result) = run_update_batch(&arrays)?; - let mut state_vec = state_to_vec!(&states[0], Boolean, bool).unwrap(); + let mut state_vec = state_to_vec_bool(&states[0])?; state_vec.sort(); + let count = match result { ScalarValue::Int64(c) => c.ok_or_else(|| { DataFusionError::Internal("Found None count".to_string()) @@ -467,22 +447,13 @@ mod tests { Some(false), ]); - assert_eq!( - get_count(zero_count_values)?, - (Vec::>::new(), 0) - ); - assert_eq!(get_count(one_count_values)?, (vec![Some(false)], 1)); - assert_eq!( - get_count(one_count_values_with_null)?, - (vec![Some(true)], 1) - ); - assert_eq!( - get_count(two_count_values)?, - (vec![Some(false), Some(true)], 2) - ); + assert_eq!(get_count(zero_count_values)?, (Vec::::new(), 0)); + assert_eq!(get_count(one_count_values)?, (vec![false], 1)); + assert_eq!(get_count(one_count_values_with_null)?, (vec![true], 1)); + assert_eq!(get_count(two_count_values)?, (vec![false, true], 2)); assert_eq!( get_count(two_count_values_with_null)?, - (vec![Some(false), Some(true)], 2) + (vec![false, true], 2) ); Ok(()) } @@ -494,9 +465,9 @@ mod tests { )) as ArrayRef]; let (states, result) = run_update_batch(&arrays)?; - + let state_vec = state_to_vec_primitive!(&states[0], Int32Type); assert_eq!(states.len(), 1); - assert_eq!(state_to_vec!(&states[0], Int32, i32), Some(vec![])); + assert!(state_vec.is_empty()); assert_eq!(result, ScalarValue::Int64(Some(0))); Ok(()) @@ -507,9 +478,9 @@ mod tests { let arrays = vec![Arc::new(Int32Array::from(vec![0_i32; 0])) as ArrayRef]; let (states, result) = run_update_batch(&arrays)?; - + let state_vec = state_to_vec_primitive!(&states[0], Int32Type); assert_eq!(states.len(), 1); - assert_eq!(state_to_vec!(&states[0], Int32, i32), Some(vec![])); + assert!(state_vec.is_empty()); assert_eq!(result, ScalarValue::Int64(Some(0))); Ok(()) diff --git a/datafusion/physical-expr/src/aggregate/median.rs b/datafusion/physical-expr/src/aggregate/median.rs index 1ec412402638..477dcadceee7 100644 --- a/datafusion/physical-expr/src/aggregate/median.rs +++ b/datafusion/physical-expr/src/aggregate/median.rs @@ -146,14 +146,14 @@ impl std::fmt::Debug for MedianAccumulator { impl Accumulator for MedianAccumulator { fn state(&self) -> Result> { - let all_values = self + let all_values: Vec = self .all_values .iter() .map(|x| ScalarValue::new_primitive::(Some(*x), &self.data_type)) .collect(); - let state = ScalarValue::new_list(Some(all_values), self.data_type.clone()); - Ok(vec![state]) + let arr = ScalarValue::new_list(&all_values, &self.data_type); + Ok(vec![ScalarValue::List(arr)]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { diff --git a/datafusion/physical-expr/src/aggregate/sum_distinct.rs b/datafusion/physical-expr/src/aggregate/sum_distinct.rs index c3d8d5e87068..742e24b99e71 100644 --- a/datafusion/physical-expr/src/aggregate/sum_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/sum_distinct.rs @@ -166,10 +166,10 @@ impl Accumulator for DistinctSumAccumulator { &self.data_type, )) }); - vec![ScalarValue::new_list( - Some(distinct_values), - self.data_type.clone(), - )] + vec![ScalarValue::List(ScalarValue::new_list( + &distinct_values, + &self.data_type, + ))] }; Ok(state_out) } diff --git a/datafusion/physical-expr/src/aggregate/tdigest.rs b/datafusion/physical-expr/src/aggregate/tdigest.rs index 7e6d2dcf8f4f..90f5244f477d 100644 --- a/datafusion/physical-expr/src/aggregate/tdigest.rs +++ b/datafusion/physical-expr/src/aggregate/tdigest.rs @@ -28,6 +28,9 @@ //! [Facebook's Folly TDigest]: https://github.com/facebook/folly/blob/main/folly/stats/TDigest.h use arrow::datatypes::DataType; +use arrow_array::cast::as_list_array; +use arrow_array::types::Float64Type; +use datafusion_common::cast::as_primitive_array; use datafusion_common::Result; use datafusion_common::ScalarValue; use std::cmp::Ordering; @@ -566,20 +569,22 @@ impl TDigest { /// [`TDigest`]. pub(crate) fn to_scalar_state(&self) -> Vec { // Gather up all the centroids - let centroids: Vec<_> = self + let centroids: Vec = self .centroids .iter() .flat_map(|c| [c.mean(), c.weight()]) .map(|v| ScalarValue::Float64(Some(v))) .collect(); + let arr = ScalarValue::new_list(¢roids, &DataType::Float64); + vec![ ScalarValue::UInt64(Some(self.max_size as u64)), ScalarValue::Float64(Some(self.sum)), ScalarValue::Float64(Some(self.count)), ScalarValue::Float64(Some(self.max)), ScalarValue::Float64(Some(self.min)), - ScalarValue::new_list(Some(centroids), DataType::Float64), + ScalarValue::List(arr), ] } @@ -600,10 +605,18 @@ impl TDigest { }; let centroids: Vec<_> = match &state[5] { - ScalarValue::List(Some(c), f) if *f.data_type() == DataType::Float64 => c - .chunks(2) - .map(|v| Centroid::new(cast_scalar_f64!(v[0]), cast_scalar_f64!(v[1]))) - .collect(), + ScalarValue::List(arr) => { + let list_array = as_list_array(arr); + let arr = list_array.values(); + + let f64arr = + as_primitive_array::(arr).expect("expected f64 array"); + f64arr + .values() + .chunks(2) + .map(|v| Centroid::new(v[0], v[1])) + .collect() + } v => panic!("invalid centroids type {v:?}"), }; diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 0a9c2f56048e..067a4cfdffc0 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -395,7 +395,7 @@ fn array_array(args: &[ArrayRef], data_type: DataType) -> Result { /// `ListArray` /// /// See [`array_array`] for more details. -fn array(values: &[ColumnarValue]) -> Result { +fn array(values: &[ColumnarValue]) -> Result { let arrays: Vec = values .iter() .map(|x| match x { @@ -417,26 +417,17 @@ fn array(values: &[ColumnarValue]) -> Result { match data_type { // empty array - None => Ok(ColumnarValue::Scalar(ScalarValue::new_list( - Some(vec![]), - DataType::Null, - ))), + None => { + let list_arr = ScalarValue::new_list(&[], &DataType::Null); + Ok(Arc::new(list_arr)) + } // all nulls, set default data type as int32 Some(DataType::Null) => { - let nulls = arrays.len(); - let null_arr = Int32Array::from(vec![None; nulls]); - let field = Arc::new(Field::new("item", DataType::Int32, true)); - let offsets = OffsetBuffer::from_lengths([nulls]); - let values = Arc::new(null_arr) as ArrayRef; - let nulls = None; - Ok(ColumnarValue::Array(Arc::new(ListArray::new( - field, offsets, values, nulls, - )))) + let null_arr = vec![ScalarValue::Int32(None); arrays.len()]; + let list_arr = ScalarValue::new_list(null_arr.as_slice(), &DataType::Int32); + Ok(Arc::new(list_arr)) } - Some(data_type) => Ok(ColumnarValue::Array(array_array( - arrays.as_slice(), - data_type, - )?)), + Some(data_type) => Ok(array_array(arrays.as_slice(), data_type)?), } } @@ -446,11 +437,7 @@ pub fn make_array(arrays: &[ArrayRef]) -> Result { .iter() .map(|x| ColumnarValue::Array(x.clone())) .collect(); - - match array(values.as_slice())? { - ColumnarValue::Array(array) => Ok(array), - ColumnarValue::Scalar(scalar) => Ok(scalar.to_array().clone()), - } + array(values.as_slice()) } fn return_empty(return_null: bool, data_type: DataType) -> Arc { @@ -671,9 +658,7 @@ pub fn array_append(args: &[ArrayRef]) -> Result { check_datatypes("array_append", &[arr.values(), element])?; let res = match arr.value_type() { DataType::List(_) => concat_internal(args)?, - DataType::Null => { - return Ok(array(&[ColumnarValue::Array(args[1].clone())])?.into_array(1)) - } + DataType::Null => return array(&[ColumnarValue::Array(args[1].clone())]), data_type => { macro_rules! array_function { ($ARRAY_TYPE:ident) => { @@ -747,9 +732,7 @@ pub fn array_prepend(args: &[ArrayRef]) -> Result { check_datatypes("array_prepend", &[element, arr.values()])?; let res = match arr.value_type() { DataType::List(_) => concat_internal(args)?, - DataType::Null => { - return Ok(array(&[ColumnarValue::Array(args[0].clone())])?.into_array(1)) - } + DataType::Null => return array(&[ColumnarValue::Array(args[0].clone())]), data_type => { macro_rules! array_function { ($ARRAY_TYPE:ident) => { @@ -1636,7 +1619,7 @@ fn flatten_internal( indexes: Option>, ) -> Result { let list_arr = as_list_array(array)?; - let (field, offsets, values, nulls) = list_arr.clone().into_parts(); + let (field, offsets, values, _) = list_arr.clone().into_parts(); let data_type = field.data_type(); match data_type { @@ -1653,7 +1636,7 @@ fn flatten_internal( _ => { if let Some(indexes) = indexes { let offsets = get_offsets_for_flatten(offsets, indexes); - let list_arr = ListArray::new(field, offsets, values, nulls); + let list_arr = ListArray::new(field, offsets, values, None); Ok(list_arr) } else { Ok(list_arr.clone()) @@ -1974,9 +1957,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), ]; - let array = array(&args) - .expect("failed to initialize function array") - .into_array(1); + let array = array(&args).expect("failed to initialize function array"); let result = as_list_array(&array).expect("failed to initialize function array"); assert_eq!(result.len(), 1); assert_eq!( @@ -1998,9 +1979,7 @@ mod tests { ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 4]))), ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 6]))), ]; - let array = array(&args) - .expect("failed to initialize function array") - .into_array(1); + let array = array(&args).expect("failed to initialize function array"); let result = as_list_array(&array).expect("failed to initialize function array"); assert_eq!(result.len(), 2); assert_eq!( @@ -3317,9 +3296,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), ]; - let result = array(&args) - .expect("failed to initialize function array") - .into_array(1); + let result = array(&args).expect("failed to initialize function array"); ColumnarValue::Array(result.clone()) } @@ -3331,9 +3308,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::Int64(Some(13))), ColumnarValue::Scalar(ScalarValue::Int64(Some(14))), ]; - let result = array(&args) - .expect("failed to initialize function array") - .into_array(1); + let result = array(&args).expect("failed to initialize function array"); ColumnarValue::Array(result.clone()) } @@ -3345,9 +3320,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), ]; - let arr1 = array(&args) - .expect("failed to initialize function array") - .into_array(1); + let arr1 = array(&args).expect("failed to initialize function array"); let args = [ ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), @@ -3355,14 +3328,10 @@ mod tests { ColumnarValue::Scalar(ScalarValue::Int64(Some(7))), ColumnarValue::Scalar(ScalarValue::Int64(Some(8))), ]; - let arr2 = array(&args) - .expect("failed to initialize function array") - .into_array(1); + let arr2 = array(&args).expect("failed to initialize function array"); let args = [ColumnarValue::Array(arr1), ColumnarValue::Array(arr2)]; - let result = array(&args) - .expect("failed to initialize function array") - .into_array(1); + let result = array(&args).expect("failed to initialize function array"); ColumnarValue::Array(result.clone()) } @@ -3374,9 +3343,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), ColumnarValue::Scalar(ScalarValue::Null), ]; - let result = array(&args) - .expect("failed to initialize function array") - .into_array(1); + let result = array(&args).expect("failed to initialize function array"); ColumnarValue::Array(result.clone()) } @@ -3388,9 +3355,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), ColumnarValue::Scalar(ScalarValue::Null), ]; - let arr1 = array(&args) - .expect("failed to initialize function array") - .into_array(1); + let arr1 = array(&args).expect("failed to initialize function array"); let args = [ ColumnarValue::Scalar(ScalarValue::Null), @@ -3398,14 +3363,10 @@ mod tests { ColumnarValue::Scalar(ScalarValue::Int64(Some(7))), ColumnarValue::Scalar(ScalarValue::Null), ]; - let arr2 = array(&args) - .expect("failed to initialize function array") - .into_array(1); + let arr2 = array(&args).expect("failed to initialize function array"); let args = [ColumnarValue::Array(arr1), ColumnarValue::Array(arr2)]; - let result = array(&args) - .expect("failed to initialize function array") - .into_array(1); + let result = array(&args).expect("failed to initialize function array"); ColumnarValue::Array(result.clone()) } @@ -3419,9 +3380,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), ]; - let result = array(&args) - .expect("failed to initialize function array") - .into_array(1); + let result = array(&args).expect("failed to initialize function array"); ColumnarValue::Array(result.clone()) } @@ -3433,9 +3392,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), ]; - let arr1 = array(&args) - .expect("failed to initialize function array") - .into_array(1); + let arr1 = array(&args).expect("failed to initialize function array"); let args = [ ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), @@ -3443,9 +3400,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::Int64(Some(7))), ColumnarValue::Scalar(ScalarValue::Int64(Some(8))), ]; - let arr2 = array(&args) - .expect("failed to initialize function array") - .into_array(1); + let arr2 = array(&args).expect("failed to initialize function array"); let args = [ ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), @@ -3453,9 +3408,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), ]; - let arr3 = array(&args) - .expect("failed to initialize function array") - .into_array(1); + let arr3 = array(&args).expect("failed to initialize function array"); let args = [ ColumnarValue::Scalar(ScalarValue::Int64(Some(9))), @@ -3463,9 +3416,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::Int64(Some(11))), ColumnarValue::Scalar(ScalarValue::Int64(Some(12))), ]; - let arr4 = array(&args) - .expect("failed to initialize function array") - .into_array(1); + let arr4 = array(&args).expect("failed to initialize function array"); let args = [ ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), @@ -3473,9 +3424,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::Int64(Some(7))), ColumnarValue::Scalar(ScalarValue::Int64(Some(8))), ]; - let arr5 = array(&args) - .expect("failed to initialize function array") - .into_array(1); + let arr5 = array(&args).expect("failed to initialize function array"); let args = [ ColumnarValue::Array(arr1), @@ -3484,9 +3433,7 @@ mod tests { ColumnarValue::Array(arr4), ColumnarValue::Array(arr5), ]; - let result = array(&args) - .expect("failed to initialize function array") - .into_array(1); + let result = array(&args).expect("failed to initialize function array"); ColumnarValue::Array(result.clone()) } } diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index e04a68615a46..f23b45e26a03 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -349,13 +349,7 @@ where .collect::>(); let result = (inner)(&args); - - // maybe back to scalar - if len.is_some() { - result.map(ColumnarValue::Array) - } else { - ScalarValue::try_from_array(&result?, 0).map(ColumnarValue::Scalar) - } + result.map(ColumnarValue::Array) }) } diff --git a/datafusion/physical-plan/src/values.rs b/datafusion/physical-plan/src/values.rs index 2cf341d1fe60..383726401c5a 100644 --- a/datafusion/physical-plan/src/values.rs +++ b/datafusion/physical-plan/src/values.rs @@ -66,10 +66,11 @@ impl ValuesExec { (0..n_row) .map(|i| { let r = data[i][j].evaluate(&batch); + match r { Ok(ColumnarValue::Scalar(scalar)) => Ok(scalar), Ok(ColumnarValue::Array(a)) if a.len() == 1 => { - ScalarValue::try_from_array(&a, 0) + Ok(ScalarValue::List(a)) } Ok(ColumnarValue::Array(a)) => { plan_err!( diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index c60dae71ef86..1819d1a4392d 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -883,12 +883,10 @@ message Union{ repeated int32 type_ids = 3; } -message ScalarListValue{ - // encode null explicitly to distinguish a list with a null value - // from a list with no values) - bool is_null = 3; - Field field = 1; - repeated ScalarValue values = 2; +message ScalarListValue { + bytes ipc_message = 1; + bytes arrow_data = 2; + Schema schema = 3; } message ScalarTime32Value { @@ -965,7 +963,6 @@ message ScalarValue{ int32 date_32_value = 14; ScalarTime32Value time32_value = 15; ScalarListValue list_value = 17; - //WAS: ScalarType null_list_value = 18; Decimal128 decimal128_value = 20; Decimal256 decimal256_value = 39; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 266075e68922..9aed987491f3 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -20003,24 +20003,26 @@ impl serde::Serialize for ScalarListValue { { use serde::ser::SerializeStruct; let mut len = 0; - if self.is_null { + if !self.ipc_message.is_empty() { len += 1; } - if self.field.is_some() { + if !self.arrow_data.is_empty() { len += 1; } - if !self.values.is_empty() { + if self.schema.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.ScalarListValue", len)?; - if self.is_null { - struct_ser.serialize_field("isNull", &self.is_null)?; + if !self.ipc_message.is_empty() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("ipcMessage", pbjson::private::base64::encode(&self.ipc_message).as_str())?; } - if let Some(v) = self.field.as_ref() { - struct_ser.serialize_field("field", v)?; + if !self.arrow_data.is_empty() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("arrowData", pbjson::private::base64::encode(&self.arrow_data).as_str())?; } - if !self.values.is_empty() { - struct_ser.serialize_field("values", &self.values)?; + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; } struct_ser.end() } @@ -20032,17 +20034,18 @@ impl<'de> serde::Deserialize<'de> for ScalarListValue { D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "is_null", - "isNull", - "field", - "values", + "ipc_message", + "ipcMessage", + "arrow_data", + "arrowData", + "schema", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - IsNull, - Field, - Values, + IpcMessage, + ArrowData, + Schema, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -20064,9 +20067,9 @@ impl<'de> serde::Deserialize<'de> for ScalarListValue { E: serde::de::Error, { match value { - "isNull" | "is_null" => Ok(GeneratedField::IsNull), - "field" => Ok(GeneratedField::Field), - "values" => Ok(GeneratedField::Values), + "ipcMessage" | "ipc_message" => Ok(GeneratedField::IpcMessage), + "arrowData" | "arrow_data" => Ok(GeneratedField::ArrowData), + "schema" => Ok(GeneratedField::Schema), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -20086,35 +20089,39 @@ impl<'de> serde::Deserialize<'de> for ScalarListValue { where V: serde::de::MapAccess<'de>, { - let mut is_null__ = None; - let mut field__ = None; - let mut values__ = None; + let mut ipc_message__ = None; + let mut arrow_data__ = None; + let mut schema__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::IsNull => { - if is_null__.is_some() { - return Err(serde::de::Error::duplicate_field("isNull")); + GeneratedField::IpcMessage => { + if ipc_message__.is_some() { + return Err(serde::de::Error::duplicate_field("ipcMessage")); } - is_null__ = Some(map_.next_value()?); + ipc_message__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; } - GeneratedField::Field => { - if field__.is_some() { - return Err(serde::de::Error::duplicate_field("field")); + GeneratedField::ArrowData => { + if arrow_data__.is_some() { + return Err(serde::de::Error::duplicate_field("arrowData")); } - field__ = map_.next_value()?; + arrow_data__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; } - GeneratedField::Values => { - if values__.is_some() { - return Err(serde::de::Error::duplicate_field("values")); + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); } - values__ = Some(map_.next_value()?); + schema__ = map_.next_value()?; } } } Ok(ScalarListValue { - is_null: is_null__.unwrap_or_default(), - field: field__, - values: values__.unwrap_or_default(), + ipc_message: ipc_message__.unwrap_or_default(), + arrow_data: arrow_data__.unwrap_or_default(), + schema: schema__, }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 894afa570fb0..883799b1590d 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1072,14 +1072,12 @@ pub struct Union { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ScalarListValue { - /// encode null explicitly to distinguish a list with a null value - /// from a list with no values) - #[prost(bool, tag = "3")] - pub is_null: bool, - #[prost(message, optional, tag = "1")] - pub field: ::core::option::Option, - #[prost(message, repeated, tag = "2")] - pub values: ::prost::alloc::vec::Vec, + #[prost(bytes = "vec", tag = "1")] + pub ipc_message: ::prost::alloc::vec::Vec, + #[prost(bytes = "vec", tag = "2")] + pub arrow_data: ::prost::alloc::vec::Vec, + #[prost(message, optional, tag = "3")] + pub schema: ::core::option::Option, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -1224,7 +1222,6 @@ pub mod scalar_value { Date32Value(i32), #[prost(message, tag = "15")] Time32Value(super::ScalarTime32Value), - /// WAS: ScalarType null_list_value = 18; #[prost(message, tag = "17")] ListValue(super::ScalarListValue), #[prost(message, tag = "20")] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 2203016e08f1..b3873c01dd06 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -25,9 +25,13 @@ use crate::protobuf::{ AnalyzedLogicalPlanType, CubeNode, GroupingSetNode, OptimizedLogicalPlanType, OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, }; -use arrow::datatypes::{ - i256, DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, TimeUnit, - UnionFields, UnionMode, +use arrow::{ + buffer::{Buffer, MutableBuffer}, + datatypes::{ + i256, DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, TimeUnit, + UnionFields, UnionMode, + }, + ipc::{reader::read_record_batch, root_as_message}, }; use datafusion::execution::registry::FunctionRegistry; use datafusion_common::{ @@ -643,23 +647,39 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Value::Date32Value(v) => Self::Date32(Some(*v)), Value::ListValue(scalar_list) => { let protobuf::ScalarListValue { - is_null, - values, - field, + ipc_message, + arrow_data, + schema, } = &scalar_list; - let field: Field = field.as_ref().required("field")?; - let field = Arc::new(field); - - let values: Result, Error> = - values.iter().map(|val| val.try_into()).collect(); - let values = values?; - - validate_list_values(field.as_ref(), &values)?; - - let values = if *is_null { None } else { Some(values) }; + let schema: Schema = if let Some(schema_ref) = schema { + schema_ref.try_into()? + } else { + return Err(Error::General("Unexpected schema".to_string())); + }; - Self::List(values, field) + let message = root_as_message(ipc_message.as_slice()).unwrap(); + + // TODO: Add comment to why adding 0 before arrow_data. + // This code is from https://github.com/apache/arrow-rs/blob/4320a753beaee0a1a6870c59ef46b59e88c9c323/arrow-ipc/src/reader.rs#L1670-L1674C45 + // Construct an unaligned buffer + let mut buffer = MutableBuffer::with_capacity(arrow_data.len() + 1); + buffer.push(0_u8); + buffer.extend_from_slice(arrow_data.as_slice()); + let b = Buffer::from(buffer).slice(1); + + let ipc_batch = message.header_as_record_batch().unwrap(); + let record_batch = read_record_batch( + &b, + ipc_batch, + Arc::new(schema), + &Default::default(), + None, + &message.version(), + ) + .unwrap(); + let arr = record_batch.column(0); + Self::List(arr.to_owned()) } Value::NullValue(v) => { let null_type: DataType = v.try_into()?; @@ -925,22 +945,6 @@ pub fn parse_i32_to_aggregate_function(value: &i32) -> Result Result<(), Error> { - for value in values { - let field_type = field.data_type(); - let value_type = value.data_type(); - - if field_type != &value_type { - return Err(proto_error(format!( - "Expected field type {field_type:?}, got scalar of type: {value_type:?}" - ))); - } - } - Ok(()) -} - pub fn parse_expr( proto: &protobuf::LogicalExprNode, registry: &dyn FunctionRegistry, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index c10855bf2514..e80d60931cf6 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -30,10 +30,13 @@ use crate::protobuf::{ AnalyzedLogicalPlanType, CubeNode, EmptyMessage, GroupingSetNode, LogicalExprList, OptimizedLogicalPlanType, OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, }; - -use arrow::datatypes::{ - DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, - UnionMode, +use arrow::{ + datatypes::{ + DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, + TimeUnit, UnionMode, + }, + ipc::writer::{DictionaryTracker, IpcDataGenerator}, + record_batch::RecordBatch, }; use datafusion_common::{ Column, Constraint, Constraints, DFField, DFSchema, DFSchemaRef, OwnedTableReference, @@ -1142,27 +1145,27 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { "Proto serialization error: ScalarValue::Fixedsizelist not supported" .to_string(), )), - ScalarValue::List(values, boxed_field) => { - let is_null = values.is_none(); - - let values = if let Some(values) = values.as_ref() { - values - .iter() - .map(|v| v.try_into()) - .collect::, _>>()? - } else { - vec![] + ScalarValue::List(arr) => { + let batch = + RecordBatch::try_from_iter(vec![("field_name", arr.to_owned())]) + .unwrap(); + let gen = IpcDataGenerator {}; + let mut dict_tracker = DictionaryTracker::new(false); + let (_, encoded_message) = gen + .encoded_batch(&batch, &mut dict_tracker, &Default::default()) + .unwrap(); + + let schema: protobuf::Schema = batch.schema().try_into()?; + + let scalar_list_value = protobuf::ScalarListValue { + ipc_message: encoded_message.ipc_message, + arrow_data: encoded_message.arrow_data, + schema: Some(schema), }; - let field = boxed_field.as_ref().try_into()?; - Ok(protobuf::ScalarValue { value: Some(protobuf::scalar_value::Value::ListValue( - protobuf::ScalarListValue { - is_null, - field: Some(field), - values, - }, + scalar_list_value, )), }) } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 11ee8c0876bc..ca801df337f1 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -24,6 +24,7 @@ use arrow::datatypes::{ DataType, Field, Fields, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, }; + use prost::Message; use datafusion::datasource::provider::TableProviderFactory; @@ -512,59 +513,6 @@ fn scalar_values_error_serialization() { Some(vec![]), vec![Field::new("item", DataType::Int16, true)].into(), ), - // Should fail due to inconsistent types in the list - ScalarValue::new_list( - Some(vec![ - ScalarValue::Int16(None), - ScalarValue::Float32(Some(32.0)), - ]), - DataType::List(new_arc_field("item", DataType::Int16, true)), - ), - ScalarValue::new_list( - Some(vec![ - ScalarValue::Float32(None), - ScalarValue::Float32(Some(32.0)), - ]), - DataType::List(new_arc_field("item", DataType::Int16, true)), - ), - ScalarValue::new_list( - Some(vec![ - ScalarValue::Float32(None), - ScalarValue::Float32(Some(32.0)), - ]), - DataType::Int16, - ), - ScalarValue::new_list( - Some(vec![ - ScalarValue::new_list( - None, - DataType::List(new_arc_field("level2", DataType::Float32, true)), - ), - ScalarValue::new_list( - Some(vec![ - ScalarValue::Float32(Some(-213.1)), - ScalarValue::Float32(None), - ScalarValue::Float32(Some(5.5)), - ScalarValue::Float32(Some(2.0)), - ScalarValue::Float32(Some(1.0)), - ]), - DataType::List(new_arc_field("level2", DataType::Float32, true)), - ), - ScalarValue::new_list( - None, - DataType::List(new_arc_field( - "lists are typed inconsistently", - DataType::Int16, - true, - )), - ), - ]), - DataType::List(new_arc_field( - "level1", - DataType::List(new_arc_field("level2", DataType::Float32, true)), - true, - )), - ), ]; for test_case in should_fail_on_seralize.into_iter() { @@ -599,7 +547,7 @@ fn round_trip_scalar_values() { ScalarValue::UInt64(None), ScalarValue::Utf8(None), ScalarValue::LargeUtf8(None), - ScalarValue::new_list(None, DataType::Boolean), + ScalarValue::List(ScalarValue::new_list(&[], &DataType::Boolean)), ScalarValue::Date32(None), ScalarValue::Boolean(Some(true)), ScalarValue::Boolean(Some(false)), @@ -690,32 +638,32 @@ fn round_trip_scalar_values() { i64::MAX, ))), ScalarValue::IntervalMonthDayNano(None), - ScalarValue::new_list( - Some(vec![ + ScalarValue::List(ScalarValue::new_list( + &[ ScalarValue::Float32(Some(-213.1)), ScalarValue::Float32(None), ScalarValue::Float32(Some(5.5)), ScalarValue::Float32(Some(2.0)), ScalarValue::Float32(Some(1.0)), - ]), - DataType::Float32, - ), - ScalarValue::new_list( - Some(vec![ - ScalarValue::new_list(None, DataType::Float32), - ScalarValue::new_list( - Some(vec![ + ], + &DataType::Float32, + )), + ScalarValue::List(ScalarValue::new_list( + &[ + ScalarValue::List(ScalarValue::new_list(&[], &DataType::Float32)), + ScalarValue::List(ScalarValue::new_list( + &[ ScalarValue::Float32(Some(-213.1)), ScalarValue::Float32(None), ScalarValue::Float32(Some(5.5)), ScalarValue::Float32(Some(2.0)), ScalarValue::Float32(Some(1.0)), - ]), - DataType::Float32, - ), - ]), - DataType::List(new_arc_field("item", DataType::Float32, true)), - ), + ], + &DataType::Float32, + )), + ], + &DataType::List(new_arc_field("item", DataType::Float32, true)), + )), ScalarValue::Dictionary( Box::new(DataType::Int32), Box::new(ScalarValue::Utf8(Some("foo".into()))), @@ -978,7 +926,6 @@ fn roundtrip_null_scalar_values() { ScalarValue::Date32(None), ScalarValue::TimestampMicrosecond(None, None), ScalarValue::TimestampNanosecond(None, None), - ScalarValue::List(None, Arc::new(Field::new("item", DataType::Boolean, false))), ]; for test_case in test_types.into_iter() { diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index c949904cd84c..3a06fdb158f7 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -16,6 +16,7 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; +use arrow::array::new_null_array; use arrow::compute::kernels::cast_utils::parse_interval_month_day_nano; use arrow::datatypes::DECIMAL128_MAX_PRECISION; use arrow_schema::DataType; @@ -153,13 +154,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { values.iter().map(|e| e.data_type()).collect(); if data_types.is_empty() { - Ok(lit(ScalarValue::new_list(None, DataType::Utf8))) + Ok(lit(ScalarValue::List(new_null_array(&DataType::Null, 0)))) } else if data_types.len() > 1 { not_impl_err!("Arrays with different types are not supported: {data_types:?}") } else { let data_type = values[0].data_type(); - - Ok(lit(ScalarValue::new_list(Some(values), data_type))) + let arr = ScalarValue::new_list(&values, &data_type); + Ok(lit(ScalarValue::List(arr))) } }