diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index cc5b70796e88..9cbd9e292ff3 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -19,11 +19,13 @@ use std::borrow::Borrow; use std::cmp::Ordering; -use std::collections::HashSet; -use std::convert::{Infallible, TryInto}; +use std::collections::{HashSet, VecDeque}; +use std::convert::{Infallible, TryFrom, TryInto}; +use std::fmt; use std::hash::Hash; +use std::iter::repeat; use std::str::FromStr; -use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; +use std::sync::Arc; use crate::arrow_datafusion_err; use crate::cast::{ @@ -33,23 +35,22 @@ use crate::cast::{ use crate::error::{DataFusionError, Result, _internal_err, _not_impl_err}; use crate::hash_utils::create_hashes; use crate::utils::{array_into_large_list_array, array_into_list_array}; + use arrow::compute::kernels::numeric::*; -use arrow::datatypes::{i256, Fields, SchemaBuilder}; use arrow::util::display::{ArrayFormatter, FormatOptions}; use arrow::{ array::*, compute::kernels::cast::{cast_with_options, CastOptions}, datatypes::{ - ArrowDictionaryKeyType, ArrowNativeType, DataType, Field, Float32Type, Int16Type, - Int32Type, Int64Type, Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, - IntervalUnit, IntervalYearMonthType, TimeUnit, TimestampMicrosecondType, + i256, ArrowDictionaryKeyType, ArrowNativeType, ArrowTimestampType, DataType, + Field, Fields, Float32Type, Int16Type, Int32Type, Int64Type, Int8Type, + IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, + IntervalYearMonthType, SchemaBuilder, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, DECIMAL128_MAX_PRECISION, }, }; use arrow_array::cast::as_list_array; -use arrow_array::types::ArrowTimestampType; -use arrow_array::{ArrowNativeTypeOp, Scalar}; /// A dynamically typed, nullable single value, (the single-valued counter-part /// to arrow's [`Array`]) @@ -1728,6 +1729,43 @@ impl ScalarValue { Arc::new(array_into_list_array(values)) } + /// Converts `IntoIterator` where each element has type corresponding to + /// `data_type`, to a [`ListArray`]. + /// + /// Example + /// ``` + /// use datafusion_common::ScalarValue; + /// use arrow::array::{ListArray, Int32Array}; + /// use arrow::datatypes::{DataType, Int32Type}; + /// use datafusion_common::cast::as_list_array; + /// + /// let scalars = vec![ + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(None), + /// ScalarValue::Int32(Some(2)) + /// ]; + /// + /// let result = ScalarValue::new_list_from_iter(scalars.into_iter(), &DataType::Int32); + /// + /// let expected = ListArray::from_iter_primitive::( + /// vec![ + /// Some(vec![Some(1), None, Some(2)]) + /// ]); + /// + /// assert_eq!(*result, expected); + /// ``` + pub fn new_list_from_iter( + values: impl IntoIterator + ExactSizeIterator, + data_type: &DataType, + ) -> Arc { + let values = if values.len() == 0 { + new_empty_array(data_type) + } else { + Self::iter_to_array(values).unwrap() + }; + Arc::new(array_into_list_array(values)) + } + /// Converts `Vec` where each element has type corresponding to /// `data_type`, to a [`LargeListArray`]. /// @@ -2626,6 +2664,18 @@ impl ScalarValue { .sum::() } + /// Estimates [size](Self::size) of [`VecDeque`] in bytes. + /// + /// Includes the size of the [`VecDeque`] container itself. + pub fn size_of_vec_deque(vec_deque: &VecDeque) -> usize { + std::mem::size_of_val(vec_deque) + + (std::mem::size_of::() * vec_deque.capacity()) + + vec_deque + .iter() + .map(|sv| sv.size() - std::mem::size_of_val(sv)) + .sum::() + } + /// Estimates [size](Self::size) of [`HashSet`] in bytes. /// /// Includes the size of the [`HashSet`] container itself. @@ -3151,22 +3201,19 @@ impl ScalarType for TimestampNanosecondType { #[cfg(test)] mod tests { - use super::*; - use std::cmp::Ordering; use std::sync::Arc; - use chrono::NaiveDate; - use rand::Rng; + use super::*; + use crate::cast::{as_string_array, as_uint32_array, as_uint64_array}; use arrow::buffer::OffsetBuffer; - use arrow::compute::kernels; - use arrow::compute::{concat, is_null}; - use arrow::datatypes::ArrowPrimitiveType; + use arrow::compute::{concat, is_null, kernels}; + use arrow::datatypes::{ArrowNumericType, ArrowPrimitiveType}; use arrow::util::pretty::pretty_format_columns; - use arrow_array::ArrowNumericType; - use crate::cast::{as_string_array, as_uint32_array, as_uint64_array}; + use chrono::NaiveDate; + use rand::Rng; #[test] fn test_to_array_of_size_for_list() { diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 9db7635d99a0..574de3e7082a 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -17,12 +17,15 @@ //! Aggregate function module contains all built-in aggregate functions definitions +use std::sync::Arc; +use std::{fmt, str::FromStr}; + use crate::utils; use crate::{type_coercion::aggregates::*, Signature, TypeSignature, Volatility}; + use arrow::datatypes::{DataType, Field}; use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; -use std::sync::Arc; -use std::{fmt, str::FromStr}; + use strum_macros::EnumIter; /// Enum of all built-in aggregate functions @@ -30,26 +33,28 @@ use strum_macros::EnumIter; // https://arrow.apache.org/datafusion/contributor-guide/index.html#how-to-add-a-new-aggregate-function #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)] pub enum AggregateFunction { - /// count + /// Count Count, - /// sum + /// Sum Sum, - /// min + /// Minimum Min, - /// max + /// Maximum Max, - /// avg + /// Average Avg, - /// median + /// Median Median, - /// Approximate aggregate function + /// Approximate distinct function ApproxDistinct, - /// array_agg + /// Aggregation into an array ArrayAgg, - /// first_value + /// First value in a group according to some ordering FirstValue, - /// last_value + /// Last value in a group according to some ordering LastValue, + /// N'th value in a group according to some ordering + NthValue, /// Variance (Sample) Variance, /// Variance (Population) @@ -100,7 +105,7 @@ pub enum AggregateFunction { BoolAnd, /// Bool Or BoolOr, - /// string_agg + /// String aggregation StringAgg, } @@ -118,6 +123,7 @@ impl AggregateFunction { ArrayAgg => "ARRAY_AGG", FirstValue => "FIRST_VALUE", LastValue => "LAST_VALUE", + NthValue => "NTH_VALUE", Variance => "VAR", VariancePop => "VAR_POP", Stddev => "STDDEV", @@ -174,6 +180,7 @@ impl FromStr for AggregateFunction { "array_agg" => AggregateFunction::ArrayAgg, "first_value" => AggregateFunction::FirstValue, "last_value" => AggregateFunction::LastValue, + "nth_value" => AggregateFunction::NthValue, "string_agg" => AggregateFunction::StringAgg, // statistical "corr" => AggregateFunction::Correlation, @@ -300,9 +307,9 @@ impl AggregateFunction { Ok(coerced_data_types[0].clone()) } AggregateFunction::Grouping => Ok(DataType::Int32), - AggregateFunction::FirstValue | AggregateFunction::LastValue => { - Ok(coerced_data_types[0].clone()) - } + AggregateFunction::FirstValue + | AggregateFunction::LastValue + | AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()), AggregateFunction::StringAgg => Ok(DataType::LargeUtf8), } } @@ -371,6 +378,7 @@ impl AggregateFunction { | AggregateFunction::LastValue => { Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) } + AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable), AggregateFunction::Covariance | AggregateFunction::CovariancePop | AggregateFunction::Correlation @@ -428,6 +436,7 @@ impl AggregateFunction { #[cfg(test)] mod tests { use super::*; + use strum::IntoEnumIterator; #[test] diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 56bb5c9b69c4..ab994c143ac2 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -15,17 +15,16 @@ // specific language governing permissions and limitations // under the License. +use std::ops::Deref; + +use super::functions::can_coerce_from; +use crate::{AggregateFunction, Signature, TypeSignature}; + use arrow::datatypes::{ DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, }; - use datafusion_common::{internal_err, plan_err, DataFusionError, Result}; -use std::ops::Deref; - -use crate::{AggregateFunction, Signature, TypeSignature}; - -use super::functions::can_coerce_from; pub static STRINGS: &[DataType] = &[DataType::Utf8, DataType::LargeUtf8]; @@ -297,6 +296,7 @@ pub fn coerce_types( AggregateFunction::Median | AggregateFunction::FirstValue | AggregateFunction::LastValue => Ok(input_types.to_vec()), + AggregateFunction::NthValue => Ok(input_types.to_vec()), AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]), AggregateFunction::StringAgg => { if !is_string_agg_supported_arg_type(&input_types[0]) { @@ -584,6 +584,7 @@ pub fn is_string_agg_supported_arg_type(arg_type: &DataType) -> bool { #[cfg(test)] mod tests { use super::*; + use arrow::datatypes::DataType; #[test] diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs index eb5ae8b0b0c3..34f8d20628dc 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs @@ -20,46 +20,43 @@ use std::any::Any; use std::cmp::Ordering; -use std::collections::BinaryHeap; +use std::collections::{BinaryHeap, VecDeque}; use std::fmt::Debug; use std::sync::Arc; use crate::aggregate::utils::{down_cast_any_ref, ordering_fields}; use crate::expressions::format_state_name; -use crate::{AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr}; +use crate::{ + reverse_order_bys, AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr, +}; -use arrow::array::ArrayRef; +use arrow::array::{Array, ArrayRef}; use arrow::datatypes::{DataType, Field}; use arrow_array::cast::AsArray; -use arrow_array::Array; use arrow_schema::{Fields, SortOptions}; use datafusion_common::utils::{compare_rows, get_row_at_idx}; use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::Accumulator; -use itertools::izip; - -/// Expression for a ARRAY_AGG(ORDER BY) aggregation. -/// When aggregation works in multiple partitions -/// aggregations are split into multiple partitions, -/// then their results are merged. This aggregator -/// is a version of ARRAY_AGG that can support producing -/// intermediate aggregation (with necessary side information) -/// and that can merge aggregations from multiple partitions. +/// Expression for a `ARRAY_AGG(... ORDER BY ..., ...)` aggregation. In a multi +/// partition setting, partial aggregations are computed for every partition, +/// and then their results are merged. #[derive(Debug)] pub struct OrderSensitiveArrayAgg { /// Column name name: String, - /// The DataType for the input expression + /// The `DataType` for the input expression input_data_type: DataType, /// The input expression expr: Arc, - /// If the input expression can have NULLs + /// If the input expression can have `NULL`s nullable: bool, /// Ordering data types order_by_data_types: Vec, /// Ordering requirement ordering_req: LexOrdering, + /// Whether the aggregation is running in reverse + reverse: bool, } impl OrderSensitiveArrayAgg { @@ -79,6 +76,7 @@ impl OrderSensitiveArrayAgg { nullable, order_by_data_types, ordering_req, + reverse: false, } } } @@ -98,11 +96,13 @@ impl AggregateExpr for OrderSensitiveArrayAgg { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(OrderSensitiveArrayAggAccumulator::try_new( + OrderSensitiveArrayAggAccumulator::try_new( &self.input_data_type, &self.order_by_data_types, self.ordering_req.clone(), - )?)) + self.reverse, + ) + .map(|acc| Box::new(acc) as _) } fn state_fields(&self) -> Result> { @@ -125,16 +125,25 @@ impl AggregateExpr for OrderSensitiveArrayAgg { } fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { - if self.ordering_req.is_empty() { - None - } else { - Some(&self.ordering_req) - } + (!self.ordering_req.is_empty()).then_some(&self.ordering_req) } fn name(&self) -> &str { &self.name } + + fn reverse_expr(&self) -> Option> { + Some(Arc::new(Self { + name: self.name.to_string(), + input_data_type: self.input_data_type.clone(), + expr: self.expr.clone(), + nullable: self.nullable, + order_by_data_types: self.order_by_data_types.clone(), + // Reverse requirement: + ordering_req: reverse_order_bys(&self.ordering_req), + reverse: !self.reverse, + })) + } } impl PartialEq for OrderSensitiveArrayAgg { @@ -153,19 +162,20 @@ impl PartialEq for OrderSensitiveArrayAgg { #[derive(Debug)] pub(crate) struct OrderSensitiveArrayAggAccumulator { - // `values` stores entries in the ARRAY_AGG result. + /// Stores entries in the `ARRAY_AGG` result. values: Vec, - // `ordering_values` stores values of ordering requirement expression - // corresponding to each value in the ARRAY_AGG. - // For each `ScalarValue` inside `values`, there will be a corresponding - // `Vec` inside `ordering_values` which stores it ordering. - // This information is used during merging results of the different partitions. - // For detailed information how merging is done see [`merge_ordered_arrays`] + /// Stores values of ordering requirement expressions corresponding to each + /// entry in `values`. This information is used when merging results from + /// different partitions. For detailed information how merging is done, see + /// [`merge_ordered_arrays`]. ordering_values: Vec>, - // `datatypes` stores, datatype of expression inside ARRAY_AGG and ordering requirement expressions. + /// Stores datatypes of expressions inside values and ordering requirement + /// expressions. datatypes: Vec, - // Stores ordering requirement of the Accumulator + /// Stores the ordering requirement of the `Accumulator`. ordering_req: LexOrdering, + /// Whether the aggregation is running in reverse. + reverse: bool, } impl OrderSensitiveArrayAggAccumulator { @@ -175,6 +185,7 @@ impl OrderSensitiveArrayAggAccumulator { datatype: &DataType, ordering_dtypes: &[DataType], ordering_req: LexOrdering, + reverse: bool, ) -> Result { let mut datatypes = vec![datatype.clone()]; datatypes.extend(ordering_dtypes.iter().cloned()); @@ -183,6 +194,7 @@ impl OrderSensitiveArrayAggAccumulator { ordering_values: vec![], datatypes, ordering_req, + reverse, }) } } @@ -207,63 +219,63 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { if states.is_empty() { return Ok(()); } - // First entry in the state is the aggregation result. - let array_agg_values = &states[0]; - // 2nd entry stores values received for ordering requirement columns, for each aggregation value inside ARRAY_AGG list. - // For each `StructArray` inside ARRAY_AGG list, we will receive an `Array` that stores - // values received from its ordering requirement expression. (This information is necessary for during merging). - let agg_orderings = &states[1]; - - if let Some(agg_orderings) = agg_orderings.as_list_opt::() { - // Stores ARRAY_AGG results coming from each partition - let mut partition_values = vec![]; - // Stores ordering requirement expression results coming from each partition - let mut partition_ordering_values = vec![]; - - // Existing values should be merged also. - partition_values.push(self.values.clone()); - partition_ordering_values.push(self.ordering_values.clone()); - - let array_agg_res = - ScalarValue::convert_array_to_scalar_vec(array_agg_values)?; - - for v in array_agg_res.into_iter() { - partition_values.push(v); - } + // First entry in the state is the aggregation result. Second entry + // stores values received for ordering requirement columns for each + // aggregation value inside `ARRAY_AGG` list. For each `StructArray` + // inside `ARRAY_AGG` list, we will receive an `Array` that stores values + // received from its ordering requirement expression. (This information + // is necessary for during merging). + let [array_agg_values, agg_orderings, ..] = &states else { + return exec_err!("State should have two elements"); + }; + let Some(agg_orderings) = agg_orderings.as_list_opt::() else { + return exec_err!("Expects to receive a list array"); + }; - let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?; - - for partition_ordering_rows in orderings.into_iter() { - // Extract value from struct to ordering_rows for each group/partition - let ordering_value = partition_ordering_rows.into_iter().map(|ordering_row| { - if let ScalarValue::Struct(Some(ordering_columns_per_row), _) = ordering_row { - Ok(ordering_columns_per_row) - } else { - exec_err!( - "Expects to receive ScalarValue::Struct(Some(..), _) but got:{:?}", - ordering_row.data_type() - ) - } - }).collect::>>()?; - - partition_ordering_values.push(ordering_value); - } + // Stores ARRAY_AGG results coming from each partition + let mut partition_values = vec![]; + // Stores ordering requirement expression results coming from each partition + let mut partition_ordering_values = vec![]; - let sort_options = self - .ordering_req - .iter() - .map(|sort_expr| sort_expr.options) - .collect::>(); - let (new_values, new_orderings) = merge_ordered_arrays( - &partition_values, - &partition_ordering_values, - &sort_options, - )?; - self.values = new_values; - self.ordering_values = new_orderings; - } else { - return exec_err!("Expects to receive a list array"); + // Existing values should be merged also. + partition_values.push(self.values.clone().into()); + partition_ordering_values.push(self.ordering_values.clone().into()); + + let array_agg_res = ScalarValue::convert_array_to_scalar_vec(array_agg_values)?; + + for v in array_agg_res.into_iter() { + partition_values.push(v.into()); + } + + let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?; + + let ordering_values = orderings.into_iter().map(|partition_ordering_rows| { + // Extract value from struct to ordering_rows for each group/partition + partition_ordering_rows.into_iter().map(|ordering_row| { + if let ScalarValue::Struct(Some(ordering_columns_per_row), _) = ordering_row { + Ok(ordering_columns_per_row) + } else { + exec_err!( + "Expects to receive ScalarValue::Struct(Some(..), _) but got: {:?}", + ordering_row.data_type() + ) + } + }).collect::>>() + }).collect::>>()?; + for ordering_values in ordering_values.into_iter() { + partition_ordering_values.push(ordering_values); } + + let sort_options = self + .ordering_req + .iter() + .map(|sort_expr| sort_expr.options) + .collect::>(); + (self.values, self.ordering_values) = merge_ordered_arrays( + &mut partition_values, + &mut partition_ordering_values, + &sort_options, + )?; Ok(()) } @@ -274,8 +286,13 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { } fn evaluate(&self) -> Result { - let arr = ScalarValue::new_list(&self.values, &self.datatypes[0]); - Ok(ScalarValue::List(arr)) + let values = self.values.clone(); + let array = if self.reverse { + ScalarValue::new_list_from_iter(values.into_iter().rev(), &self.datatypes[0]) + } else { + ScalarValue::new_list_from_iter(values.into_iter(), &self.datatypes[0]) + }; + Ok(ScalarValue::List(array)) } fn size(&self) -> usize { @@ -306,7 +323,7 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { impl OrderSensitiveArrayAggAccumulator { fn evaluate_orderings(&self) -> Result { let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]); - let struct_field = Fields::from(fields.clone()); + let struct_field = Fields::from(fields); let orderings: Vec = self .ordering_values @@ -315,7 +332,7 @@ impl OrderSensitiveArrayAggAccumulator { ScalarValue::Struct(Some(ordering.clone()), struct_field.clone()) }) .collect(); - let struct_type = DataType::Struct(Fields::from(fields)); + let struct_type = DataType::Struct(struct_field); // Wrap in List, so we have the same data structure ListArray(StructArray..) for group by cases let arr = ScalarValue::new_list(&orderings, &struct_type); @@ -323,20 +340,19 @@ impl OrderSensitiveArrayAggAccumulator { } } -/// This is a wrapper struct to be able to correctly merge ARRAY_AGG -/// data from multiple partitions using `BinaryHeap`. -/// When used inside `BinaryHeap` this struct returns smallest `CustomElement`, -/// where smallest is determined by `ordering` values (`Vec`) -/// according to `sort_options` +/// This is a wrapper struct to be able to correctly merge `ARRAY_AGG` data from +/// multiple partitions using `BinaryHeap`. When used inside `BinaryHeap`, this +/// struct returns smallest `CustomElement`, where smallest is determined by +/// `ordering` values (`Vec`) according to `sort_options`. #[derive(Debug, PartialEq, Eq)] struct CustomElement<'a> { - // Stores from which partition entry is received + /// Stores the partition this entry came from branch_idx: usize, - // values to be merged + /// Values to merge value: ScalarValue, - // according to `ordering` values, comparisons will be done. + // Comparison "key" ordering: Vec, - // `sort_options` defines, desired ordering by the user + /// Options defining the ordering semantics sort_options: &'a [SortOptions], } @@ -411,87 +427,86 @@ impl<'a> PartialOrd for CustomElement<'a> { /// For each ScalarValue in the `values` we have a corresponding `Vec` (like timestamp of it) /// for the example above `sort_options` will have size two, that defines ordering requirement of the merge. /// Inner `Vec`s of the `ordering_values` will be compared according `sort_options` (Their sizes should match) -fn merge_ordered_arrays( +pub(crate) fn merge_ordered_arrays( // We will merge values into single `Vec`. - values: &[Vec], + values: &mut [VecDeque], // `values` will be merged according to `ordering_values`. // Inner `Vec` can be thought as ordering information for the // each `ScalarValue` in the values`. - ordering_values: &[Vec>], + ordering_values: &mut [VecDeque>], // Defines according to which ordering comparisons should be done. sort_options: &[SortOptions], ) -> Result<(Vec, Vec>)> { // Keep track the most recent data of each branch, in binary heap data structure. - let mut heap: BinaryHeap = BinaryHeap::new(); + let mut heap = BinaryHeap::::new(); - if !(values.len() == ordering_values.len() - && values + if values.len() != ordering_values.len() + || values .iter() .zip(ordering_values.iter()) - .all(|(vals, ordering_vals)| vals.len() == ordering_vals.len())) + .any(|(vals, ordering_vals)| vals.len() != ordering_vals.len()) { return exec_err!( "Expects values arguments and/or ordering_values arguments to have same size" ); } let n_branch = values.len(); - // For each branch we keep track of indices of next will be merged entry - let mut indices = vec![0_usize; n_branch]; - // Keep track of sizes of each branch. - let end_indices = (0..n_branch) - .map(|idx| values[idx].len()) - .collect::>(); let mut merged_values = vec![]; let mut merged_orderings = vec![]; // Continue iterating the loop until consuming data of all branches. loop { - let min_elem = if let Some(min_elem) = heap.pop() { - min_elem + let minimum = if let Some(minimum) = heap.pop() { + minimum } else { // Heap is empty, fill it with the next entries from each branch. - for (idx, end_idx, ordering, branch_index) in izip!( - indices.iter(), - end_indices.iter(), - ordering_values.iter(), - 0..n_branch - ) { - // We consumed this branch, skip it - if idx == end_idx { - continue; + for branch_idx in 0..n_branch { + if let Some(orderings) = ordering_values[branch_idx].pop_front() { + // Their size should be same, we can safely .unwrap here. + let value = values[branch_idx].pop_front().unwrap(); + // Push the next element to the heap: + heap.push(CustomElement::new( + branch_idx, + value, + orderings, + sort_options, + )); } - - // Push the next element to the heap. - let elem = CustomElement::new( - branch_index, - values[branch_index][*idx].clone(), - ordering[*idx].to_vec(), - sort_options, - ); - heap.push(elem); + // If None, we consumed this branch, skip it. } - // Now we have filled the heap, get the largest entry (this will be the next element in merge) - if let Some(min_elem) = heap.pop() { - min_elem + + // Now we have filled the heap, get the largest entry (this will be + // the next element in merge). + if let Some(minimum) = heap.pop() { + minimum } else { - // Heap is empty, this means that all indices are same with end_indices. e.g - // We have consumed all of the branches. Merging is completed - // Exit from the loop + // Heap is empty, this means that all indices are same with + // `end_indices`. We have consumed all of the branches, merge + // is completed, exit from the loop: break; } }; - let branch_idx = min_elem.branch_idx; - // Increment the index of merged branch, - indices[branch_idx] += 1; - let row_idx = indices[branch_idx]; - merged_values.push(min_elem.value.clone()); - merged_orderings.push(min_elem.ordering.clone()); - if row_idx < end_indices[branch_idx] { - // Push next entry in the most recently consumed branch to the heap - // If there is an available entry - let value = values[branch_idx][row_idx].clone(); - let ordering_row = ordering_values[branch_idx][row_idx].to_vec(); - let elem = CustomElement::new(branch_idx, value, ordering_row, sort_options); - heap.push(elem); + let CustomElement { + branch_idx, + value, + ordering, + .. + } = minimum; + // Add minimum value in the heap to the result + merged_values.push(value); + merged_orderings.push(ordering); + + // If there is an available entry, push next entry in the most + // recently consumed branch to the heap. + if let Some(orderings) = ordering_values[branch_idx].pop_front() { + // Their size should be same, we can safely .unwrap here. + let value = values[branch_idx].pop_front().unwrap(); + // Push the next element to the heap: + heap.push(CustomElement::new( + branch_idx, + value, + orderings, + sort_options, + )); } } @@ -500,12 +515,15 @@ fn merge_ordered_arrays( #[cfg(test)] mod tests { + use std::collections::VecDeque; + use std::sync::Arc; + use crate::aggregate::array_agg_ordered::merge_ordered_arrays; + use arrow_array::{Array, ArrayRef, Int64Array}; use arrow_schema::SortOptions; use datafusion_common::utils::get_row_at_idx; use datafusion_common::{Result, ScalarValue}; - use std::sync::Arc; #[test] fn test_merge_asc() -> Result<()> { @@ -516,7 +534,7 @@ mod tests { let n_row = lhs_arrays[0].len(); let lhs_orderings = (0..n_row) .map(|idx| get_row_at_idx(&lhs_arrays, idx)) - .collect::>>()?; + .collect::>>()?; let rhs_arrays: Vec = vec![ Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2])), @@ -525,7 +543,7 @@ mod tests { let n_row = rhs_arrays[0].len(); let rhs_orderings = (0..n_row) .map(|idx| get_row_at_idx(&rhs_arrays, idx)) - .collect::>>()?; + .collect::>>()?; let sort_options = vec![ SortOptions { descending: false, @@ -540,12 +558,12 @@ mod tests { let lhs_vals_arr = Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4])) as ArrayRef; let lhs_vals = (0..lhs_vals_arr.len()) .map(|idx| ScalarValue::try_from_array(&lhs_vals_arr, idx)) - .collect::>>()?; + .collect::>>()?; let rhs_vals_arr = Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4])) as ArrayRef; let rhs_vals = (0..rhs_vals_arr.len()) .map(|idx| ScalarValue::try_from_array(&rhs_vals_arr, idx)) - .collect::>>()?; + .collect::>>()?; let expected = Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2, 2, 3, 3, 4, 4])) as ArrayRef; let expected_ts = vec![ @@ -554,8 +572,8 @@ mod tests { ]; let (merged_vals, merged_ts) = merge_ordered_arrays( - &[lhs_vals, rhs_vals], - &[lhs_orderings, rhs_orderings], + &mut [lhs_vals, rhs_vals], + &mut [lhs_orderings, rhs_orderings], &sort_options, )?; let merged_vals = ScalarValue::iter_to_array(merged_vals.into_iter())?; @@ -583,7 +601,7 @@ mod tests { let n_row = lhs_arrays[0].len(); let lhs_orderings = (0..n_row) .map(|idx| get_row_at_idx(&lhs_arrays, idx)) - .collect::>>()?; + .collect::>>()?; let rhs_arrays: Vec = vec![ Arc::new(Int64Array::from(vec![2, 1, 1, 0, 0])), @@ -592,7 +610,7 @@ mod tests { let n_row = rhs_arrays[0].len(); let rhs_orderings = (0..n_row) .map(|idx| get_row_at_idx(&rhs_arrays, idx)) - .collect::>>()?; + .collect::>>()?; let sort_options = vec![ SortOptions { descending: true, @@ -608,12 +626,12 @@ mod tests { let lhs_vals_arr = Arc::new(Int64Array::from(vec![0, 1, 2, 1, 2])) as ArrayRef; let lhs_vals = (0..lhs_vals_arr.len()) .map(|idx| ScalarValue::try_from_array(&lhs_vals_arr, idx)) - .collect::>>()?; + .collect::>>()?; let rhs_vals_arr = Arc::new(Int64Array::from(vec![0, 1, 2, 1, 2])) as ArrayRef; let rhs_vals = (0..rhs_vals_arr.len()) .map(|idx| ScalarValue::try_from_array(&rhs_vals_arr, idx)) - .collect::>>()?; + .collect::>>()?; let expected = Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2, 2, 1, 1, 2, 2])) as ArrayRef; let expected_ts = vec![ @@ -621,8 +639,8 @@ mod tests { Arc::new(Int64Array::from(vec![4, 4, 3, 3, 2, 2, 1, 1, 0, 0])) as ArrayRef, ]; let (merged_vals, merged_ts) = merge_ordered_arrays( - &[lhs_vals, rhs_vals], - &[lhs_orderings, rhs_orderings], + &mut [lhs_vals, rhs_vals], + &mut [lhs_orderings, rhs_orderings], &sort_options, )?; let merged_vals = ScalarValue::iter_to_array(merged_vals.into_iter())?; diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index c40f0db19405..1a3d21fc40bc 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -26,12 +26,15 @@ //! * Signature: see `Signature` //! * Return type: a function `(arg_types) -> return_type`. E.g. for min, ([f32]) -> f32, ([f64]) -> f64. +use std::sync::Arc; + use crate::aggregate::regr::RegrType; -use crate::{expressions, AggregateExpr, PhysicalExpr, PhysicalSortExpr}; +use crate::expressions::{self, Literal}; +use crate::{AggregateExpr, PhysicalExpr, PhysicalSortExpr}; + use arrow::datatypes::Schema; -use datafusion_common::{not_impl_err, DataFusionError, Result}; -pub use datafusion_expr::AggregateFunction; -use std::sync::Arc; +use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; +use datafusion_expr::AggregateFunction; /// Create a physical aggregation expression. /// This function errors when `input_phy_exprs`' can't be coerced to a valid argument type of the aggregation function. @@ -369,6 +372,28 @@ pub fn create_aggregate_expr( ordering_req.to_vec(), ordering_types, )), + (AggregateFunction::NthValue, _) => { + let expr = &input_phy_exprs[0]; + let Some(n) = input_phy_exprs[1] + .as_any() + .downcast_ref::() + .map(|literal| literal.value()) + else { + return internal_err!( + "Second argument of NTH_VALUE needs to be a literal" + ); + }; + let nullable = expr.nullable(input_schema)?; + Arc::new(expressions::NthValueAgg::new( + expr.clone(), + n.clone().try_into()?, + name, + input_phy_types[0].clone(), + nullable, + ordering_types, + ordering_req.to_vec(), + )) + } (AggregateFunction::StringAgg, false) => { if !ordering_req.is_empty() { return not_impl_err!( @@ -396,9 +421,9 @@ mod tests { BitAnd, BitOr, BitXor, BoolAnd, BoolOr, Correlation, Count, Covariance, DistinctArrayAgg, DistinctCount, Max, Min, Stddev, Sum, Variance, }; + use arrow::datatypes::{DataType, Field}; - use datafusion_common::plan_err; - use datafusion_common::ScalarValue; + use datafusion_common::{plan_err, ScalarValue}; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::{type_coercion, Signature}; diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 5bd1fca385b1..270a8e6f7705 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -20,7 +20,7 @@ use std::fmt::Debug; use std::sync::Arc; use self::groups_accumulator::GroupsAccumulator; -use crate::expressions::OrderSensitiveArrayAgg; +use crate::expressions::{NthValueAgg, OrderSensitiveArrayAgg}; use crate::{PhysicalExpr, PhysicalSortExpr}; use arrow::datatypes::Field; @@ -47,6 +47,7 @@ pub(crate) mod covariance; pub(crate) mod first_last; pub(crate) mod grouping; pub(crate) mod median; +pub(crate) mod nth_value; pub(crate) mod string_agg; #[macro_use] pub(crate) mod min_max; @@ -140,4 +141,5 @@ pub trait AggregateExpr: Send + Sync + Debug + PartialEq { /// However, an `ARRAY_AGG` with `ORDER BY` depends on the input ordering. pub fn is_order_sensitive(aggr_expr: &Arc) -> bool { aggr_expr.as_any().is::() + || aggr_expr.as_any().is::() } diff --git a/datafusion/physical-expr/src/aggregate/nth_value.rs b/datafusion/physical-expr/src/aggregate/nth_value.rs new file mode 100644 index 000000000000..5a1ca90b7f5e --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/nth_value.rs @@ -0,0 +1,400 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines NTH_VALUE aggregate expression which may specify ordering requirement +//! that can evaluated at runtime during query execution + +use std::any::Any; +use std::collections::VecDeque; +use std::sync::Arc; + +use crate::aggregate::array_agg_ordered::merge_ordered_arrays; +use crate::aggregate::utils::{down_cast_any_ref, ordering_fields}; +use crate::expressions::format_state_name; +use crate::{ + reverse_order_bys, AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr, +}; + +use arrow_array::cast::AsArray; +use arrow_array::ArrayRef; +use arrow_schema::{DataType, Field, Fields}; +use datafusion_common::utils::get_row_at_idx; +use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::Accumulator; + +/// Expression for a `NTH_VALUE(... ORDER BY ..., ...)` aggregation. In a multi +/// partition setting, partial aggregations are computed for every partition, +/// and then their results are merged. +#[derive(Debug)] +pub struct NthValueAgg { + /// Column name + name: String, + /// The `DataType` for the input expression + input_data_type: DataType, + /// The input expression + expr: Arc, + /// The `N` value. + n: i64, + /// If the input expression can have `NULL`s + nullable: bool, + /// Ordering data types + order_by_data_types: Vec, + /// Ordering requirement + ordering_req: LexOrdering, +} + +impl NthValueAgg { + /// Create a new `NthValueAgg` aggregate function + pub fn new( + expr: Arc, + n: i64, + name: impl Into, + input_data_type: DataType, + nullable: bool, + order_by_data_types: Vec, + ordering_req: LexOrdering, + ) -> Self { + Self { + name: name.into(), + input_data_type, + expr, + n, + nullable, + order_by_data_types, + ordering_req, + } + } +} + +impl AggregateExpr for NthValueAgg { + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, self.input_data_type.clone(), true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(NthValueAccumulator::try_new( + self.n, + &self.input_data_type, + &self.order_by_data_types, + self.ordering_req.clone(), + )?)) + } + + fn state_fields(&self) -> Result> { + let mut fields = vec![Field::new_list( + format_state_name(&self.name, "nth_value"), + Field::new("item", self.input_data_type.clone(), true), + self.nullable, // This should be the same as field() + )]; + if !self.ordering_req.is_empty() { + let orderings = + ordering_fields(&self.ordering_req, &self.order_by_data_types); + fields.push(Field::new_list( + format_state_name(&self.name, "nth_value_orderings"), + Field::new("item", DataType::Struct(Fields::from(orderings)), true), + self.nullable, + )); + } + Ok(fields) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { + (!self.ordering_req.is_empty()).then_some(&self.ordering_req) + } + + fn name(&self) -> &str { + &self.name + } + + fn reverse_expr(&self) -> Option> { + Some(Arc::new(Self { + name: self.name.to_string(), + input_data_type: self.input_data_type.clone(), + expr: self.expr.clone(), + // index should be from the opposite side + n: -self.n, + nullable: self.nullable, + order_by_data_types: self.order_by_data_types.clone(), + // reverse requirement + ordering_req: reverse_order_bys(&self.ordering_req), + }) as _) + } +} + +impl PartialEq for NthValueAgg { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.name == x.name + && self.input_data_type == x.input_data_type + && self.order_by_data_types == x.order_by_data_types + && self.expr.eq(&x.expr) + }) + .unwrap_or(false) + } +} + +#[derive(Debug)] +pub(crate) struct NthValueAccumulator { + n: i64, + /// Stores entries in the `NTH_VALUE` result. + values: VecDeque, + /// Stores values of ordering requirement expressions corresponding to each + /// entry in `values`. This information is used when merging results from + /// different partitions. For detailed information how merging is done, see + /// [`merge_ordered_arrays`]. + ordering_values: VecDeque>, + /// Stores datatypes of expressions inside values and ordering requirement + /// expressions. + datatypes: Vec, + /// Stores the ordering requirement of the `Accumulator`. + ordering_req: LexOrdering, +} + +impl NthValueAccumulator { + /// Create a new order-sensitive NTH_VALUE accumulator based on the given + /// item data type. + pub fn try_new( + n: i64, + datatype: &DataType, + ordering_dtypes: &[DataType], + ordering_req: LexOrdering, + ) -> Result { + if n == 0 { + // n cannot be 0 + return internal_err!("Nth value indices are 1 based. 0 is invalid index"); + } + let mut datatypes = vec![datatype.clone()]; + datatypes.extend(ordering_dtypes.iter().cloned()); + Ok(Self { + n, + values: VecDeque::new(), + ordering_values: VecDeque::new(), + datatypes, + ordering_req, + }) + } +} + +impl Accumulator for NthValueAccumulator { + /// Updates its state with the `values`. Assumes data in the `values` satisfies the required + /// ordering for the accumulator (across consecutive batches, not just batch-wise). + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + let n_required = self.n.unsigned_abs() as usize; + let from_start = self.n > 0; + if from_start { + // direction is from start + let n_remaining = n_required.saturating_sub(self.values.len()); + self.append_new_data(values, Some(n_remaining))?; + } else { + // direction is from end + self.append_new_data(values, None)?; + let start_offset = self.values.len().saturating_sub(n_required); + if start_offset > 0 { + self.values.drain(0..start_offset); + self.ordering_values.drain(0..start_offset); + } + } + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + // First entry in the state is the aggregation result. + let array_agg_values = &states[0]; + let n_required = self.n.unsigned_abs() as usize; + if self.ordering_req.is_empty() { + let array_agg_res = + ScalarValue::convert_array_to_scalar_vec(array_agg_values)?; + for v in array_agg_res.into_iter() { + self.values.extend(v); + if self.values.len() > n_required { + // There is enough data collected can stop merging + break; + } + } + } else if let Some(agg_orderings) = states[1].as_list_opt::() { + // 2nd entry stores values received for ordering requirement columns, for each aggregation value inside NTH_VALUE list. + // For each `StructArray` inside NTH_VALUE list, we will receive an `Array` that stores + // values received from its ordering requirement expression. (This information is necessary for during merging). + + // Stores NTH_VALUE results coming from each partition + let mut partition_values: Vec> = vec![]; + // Stores ordering requirement expression results coming from each partition + let mut partition_ordering_values: Vec>> = vec![]; + + // Existing values should be merged also. + partition_values.push(self.values.clone()); + + partition_ordering_values.push(self.ordering_values.clone()); + + let array_agg_res = + ScalarValue::convert_array_to_scalar_vec(array_agg_values)?; + + for v in array_agg_res.into_iter() { + partition_values.push(v.into()); + } + + let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?; + + let ordering_values = orderings.into_iter().map(|partition_ordering_rows| { + // Extract value from struct to ordering_rows for each group/partition + partition_ordering_rows.into_iter().map(|ordering_row| { + if let ScalarValue::Struct(Some(ordering_columns_per_row), _) = ordering_row { + Ok(ordering_columns_per_row) + } else { + exec_err!( + "Expects to receive ScalarValue::Struct(Some(..), _) but got: {:?}", + ordering_row.data_type() + ) + } + }).collect::>>() + }).collect::>>()?; + for ordering_values in ordering_values.into_iter() { + partition_ordering_values.push(ordering_values.into()); + } + + let sort_options = self + .ordering_req + .iter() + .map(|sort_expr| sort_expr.options) + .collect::>(); + let (new_values, new_orderings) = merge_ordered_arrays( + &mut partition_values, + &mut partition_ordering_values, + &sort_options, + )?; + self.values = new_values.into(); + self.ordering_values = new_orderings.into(); + } else { + return exec_err!("Expects to receive a list array"); + } + Ok(()) + } + + fn state(&self) -> Result> { + let mut result = vec![self.evaluate_values()]; + if !self.ordering_req.is_empty() { + result.push(self.evaluate_orderings()); + } + Ok(result) + } + + fn evaluate(&self) -> Result { + let n_required = self.n.unsigned_abs() as usize; + let from_start = self.n > 0; + let nth_value_idx = if from_start { + // index is from start + let forward_idx = n_required - 1; + (forward_idx < self.values.len()).then_some(forward_idx) + } else { + // index is from end + self.values.len().checked_sub(n_required) + }; + if let Some(idx) = nth_value_idx { + Ok(self.values[idx].clone()) + } else { + ScalarValue::try_from(self.datatypes[0].clone()) + } + } + + fn size(&self) -> usize { + let mut total = std::mem::size_of_val(self) + + ScalarValue::size_of_vec_deque(&self.values) + - std::mem::size_of_val(&self.values); + + // Add size of the `self.ordering_values` + total += + std::mem::size_of::>() * self.ordering_values.capacity(); + for row in &self.ordering_values { + total += ScalarValue::size_of_vec(row) - std::mem::size_of_val(row); + } + + // Add size of the `self.datatypes` + total += std::mem::size_of::() * self.datatypes.capacity(); + for dtype in &self.datatypes { + total += dtype.size() - std::mem::size_of_val(dtype); + } + + // Add size of the `self.ordering_req` + total += std::mem::size_of::() * self.ordering_req.capacity(); + // TODO: Calculate size of each `PhysicalSortExpr` more accurately. + total + } +} + +impl NthValueAccumulator { + fn evaluate_orderings(&self) -> ScalarValue { + let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]); + let struct_field = Fields::from(fields); + + let orderings = self + .ordering_values + .iter() + .map(|ordering| { + ScalarValue::Struct(Some(ordering.clone()), struct_field.clone()) + }) + .collect::>(); + let struct_type = DataType::Struct(struct_field); + + // Wrap in List, so we have the same data structure ListArray(StructArray..) for group by cases + ScalarValue::List(ScalarValue::new_list(&orderings, &struct_type)) + } + + fn evaluate_values(&self) -> ScalarValue { + let mut values_cloned = self.values.clone(); + let values_slice = values_cloned.make_contiguous(); + ScalarValue::List(ScalarValue::new_list(values_slice, &self.datatypes[0])) + } + + /// Updates state, with the `values`. Fetch contains missing number of entries for state to be complete + /// None represents all of the new `values` need to be added to the state. + fn append_new_data( + &mut self, + values: &[ArrayRef], + fetch: Option, + ) -> Result<()> { + let n_row = values[0].len(); + let n_to_add = if let Some(fetch) = fetch { + std::cmp::min(fetch, n_row) + } else { + n_row + }; + for index in 0..n_to_add { + let row = get_row_at_idx(values, index)?; + self.values.push_back(row[0].clone()); + self.ordering_values.push_back(row[1..].to_vec()); + } + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/aggregate/utils.rs b/datafusion/physical-expr/src/aggregate/utils.rs index d73c46a0f687..6dd586bfb8ce 100644 --- a/datafusion/physical-expr/src/aggregate/utils.rs +++ b/datafusion/physical-expr/src/aggregate/utils.rs @@ -196,9 +196,9 @@ pub(crate) fn ordering_fields( ordering_req .iter() .zip(data_types.iter()) - .map(|(expr, dtype)| { + .map(|(sort_expr, dtype)| { Field::new( - expr.to_string().as_str(), + sort_expr.expr.to_string().as_str(), dtype.clone(), // Multi partitions may be empty hence field should be nullable. true, diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index b6d0ad5b9104..bbfba4ad8310 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -60,6 +60,7 @@ pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::median::Median; pub use crate::aggregate::min_max::{Max, Min}; pub use crate::aggregate::min_max::{MaxAccumulator, MinAccumulator}; +pub use crate::aggregate::nth_value::NthValueAgg; pub use crate::aggregate::regr::{Regr, RegrType}; pub use crate::aggregate::stats::StatsType; pub use crate::aggregate::stddev::{Stddev, StddevPop}; @@ -67,7 +68,6 @@ pub use crate::aggregate::string_agg::StringAgg; pub use crate::aggregate::sum::Sum; pub use crate::aggregate::sum_distinct::DistinctSum; pub use crate::aggregate::variance::{Variance, VariancePop}; - pub use crate::window::cume_dist::cume_dist; pub use crate::window::cume_dist::CumeDist; pub use crate::window::lead_lag::WindowShift; @@ -77,6 +77,7 @@ pub use crate::window::ntile::Ntile; pub use crate::window::rank::{dense_rank, percent_rank, rank}; pub use crate::window::rank::{Rank, RankType}; pub use crate::window::row_number::RowNumber; +pub use crate::PhysicalSortExpr; pub use binary::{binary, BinaryExpr}; pub use case::{case, CaseExpr}; @@ -98,20 +99,20 @@ pub use try_cast::{try_cast, TryCastExpr}; pub fn format_state_name(name: &str, state_name: &str) -> String { format!("{name}[{state_name}]") } -pub use crate::PhysicalSortExpr; #[cfg(test)] pub(crate) mod tests { + use std::sync::Arc; + use crate::expressions::{col, create_aggregate_expr, try_cast}; use crate::{AggregateExpr, EmitTo}; + use arrow::record_batch::RecordBatch; use arrow_array::ArrayRef; use arrow_schema::{Field, Schema}; - use datafusion_common::Result; - use datafusion_common::ScalarValue; + use datafusion_common::{Result, ScalarValue}; use datafusion_expr::type_coercion::aggregates::coerce_types; use datafusion_expr::AggregateFunction; - use std::sync::Arc; /// macro to perform an aggregation using [`datafusion_expr::Accumulator`] and verify the /// result. diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index facd601955b6..d3ae0d5ce01f 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -25,7 +25,6 @@ use crate::aggregates::{ no_grouping::AggregateStream, row_hash::GroupedHashAggregateStream, topk_stream::GroupedTopKAggregateStream, }; - use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::windows::get_ordered_partition_by_indices; use crate::{ @@ -909,6 +908,7 @@ fn get_aggregate_exprs_requirement( let aggr_req = PhysicalSortRequirement::from_sort_exprs(aggr_req); let reverse_aggr_req = PhysicalSortRequirement::from_sort_exprs(&reverse_aggr_req); + if let Some(first_value) = aggr_expr.as_any().downcast_ref::() { let mut first_value = first_value.clone(); if eq_properties.ordering_satisfy_requirement(&concat_slices( @@ -931,7 +931,9 @@ fn get_aggregate_exprs_requirement( first_value = first_value.with_requirement_satisfied(false); *aggr_expr = Arc::new(first_value) as _; } - } else if let Some(last_value) = aggr_expr.as_any().downcast_ref::() { + continue; + } + if let Some(last_value) = aggr_expr.as_any().downcast_ref::() { let mut last_value = last_value.clone(); if eq_properties.ordering_satisfy_requirement(&concat_slices( prefix_requirement, @@ -953,18 +955,63 @@ fn get_aggregate_exprs_requirement( last_value = last_value.with_requirement_satisfied(false); *aggr_expr = Arc::new(last_value) as _; } - } else if let Some(finer_ordering) = + continue; + } + if let Some(finer_ordering) = + finer_ordering(&requirement, aggr_expr, group_by, eq_properties, agg_mode) + { + if eq_properties.ordering_satisfy(&finer_ordering) { + // Requirement is satisfied by existing ordering + requirement = finer_ordering; + continue; + } + } + if let Some(reverse_aggr_expr) = aggr_expr.reverse_expr() { + if let Some(finer_ordering) = finer_ordering( + &requirement, + &reverse_aggr_expr, + group_by, + eq_properties, + agg_mode, + ) { + if eq_properties.ordering_satisfy(&finer_ordering) { + // Reverse requirement is satisfied by exiting ordering. + // Hence reverse the aggregator + requirement = finer_ordering; + *aggr_expr = reverse_aggr_expr; + continue; + } + } + } + if let Some(finer_ordering) = finer_ordering(&requirement, aggr_expr, group_by, eq_properties, agg_mode) { + // There is a requirement that both satisfies existing requirement and current + // aggregate requirement. Use updated requirement requirement = finer_ordering; - } else { - // If neither of the requirements satisfy the other, this means - // requirements are conflicting. Currently, we do not support - // conflicting requirements. - return not_impl_err!( - "Conflicting ordering requirements in aggregate functions is not supported" - ); + continue; + } + if let Some(reverse_aggr_expr) = aggr_expr.reverse_expr() { + if let Some(finer_ordering) = finer_ordering( + &requirement, + &reverse_aggr_expr, + group_by, + eq_properties, + agg_mode, + ) { + // There is a requirement that both satisfies existing requirement and reverse + // aggregate requirement. Use updated requirement + requirement = finer_ordering; + *aggr_expr = reverse_aggr_expr; + continue; + } } + // Neither the existing requirement and current aggregate requirement satisfy the other, this means + // requirements are conflicting. Currently, we do not support + // conflicting requirements. + return not_impl_err!( + "Conflicting ordering requirements in aggregate functions is not supported" + ); } Ok(PhysicalSortRequirement::from_sort_exprs(&requirement)) } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index c95465b5ae44..8bde0da133eb 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -715,6 +715,7 @@ enum AggregateFunction { REGR_SYY = 33; REGR_SXY = 34; STRING_AGG = 35; + NTH_VALUE_AGG = 36; } message AggregateExprNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index d5d86b2179fa..528761136ca3 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -457,6 +457,7 @@ impl serde::Serialize for AggregateFunction { Self::RegrSyy => "REGR_SYY", Self::RegrSxy => "REGR_SXY", Self::StringAgg => "STRING_AGG", + Self::NthValueAgg => "NTH_VALUE_AGG", }; serializer.serialize_str(variant) } @@ -504,6 +505,7 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "REGR_SYY", "REGR_SXY", "STRING_AGG", + "NTH_VALUE_AGG", ]; struct GeneratedVisitor; @@ -580,6 +582,7 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "REGR_SYY" => Ok(AggregateFunction::RegrSyy), "REGR_SXY" => Ok(AggregateFunction::RegrSxy), "STRING_AGG" => Ok(AggregateFunction::StringAgg), + "NTH_VALUE_AGG" => Ok(AggregateFunction::NthValueAgg), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 7e262e620fa7..9a0b7ab332a6 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -3079,6 +3079,7 @@ pub enum AggregateFunction { RegrSyy = 33, RegrSxy = 34, StringAgg = 35, + NthValueAgg = 36, } impl AggregateFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -3125,6 +3126,7 @@ impl AggregateFunction { AggregateFunction::RegrSyy => "REGR_SYY", AggregateFunction::RegrSxy => "REGR_SXY", AggregateFunction::StringAgg => "STRING_AGG", + AggregateFunction::NthValueAgg => "NTH_VALUE_AGG", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -3168,6 +3170,7 @@ impl AggregateFunction { "REGR_SYY" => Some(Self::RegrSyy), "REGR_SXY" => Some(Self::RegrSxy), "STRING_AGG" => Some(Self::StringAgg), + "NTH_VALUE_AGG" => Some(Self::NthValueAgg), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 2d9c7be46bc9..9185bdb80429 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use crate::protobuf::{ self, plan_type::PlanTypeEnum::{ @@ -26,6 +28,7 @@ use crate::protobuf::{ AnalyzedLogicalPlanType, CubeNode, GroupingSetNode, OptimizedLogicalPlanType, OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, }; + use arrow::{ array::AsArray, buffer::Buffer, @@ -41,17 +44,19 @@ use datafusion_common::{ Constraints, DFField, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, Result, ScalarValue, }; +use datafusion_expr::expr::{Alias, Placeholder}; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ abs, acos, acosh, array, array_append, array_concat, array_dims, array_distinct, - array_element, array_except, array_has, array_has_all, array_has_any, - array_intersect, array_length, array_ndims, array_position, array_positions, - array_prepend, array_remove, array_remove_all, array_remove_n, array_repeat, - array_replace, array_replace_all, array_replace_n, array_resize, array_slice, - array_sort, array_to_string, array_union, arrow_typeof, ascii, asin, asinh, atan, - atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil, character_length, chr, - coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, current_date, current_time, - date_bin, date_part, date_trunc, decode, degrees, digest, encode, exp, + array_element, array_empty, array_except, array_has, array_has_all, array_has_any, + array_intersect, array_length, array_ndims, array_pop_back, array_pop_front, + array_position, array_positions, array_prepend, array_remove, array_remove_all, + array_remove_n, array_repeat, array_replace, array_replace_all, array_replace_n, + array_resize, array_slice, array_sort, array_to_string, array_union, arrow_typeof, + ascii, asin, asinh, atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil, + character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, + current_date, current_time, date_bin, date_part, date_trunc, decode, degrees, digest, + encode, exp, expr::{self, InList, Sort, WindowFunction}, factorial, find_in_set, flatten, floor, from_unixtime, gcd, gen_range, isnan, iszero, lcm, left, levenshtein, ln, log, log10, log2, @@ -68,11 +73,6 @@ use datafusion_expr::{ JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, }; -use datafusion_expr::{ - array_empty, array_pop_back, array_pop_front, - expr::{Alias, Placeholder}, -}; -use std::sync::Arc; #[derive(Debug)] pub enum Error { @@ -617,6 +617,7 @@ impl From for AggregateFunction { protobuf::AggregateFunction::Median => Self::Median, protobuf::AggregateFunction::FirstValueAgg => Self::FirstValue, protobuf::AggregateFunction::LastValueAgg => Self::LastValue, + protobuf::AggregateFunction::NthValueAgg => Self::NthValue, protobuf::AggregateFunction::StringAgg => Self::StringAgg, } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index ec9b886c1f22..7eef3da9519f 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -31,6 +31,7 @@ use crate::protobuf::{ AnalyzedLogicalPlanType, CubeNode, EmptyMessage, GroupingSetNode, LogicalExprList, OptimizedLogicalPlanType, OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, }; + use arrow::{ array::ArrayRef, datatypes::{ @@ -409,6 +410,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::Median => Self::Median, AggregateFunction::FirstValue => Self::FirstValueAgg, AggregateFunction::LastValue => Self::LastValueAgg, + AggregateFunction::NthValue => Self::NthValueAgg, AggregateFunction::StringAgg => Self::StringAgg, } } @@ -728,6 +730,9 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { AggregateFunction::LastValue => { protobuf::AggregateFunction::LastValueAgg } + AggregateFunction::NthValue => { + protobuf::AggregateFunction::NthValueAgg + } AggregateFunction::StringAgg => { protobuf::AggregateFunction::StringAgg } diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index 7c5803d38594..79e6a9357b40 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -4714,3 +4714,143 @@ statement ok DROP TABLE uint64_dict; ### END Group By with Dictionary Variants ### + +statement ok +set datafusion.execution.target_partitions = 1; + +query III? +SELECT a, b, NTH_VALUE(c, 2), ARRAY_AGG(c) +FROM multiple_ordered_table +GROUP BY a, b +ORDER BY a, b; +---- +0 0 1 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] +0 1 26 [25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49] +1 2 51 [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74] +1 3 76 [75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99] + +query III? +SELECT a, b, NTH_VALUE(c, 2 ORDER BY c ASC), ARRAY_AGG(c ORDER BY c ASC) +FROM multiple_ordered_table +GROUP BY a, b +ORDER BY a, b; +---- +0 0 1 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] +0 1 26 [25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49] +1 2 51 [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74] +1 3 76 [75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99] + +query II?I +SELECT a, b, ARRAY_AGG(c ORDER BY c ASC), NTH_VALUE(c, 2 ORDER BY c DESC) +FROM multiple_ordered_table +GROUP BY a, b +ORDER BY a, b; +---- +0 0 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] 23 +0 1 [25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49] 48 +1 2 [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74] 73 +1 3 [75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99] 98 + +query IIIIII +SELECT a, b, NTH_VALUE(c, 2 ORDER BY c ASC), NTH_VALUE(c, 3 ORDER BY c ASC), NTH_VALUE(c, 2 ORDER BY c DESC), NTH_VALUE(c, 3 ORDER BY c DESC) +FROM multiple_ordered_table +GROUP BY a, b +ORDER BY a, b; +---- +0 0 1 2 23 22 +0 1 26 27 48 47 +1 2 51 52 73 72 +1 3 76 77 98 97 + +# we should be able to reverse array agg requirement, if it helps to remove a SortExec from plan. +query TT +EXPLAIN SELECT a, b, ARRAY_AGG(c ORDER BY c DESC) +FROM multiple_ordered_table +GROUP BY a, b +ORDER BY a, b; +---- +logical_plan +Sort: multiple_ordered_table.a ASC NULLS LAST, multiple_ordered_table.b ASC NULLS LAST +--Aggregate: groupBy=[[multiple_ordered_table.a, multiple_ordered_table.b]], aggr=[[ARRAY_AGG(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST]]] +----TableScan: multiple_ordered_table projection=[a, b, c] +physical_plan +AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[ARRAY_AGG(multiple_ordered_table.c)], ordering_mode=Sorted +--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], output_orderings=[[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], [c@2 ASC NULLS LAST]], has_header=true + +query II? +SELECT a, b, ARRAY_AGG(c ORDER BY c DESC) +FROM multiple_ordered_table +GROUP BY a, b +ORDER BY a, b; +---- +0 0 [24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0] +0 1 [49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25] +1 2 [74, 73, 72, 71, 70, 69, 68, 67, 66, 65, 64, 63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50] +1 3 [99, 98, 97, 96, 95, 94, 93, 92, 91, 90, 89, 88, 87, 86, 85, 84, 83, 82, 81, 80, 79, 78, 77, 76, 75] + +query II?II +SELECT a, b, ARRAY_AGG(d ORDER BY d DESC), NTH_VALUE(d, 1 ORDER BY d DESC), NTH_VALUE(d, 1 ORDER BY d ASC) +FROM multiple_ordered_table +GROUP BY a, b +ORDER BY a, b; +---- +0 0 [4, 4, 4, 4, 3, 3, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0] 4 0 +0 1 [4, 4, 4, 3, 3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0] 4 0 +1 2 [4, 4, 4, 4, 4, 4, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0] 4 0 +1 3 [4, 4, 4, 4, 4, 4, 3, 3, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0] 4 0 + +# increase partition to 8 +statement ok +set datafusion.execution.target_partitions = 8; + +# NTH_VALUE(c, 2) and ARRAY_AGG(c)[2] should produce same result +query III +SELECT a, b, NTH_VALUE(c, 2) - ARRAY_AGG(c)[2] +FROM multiple_ordered_table +GROUP BY a, b +ORDER BY a, b; +---- +0 0 0 +0 1 0 +1 2 0 +1 3 0 + +query III? +SELECT a, b, NTH_VALUE(c, 2 ORDER BY c ASC), ARRAY_AGG(c ORDER BY c ASC) +FROM multiple_ordered_table +GROUP BY a, b +ORDER BY a, b; +---- +0 0 1 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] +0 1 26 [25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49] +1 2 51 [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74] +1 3 76 [75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99] + +query II?I +SELECT a, b, ARRAY_AGG(c ORDER BY c ASC), NTH_VALUE(c, 2 ORDER BY c DESC) +FROM multiple_ordered_table +GROUP BY a, b +ORDER BY a, b; +---- +0 0 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] 23 +0 1 [25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49] 48 +1 2 [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74] 73 +1 3 [75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99] 98 + +query IIIIII +SELECT a, b, NTH_VALUE(c, 2 ORDER BY c ASC), NTH_VALUE(c, 3 ORDER BY c ASC), NTH_VALUE(c, 2 ORDER BY c DESC), NTH_VALUE(c, 3 ORDER BY c DESC) +FROM multiple_ordered_table +GROUP BY a, b +ORDER BY a, b; +---- +0 0 1 2 23 22 +0 1 26 27 48 47 +1 2 51 52 73 72 +1 3 76 77 98 97 + +# nth value cannot work with conflicting requirements +statement error DataFusion error: This feature is not implemented: Conflicting ordering requirements in aggregate functions is not supported +SELECT a, b, NTH_VALUE(c, 2 ORDER BY c ASC), NTH_VALUE(c, 3 ORDER BY d ASC) +FROM multiple_ordered_table +GROUP BY a, b +ORDER BY a, b; \ No newline at end of file