Skip to content

Commit

Permalink
Support for order sensitive NTH_VALUE aggregation, make reverse `AR…
Browse files Browse the repository at this point in the history
…RAY_AGG` more efficient (apache#8841)

* Initial commit

* minor changes

* Parse index argument

* Move nth_value to array_agg

* Initial implementation (with redundant data)

* Add new test

* Add reverse support

* Add new slt tests

* Add multi partition support

* Minor changes

* Minor changes

* Add new aggregator to the proto

* Remove redundant tests

* Keep n entries in the state for nth value

* Change implementation

* Move nth value to its own file

* Minor changes

* minor changes

* Review

* Update comments

* Use drain method to remove from the beginning.

* Add reverse support, convert buffer to vecdeque

* Minor changes

* Minor changes

* Review Part 2

* Review Part 3

* Add new_list from iter

* Convert API to receive vecdeque

* Receive mutable argument

* Refactor merge implementation

* Fix doctest

---------

Co-authored-by: Mehmet Ozan Kabak <[email protected]>
  • Loading branch information
mustafasrepo and ozankabak authored Jan 16, 2024
1 parent d2ff112 commit 8cf1abb
Show file tree
Hide file tree
Showing 16 changed files with 943 additions and 240 deletions.
83 changes: 65 additions & 18 deletions datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
use std::borrow::Borrow;
use std::cmp::Ordering;
use std::collections::HashSet;
use std::convert::{Infallible, TryInto};
use std::collections::{HashSet, VecDeque};
use std::convert::{Infallible, TryFrom, TryInto};
use std::fmt;
use std::hash::Hash;
use std::iter::repeat;
use std::str::FromStr;
use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};
use std::sync::Arc;

use crate::arrow_datafusion_err;
use crate::cast::{
Expand All @@ -33,23 +35,22 @@ use crate::cast::{
use crate::error::{DataFusionError, Result, _internal_err, _not_impl_err};
use crate::hash_utils::create_hashes;
use crate::utils::{array_into_large_list_array, array_into_list_array};

use arrow::compute::kernels::numeric::*;
use arrow::datatypes::{i256, Fields, SchemaBuilder};
use arrow::util::display::{ArrayFormatter, FormatOptions};
use arrow::{
array::*,
compute::kernels::cast::{cast_with_options, CastOptions},
datatypes::{
ArrowDictionaryKeyType, ArrowNativeType, DataType, Field, Float32Type, Int16Type,
Int32Type, Int64Type, Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType,
IntervalUnit, IntervalYearMonthType, TimeUnit, TimestampMicrosecondType,
i256, ArrowDictionaryKeyType, ArrowNativeType, ArrowTimestampType, DataType,
Field, Fields, Float32Type, Int16Type, Int32Type, Int64Type, Int8Type,
IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit,
IntervalYearMonthType, SchemaBuilder, TimeUnit, TimestampMicrosecondType,
TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
UInt16Type, UInt32Type, UInt64Type, UInt8Type, DECIMAL128_MAX_PRECISION,
},
};
use arrow_array::cast::as_list_array;
use arrow_array::types::ArrowTimestampType;
use arrow_array::{ArrowNativeTypeOp, Scalar};

/// A dynamically typed, nullable single value, (the single-valued counter-part
/// to arrow's [`Array`])
Expand Down Expand Up @@ -1728,6 +1729,43 @@ impl ScalarValue {
Arc::new(array_into_list_array(values))
}

/// Converts `IntoIterator<Item = ScalarValue>` where each element has type corresponding to
/// `data_type`, to a [`ListArray`].
///
/// Example
/// ```
/// use datafusion_common::ScalarValue;
/// use arrow::array::{ListArray, Int32Array};
/// use arrow::datatypes::{DataType, Int32Type};
/// use datafusion_common::cast::as_list_array;
///
/// let scalars = vec![
/// ScalarValue::Int32(Some(1)),
/// ScalarValue::Int32(None),
/// ScalarValue::Int32(Some(2))
/// ];
///
/// let result = ScalarValue::new_list_from_iter(scalars.into_iter(), &DataType::Int32);
///
/// let expected = ListArray::from_iter_primitive::<Int32Type, _, _>(
/// vec![
/// Some(vec![Some(1), None, Some(2)])
/// ]);
///
/// assert_eq!(*result, expected);
/// ```
pub fn new_list_from_iter(
values: impl IntoIterator<Item = ScalarValue> + ExactSizeIterator,
data_type: &DataType,
) -> Arc<ListArray> {
let values = if values.len() == 0 {
new_empty_array(data_type)
} else {
Self::iter_to_array(values).unwrap()
};
Arc::new(array_into_list_array(values))
}

/// Converts `Vec<ScalarValue>` where each element has type corresponding to
/// `data_type`, to a [`LargeListArray`].
///
Expand Down Expand Up @@ -2626,6 +2664,18 @@ impl ScalarValue {
.sum::<usize>()
}

/// Estimates [size](Self::size) of [`VecDeque`] in bytes.
///
/// Includes the size of the [`VecDeque`] container itself.
pub fn size_of_vec_deque(vec_deque: &VecDeque<Self>) -> usize {
std::mem::size_of_val(vec_deque)
+ (std::mem::size_of::<ScalarValue>() * vec_deque.capacity())
+ vec_deque
.iter()
.map(|sv| sv.size() - std::mem::size_of_val(sv))
.sum::<usize>()
}

/// Estimates [size](Self::size) of [`HashSet`] in bytes.
///
/// Includes the size of the [`HashSet`] container itself.
Expand Down Expand Up @@ -3151,22 +3201,19 @@ impl ScalarType<i64> for TimestampNanosecondType {

#[cfg(test)]
mod tests {
use super::*;

use std::cmp::Ordering;
use std::sync::Arc;

use chrono::NaiveDate;
use rand::Rng;
use super::*;
use crate::cast::{as_string_array, as_uint32_array, as_uint64_array};

use arrow::buffer::OffsetBuffer;
use arrow::compute::kernels;
use arrow::compute::{concat, is_null};
use arrow::datatypes::ArrowPrimitiveType;
use arrow::compute::{concat, is_null, kernels};
use arrow::datatypes::{ArrowNumericType, ArrowPrimitiveType};
use arrow::util::pretty::pretty_format_columns;
use arrow_array::ArrowNumericType;

use crate::cast::{as_string_array, as_uint32_array, as_uint64_array};
use chrono::NaiveDate;
use rand::Rng;

#[test]
fn test_to_array_of_size_for_list() {
Expand Down
41 changes: 25 additions & 16 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,39 +17,44 @@

//! Aggregate function module contains all built-in aggregate functions definitions
use std::sync::Arc;
use std::{fmt, str::FromStr};

use crate::utils;
use crate::{type_coercion::aggregates::*, Signature, TypeSignature, Volatility};

use arrow::datatypes::{DataType, Field};
use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result};
use std::sync::Arc;
use std::{fmt, str::FromStr};

use strum_macros::EnumIter;

/// Enum of all built-in aggregate functions
// Contributor's guide for adding new aggregate functions
// https://arrow.apache.org/datafusion/contributor-guide/index.html#how-to-add-a-new-aggregate-function
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)]
pub enum AggregateFunction {
/// count
/// Count
Count,
/// sum
/// Sum
Sum,
/// min
/// Minimum
Min,
/// max
/// Maximum
Max,
/// avg
/// Average
Avg,
/// median
/// Median
Median,
/// Approximate aggregate function
/// Approximate distinct function
ApproxDistinct,
/// array_agg
/// Aggregation into an array
ArrayAgg,
/// first_value
/// First value in a group according to some ordering
FirstValue,
/// last_value
/// Last value in a group according to some ordering
LastValue,
/// N'th value in a group according to some ordering
NthValue,
/// Variance (Sample)
Variance,
/// Variance (Population)
Expand Down Expand Up @@ -100,7 +105,7 @@ pub enum AggregateFunction {
BoolAnd,
/// Bool Or
BoolOr,
/// string_agg
/// String aggregation
StringAgg,
}

Expand All @@ -118,6 +123,7 @@ impl AggregateFunction {
ArrayAgg => "ARRAY_AGG",
FirstValue => "FIRST_VALUE",
LastValue => "LAST_VALUE",
NthValue => "NTH_VALUE",
Variance => "VAR",
VariancePop => "VAR_POP",
Stddev => "STDDEV",
Expand Down Expand Up @@ -174,6 +180,7 @@ impl FromStr for AggregateFunction {
"array_agg" => AggregateFunction::ArrayAgg,
"first_value" => AggregateFunction::FirstValue,
"last_value" => AggregateFunction::LastValue,
"nth_value" => AggregateFunction::NthValue,
"string_agg" => AggregateFunction::StringAgg,
// statistical
"corr" => AggregateFunction::Correlation,
Expand Down Expand Up @@ -300,9 +307,9 @@ impl AggregateFunction {
Ok(coerced_data_types[0].clone())
}
AggregateFunction::Grouping => Ok(DataType::Int32),
AggregateFunction::FirstValue | AggregateFunction::LastValue => {
Ok(coerced_data_types[0].clone())
}
AggregateFunction::FirstValue
| AggregateFunction::LastValue
| AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()),
AggregateFunction::StringAgg => Ok(DataType::LargeUtf8),
}
}
Expand Down Expand Up @@ -371,6 +378,7 @@ impl AggregateFunction {
| AggregateFunction::LastValue => {
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable),
AggregateFunction::Covariance
| AggregateFunction::CovariancePop
| AggregateFunction::Correlation
Expand Down Expand Up @@ -428,6 +436,7 @@ impl AggregateFunction {
#[cfg(test)]
mod tests {
use super::*;

use strum::IntoEnumIterator;

#[test]
Expand Down
13 changes: 7 additions & 6 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,16 @@
// specific language governing permissions and limitations
// under the License.

use std::ops::Deref;

use super::functions::can_coerce_from;
use crate::{AggregateFunction, Signature, TypeSignature};

use arrow::datatypes::{
DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
};

use datafusion_common::{internal_err, plan_err, DataFusionError, Result};
use std::ops::Deref;

use crate::{AggregateFunction, Signature, TypeSignature};

use super::functions::can_coerce_from;

pub static STRINGS: &[DataType] = &[DataType::Utf8, DataType::LargeUtf8];

Expand Down Expand Up @@ -297,6 +296,7 @@ pub fn coerce_types(
AggregateFunction::Median
| AggregateFunction::FirstValue
| AggregateFunction::LastValue => Ok(input_types.to_vec()),
AggregateFunction::NthValue => Ok(input_types.to_vec()),
AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]),
AggregateFunction::StringAgg => {
if !is_string_agg_supported_arg_type(&input_types[0]) {
Expand Down Expand Up @@ -584,6 +584,7 @@ pub fn is_string_agg_supported_arg_type(arg_type: &DataType) -> bool {
#[cfg(test)]
mod tests {
use super::*;

use arrow::datatypes::DataType;

#[test]
Expand Down
Loading

0 comments on commit 8cf1abb

Please sign in to comment.