diff --git a/Cargo.toml b/Cargo.toml index 65ef191d7421f..aa1ba1f214d5c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -52,7 +52,7 @@ homepage = "https://datafusion.apache.org" license = "Apache-2.0" readme = "README.md" repository = "https://github.com/apache/datafusion" -rust-version = "1.73" +rust-version = "1.75" version = "39.0.0" [workspace.dependencies] @@ -107,7 +107,7 @@ doc-comment = "0.3" env_logger = "0.11" futures = "0.3" half = { version = "2.2.1", default-features = false } -hashbrown = { version = "0.14", features = ["raw"] } +hashbrown = { version = "0.14.5", features = ["raw"] } indexmap = "2.0.0" itertools = "0.12" log = "^0.4" diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 932f44d984866..c5b34df4f1cfc 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1376,9 +1376,11 @@ dependencies = [ name = "datafusion-physical-expr-common" version = "39.0.0" dependencies = [ + "ahash", "arrow", "datafusion-common", "datafusion-expr", + "hashbrown 0.14.5", "rand", ] diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index 5e393246b9589..8f4b3cd81f366 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -26,7 +26,7 @@ license = "Apache-2.0" homepage = "https://datafusion.apache.org" repository = "https://github.com/apache/datafusion" # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.73" +rust-version = "1.75" readme = "README.md" [dependencies] diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index ae444c2cb285e..a0e4d1a76c03b 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -354,7 +354,7 @@ pub fn longest_consecutive_prefix>( pub fn array_into_list_array(arr: ArrayRef) -> ListArray { let offsets = OffsetBuffer::from_lengths([arr.len()]); ListArray::new( - Arc::new(Field::new("item", arr.data_type().to_owned(), true)), + Arc::new(Field::new_list_field(arr.data_type().to_owned(), true)), offsets, arr, None, @@ -366,7 +366,7 @@ pub fn array_into_list_array(arr: ArrayRef) -> ListArray { pub fn array_into_large_list_array(arr: ArrayRef) -> LargeListArray { let offsets = OffsetBuffer::from_lengths([arr.len()]); LargeListArray::new( - Arc::new(Field::new("item", arr.data_type().to_owned(), true)), + Arc::new(Field::new_list_field(arr.data_type().to_owned(), true)), offsets, arr, None, @@ -379,7 +379,7 @@ pub fn array_into_fixed_size_list_array( ) -> FixedSizeListArray { let list_size = list_size as i32; FixedSizeListArray::new( - Arc::new(Field::new("item", arr.data_type().to_owned(), true)), + Arc::new(Field::new_list_field(arr.data_type().to_owned(), true)), list_size, arr, None, @@ -420,7 +420,7 @@ pub fn arrays_into_list_array( let data_type = arr[0].data_type().to_owned(); let values = arr.iter().map(|x| x.as_ref()).collect::>(); Ok(ListArray::new( - Arc::new(Field::new("item", data_type, true)), + Arc::new(Field::new_list_field(data_type, true)), OffsetBuffer::from_lengths(lens), arrow::compute::concat(values.as_slice())?, None, @@ -435,7 +435,7 @@ pub fn arrays_into_list_array( /// use datafusion_common::utils::base_type; /// use std::sync::Arc; /// -/// let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); +/// let data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); /// assert_eq!(base_type(&data_type), DataType::Int32); /// /// let data_type = DataType::Int32; @@ -458,10 +458,10 @@ pub fn base_type(data_type: &DataType) -> DataType { /// use datafusion_common::utils::coerced_type_with_base_type_only; /// use std::sync::Arc; /// -/// let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); +/// let data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); /// let base_type = DataType::Float64; /// let coerced_type = coerced_type_with_base_type_only(&data_type, &base_type); -/// assert_eq!(coerced_type, DataType::List(Arc::new(Field::new("item", DataType::Float64, true)))); +/// assert_eq!(coerced_type, DataType::List(Arc::new(Field::new_list_field(DataType::Float64, true)))); pub fn coerced_type_with_base_type_only( data_type: &DataType, base_type: &DataType, diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 7533e2cff1984..45617d88dc0cf 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -30,7 +30,7 @@ authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version and fails with # "Unable to find key 'package.rust-version' (or 'package.metadata.msrv') in 'arrow-datafusion/Cargo.toml'" # https://github.com/foresterre/cargo-msrv/issues/590 -rust-version = "1.73" +rust-version = "1.75" [lints] workspace = true diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 06a85d3036879..950cb7ddb2d30 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -50,12 +50,11 @@ use datafusion_common::{ }; use datafusion_expr::lit; use datafusion_expr::{ - avg, count, max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, + avg, max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, }; use datafusion_expr::{case, is_null}; -use datafusion_functions_aggregate::expr_fn::sum; -use datafusion_functions_aggregate::expr_fn::{median, stddev}; +use datafusion_functions_aggregate::expr_fn::{count, median, stddev, sum}; use async_trait::async_trait; @@ -854,10 +853,7 @@ impl DataFrame { /// ``` pub async fn count(self) -> Result { let rows = self - .aggregate( - vec![], - vec![datafusion_expr::count(Expr::Literal(COUNT_STAR_EXPANSION))], - )? + .aggregate(vec![], vec![count(Expr::Literal(COUNT_STAR_EXPANSION))])? .collect() .await?; let len = *rows @@ -1594,9 +1590,10 @@ mod tests { use datafusion_common::{Constraint, Constraints}; use datafusion_common_runtime::SpawnedTask; use datafusion_expr::{ - array_agg, cast, count_distinct, create_udf, expr, lit, BuiltInWindowFunction, + array_agg, cast, create_udf, expr, lit, BuiltInWindowFunction, ScalarFunctionImplementation, Volatility, WindowFrame, WindowFunctionDefinition, }; + use datafusion_functions_aggregate::expr_fn::count_distinct; use datafusion_physical_expr::expressions::Column; use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties}; diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 99c38d3f09808..572904254fd75 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -455,6 +455,8 @@ async fn fetch_schema( } /// Read and parse the statistics of the Parquet file at location `path` +/// +/// See [`statistics_from_parquet_meta`] for more details async fn fetch_statistics( store: &dyn ObjectStore, table_schema: SchemaRef, @@ -462,6 +464,17 @@ async fn fetch_statistics( metadata_size_hint: Option, ) -> Result { let metadata = fetch_parquet_metadata(store, file, metadata_size_hint).await?; + statistics_from_parquet_meta(&metadata, table_schema).await +} + +/// Convert statistics in [`ParquetMetaData`] into [`Statistics`] +/// +/// The statistics are calculated for each column in the table schema +/// using the row group statistics in the parquet metadata. +pub async fn statistics_from_parquet_meta( + metadata: &ParquetMetaData, + table_schema: SchemaRef, +) -> Result { let file_metadata = metadata.file_metadata(); let file_schema = parquet_to_arrow_schema( @@ -1402,6 +1415,66 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_statistics_from_parquet_metadata() -> Result<()> { + // Data for column c1: ["Foo", null, "bar"] + let c1: ArrayRef = + Arc::new(StringArray::from(vec![Some("Foo"), None, Some("bar")])); + let batch1 = RecordBatch::try_from_iter(vec![("c1", c1.clone())]).unwrap(); + + // Data for column c2: [1, 2, null] + let c2: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(2), None])); + let batch2 = RecordBatch::try_from_iter(vec![("c2", c2)]).unwrap(); + + // Use store_parquet to write each batch to its own file + // . batch1 written into first file and includes: + // - column c1 that has 3 rows with one null. Stats min and max of string column is missing for this test even the column has values + // . batch2 written into second file and includes: + // - column c2 that has 3 rows with one null. Stats min and max of int are avaialble and 1 and 2 respectively + let store = Arc::new(LocalFileSystem::new()) as _; + let (files, _file_names) = store_parquet(vec![batch1, batch2], false).await?; + + let state = SessionContext::new().state(); + let format = ParquetFormat::default(); + let schema = format.infer_schema(&state, &store, &files).await.unwrap(); + + let null_i64 = ScalarValue::Int64(None); + let null_utf8 = ScalarValue::Utf8(None); + + // Fetch statistics for first file + let pq_meta = fetch_parquet_metadata(store.as_ref(), &files[0], None).await?; + let stats = statistics_from_parquet_meta(&pq_meta, schema.clone()).await?; + // + assert_eq!(stats.num_rows, Precision::Exact(3)); + // column c1 + let c1_stats = &stats.column_statistics[0]; + assert_eq!(c1_stats.null_count, Precision::Exact(1)); + assert_eq!(c1_stats.max_value, Precision::Absent); + assert_eq!(c1_stats.min_value, Precision::Absent); + // column c2: missing from the file so the table treats all 3 rows as null + let c2_stats = &stats.column_statistics[1]; + assert_eq!(c2_stats.null_count, Precision::Exact(3)); + assert_eq!(c2_stats.max_value, Precision::Exact(null_i64.clone())); + assert_eq!(c2_stats.min_value, Precision::Exact(null_i64.clone())); + + // Fetch statistics for second file + let pq_meta = fetch_parquet_metadata(store.as_ref(), &files[1], None).await?; + let stats = statistics_from_parquet_meta(&pq_meta, schema.clone()).await?; + assert_eq!(stats.num_rows, Precision::Exact(3)); + // column c1: missing from the file so the table treats all 3 rows as null + let c1_stats = &stats.column_statistics[0]; + assert_eq!(c1_stats.null_count, Precision::Exact(3)); + assert_eq!(c1_stats.max_value, Precision::Exact(null_utf8.clone())); + assert_eq!(c1_stats.min_value, Precision::Exact(null_utf8.clone())); + // column c2 + let c2_stats = &stats.column_statistics[1]; + assert_eq!(c2_stats.null_count, Precision::Exact(1)); + assert_eq!(c2_stats.max_value, Precision::Exact(2i64.into())); + assert_eq!(c2_stats.min_value, Precision::Exact(1i64.into())); + + Ok(()) + } + #[tokio::test] async fn read_small_batches() -> Result<()> { let config = SessionConfig::new().with_batch_size(2); diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 746e4b8e3330a..7f5e80c4988a5 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -547,20 +547,49 @@ impl ListingOptions { } } -/// Reads data from one or more files via an -/// [`ObjectStore`]. For example, from -/// local files or objects from AWS S3. Implements [`TableProvider`], -/// a DataFusion data source. +/// Reads data from one or more files as a single table. /// -/// # Features +/// Implements [`TableProvider`], a DataFusion data source. The files are read +/// using an [`ObjectStore`] instance, for example from local files or objects +/// from AWS S3. /// -/// 1. Merges schemas if the files have compatible but not identical schemas +/// For example, given the `table1` directory (or object store prefix) /// -/// 2. Hive-style partitioning support, where a path such as -/// `/files/date=1/1/2022/data.parquet` is injected as a `date` column. +/// ```text +/// table1 +/// ├── file1.parquet +/// └── file2.parquet +/// ``` +/// +/// A `ListingTable` would read the files `file1.parquet` and `file2.parquet` as +/// a single table, merging the schemas if the files have compatible but not +/// identical schemas. +/// +/// Given the `table2` directory (or object store prefix) +/// +/// ```text +/// table2 +/// ├── date=2024-06-01 +/// │ ├── file3.parquet +/// │ └── file4.parquet +/// └── date=2024-06-02 +/// └── file5.parquet +/// ``` +/// +/// A `ListingTable` would read the files `file3.parquet`, `file4.parquet`, and +/// `file5.parquet` as a single table, again merging schemas if necessary. +/// +/// Given the hive style partitioning structure (e.g,. directories named +/// `date=2024-06-01` and `date=2026-06-02`), `ListingTable` also adds a `date` +/// column when reading the table: +/// * The files in `table2/date=2024-06-01` will have the value `2024-06-01` +/// * The files in `table2/date=2024-06-02` will have the value `2024-06-02`. +/// +/// If the query has a predicate like `WHERE date = '2024-06-01'` +/// only the corresponding directory will be read. /// -/// 3. Projection pushdown for formats that support it such as such as -/// Parquet +/// `ListingTable` also supports filter and projection pushdown for formats that +/// support it as such as Parquet. /// /// # Example /// diff --git a/datafusion/core/src/datasource/physical_plan/parquet/access_plan.rs b/datafusion/core/src/datasource/physical_plan/parquet/access_plan.rs index f51f2c49e896c..e15e907cd9b80 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/access_plan.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/access_plan.rs @@ -384,7 +384,7 @@ mod test { let access_plan = ParquetAccessPlan::new(vec![ RowGroupAccess::Scan, RowGroupAccess::Selection( - // select / skip all 20 rows in row group 1 + // specifies all 20 rows in row group 1 vec![ RowSelector::select(5), RowSelector::skip(7), @@ -463,7 +463,7 @@ mod test { fn test_invalid_too_few() { let access_plan = ParquetAccessPlan::new(vec![ RowGroupAccess::Scan, - // select 12 rows, but row group 1 has 20 + // specify only 12 rows in selection, but row group 1 has 20 RowGroupAccess::Selection( vec![RowSelector::select(5), RowSelector::skip(7)].into(), ), @@ -484,7 +484,7 @@ mod test { fn test_invalid_too_many() { let access_plan = ParquetAccessPlan::new(vec![ RowGroupAccess::Scan, - // select 22 rows, but row group 1 has only 20 + // specify 22 rows in selection, but row group 1 has only 20 RowGroupAccess::Selection( vec![ RowSelector::select(10), diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 5e5cc93bc54f4..ec21c5504c694 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -156,9 +156,8 @@ pub use writer::plan_to_parquet; /// used to implement external indexes on top of parquet files and select only /// portions of the files. /// -/// The `ParquetExec` will try and further reduce any provided -/// `ParquetAccessPlan` further based on the contents of `ParquetMetadata` and -/// other settings. +/// The `ParquetExec` will try and reduce any provided `ParquetAccessPlan` +/// further based on the contents of `ParquetMetadata` and other settings. /// /// ## Example of providing a ParquetAccessPlan /// diff --git a/datafusion/core/src/datasource/physical_plan/parquet/opener.rs b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs index 8557c6d5f9508..36335863032c1 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/opener.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs @@ -238,6 +238,8 @@ fn create_initial_plan( // check row group count matches the plan return Ok(access_plan.clone()); + } else { + debug!("ParquetExec Ignoring unknown extension specified for {file_name}"); } } diff --git a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs index a4a919f20d0f1..c0d36f1fc4d7e 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! [`min_statistics`] and [`max_statistics`] convert statistics in parquet format to arrow [`ArrayRef`]. +//! [`StatisticsConverter`] to convert statistics in parquet format to arrow [`ArrayRef`]. // TODO: potentially move this to arrow-rs: https://github.com/apache/arrow-rs/issues/4328 @@ -542,8 +542,11 @@ pub(crate) fn parquet_column<'a>( Some((parquet_idx, field)) } -/// Extracts the min statistics from an iterator of [`ParquetStatistics`] to an [`ArrayRef`] -pub(crate) fn min_statistics<'a, I: Iterator>>( +/// Extracts the min statistics from an iterator of [`ParquetStatistics`] to an +/// [`ArrayRef`] +/// +/// This is an internal helper -- see [`StatisticsConverter`] for public API +fn min_statistics<'a, I: Iterator>>( data_type: &DataType, iterator: I, ) -> Result { @@ -551,7 +554,9 @@ pub(crate) fn min_statistics<'a, I: Iterator>>( +/// +/// This is an internal helper -- see [`StatisticsConverter`] for public API +fn max_statistics<'a, I: Iterator>>( data_type: &DataType, iterator: I, ) -> Result { @@ -1425,9 +1430,10 @@ mod test { assert_eq!(idx, 2); let row_groups = metadata.row_groups(); - let iter = row_groups.iter().map(|x| x.column(idx).statistics()); + let converter = + StatisticsConverter::try_new("int_col", &schema, parquet_schema).unwrap(); - let min = min_statistics(&DataType::Int32, iter.clone()).unwrap(); + let min = converter.row_group_mins(row_groups.iter()).unwrap(); assert_eq!( &min, &expected_min, @@ -1435,7 +1441,7 @@ mod test { DisplayStats(row_groups) ); - let max = max_statistics(&DataType::Int32, iter).unwrap(); + let max = converter.row_group_maxes(row_groups.iter()).unwrap(); assert_eq!( &max, &expected_max, @@ -1623,22 +1629,23 @@ mod test { continue; } - let (idx, f) = - parquet_column(parquet_schema, &schema, field.name()).unwrap(); - assert_eq!(f, field); + let converter = + StatisticsConverter::try_new(field.name(), &schema, parquet_schema) + .unwrap(); - let iter = row_groups.iter().map(|x| x.column(idx).statistics()); - let min = min_statistics(f.data_type(), iter.clone()).unwrap(); + assert_eq!(converter.arrow_field, field.as_ref()); + + let mins = converter.row_group_mins(row_groups.iter()).unwrap(); assert_eq!( - &min, + &mins, &expected_min, "Min. Statistics\n\n{}\n\n", DisplayStats(row_groups) ); - let max = max_statistics(f.data_type(), iter).unwrap(); + let maxes = converter.row_group_maxes(row_groups.iter()).unwrap(); assert_eq!( - &max, + &maxes, &expected_max, "Max. Statistics\n\n{}\n\n", DisplayStats(row_groups) @@ -1705,7 +1712,7 @@ mod test { self } - /// Reads the specified parquet file and validates that the exepcted min/max + /// Reads the specified parquet file and validates that the expected min/max /// values for the specified columns are as expected. fn run(self) { let path = PathBuf::from(parquet_test_data()).join(self.file_name); @@ -1723,14 +1730,13 @@ mod test { expected_max, } = expected_column; - let (idx, field) = - parquet_column(parquet_schema, arrow_schema, name).unwrap(); - - let iter = row_groups.iter().map(|x| x.column(idx).statistics()); - let actual_min = min_statistics(field.data_type(), iter.clone()).unwrap(); + let converter = + StatisticsConverter::try_new(name, arrow_schema, parquet_schema) + .unwrap(); + let actual_min = converter.row_group_mins(row_groups.iter()).unwrap(); assert_eq!(&expected_min, &actual_min, "column {name}"); - let actual_max = max_statistics(field.data_type(), iter).unwrap(); + let actual_max = converter.row_group_maxes(row_groups.iter()).unwrap(); assert_eq!(&expected_max, &actual_max, "column {name}"); } } diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index 05f05d95b8db7..eeacc48b85dbc 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -170,38 +170,6 @@ fn take_optimizable_column_and_table_count( } } } - // TODO: Remove this after revmoing Builtin Count - else if let (&Precision::Exact(num_rows), Some(casted_expr)) = ( - &stats.num_rows, - agg_expr.as_any().downcast_ref::(), - ) { - // TODO implementing Eq on PhysicalExpr would help a lot here - if casted_expr.expressions().len() == 1 { - // TODO optimize with exprs other than Column - if let Some(col_expr) = casted_expr.expressions()[0] - .as_any() - .downcast_ref::() - { - let current_val = &col_stats[col_expr.index()].null_count; - if let &Precision::Exact(val) = current_val { - return Some(( - ScalarValue::Int64(Some((num_rows - val) as i64)), - casted_expr.name().to_string(), - )); - } - } else if let Some(lit_expr) = casted_expr.expressions()[0] - .as_any() - .downcast_ref::() - { - if lit_expr.value() == &COUNT_STAR_EXPANSION { - return Some(( - ScalarValue::Int64(Some(num_rows as i64)), - casted_expr.name().to_owned(), - )); - } - } - } - } None } @@ -307,13 +275,12 @@ fn take_optimizable_max( #[cfg(test)] pub(crate) mod tests { - use super::*; + use crate::logical_expr::Operator; use crate::physical_plan::aggregates::PhysicalGroupBy; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::common; - use crate::physical_plan::expressions::Count; use crate::physical_plan::filter::FilterExec; use crate::physical_plan::memory::MemoryExec; use crate::prelude::SessionContext; @@ -322,8 +289,10 @@ pub(crate) mod tests { use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::cast::as_int64_array; + use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::expressions::cast; use datafusion_physical_expr::PhysicalExpr; + use datafusion_physical_expr_common::aggregate::create_aggregate_expr; use datafusion_physical_plan::aggregates::AggregateMode; /// Mock data using a MemoryExec which has an exact count statistic @@ -414,13 +383,19 @@ pub(crate) mod tests { Self::ColumnA(schema.clone()) } - /// Return appropriate expr depending if COUNT is for col or table (*) - pub(crate) fn count_expr(&self) -> Arc { - Arc::new(Count::new( - self.column(), + // Return appropriate expr depending if COUNT is for col or table (*) + pub(crate) fn count_expr(&self, schema: &Schema) -> Arc { + create_aggregate_expr( + &count_udaf(), + &[self.column()], + &[], + &[], + schema, self.column_name(), - DataType::Int64, - )) + false, + false, + ) + .unwrap() } /// what argument would this aggregate need in the plan? @@ -458,7 +433,7 @@ pub(crate) mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], source, Arc::clone(&schema), @@ -467,7 +442,7 @@ pub(crate) mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], Arc::new(partial_agg), Arc::clone(&schema), @@ -488,7 +463,7 @@ pub(crate) mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], source, Arc::clone(&schema), @@ -497,7 +472,7 @@ pub(crate) mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], Arc::new(partial_agg), Arc::clone(&schema), @@ -517,7 +492,7 @@ pub(crate) mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], source, Arc::clone(&schema), @@ -529,7 +504,7 @@ pub(crate) mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], Arc::new(coalesce), Arc::clone(&schema), @@ -549,7 +524,7 @@ pub(crate) mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], source, Arc::clone(&schema), @@ -561,7 +536,7 @@ pub(crate) mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], Arc::new(coalesce), Arc::clone(&schema), @@ -592,7 +567,7 @@ pub(crate) mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], filter, Arc::clone(&schema), @@ -601,7 +576,7 @@ pub(crate) mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], Arc::new(partial_agg), Arc::clone(&schema), @@ -637,7 +612,7 @@ pub(crate) mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], filter, Arc::clone(&schema), @@ -646,7 +621,7 @@ pub(crate) mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], Arc::new(partial_agg), Arc::clone(&schema), diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index 3ad61e52c82e0..38b92959e8412 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -206,8 +206,9 @@ mod tests { use crate::physical_plan::{displayable, Partitioning}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::sum::sum_udaf; - use datafusion_physical_expr::expressions::{col, Count}; + use datafusion_physical_expr::expressions::col; use datafusion_physical_plan::udaf::create_aggregate_expr; /// Runs the CombinePartialFinalAggregate optimizer and asserts the plan against the expected @@ -303,15 +304,31 @@ mod tests { ) } + // Return appropriate expr depending if COUNT is for col or table (*) + fn count_expr( + expr: Arc, + name: &str, + schema: &Schema, + ) -> Arc { + create_aggregate_expr( + &count_udaf(), + &[expr], + &[], + &[], + schema, + name, + false, + false, + ) + .unwrap() + } + #[test] fn aggregations_not_combined() -> Result<()> { let schema = schema(); - let aggr_expr = vec![Arc::new(Count::new( - lit(1i8), - "COUNT(1)".to_string(), - DataType::Int64, - )) as _]; + let aggr_expr = vec![count_expr(lit(1i8), "COUNT(1)", &schema)]; + let plan = final_aggregate_exec( repartition_exec(partial_aggregate_exec( parquet_exec(&schema), @@ -330,16 +347,8 @@ mod tests { ]; assert_optimized!(expected, plan); - let aggr_expr1 = vec![Arc::new(Count::new( - lit(1i8), - "COUNT(1)".to_string(), - DataType::Int64, - )) as _]; - let aggr_expr2 = vec![Arc::new(Count::new( - lit(1i8), - "COUNT(2)".to_string(), - DataType::Int64, - )) as _]; + let aggr_expr1 = vec![count_expr(lit(1i8), "COUNT(1)", &schema)]; + let aggr_expr2 = vec![count_expr(lit(1i8), "COUNT(2)", &schema)]; let plan = final_aggregate_exec( partial_aggregate_exec( @@ -365,11 +374,7 @@ mod tests { #[test] fn aggregations_combined() -> Result<()> { let schema = schema(); - let aggr_expr = vec![Arc::new(Count::new( - lit(1i8), - "COUNT(1)".to_string(), - DataType::Int64, - )) as _]; + let aggr_expr = vec![count_expr(lit(1i8), "COUNT(1)", &schema)]; let plan = final_aggregate_exec( partial_aggregate_exec( diff --git a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs index 1274fbe50a5fb..f9d5a4c186eee 100644 --- a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs @@ -517,10 +517,10 @@ mod tests { let single_agg = AggregateExec::try_new( AggregateMode::Single, build_group_by(&schema.clone(), vec!["a".to_string()]), - vec![agg.count_expr()], /* aggr_expr */ - vec![None], /* filter_expr */ - source, /* input */ - schema.clone(), /* input_schema */ + vec![agg.count_expr(&schema)], /* aggr_expr */ + vec![None], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ )?; let limit_exec = LocalLimitExec::new( Arc::new(single_agg), @@ -554,10 +554,10 @@ mod tests { let single_agg = AggregateExec::try_new( AggregateMode::Single, build_group_by(&schema.clone(), vec!["a".to_string()]), - vec![agg.count_expr()], /* aggr_expr */ - vec![filter_expr], /* filter_expr */ - source, /* input */ - schema.clone(), /* input_schema */ + vec![agg.count_expr(&schema)], /* aggr_expr */ + vec![filter_expr], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ )?; let limit_exec = LocalLimitExec::new( Arc::new(single_agg), diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 5895c39a5f87d..154e77cd23ae8 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -43,7 +43,8 @@ use arrow_schema::{Schema, SchemaRef, SortOptions}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::JoinType; use datafusion_execution::object_store::ObjectStoreUrl; -use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunctionDefinition}; +use datafusion_expr::{WindowFrame, WindowFunctionDefinition}; +use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use datafusion_physical_plan::displayable; @@ -240,7 +241,7 @@ pub fn bounded_window_exec( Arc::new( crate::physical_plan::windows::BoundedWindowAggExec::try_new( vec![create_window_expr( - &WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + &WindowFunctionDefinition::AggregateUDF(count_udaf()), "count".to_owned(), &[col(col_name, &schema).unwrap()], &[], diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 79033643cf378..4f91875950183 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -2181,7 +2181,6 @@ impl DefaultPhysicalPlanner { expr: &[Expr], ) -> Result> { let input_schema = input.as_ref().schema(); - let physical_exprs = expr .iter() .map(|e| { diff --git a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs index 8c9cffcf08d1b..068383b200315 100644 --- a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs +++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs @@ -35,6 +35,7 @@ use datafusion::scalar::ScalarValue; use datafusion_common::cast::as_primitive_array; use datafusion_common::{internal_err, not_impl_err}; use datafusion_expr::expr::{BinaryExpr, Cast}; +use datafusion_functions_aggregate::expr_fn::count; use datafusion_physical_expr::EquivalenceProperties; use async_trait::async_trait; diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index befd98d043022..fa364c5f2a653 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -31,6 +31,7 @@ use arrow::{ }; use arrow_array::Float32Array; use arrow_schema::ArrowError; +use datafusion_functions_aggregate::count::count_udaf; use object_store::local::LocalFileSystem; use std::fs; use std::sync::Arc; @@ -51,11 +52,11 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ - array_agg, avg, cast, col, count, exists, expr, in_subquery, lit, max, out_ref_col, - placeholder, scalar_subquery, when, wildcard, AggregateFunction, Expr, ExprSchemable, - WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, + array_agg, avg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col, + placeholder, scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, + WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; -use datafusion_functions_aggregate::expr_fn::sum; +use datafusion_functions_aggregate::expr_fn::{count, sum}; #[tokio::test] async fn test_count_wildcard_on_sort() -> Result<()> { @@ -178,7 +179,7 @@ async fn test_count_wildcard_on_window() -> Result<()> { .table("t1") .await? .select(vec![Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], vec![], vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 824f1eec4a853..516749e82a531 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -22,6 +22,11 @@ use arrow::compute::SortOptions; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; use arrow_schema::Schema; + +use datafusion_common::ScalarValue; +use datafusion_physical_expr::expressions::Literal; +use datafusion_physical_expr::PhysicalExprRef; + use rand::Rng; use datafusion::common::JoinSide; @@ -40,92 +45,210 @@ use test_utils::stagger_batch_with_seed; #[tokio::test] async fn test_inner_join_1k() { - run_join_test( + JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::Inner, + None, ) + .run_test() + .await +} + +fn less_than_100_join_filter(schema1: Arc, _schema2: Arc) -> JoinFilter { + let less_than_100 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Lt, + Arc::new(Literal::new(ScalarValue::from(100))), + )) as _; + let column_indices = vec![ColumnIndex { + index: 0, + side: JoinSide::Left, + }]; + let intermediate_schema = + Schema::new(vec![schema1.field_with_name("a").unwrap().to_owned()]); + + JoinFilter::new(less_than_100, column_indices, intermediate_schema) +} + +#[tokio::test] +async fn test_inner_join_1k_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::Inner, + Some(Box::new(less_than_100_join_filter)), + ) + .run_test() + .await +} + +#[tokio::test] +async fn test_inner_join_1k_smjoin() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::Inner, + None, + ) + .run_test() .await } #[tokio::test] async fn test_left_join_1k() { - run_join_test( + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::Left, + None, + ) + .run_test() + .await +} + +#[tokio::test] +async fn test_left_join_1k_filtered() { + JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::Left, + Some(Box::new(less_than_100_join_filter)), ) + .run_test() .await } #[tokio::test] async fn test_right_join_1k() { - run_join_test( + JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::Right, + None, ) + .run_test() + .await +} +// Add support for Right filtered joins +#[ignore] +#[tokio::test] +async fn test_right_join_1k_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::Right, + Some(Box::new(less_than_100_join_filter)), + ) + .run_test() .await } #[tokio::test] async fn test_full_join_1k() { - run_join_test( + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::Full, + None, + ) + .run_test() + .await +} + +#[tokio::test] +async fn test_full_join_1k_filtered() { + JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::Full, + Some(Box::new(less_than_100_join_filter)), ) + .run_test() .await } #[tokio::test] async fn test_semi_join_1k() { - run_join_test( + JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::LeftSemi, + None, ) + .run_test() + .await +} + +// The test is flaky +// https://github.com/apache/datafusion/issues/10886 +#[ignore] +#[tokio::test] +async fn test_semi_join_1k_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::LeftSemi, + Some(Box::new(less_than_100_join_filter)), + ) + .run_test() .await } #[tokio::test] async fn test_anti_join_1k() { - run_join_test( + JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::LeftAnti, + None, ) + .run_test() .await } -/// Perform sort-merge join and hash join on same input -/// and verify two outputs are equal -async fn run_join_test( +// Test failed for now. https://github.com/apache/datafusion/issues/10872 +#[ignore] +#[tokio::test] +async fn test_anti_join_1k_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::LeftAnti, + Some(Box::new(less_than_100_join_filter)), + ) + .run_test() + .await +} + +type JoinFilterBuilder = Box, Arc) -> JoinFilter>; + +struct JoinFuzzTestCase { + batch_sizes: &'static [usize], input1: Vec, input2: Vec, join_type: JoinType, -) { - let batch_sizes = [1, 2, 7, 49, 50, 51, 100]; - for batch_size in batch_sizes { - let session_config = SessionConfig::new().with_batch_size(batch_size); - let ctx = SessionContext::new_with_config(session_config); - let task_ctx = ctx.task_ctx(); - - let schema1 = input1[0].schema(); - let schema2 = input2[0].schema(); - let on_columns = vec![ - ( - Arc::new(Column::new_with_schema("a", &schema1).unwrap()) as _, - Arc::new(Column::new_with_schema("a", &schema2).unwrap()) as _, - ), - ( - Arc::new(Column::new_with_schema("b", &schema1).unwrap()) as _, - Arc::new(Column::new_with_schema("b", &schema2).unwrap()) as _, - ), - ]; + join_filter_builder: Option, +} - // Nested loop join uses filter for joining records - let column_indices = vec![ +impl JoinFuzzTestCase { + fn new( + input1: Vec, + input2: Vec, + join_type: JoinType, + join_filter_builder: Option, + ) -> Self { + Self { + batch_sizes: &[1, 2, 7, 49, 50, 51, 100], + input1, + input2, + join_type, + join_filter_builder, + } + } + + fn column_indices(&self) -> Vec { + vec![ ColumnIndex { index: 0, side: JoinSide::Left, @@ -142,120 +265,193 @@ async fn run_join_test( index: 1, side: JoinSide::Right, }, - ]; - let intermediate_schema = Schema::new(vec![ - schema1.field_with_name("a").unwrap().to_owned(), - schema1.field_with_name("b").unwrap().to_owned(), - schema2.field_with_name("a").unwrap().to_owned(), - schema2.field_with_name("b").unwrap().to_owned(), - ]); + ] + } - let equal_a = Arc::new(BinaryExpr::new( - Arc::new(Column::new("a", 0)), - Operator::Eq, - Arc::new(Column::new("a", 2)), - )) as _; - let equal_b = Arc::new(BinaryExpr::new( - Arc::new(Column::new("b", 1)), - Operator::Eq, - Arc::new(Column::new("b", 3)), - )) as _; - let expression = Arc::new(BinaryExpr::new(equal_a, Operator::And, equal_b)) as _; + fn on_columns(&self) -> Vec<(PhysicalExprRef, PhysicalExprRef)> { + let schema1 = self.input1[0].schema(); + let schema2 = self.input2[0].schema(); + vec![ + ( + Arc::new(Column::new_with_schema("a", &schema1).unwrap()) as _, + Arc::new(Column::new_with_schema("a", &schema2).unwrap()) as _, + ), + ( + Arc::new(Column::new_with_schema("b", &schema1).unwrap()) as _, + Arc::new(Column::new_with_schema("b", &schema2).unwrap()) as _, + ), + ] + } - let on_filter = JoinFilter::new(expression, column_indices, intermediate_schema); + fn intermediate_schema(&self) -> Schema { + let schema1 = self.input1[0].schema(); + let schema2 = self.input2[0].schema(); + Schema::new(vec![ + schema1 + .field_with_name("a") + .unwrap() + .to_owned() + .with_nullable(true), + schema1 + .field_with_name("b") + .unwrap() + .to_owned() + .with_nullable(true), + schema2.field_with_name("a").unwrap().to_owned(), + schema2.field_with_name("b").unwrap().to_owned(), + ]) + } - // sort-merge join + fn left_right(&self) -> (Arc, Arc) { + let schema1 = self.input1[0].schema(); + let schema2 = self.input2[0].schema(); let left = Arc::new( - MemoryExec::try_new(&[input1.clone()], schema1.clone(), None).unwrap(), + MemoryExec::try_new(&[self.input1.clone()], schema1.clone(), None).unwrap(), ); let right = Arc::new( - MemoryExec::try_new(&[input2.clone()], schema2.clone(), None).unwrap(), + MemoryExec::try_new(&[self.input2.clone()], schema2.clone(), None).unwrap(), ); - let smj = Arc::new( + (left, right) + } + + fn join_filter(&self) -> Option { + let schema1 = self.input1[0].schema(); + let schema2 = self.input2[0].schema(); + self.join_filter_builder + .as_ref() + .map(|builder| builder(schema1, schema2)) + } + + fn sort_merge_join(&self) -> Arc { + let (left, right) = self.left_right(); + Arc::new( SortMergeJoinExec::try_new( left, right, - on_columns.clone(), - None, - join_type, + self.on_columns().clone(), + self.join_filter(), + self.join_type, vec![SortOptions::default(), SortOptions::default()], false, ) .unwrap(), - ); - let smj_collected = collect(smj, task_ctx.clone()).await.unwrap(); + ) + } - // hash join - let left = Arc::new( - MemoryExec::try_new(&[input1.clone()], schema1.clone(), None).unwrap(), - ); - let right = Arc::new( - MemoryExec::try_new(&[input2.clone()], schema2.clone(), None).unwrap(), - ); - let hj = Arc::new( + fn hash_join(&self) -> Arc { + let (left, right) = self.left_right(); + Arc::new( HashJoinExec::try_new( left, right, - on_columns.clone(), - None, - &join_type, + self.on_columns().clone(), + self.join_filter(), + &self.join_type, None, PartitionMode::Partitioned, false, ) .unwrap(), - ); - let hj_collected = collect(hj, task_ctx.clone()).await.unwrap(); + ) + } - // nested loop join - let left = Arc::new( - MemoryExec::try_new(&[input1.clone()], schema1.clone(), None).unwrap(), - ); - let right = Arc::new( - MemoryExec::try_new(&[input2.clone()], schema2.clone(), None).unwrap(), - ); - let nlj = Arc::new( - NestedLoopJoinExec::try_new(left, right, Some(on_filter), &join_type) - .unwrap(), - ); - let nlj_collected = collect(nlj, task_ctx.clone()).await.unwrap(); + fn nested_loop_join(&self) -> Arc { + let (left, right) = self.left_right(); + // Nested loop join uses filter for joining records + let column_indices = self.column_indices(); + let intermediate_schema = self.intermediate_schema(); + + let equal_a = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Column::new("a", 2)), + )) as _; + let equal_b = Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Eq, + Arc::new(Column::new("b", 3)), + )) as _; + let expression = Arc::new(BinaryExpr::new(equal_a, Operator::And, equal_b)) as _; - // compare - let smj_formatted = pretty_format_batches(&smj_collected).unwrap().to_string(); - let hj_formatted = pretty_format_batches(&hj_collected).unwrap().to_string(); - let nlj_formatted = pretty_format_batches(&nlj_collected).unwrap().to_string(); + let on_filter = JoinFilter::new(expression, column_indices, intermediate_schema); - let mut smj_formatted_sorted: Vec<&str> = smj_formatted.trim().lines().collect(); - smj_formatted_sorted.sort_unstable(); + Arc::new( + NestedLoopJoinExec::try_new(left, right, Some(on_filter), &self.join_type) + .unwrap(), + ) + } - let mut hj_formatted_sorted: Vec<&str> = hj_formatted.trim().lines().collect(); - hj_formatted_sorted.sort_unstable(); + /// Perform sort-merge join and hash join on same input + /// and verify two outputs are equal + async fn run_test(&self) { + for batch_size in self.batch_sizes { + let session_config = SessionConfig::new().with_batch_size(*batch_size); + let ctx = SessionContext::new_with_config(session_config); + let task_ctx = ctx.task_ctx(); + let smj = self.sort_merge_join(); + let smj_collected = collect(smj, task_ctx.clone()).await.unwrap(); - let mut nlj_formatted_sorted: Vec<&str> = nlj_formatted.trim().lines().collect(); - nlj_formatted_sorted.sort_unstable(); + let hj = self.hash_join(); + let hj_collected = collect(hj, task_ctx.clone()).await.unwrap(); - for (i, (smj_line, hj_line)) in smj_formatted_sorted - .iter() - .zip(&hj_formatted_sorted) - .enumerate() - { - assert_eq!( - (i, smj_line), - (i, hj_line), - "SortMergeJoinExec and HashJoinExec produced different results" - ); - } + // Get actual row counts(without formatting overhead) for HJ and SMJ + let hj_rows = hj_collected.iter().fold(0, |acc, b| acc + b.num_rows()); + let smj_rows = smj_collected.iter().fold(0, |acc, b| acc + b.num_rows()); - for (i, (nlj_line, hj_line)) in nlj_formatted_sorted - .iter() - .zip(&hj_formatted_sorted) - .enumerate() - { assert_eq!( - (i, nlj_line), - (i, hj_line), - "NestedLoopJoinExec and HashJoinExec produced different results" + hj_rows, smj_rows, + "SortMergeJoinExec and HashJoinExec produced different row counts" ); + + let nlj = self.nested_loop_join(); + let nlj_collected = collect(nlj, task_ctx.clone()).await.unwrap(); + + // compare + let smj_formatted = + pretty_format_batches(&smj_collected).unwrap().to_string(); + let hj_formatted = pretty_format_batches(&hj_collected).unwrap().to_string(); + let nlj_formatted = + pretty_format_batches(&nlj_collected).unwrap().to_string(); + + let mut smj_formatted_sorted: Vec<&str> = + smj_formatted.trim().lines().collect(); + smj_formatted_sorted.sort_unstable(); + + let mut hj_formatted_sorted: Vec<&str> = + hj_formatted.trim().lines().collect(); + hj_formatted_sorted.sort_unstable(); + + let mut nlj_formatted_sorted: Vec<&str> = + nlj_formatted.trim().lines().collect(); + nlj_formatted_sorted.sort_unstable(); + + // row level compare if any of joins returns the result + // the reason is different formatting when there is no rows + if smj_rows > 0 || hj_rows > 0 { + for (i, (smj_line, hj_line)) in smj_formatted_sorted + .iter() + .zip(&hj_formatted_sorted) + .enumerate() + { + assert_eq!( + (i, smj_line), + (i, hj_line), + "SortMergeJoinExec and HashJoinExec produced different results" + ); + } + } + + for (i, (nlj_line, hj_line)) in nlj_formatted_sorted + .iter() + .zip(&hj_formatted_sorted) + .enumerate() + { + assert_eq!( + (i, nlj_line), + (i, hj_line), + "NestedLoopJoinExec and HashJoinExec produced different results" + ); + } } } } diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index b85f6376c3f27..4358691ee5a58 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -38,6 +38,7 @@ use datafusion_expr::{ AggregateFunction, BuiltInWindowFunction, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; +use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_physical_expr::expressions::{cast, col, lit}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; @@ -165,7 +166,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { // ) ( // Window function - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateUDF(count_udaf()), // its name "COUNT", // window function argument @@ -350,7 +351,7 @@ fn get_random_function( window_fn_map.insert( "count", ( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![arg.clone()], ), ); diff --git a/datafusion/core/tests/parquet/arrow_statistics.rs b/datafusion/core/tests/parquet/arrow_statistics.rs index 0e23e6824027c..2ea18d7cf8232 100644 --- a/datafusion/core/tests/parquet/arrow_statistics.rs +++ b/datafusion/core/tests/parquet/arrow_statistics.rs @@ -30,8 +30,7 @@ use arrow::datatypes::{ use arrow_array::{ make_array, Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array, Decimal256Array, FixedSizeBinaryArray, Float16Array, Float32Array, - Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray, - IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray, + Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, LargeBinaryArray, LargeStringArray, RecordBatch, StringArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, @@ -1061,84 +1060,6 @@ async fn test_dates_64_diff_rg_sizes() { .run(); } -#[tokio::test] -#[should_panic] -// Currently this test `should_panic` since statistics for `Intervals` -// are not supported and `IntervalMonthDayNano` cannot be written -// to parquet yet. -// Refer to issue: https://github.com/apache/arrow-rs/issues/5847 -// and https://github.com/apache/arrow-rs/blob/master/parquet/src/arrow/arrow_writer/mod.rs#L747 -async fn test_interval_diff_rg_sizes() { - // This creates a parquet files of 3 columns: - // "year_month" --> IntervalYearMonthArray - // "day_time" --> IntervalDayTimeArray - // "month_day_nano" --> IntervalMonthDayNanoArray - // - // The file is created by 4 record batches (each has a null row) - // each has 5 rows but then will be split into 2 row groups with size 13, 7 - let reader = TestReader { - scenario: Scenario::Interval, - row_per_group: 13, - } - .build() - .await; - - // TODO: expected values need to be changed once issue is resolved - // expected_min: Arc::new(IntervalYearMonthArray::from(vec![ - // IntervalYearMonthType::make_value(1, 10), - // IntervalYearMonthType::make_value(4, 13), - // ])), - // expected_max: Arc::new(IntervalYearMonthArray::from(vec![ - // IntervalYearMonthType::make_value(6, 51), - // IntervalYearMonthType::make_value(8, 53), - // ])), - Test { - reader: &reader, - expected_min: Arc::new(IntervalYearMonthArray::from(vec![None, None])), - expected_max: Arc::new(IntervalYearMonthArray::from(vec![None, None])), - expected_null_counts: UInt64Array::from(vec![2, 2]), - expected_row_counts: UInt64Array::from(vec![13, 7]), - column_name: "year_month", - } - .run(); - - // expected_min: Arc::new(IntervalDayTimeArray::from(vec![ - // IntervalDayTimeType::make_value(1, 10), - // IntervalDayTimeType::make_value(4, 13), - // ])), - // expected_max: Arc::new(IntervalDayTimeArray::from(vec![ - // IntervalDayTimeType::make_value(6, 51), - // IntervalDayTimeType::make_value(8, 53), - // ])), - Test { - reader: &reader, - expected_min: Arc::new(IntervalDayTimeArray::from(vec![None, None])), - expected_max: Arc::new(IntervalDayTimeArray::from(vec![None, None])), - expected_null_counts: UInt64Array::from(vec![2, 2]), - expected_row_counts: UInt64Array::from(vec![13, 7]), - column_name: "day_time", - } - .run(); - - // expected_min: Arc::new(IntervalMonthDayNanoArray::from(vec![ - // IntervalMonthDayNanoType::make_value(1, 10, 100), - // IntervalMonthDayNanoType::make_value(4, 13, 103), - // ])), - // expected_max: Arc::new(IntervalMonthDayNanoArray::from(vec![ - // IntervalMonthDayNanoType::make_value(6, 51, 501), - // IntervalMonthDayNanoType::make_value(8, 53, 503), - // ])), - Test { - reader: &reader, - expected_min: Arc::new(IntervalMonthDayNanoArray::from(vec![None, None])), - expected_max: Arc::new(IntervalMonthDayNanoArray::from(vec![None, None])), - expected_null_counts: UInt64Array::from(vec![2, 2]), - expected_row_counts: UInt64Array::from(vec![13, 7]), - column_name: "month_day_nano", - } - .run(); -} - #[tokio::test] async fn test_uint() { // This creates a parquet files of 4 columns named "u8", "u16", "u32", "u64" diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 5ab268beb92f9..9546ab30c9e01 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -18,9 +18,7 @@ //! Parquet integration tests use crate::parquet::utils::MetricsFinder; use arrow::array::Decimal128Array; -use arrow::datatypes::{ - i256, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType, -}; +use arrow::datatypes::i256; use arrow::{ array::{ make_array, Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array, @@ -36,10 +34,6 @@ use arrow::{ record_batch::RecordBatch, util::pretty::pretty_format_batches, }; -use arrow_array::{ - IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, -}; -use arrow_schema::IntervalUnit; use chrono::{Datelike, Duration, TimeDelta}; use datafusion::{ datasource::{provider_as_source, TableProvider}, @@ -92,7 +86,6 @@ enum Scenario { Time32Millisecond, Time64Nanosecond, Time64Microsecond, - Interval, /// 7 Rows, for each i8, i16, i32, i64, u8, u16, u32, u64, f32, f64 /// -MIN, -100, -1, 0, 1, 100, MAX NumericLimits, @@ -921,71 +914,6 @@ fn make_dict_batch() -> RecordBatch { .unwrap() } -fn make_interval_batch(offset: i32) -> RecordBatch { - let schema = Schema::new(vec![ - Field::new( - "year_month", - DataType::Interval(IntervalUnit::YearMonth), - true, - ), - Field::new("day_time", DataType::Interval(IntervalUnit::DayTime), true), - Field::new( - "month_day_nano", - DataType::Interval(IntervalUnit::MonthDayNano), - true, - ), - ]); - let schema = Arc::new(schema); - - let ym_arr = IntervalYearMonthArray::from(vec![ - Some(IntervalYearMonthType::make_value(1 + offset, 10 + offset)), - Some(IntervalYearMonthType::make_value(2 + offset, 20 + offset)), - Some(IntervalYearMonthType::make_value(3 + offset, 30 + offset)), - None, - Some(IntervalYearMonthType::make_value(5 + offset, 50 + offset)), - ]); - - let dt_arr = IntervalDayTimeArray::from(vec![ - Some(IntervalDayTimeType::make_value(1 + offset, 10 + offset)), - Some(IntervalDayTimeType::make_value(2 + offset, 20 + offset)), - Some(IntervalDayTimeType::make_value(3 + offset, 30 + offset)), - None, - Some(IntervalDayTimeType::make_value(5 + offset, 50 + offset)), - ]); - - // Not yet implemented, refer to: - // https://github.com/apache/arrow-rs/blob/master/parquet/src/arrow/arrow_writer/mod.rs#L747 - let mdn_arr = IntervalMonthDayNanoArray::from(vec![ - Some(IntervalMonthDayNanoType::make_value( - 1 + offset, - 10 + offset, - 100 + (offset as i64), - )), - Some(IntervalMonthDayNanoType::make_value( - 2 + offset, - 20 + offset, - 200 + (offset as i64), - )), - Some(IntervalMonthDayNanoType::make_value( - 3 + offset, - 30 + offset, - 300 + (offset as i64), - )), - None, - Some(IntervalMonthDayNanoType::make_value( - 5 + offset, - 50 + offset, - 500 + (offset as i64), - )), - ]); - - RecordBatch::try_new( - schema, - vec![Arc::new(ym_arr), Arc::new(dt_arr), Arc::new(mdn_arr)], - ) - .unwrap() -} - fn create_data_batch(scenario: Scenario) -> Vec { match scenario { Scenario::Boolean => { @@ -1407,12 +1335,6 @@ fn create_data_batch(scenario: Scenario) -> Vec { ]), ] } - Scenario::Interval => vec![ - make_interval_batch(0), - make_interval_batch(1), - make_interval_batch(2), - make_interval_batch(3), - ], } } diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index e3d2e6555d5cd..81562bf12476a 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -33,8 +33,6 @@ use strum_macros::EnumIter; // https://datafusion.apache.org/contributor-guide/index.html#how-to-add-a-new-aggregate-function #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)] pub enum AggregateFunction { - /// Count - Count, /// Minimum Min, /// Maximum @@ -47,24 +45,6 @@ pub enum AggregateFunction { NthValue, /// Correlation Correlation, - /// Slope from linear regression - RegrSlope, - /// Intercept from linear regression - RegrIntercept, - /// Number of input rows in which both expressions are not null - RegrCount, - /// R-squared value from linear regression - RegrR2, - /// Average of the independent variable - RegrAvgx, - /// Average of the dependent variable - RegrAvgy, - /// Sum of squares of the independent variable - RegrSXX, - /// Sum of squares of the dependent variable - RegrSYY, - /// Sum of products of pairs of numbers - RegrSXY, /// Approximate continuous percentile function ApproxPercentileCont, /// Approximate continuous percentile function with weight @@ -89,22 +69,12 @@ impl AggregateFunction { pub fn name(&self) -> &str { use AggregateFunction::*; match self { - Count => "COUNT", Min => "MIN", Max => "MAX", Avg => "AVG", ArrayAgg => "ARRAY_AGG", NthValue => "NTH_VALUE", Correlation => "CORR", - RegrSlope => "REGR_SLOPE", - RegrIntercept => "REGR_INTERCEPT", - RegrCount => "REGR_COUNT", - RegrR2 => "REGR_R2", - RegrAvgx => "REGR_AVGX", - RegrAvgy => "REGR_AVGY", - RegrSXX => "REGR_SXX", - RegrSYY => "REGR_SYY", - RegrSXY => "REGR_SXY", ApproxPercentileCont => "APPROX_PERCENTILE_CONT", ApproxPercentileContWithWeight => "APPROX_PERCENTILE_CONT_WITH_WEIGHT", Grouping => "GROUPING", @@ -135,7 +105,6 @@ impl FromStr for AggregateFunction { "bit_xor" => AggregateFunction::BitXor, "bool_and" => AggregateFunction::BoolAnd, "bool_or" => AggregateFunction::BoolOr, - "count" => AggregateFunction::Count, "max" => AggregateFunction::Max, "mean" => AggregateFunction::Avg, "min" => AggregateFunction::Min, @@ -144,15 +113,6 @@ impl FromStr for AggregateFunction { "string_agg" => AggregateFunction::StringAgg, // statistical "corr" => AggregateFunction::Correlation, - "regr_slope" => AggregateFunction::RegrSlope, - "regr_intercept" => AggregateFunction::RegrIntercept, - "regr_count" => AggregateFunction::RegrCount, - "regr_r2" => AggregateFunction::RegrR2, - "regr_avgx" => AggregateFunction::RegrAvgx, - "regr_avgy" => AggregateFunction::RegrAvgy, - "regr_sxx" => AggregateFunction::RegrSXX, - "regr_syy" => AggregateFunction::RegrSYY, - "regr_sxy" => AggregateFunction::RegrSXY, // approximate "approx_percentile_cont" => AggregateFunction::ApproxPercentileCont, "approx_percentile_cont_with_weight" => { @@ -190,7 +150,6 @@ impl AggregateFunction { })?; match self { - AggregateFunction::Count => Ok(DataType::Int64), AggregateFunction::Max | AggregateFunction::Min => { // For min and max agg function, the returned type is same as input type. // The coerced_data_types is same with input_types. @@ -205,15 +164,6 @@ impl AggregateFunction { AggregateFunction::Correlation => { correlation_return_type(&coerced_data_types[0]) } - AggregateFunction::RegrSlope - | AggregateFunction::RegrIntercept - | AggregateFunction::RegrCount - | AggregateFunction::RegrR2 - | AggregateFunction::RegrAvgx - | AggregateFunction::RegrAvgy - | AggregateFunction::RegrSXX - | AggregateFunction::RegrSYY - | AggregateFunction::RegrSXY => Ok(DataType::Float64), AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]), AggregateFunction::ArrayAgg => Ok(DataType::List(Arc::new(Field::new( "item", @@ -249,7 +199,6 @@ impl AggregateFunction { pub fn signature(&self) -> Signature { // note: the physical expression must accept the type returned by this function or the execution panics. match self { - AggregateFunction::Count => Signature::variadic_any(Volatility::Immutable), AggregateFunction::Grouping | AggregateFunction::ArrayAgg => { Signature::any(1, Volatility::Immutable) } @@ -278,16 +227,7 @@ impl AggregateFunction { Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) } AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable), - AggregateFunction::Correlation - | AggregateFunction::RegrSlope - | AggregateFunction::RegrIntercept - | AggregateFunction::RegrCount - | AggregateFunction::RegrR2 - | AggregateFunction::RegrAvgx - | AggregateFunction::RegrAvgy - | AggregateFunction::RegrSXX - | AggregateFunction::RegrSYY - | AggregateFunction::RegrSXY => { + AggregateFunction::Correlation => { Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable) } AggregateFunction::ApproxPercentileCont => { diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 98ab8ec251f4f..9ba866a4c9198 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1861,6 +1861,7 @@ fn write_name(w: &mut W, e: &Expr) -> Result<()> { null_treatment, }) => { write_function_name(w, &fun.to_string(), false, args)?; + if let Some(nt) = null_treatment { w.write_str(" ")?; write!(w, "{}", nt)?; @@ -1885,7 +1886,6 @@ fn write_name(w: &mut W, e: &Expr) -> Result<()> { null_treatment, }) => { write_function_name(w, func_def.name(), *distinct, args)?; - if let Some(fe) = filter { write!(w, " FILTER (WHERE {fe})")?; }; @@ -2135,18 +2135,6 @@ mod test { use super::*; - #[test] - fn test_count_return_type() -> Result<()> { - let fun = find_df_window_func("count").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Int64, observed); - - let observed = fun.return_type(&[DataType::UInt64])?; - assert_eq!(DataType::Int64, observed); - - Ok(()) - } - #[test] fn test_first_value_return_type() -> Result<()> { let fun = find_df_window_func("first_value").unwrap(); @@ -2250,7 +2238,6 @@ mod test { "nth_value", "min", "max", - "count", "avg", ]; for name in names { diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 4203120508708..fb5b3991ecd8d 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -192,18 +192,6 @@ pub fn avg(expr: Expr) -> Expr { )) } -/// Create an expression to represent the count() aggregate function -pub fn count(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Count, - vec![expr], - false, - None, - None, - None, - )) -} - /// Return a new expression with bitwise AND pub fn bitwise_and(left: Expr, right: Expr) -> Expr { Expr::BinaryExpr(BinaryExpr::new( @@ -249,18 +237,6 @@ pub fn bitwise_shift_left(left: Expr, right: Expr) -> Expr { )) } -/// Create an expression to represent the count(distinct) aggregate function -pub fn count_distinct(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Count, - vec![expr], - true, - None, - None, - None, - )) -} - /// Create an in_list expression pub fn in_list(expr: Expr, list: Vec, negated: bool) -> Expr { Expr::InList(InList::new(Box::new(expr), list, negated)) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 9ea2abe64edef..02378ab3fc1b9 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -2965,11 +2965,13 @@ mod tests { use super::*; use crate::builder::LogicalTableSource; use crate::logical_plan::table_scan; - use crate::{col, count, exists, in_subquery, lit, placeholder, GroupingSet}; + use crate::{col, exists, in_subquery, lit, placeholder, GroupingSet}; use datafusion_common::tree_node::TreeNodeVisitor; use datafusion_common::{not_impl_err, Constraint, ScalarValue}; + use crate::test::function_stub::count; + fn employee_schema() -> Schema { Schema::new(vec![ Field::new("id", DataType::Int32, false), diff --git a/datafusion/expr/src/test/function_stub.rs b/datafusion/expr/src/test/function_stub.rs index b9aa1e636d949..ac98ee9747cc1 100644 --- a/datafusion/expr/src/test/function_stub.rs +++ b/datafusion/expr/src/test/function_stub.rs @@ -31,7 +31,7 @@ use crate::{ use arrow::datatypes::{ DataType, Field, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, }; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{exec_err, not_impl_err, Result}; macro_rules! create_func { ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => { @@ -69,6 +69,19 @@ pub fn sum(expr: Expr) -> Expr { )) } +create_func!(Count, count_udaf); + +pub fn count(expr: Expr) -> Expr { + Expr::AggregateFunction(AggregateFunction::new_udf( + count_udaf(), + vec![expr], + false, + None, + None, + None, + )) +} + /// Stub `sum` used for optimizer testing #[derive(Debug)] pub struct Sum { @@ -189,3 +202,74 @@ impl AggregateUDFImpl for Sum { AggregateOrderSensitivity::Insensitive } } + +/// Testing stub implementation of COUNT aggregate +pub struct Count { + signature: Signature, + aliases: Vec, +} + +impl std::fmt::Debug for Count { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("Count") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for Count { + fn default() -> Self { + Self::new() + } +} + +impl Count { + pub fn new() -> Self { + Self { + aliases: vec!["count".to_string()], + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for Count { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "COUNT" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int64) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + not_impl_err!("no impl for stub") + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + not_impl_err!("no impl for stub") + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + not_impl_err!("no impl for stub") + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } +} diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index ab7deaff9885b..6c9a71bab46a9 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -96,7 +96,6 @@ pub fn coerce_types( check_arg_count(agg_fun.name(), input_types, &signature.type_signature)?; match agg_fun { - AggregateFunction::Count => Ok(input_types.to_vec()), AggregateFunction::ArrayAgg => Ok(input_types.to_vec()), AggregateFunction::Min | AggregateFunction::Max => { // min and max support the dictionary data type @@ -159,27 +158,6 @@ pub fn coerce_types( } Ok(vec![Float64, Float64]) } - AggregateFunction::RegrSlope - | AggregateFunction::RegrIntercept - | AggregateFunction::RegrCount - | AggregateFunction::RegrR2 - | AggregateFunction::RegrAvgx - | AggregateFunction::RegrAvgy - | AggregateFunction::RegrSXX - | AggregateFunction::RegrSYY - | AggregateFunction::RegrSXY => { - let valid_types = [NUMERICS.to_vec(), vec![Null]].concat(); - let input_types_valid = // number of input already checked before - valid_types.contains(&input_types[0]) && valid_types.contains(&input_types[1]); - if !input_types_valid { - return plan_err!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[0] - ); - } - Ok(vec![Float64, Float64]) - } AggregateFunction::ApproxPercentileCont => { if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) { return plan_err!( @@ -525,7 +503,6 @@ mod tests { // test count, array_agg, approx_distinct, min, max. // the coerced types is same with input types let funs = vec![ - AggregateFunction::Count, AggregateFunction::ArrayAgg, AggregateFunction::Min, AggregateFunction::Max, diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 71a3a5fe7309d..3ab0c180dcba9 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -818,7 +818,7 @@ pub(crate) fn find_column_indexes_referenced_by_expr( } } Expr::Literal(_) => { - indexes.push(std::usize::MAX); + indexes.push(usize::MAX); } _ => {} } diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs new file mode 100644 index 0000000000000..cfd56619537bd --- /dev/null +++ b/datafusion/functions-aggregate/src/count.rs @@ -0,0 +1,562 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use ahash::RandomState; +use std::collections::HashSet; +use std::ops::BitAnd; +use std::{fmt::Debug, sync::Arc}; + +use arrow::{ + array::{ArrayRef, AsArray}, + datatypes::{ + DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, + Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, + Time32MillisecondType, Time32SecondType, Time64MicrosecondType, + Time64NanosecondType, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, + UInt16Type, UInt32Type, UInt64Type, UInt8Type, + }, +}; + +use arrow::{ + array::{Array, BooleanArray, Int64Array, PrimitiveArray}, + buffer::BooleanBuffer, +}; +use datafusion_common::{ + downcast_value, internal_err, DataFusionError, Result, ScalarValue, +}; +use datafusion_expr::function::StateFieldsArgs; +use datafusion_expr::{ + function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, + EmitTo, GroupsAccumulator, Signature, Volatility, +}; +use datafusion_expr::{Expr, ReversedUDAF}; +use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::accumulate_indices; +use datafusion_physical_expr_common::{ + aggregate::count_distinct::{ + BytesDistinctCountAccumulator, FloatDistinctCountAccumulator, + PrimitiveDistinctCountAccumulator, + }, + binary_map::OutputType, +}; + +make_udaf_expr_and_func!( + Count, + count, + expr, + "Count the number of non-null values in the column", + count_udaf +); + +pub fn count_distinct(expr: Expr) -> datafusion_expr::Expr { + datafusion_expr::Expr::AggregateFunction( + datafusion_expr::expr::AggregateFunction::new_udf( + count_udaf(), + vec![expr], + true, + None, + None, + None, + ), + ) +} + +pub struct Count { + signature: Signature, + aliases: Vec, +} + +impl Debug for Count { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("Count") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for Count { + fn default() -> Self { + Self::new() + } +} + +impl Count { + pub fn new() -> Self { + Self { + aliases: vec!["count".to_string()], + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for Count { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "COUNT" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int64) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + if args.is_distinct { + Ok(vec![Field::new_list( + format_state_name(args.name, "count distinct"), + Field::new("item", args.input_type.clone(), true), + false, + )]) + } else { + Ok(vec![Field::new( + format_state_name(args.name, "count"), + DataType::Int64, + true, + )]) + } + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + if !acc_args.is_distinct { + return Ok(Box::new(CountAccumulator::new())); + } + + let data_type = acc_args.input_type; + Ok(match data_type { + // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator + DataType::Int8 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Int16 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Int32 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Int64 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::UInt8 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::UInt16 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::UInt32 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::UInt64 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< + Decimal128Type, + >::new(data_type)), + DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< + Decimal256Type, + >::new(data_type)), + + DataType::Date32 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Date64 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Time32(TimeUnit::Millisecond) => Box::new( + PrimitiveDistinctCountAccumulator::::new( + data_type, + ), + ), + DataType::Time32(TimeUnit::Second) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Time64(TimeUnit::Microsecond) => Box::new( + PrimitiveDistinctCountAccumulator::::new( + data_type, + ), + ), + DataType::Time64(TimeUnit::Nanosecond) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new( + data_type, + ), + ), + DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new( + data_type, + ), + ), + DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new( + data_type, + ), + ), + DataType::Timestamp(TimeUnit::Second, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + + DataType::Float16 => { + Box::new(FloatDistinctCountAccumulator::::new()) + } + DataType::Float32 => { + Box::new(FloatDistinctCountAccumulator::::new()) + } + DataType::Float64 => { + Box::new(FloatDistinctCountAccumulator::::new()) + } + + DataType::Utf8 => { + Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) + } + DataType::LargeUtf8 => { + Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) + } + DataType::Binary => Box::new(BytesDistinctCountAccumulator::::new( + OutputType::Binary, + )), + DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::::new( + OutputType::Binary, + )), + + // Use the generic accumulator based on `ScalarValue` for all other types + _ => Box::new(DistinctCountAccumulator { + values: HashSet::default(), + state_data_type: data_type.clone(), + }), + }) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + // groups accumulator only supports `COUNT(c1)`, not + // `COUNT(c1, c2)`, etc + if args.is_distinct { + return false; + } + args.args_num == 1 + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + // instantiate specialized accumulator + Ok(Box::new(CountGroupsAccumulator::new())) + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } +} + +#[derive(Debug)] +struct CountAccumulator { + count: i64, +} + +impl CountAccumulator { + /// new count accumulator + pub fn new() -> Self { + Self { count: 0 } + } +} + +impl Accumulator for CountAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![ScalarValue::Int64(Some(self.count))]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let array = &values[0]; + self.count += (array.len() - null_count_for_multiple_cols(values)) as i64; + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let array = &values[0]; + self.count -= (array.len() - null_count_for_multiple_cols(values)) as i64; + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let counts = downcast_value!(states[0], Int64Array); + let delta = &arrow::compute::sum(counts); + if let Some(d) = delta { + self.count += *d; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + Ok(ScalarValue::Int64(Some(self.count))) + } + + fn supports_retract_batch(&self) -> bool { + true + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } +} + +/// An accumulator to compute the counts of [`PrimitiveArray`]. +/// Stores values as native types, and does overflow checking +/// +/// Unlike most other accumulators, COUNT never produces NULLs. If no +/// non-null values are seen in any group the output is 0. Thus, this +/// accumulator has no additional null or seen filter tracking. +#[derive(Debug)] +struct CountGroupsAccumulator { + /// Count per group. + /// + /// Note this is an i64 and not a u64 (or usize) because the + /// output type of count is `DataType::Int64`. Thus by using `i64` + /// for the counts, the output [`Int64Array`] can be created + /// without copy. + counts: Vec, +} + +impl CountGroupsAccumulator { + pub fn new() -> Self { + Self { counts: vec![] } + } +} + +impl GroupsAccumulator for CountGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let values = &values[0]; + + // Add one to each group's counter for each non null, non + // filtered value + self.counts.resize(total_num_groups, 0); + accumulate_indices( + group_indices, + values.logical_nulls().as_ref(), + opt_filter, + |group_index| { + self.counts[group_index] += 1; + }, + ); + + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "one argument to merge_batch"); + // first batch is counts, second is partial sums + let partial_counts = values[0].as_primitive::(); + + // intermediate counts are always created as non null + assert_eq!(partial_counts.null_count(), 0); + let partial_counts = partial_counts.values(); + + // Adds the counts with the partial counts + self.counts.resize(total_num_groups, 0); + match opt_filter { + Some(filter) => filter + .iter() + .zip(group_indices.iter()) + .zip(partial_counts.iter()) + .for_each(|((filter_value, &group_index), partial_count)| { + if let Some(true) = filter_value { + self.counts[group_index] += partial_count; + } + }), + None => group_indices.iter().zip(partial_counts.iter()).for_each( + |(&group_index, partial_count)| { + self.counts[group_index] += partial_count; + }, + ), + } + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let counts = emit_to.take_needed(&mut self.counts); + + // Count is always non null (null inputs just don't contribute to the overall values) + let nulls = None; + let array = PrimitiveArray::::new(counts.into(), nulls); + + Ok(Arc::new(array)) + } + + // return arrays for counts + fn state(&mut self, emit_to: EmitTo) -> Result> { + let counts = emit_to.take_needed(&mut self.counts); + let counts: PrimitiveArray = Int64Array::from(counts); // zero copy, no nulls + Ok(vec![Arc::new(counts) as ArrayRef]) + } + + fn size(&self) -> usize { + self.counts.capacity() * std::mem::size_of::() + } +} + +/// count null values for multiple columns +/// for each row if one column value is null, then null_count + 1 +fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize { + if values.len() > 1 { + let result_bool_buf: Option = values + .iter() + .map(|a| a.logical_nulls()) + .fold(None, |acc, b| match (acc, b) { + (Some(acc), Some(b)) => Some(acc.bitand(b.inner())), + (Some(acc), None) => Some(acc), + (None, Some(b)) => Some(b.into_inner()), + _ => None, + }); + result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits()) + } else { + values[0] + .logical_nulls() + .map_or(0, |nulls| nulls.null_count()) + } +} + +/// General purpose distinct accumulator that works for any DataType by using +/// [`ScalarValue`]. +/// +/// It stores intermediate results as a `ListArray` +/// +/// Note that many types have specialized accumulators that are (much) +/// more efficient such as [`PrimitiveDistinctCountAccumulator`] and +/// [`BytesDistinctCountAccumulator`] +#[derive(Debug)] +struct DistinctCountAccumulator { + values: HashSet, + state_data_type: DataType, +} + +impl DistinctCountAccumulator { + // calculating the size for fixed length values, taking first batch size * + // number of batches This method is faster than .full_size(), however it is + // not suitable for variable length values like strings or complex types + fn fixed_size(&self) -> usize { + std::mem::size_of_val(self) + + (std::mem::size_of::() * self.values.capacity()) + + self + .values + .iter() + .next() + .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) + .unwrap_or(0) + + std::mem::size_of::() + } + + // calculates the size as accurately as possible. Note that calling this + // method is expensive + fn full_size(&self) -> usize { + std::mem::size_of_val(self) + + (std::mem::size_of::() * self.values.capacity()) + + self + .values + .iter() + .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) + .sum::() + + std::mem::size_of::() + } +} + +impl Accumulator for DistinctCountAccumulator { + /// Returns the distinct values seen so far as (one element) ListArray. + fn state(&mut self) -> Result> { + let scalars = self.values.iter().cloned().collect::>(); + let arr = ScalarValue::new_list(scalars.as_slice(), &self.state_data_type); + Ok(vec![ScalarValue::List(arr)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + let arr = &values[0]; + if arr.data_type() == &DataType::Null { + return Ok(()); + } + + (0..arr.len()).try_for_each(|index| { + if !arr.is_null(index) { + let scalar = ScalarValue::try_from_array(arr, index)?; + self.values.insert(scalar); + } + Ok(()) + }) + } + + /// Merges multiple sets of distinct values into the current set. + /// + /// The input to this function is a `ListArray` with **multiple** rows, + /// where each row contains the values from a partial aggregate's phase (e.g. + /// the result of calling `Self::state` on multiple accumulators). + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + assert_eq!(states.len(), 1, "array_agg states must be singleton!"); + let array = &states[0]; + let list_array = array.as_list::(); + for inner_array in list_array.iter() { + let Some(inner_array) = inner_array else { + return internal_err!( + "Intermediate results of COUNT DISTINCT should always be non null" + ); + }; + self.update_batch(&[inner_array])?; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + Ok(ScalarValue::Int64(Some(self.values.len() as i64))) + } + + fn size(&self) -> usize { + match &self.state_data_type { + DataType::Boolean | DataType::Null => self.fixed_size(), + d if d.is_primitive() => self.fixed_size(), + _ => self.full_size(), + } + } +} diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 2d062cf2cb9b4..fabe15e416f40 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -56,10 +56,12 @@ pub mod macros; pub mod approx_distinct; +pub mod count; pub mod covariance; pub mod first_last; pub mod hyperloglog; pub mod median; +pub mod regr; pub mod stddev; pub mod sum; pub mod variance; @@ -77,11 +79,22 @@ use std::sync::Arc; pub mod expr_fn { pub use super::approx_distinct; pub use super::approx_median::approx_median; + pub use super::count::count; + pub use super::count::count_distinct; pub use super::covariance::covar_pop; pub use super::covariance::covar_samp; pub use super::first_last::first_value; pub use super::first_last::last_value; pub use super::median::median; + pub use super::regr::regr_avgx; + pub use super::regr::regr_avgy; + pub use super::regr::regr_count; + pub use super::regr::regr_intercept; + pub use super::regr::regr_r2; + pub use super::regr::regr_slope; + pub use super::regr::regr_sxx; + pub use super::regr::regr_sxy; + pub use super::regr::regr_syy; pub use super::stddev::stddev; pub use super::stddev::stddev_pop; pub use super::sum::sum; @@ -98,6 +111,16 @@ pub fn all_default_aggregate_functions() -> Vec> { sum::sum_udaf(), covariance::covar_pop_udaf(), median::median_udaf(), + count::count_udaf(), + regr::regr_slope_udaf(), + regr::regr_intercept_udaf(), + regr::regr_count_udaf(), + regr::regr_r2_udaf(), + regr::regr_avgx_udaf(), + regr::regr_avgy_udaf(), + regr::regr_sxx_udaf(), + regr::regr_syy_udaf(), + regr::regr_sxy_udaf(), variance::var_samp_udaf(), variance::var_pop_udaf(), stddev::stddev_udaf(), @@ -133,8 +156,8 @@ mod tests { let mut names = HashSet::new(); for func in all_default_aggregate_functions() { // TODO: remove this - // sum is in intermidiate migration state, skip this - if func.name().to_lowercase() == "sum" { + // These functions are in intermidiate migration state, skip them + if func.name().to_lowercase() == "count" { continue; } assert!( diff --git a/datafusion/functions-aggregate/src/macros.rs b/datafusion/functions-aggregate/src/macros.rs index 75bb9dc54719d..cae72cf352238 100644 --- a/datafusion/functions-aggregate/src/macros.rs +++ b/datafusion/functions-aggregate/src/macros.rs @@ -32,8 +32,8 @@ // specific language governing permissions and limitations // under the License. -macro_rules! make_udaf_expr_and_func { - ($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { +macro_rules! make_udaf_expr { + ($EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { // "fluent expr_fn" style function #[doc = $DOC] pub fn $EXPR_FN( @@ -48,7 +48,12 @@ macro_rules! make_udaf_expr_and_func { None, )) } + }; +} +macro_rules! make_udaf_expr_and_func { + ($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { + make_udaf_expr!($EXPR_FN, $($arg)*, $DOC, $AGGREGATE_UDF_FN); create_func!($UDAF, $AGGREGATE_UDF_FN); }; ($UDAF:ty, $EXPR_FN:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { @@ -73,6 +78,9 @@ macro_rules! make_udaf_expr_and_func { macro_rules! create_func { ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => { + create_func!($UDAF, $AGGREGATE_UDF_FN, <$UDAF>::default()); + }; + ($UDAF:ty, $AGGREGATE_UDF_FN:ident, $CREATE:expr) => { paste::paste! { /// Singleton instance of [$UDAF], ensures the UDAF is only created once /// named STATIC_$(UDAF). For example `STATIC_FirstValue` @@ -86,7 +94,7 @@ macro_rules! create_func { pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc { [< STATIC_ $UDAF >] .get_or_init(|| { - std::sync::Arc::new(datafusion_expr::AggregateUDF::from(<$UDAF>::default())) + std::sync::Arc::new(datafusion_expr::AggregateUDF::from($CREATE)) }) .clone() } diff --git a/datafusion/physical-expr/src/aggregate/regr.rs b/datafusion/functions-aggregate/src/regr.rs similarity index 84% rename from datafusion/physical-expr/src/aggregate/regr.rs rename to datafusion/functions-aggregate/src/regr.rs index 36e7b7c9b3e43..8d04ae87157d4 100644 --- a/datafusion/physical-expr/src/aggregate/regr.rs +++ b/datafusion/functions-aggregate/src/regr.rs @@ -18,9 +18,8 @@ //! Defines physical expressions that can evaluated at runtime during query execution use std::any::Any; -use std::sync::Arc; +use std::fmt::Debug; -use crate::{AggregateExpr, PhysicalExpr}; use arrow::array::Float64Array; use arrow::{ array::{ArrayRef, UInt64Array}, @@ -28,13 +27,56 @@ use arrow::{ datatypes::DataType, datatypes::Field, }; -use datafusion_common::{downcast_value, unwrap_or_internal_err, ScalarValue}; +use datafusion_common::{downcast_value, plan_err, unwrap_or_internal_err, ScalarValue}; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::Accumulator; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::type_coercion::aggregates::NUMERICS; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; + +macro_rules! make_regr_udaf_expr_and_func { + ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $REGR_TYPE:expr) => { + make_udaf_expr!($EXPR_FN, expr_y expr_x, concat!("Compute a linear regression of type [", stringify!($REGR_TYPE), "]"), $AGGREGATE_UDF_FN); + create_func!($EXPR_FN, $AGGREGATE_UDF_FN, Regr::new($REGR_TYPE, stringify!($EXPR_FN))); + } +} + +make_regr_udaf_expr_and_func!(regr_slope, regr_slope_udaf, RegrType::Slope); +make_regr_udaf_expr_and_func!(regr_intercept, regr_intercept_udaf, RegrType::Intercept); +make_regr_udaf_expr_and_func!(regr_count, regr_count_udaf, RegrType::Count); +make_regr_udaf_expr_and_func!(regr_r2, regr_r2_udaf, RegrType::R2); +make_regr_udaf_expr_and_func!(regr_avgx, regr_avgx_udaf, RegrType::AvgX); +make_regr_udaf_expr_and_func!(regr_avgy, regr_avgy_udaf, RegrType::AvgY); +make_regr_udaf_expr_and_func!(regr_sxx, regr_sxx_udaf, RegrType::SXX); +make_regr_udaf_expr_and_func!(regr_syy, regr_syy_udaf, RegrType::SYY); +make_regr_udaf_expr_and_func!(regr_sxy, regr_sxy_udaf, RegrType::SXY); + +pub struct Regr { + signature: Signature, + regr_type: RegrType, + func_name: &'static str, +} -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; +impl Debug for Regr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("regr") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} +impl Regr { + pub fn new(regr_type: RegrType, func_name: &'static str) -> Self { + Self { + signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), + regr_type, + func_name, + } + } +} + +/* #[derive(Debug)] pub struct Regr { name: String, @@ -48,6 +90,7 @@ impl Regr { self.regr_type.clone() } } +*/ #[derive(Debug, Clone)] #[allow(clippy::upper_case_acronyms)] @@ -92,86 +135,75 @@ pub enum RegrType { SXY, } -impl Regr { - pub fn new( - expr_y: Arc, - expr_x: Arc, - name: impl Into, - regr_type: RegrType, - return_type: DataType, - ) -> Self { - // the result of regr_slope only support FLOAT64 data type. - assert!(matches!(return_type, DataType::Float64)); - Self { - name: name.into(), - regr_type, - expr_y, - expr_x, - } - } -} - -impl AggregateExpr for Regr { +impl AggregateUDFImpl for Regr { fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64, true)) + fn name(&self) -> &str { + self.func_name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if !arg_types[0].is_numeric() { + return plan_err!("Covariance requires numeric input types"); + } + + Ok(DataType::Float64) } - fn create_accumulator(&self) -> Result> { + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?)) } - fn create_sliding_accumulator(&self) -> Result> { + fn create_sliding_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?)) } - fn state_fields(&self) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( - format_state_name(&self.name, "count"), + format_state_name(args.name, "count"), DataType::UInt64, true, ), Field::new( - format_state_name(&self.name, "mean_x"), + format_state_name(args.name, "mean_x"), DataType::Float64, true, ), Field::new( - format_state_name(&self.name, "mean_y"), + format_state_name(args.name, "mean_y"), DataType::Float64, true, ), Field::new( - format_state_name(&self.name, "m2_x"), + format_state_name(args.name, "m2_x"), DataType::Float64, true, ), Field::new( - format_state_name(&self.name, "m2_y"), + format_state_name(args.name, "m2_y"), DataType::Float64, true, ), Field::new( - format_state_name(&self.name, "algo_const"), + format_state_name(args.name, "algo_const"), DataType::Float64, true, ), ]) } - - fn expressions(&self) -> Vec> { - vec![self.expr_y.clone(), self.expr_x.clone()] - } - - fn name(&self) -> &str { - &self.name - } } +/* impl PartialEq for Regr { fn eq(&self, other: &dyn Any) -> bool { down_cast_any_ref(other) @@ -184,6 +216,7 @@ impl PartialEq for Regr { .unwrap_or(false) } } +*/ /// `RegrAccumulator` is used to compute linear regression aggregate functions /// by maintaining statistics needed to compute them in an online fashion. @@ -305,6 +338,10 @@ impl Accumulator for RegrAccumulator { Ok(()) } + fn supports_retract_batch(&self) -> bool { + true + } + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values_y = &cast(&values[0], &DataType::Float64)?; let values_x = &cast(&values[1], &DataType::Float64)?; diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index cb14f6bdd4a37..1a9e9630c0768 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -56,5 +56,6 @@ regex-syntax = "0.8.0" [dev-dependencies] arrow-buffer = { workspace = true } ctor = { workspace = true } +datafusion-functions-aggregate = { workspace = true } datafusion-sql = { workspace = true } env_logger = { workspace = true } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index af1c99c52390a..de2af520053a2 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -25,9 +25,7 @@ use datafusion_expr::expr::{ AggregateFunction, AggregateFunctionDefinition, WindowFunction, }; use datafusion_expr::utils::COUNT_STAR_EXPANSION; -use datafusion_expr::{ - aggregate_function, lit, Expr, LogicalPlan, WindowFunctionDefinition, -}; +use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition}; /// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`. /// @@ -56,37 +54,19 @@ fn is_wildcard(expr: &Expr) -> bool { } fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool { - match aggregate_function { + matches!(aggregate_function, AggregateFunction { func_def: AggregateFunctionDefinition::UDF(udf), args, .. - } if udf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0]) => true, - AggregateFunction { - func_def: - AggregateFunctionDefinition::BuiltIn( - datafusion_expr::aggregate_function::AggregateFunction::Count, - ), - args, - .. - } if args.len() == 1 && is_wildcard(&args[0]) => true, - _ => false, - } + } if udf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0])) } fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool { let args = &window_function.args; - match window_function.fun { - WindowFunctionDefinition::AggregateFunction( - aggregate_function::AggregateFunction::Count, - ) if args.len() == 1 && is_wildcard(&args[0]) => true, + matches!(window_function.fun, WindowFunctionDefinition::AggregateUDF(ref udaf) - if udaf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0]) => - { - true - } - _ => false, - } + if udaf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0])) } fn analyze_internal(plan: LogicalPlan) -> Result> { @@ -121,14 +101,16 @@ mod tests { use arrow::datatypes::DataType; use datafusion_common::ScalarValue; use datafusion_expr::expr::Sort; - use datafusion_expr::test::function_stub::sum; use datafusion_expr::{ - col, count, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, max, - out_ref_col, scalar_subquery, wildcard, AggregateFunction, WindowFrame, - WindowFrameBound, WindowFrameUnits, + col, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, max, + out_ref_col, scalar_subquery, wildcard, WindowFrame, WindowFrameBound, + WindowFrameUnits, }; + use datafusion_functions_aggregate::count::count_udaf; use std::sync::Arc; + use datafusion_functions_aggregate::expr_fn::{count, sum}; + fn assert_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { assert_analyzed_plan_eq_display_indent( Arc::new(CountWildcardRule::new()), @@ -239,7 +221,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .window(vec![Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], vec![], vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index b55b1a7f8f2df..e949e1921b972 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -432,8 +432,11 @@ fn agg_exprs_evaluation_result_on_empty_batch( Expr::AggregateFunction(expr::AggregateFunction { func_def, .. }) => match func_def { - AggregateFunctionDefinition::BuiltIn(fun) => { - if matches!(fun, datafusion_expr::AggregateFunction::Count) { + AggregateFunctionDefinition::BuiltIn(_fun) => { + Transformed::yes(Expr::Literal(ScalarValue::Null)) + } + AggregateFunctionDefinition::UDF(fun) => { + if fun.name() == "COUNT" { Transformed::yes(Expr::Literal(ScalarValue::Int64(Some( 0, )))) @@ -441,9 +444,6 @@ fn agg_exprs_evaluation_result_on_empty_batch( Transformed::yes(Expr::Literal(ScalarValue::Null)) } } - AggregateFunctionDefinition::UDF { .. } => { - Transformed::yes(Expr::Literal(ScalarValue::Null)) - } }, _ => Transformed::no(expr), }; diff --git a/datafusion/optimizer/src/eliminate_group_by_constant.rs b/datafusion/optimizer/src/eliminate_group_by_constant.rs index cef226d67b6c7..7a8dd7aac2497 100644 --- a/datafusion/optimizer/src/eliminate_group_by_constant.rs +++ b/datafusion/optimizer/src/eliminate_group_by_constant.rs @@ -129,10 +129,12 @@ mod tests { use datafusion_common::Result; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ - col, count, lit, ColumnarValue, LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, - Signature, TypeSignature, + col, lit, ColumnarValue, LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, + TypeSignature, }; + use datafusion_functions_aggregate::expr_fn::count; + use std::sync::Arc; #[derive(Debug)] diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index af51814c96862..11540d3e162e4 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -818,10 +818,11 @@ mod tests { use datafusion_common::{ Column, DFSchema, DFSchemaRef, JoinType, Result, TableReference, }; + use datafusion_expr::AggregateExt; use datafusion_expr::{ binary_expr, build_join_schema, builder::table_scan_with_filters, - col, count, + col, expr::{self, Cast}, lit, logical_plan::{builder::LogicalPlanBuilder, table_scan}, @@ -830,6 +831,9 @@ mod tests { WindowFunctionDefinition, }; + use datafusion_functions_aggregate::count::count_udaf; + use datafusion_functions_aggregate::expr_fn::count; + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan, expected) } @@ -1886,16 +1890,10 @@ mod tests { #[test] fn aggregate_filter_pushdown() -> Result<()> { let table_scan = test_table_scan()?; - - let aggr_with_filter = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("b")], - false, - Some(Box::new(col("c").gt(lit(42)))), - None, - None, - )); - + let aggr_with_filter = count_udaf() + .call(vec![col("b")]) + .filter(col("c").gt(lit(42))) + .build()?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate( vec![col("a")], diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 32b6703bcae59..d3d22eb53f395 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -361,13 +361,14 @@ impl OptimizerRule for SingleDistinctToGroupBy { mod tests { use super::*; use crate::test::*; - use datafusion_expr::expr; - use datafusion_expr::expr::GroupingSet; - use datafusion_expr::test::function_stub::{sum, sum_udaf}; + use datafusion_expr::expr::{self, GroupingSet}; + use datafusion_expr::AggregateExt; use datafusion_expr::{ - count, count_distinct, lit, logical_plan::builder::LogicalPlanBuilder, max, min, - AggregateFunction, + lit, logical_plan::builder::LogicalPlanBuilder, max, min, AggregateFunction, }; + use datafusion_functions_aggregate::count::count_udaf; + use datafusion_functions_aggregate::expr_fn::{count, count_distinct, sum}; + use datafusion_functions_aggregate::sum::sum_udaf; fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( @@ -680,14 +681,11 @@ mod tests { let table_scan = test_table_scan()?; // COUNT(DISTINCT a) FILTER (WHERE a > 5) - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("a")], - true, - Some(Box::new(col("a").gt(lit(5)))), - None, - None, - )); + let expr = count_udaf() + .call(vec![col("a")]) + .distinct() + .filter(col("a").gt(lit(5))) + .build()?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; @@ -726,19 +724,16 @@ mod tests { let table_scan = test_table_scan()?; // COUNT(DISTINCT a ORDER BY a) - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("a")], - true, - None, - Some(vec![col("a")]), - None, - )); + let expr = count_udaf() + .call(vec![col("a")]) + .distinct() + .order_by(vec![col("a").sort(true, false)]) + .build()?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), COUNT(DISTINCT test.a) ORDER BY [test.a]]] [c:UInt32, sum(test.a):UInt64;N, COUNT(DISTINCT test.a) ORDER BY [test.a]:Int64;N]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), COUNT(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, COUNT(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]:Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -749,19 +744,17 @@ mod tests { let table_scan = test_table_scan()?; // COUNT(DISTINCT a ORDER BY a) FILTER (WHERE a > 5) - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("a")], - true, - Some(Box::new(col("a").gt(lit(5)))), - Some(vec![col("a")]), - None, - )); + let expr = count_udaf() + .call(vec![col("a")]) + .distinct() + .filter(col("a").gt(lit(5))) + .order_by(vec![col("a").sort(true, false)]) + .build()?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]]] [c:UInt32, sum(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]:Int64;N]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]:Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index b3501cca9efa9..f60bf6609005c 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -25,6 +25,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::{plan_err, Result}; use datafusion_expr::test::function_stub::sum_udaf; use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF}; +use datafusion_functions_aggregate::count::count_udaf; use datafusion_optimizer::analyzer::Analyzer; use datafusion_optimizer::optimizer::Optimizer; use datafusion_optimizer::{OptimizerConfig, OptimizerContext, OptimizerRule}; @@ -323,7 +324,9 @@ fn test_sql(sql: &str) -> Result { let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ... let ast: Vec = Parser::parse_sql(&dialect, sql).unwrap(); let statement = &ast[0]; - let context_provider = MyContextProvider::default().with_udaf(sum_udaf()); + let context_provider = MyContextProvider::default() + .with_udaf(sum_udaf()) + .with_udaf(count_udaf()); let sql_to_rel = SqlToRel::new(&context_provider); let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); @@ -345,7 +348,8 @@ struct MyContextProvider { impl MyContextProvider { fn with_udaf(mut self, udaf: Arc) -> Self { - self.udafs.insert(udaf.name().to_string(), udaf); + // TODO: change to to_string() if all the function name is converted to lowercase + self.udafs.insert(udaf.name().to_lowercase(), udaf); self } } diff --git a/datafusion/physical-expr-common/Cargo.toml b/datafusion/physical-expr-common/Cargo.toml index 637b8775112ea..3ef2d53455339 100644 --- a/datafusion/physical-expr-common/Cargo.toml +++ b/datafusion/physical-expr-common/Cargo.toml @@ -36,7 +36,9 @@ name = "datafusion_physical_expr_common" path = "src/lib.rs" [dependencies] +ahash = { workspace = true } arrow = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } +hashbrown = { workspace = true } rand = { workspace = true } diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/bytes.rs b/datafusion/physical-expr-common/src/aggregate/count_distinct/bytes.rs similarity index 93% rename from datafusion/physical-expr/src/aggregate/count_distinct/bytes.rs rename to datafusion/physical-expr-common/src/aggregate/count_distinct/bytes.rs index 2ed9b002c8415..5c888ca66caa6 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct/bytes.rs +++ b/datafusion/physical-expr-common/src/aggregate/count_distinct/bytes.rs @@ -18,7 +18,7 @@ //! [`BytesDistinctCountAccumulator`] for Utf8/LargeUtf8/Binary/LargeBinary values use crate::binary_map::{ArrowBytesSet, OutputType}; -use arrow_array::{ArrayRef, OffsetSizeTrait}; +use arrow::array::{ArrayRef, OffsetSizeTrait}; use datafusion_common::cast::as_list_array; use datafusion_common::utils::array_into_list_array; use datafusion_common::ScalarValue; @@ -35,10 +35,10 @@ use std::sync::Arc; /// [`BinaryArray`]: arrow::array::BinaryArray /// [`LargeBinaryArray`]: arrow::array::LargeBinaryArray #[derive(Debug)] -pub(super) struct BytesDistinctCountAccumulator(ArrowBytesSet); +pub struct BytesDistinctCountAccumulator(ArrowBytesSet); impl BytesDistinctCountAccumulator { - pub(super) fn new(output_type: OutputType) -> Self { + pub fn new(output_type: OutputType) -> Self { Self(ArrowBytesSet::new(output_type)) } } diff --git a/datafusion/physical-expr-common/src/aggregate/count_distinct/mod.rs b/datafusion/physical-expr-common/src/aggregate/count_distinct/mod.rs new file mode 100644 index 0000000000000..f216406d0dd74 --- /dev/null +++ b/datafusion/physical-expr-common/src/aggregate/count_distinct/mod.rs @@ -0,0 +1,23 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod bytes; +mod native; + +pub use bytes::BytesDistinctCountAccumulator; +pub use native::FloatDistinctCountAccumulator; +pub use native::PrimitiveDistinctCountAccumulator; diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/native.rs b/datafusion/physical-expr-common/src/aggregate/count_distinct/native.rs similarity index 93% rename from datafusion/physical-expr/src/aggregate/count_distinct/native.rs rename to datafusion/physical-expr-common/src/aggregate/count_distinct/native.rs index 0e7483d4a1cd9..72b83676e81d9 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct/native.rs +++ b/datafusion/physical-expr-common/src/aggregate/count_distinct/native.rs @@ -26,10 +26,10 @@ use std::hash::Hash; use std::sync::Arc; use ahash::RandomState; +use arrow::array::types::ArrowPrimitiveType; use arrow::array::ArrayRef; -use arrow_array::types::ArrowPrimitiveType; -use arrow_array::PrimitiveArray; -use arrow_schema::DataType; +use arrow::array::PrimitiveArray; +use arrow::datatypes::DataType; use datafusion_common::cast::{as_list_array, as_primitive_array}; use datafusion_common::utils::array_into_list_array; @@ -40,7 +40,7 @@ use datafusion_expr::Accumulator; use crate::aggregate::utils::Hashable; #[derive(Debug)] -pub(super) struct PrimitiveDistinctCountAccumulator +pub struct PrimitiveDistinctCountAccumulator where T: ArrowPrimitiveType + Send, T::Native: Eq + Hash, @@ -54,7 +54,7 @@ where T: ArrowPrimitiveType + Send, T::Native: Eq + Hash, { - pub(super) fn new(data_type: &DataType) -> Self { + pub fn new(data_type: &DataType) -> Self { Self { values: HashSet::default(), data_type: data_type.clone(), @@ -125,7 +125,7 @@ where } #[derive(Debug)] -pub(super) struct FloatDistinctCountAccumulator +pub struct FloatDistinctCountAccumulator where T: ArrowPrimitiveType + Send, { @@ -136,13 +136,22 @@ impl FloatDistinctCountAccumulator where T: ArrowPrimitiveType + Send, { - pub(super) fn new() -> Self { + pub fn new() -> Self { Self { values: HashSet::default(), } } } +impl Default for FloatDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send, +{ + fn default() -> Self { + Self::new() + } +} + impl Accumulator for FloatDistinctCountAccumulator where T: ArrowPrimitiveType + Send + Debug, diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index ec02df57b82d4..21884f840dbdb 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +pub mod count_distinct; pub mod groups_accumulator; pub mod stats; pub mod tdigest; diff --git a/datafusion/physical-expr/src/binary_map.rs b/datafusion/physical-expr-common/src/binary_map.rs similarity index 98% rename from datafusion/physical-expr/src/binary_map.rs rename to datafusion/physical-expr-common/src/binary_map.rs index 0923fcdaeb91b..6d5ba737a1df5 100644 --- a/datafusion/physical-expr/src/binary_map.rs +++ b/datafusion/physical-expr-common/src/binary_map.rs @@ -19,17 +19,16 @@ //! StringArray / LargeStringArray / BinaryArray / LargeBinaryArray. use ahash::RandomState; -use arrow_array::cast::AsArray; -use arrow_array::types::{ByteArrayType, GenericBinaryType, GenericStringType}; -use arrow_array::{ - Array, ArrayRef, GenericBinaryArray, GenericStringArray, OffsetSizeTrait, +use arrow::array::cast::AsArray; +use arrow::array::types::{ByteArrayType, GenericBinaryType, GenericStringType}; +use arrow::array::{ + Array, ArrayRef, BooleanBufferBuilder, BufferBuilder, GenericBinaryArray, + GenericStringArray, OffsetSizeTrait, }; -use arrow_buffer::{ - BooleanBufferBuilder, BufferBuilder, NullBuffer, OffsetBuffer, ScalarBuffer, -}; -use arrow_schema::DataType; +use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; +use arrow::datatypes::DataType; use datafusion_common::hash_utils::create_hashes; -use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; +use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt}; use std::any::type_name; use std::fmt::Debug; use std::mem; @@ -605,8 +604,8 @@ where #[cfg(test)] mod tests { use super::*; - use arrow_array::{BinaryArray, LargeBinaryArray, StringArray}; - use hashbrown::HashMap; + use arrow::array::{BinaryArray, LargeBinaryArray, StringArray}; + use std::collections::HashMap; #[test] fn string_set_empty() { diff --git a/datafusion/physical-expr-common/src/lib.rs b/datafusion/physical-expr-common/src/lib.rs index f335958698ab2..0ddb84141a073 100644 --- a/datafusion/physical-expr-common/src/lib.rs +++ b/datafusion/physical-expr-common/src/lib.rs @@ -16,6 +16,7 @@ // under the License. pub mod aggregate; +pub mod binary_map; pub mod expressions; pub mod physical_expr; pub mod sort_expr; diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index ac24dd2e7603c..df87a2e261a10 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -30,12 +30,12 @@ use std::sync::Arc; use arrow::datatypes::Schema; +use datafusion_common::{exec_err, not_impl_err, Result}; +use datafusion_expr::AggregateFunction; + use crate::aggregate::average::Avg; -use crate::aggregate::regr::RegrType; use crate::expressions::{self, Literal}; use crate::{AggregateExpr, PhysicalExpr, PhysicalSortExpr}; -use datafusion_common::{exec_err, not_impl_err, Result}; -use datafusion_expr::AggregateFunction; /// Create a physical aggregation expression. /// This function errors when `input_phy_exprs`' can't be coerced to a valid argument type of the aggregation function. pub fn create_aggregate_expr( @@ -60,14 +60,6 @@ pub fn create_aggregate_expr( .collect::>>()?; let input_phy_exprs = input_phy_exprs.to_vec(); Ok(match (fun, distinct) { - (AggregateFunction::Count, false) => Arc::new( - expressions::Count::new_with_multiple_exprs(input_phy_exprs, name, data_type), - ), - (AggregateFunction::Count, true) => Arc::new(expressions::DistinctCount::new( - data_type, - input_phy_exprs[0].clone(), - name, - )), (AggregateFunction::Grouping, _) => Arc::new(expressions::Grouping::new( input_phy_exprs[0].clone(), name, @@ -162,83 +154,6 @@ pub fn create_aggregate_expr( (AggregateFunction::Correlation, true) => { return not_impl_err!("CORR(DISTINCT) aggregations are not available"); } - (AggregateFunction::RegrSlope, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::Slope, - data_type, - )), - (AggregateFunction::RegrIntercept, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::Intercept, - data_type, - )), - (AggregateFunction::RegrCount, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::Count, - data_type, - )), - (AggregateFunction::RegrR2, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::R2, - data_type, - )), - (AggregateFunction::RegrAvgx, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::AvgX, - data_type, - )), - (AggregateFunction::RegrAvgy, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::AvgY, - data_type, - )), - (AggregateFunction::RegrSXX, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::SXX, - data_type, - )), - (AggregateFunction::RegrSYY, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::SYY, - data_type, - )), - (AggregateFunction::RegrSXY, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::SXY, - data_type, - )), - ( - AggregateFunction::RegrSlope - | AggregateFunction::RegrIntercept - | AggregateFunction::RegrCount - | AggregateFunction::RegrR2 - | AggregateFunction::RegrAvgx - | AggregateFunction::RegrAvgy - | AggregateFunction::RegrSXX - | AggregateFunction::RegrSYY - | AggregateFunction::RegrSXY, - true, - ) => { - return not_impl_err!("{}(DISTINCT) aggregations are not available", fun); - } (AggregateFunction::ApproxPercentileCont, false) => { if input_phy_exprs.len() == 2 { Arc::new(expressions::ApproxPercentileCont::new( @@ -320,7 +235,7 @@ mod tests { use super::*; use crate::expressions::{ try_cast, ApproxPercentileCont, ArrayAgg, Avg, BitAnd, BitOr, BitXor, BoolAnd, - BoolOr, Count, DistinctArrayAgg, DistinctCount, Max, Min, + BoolOr, DistinctArrayAgg, Max, Min, }; use datafusion_common::{plan_err, DataFusionError, ScalarValue}; @@ -328,8 +243,8 @@ mod tests { use datafusion_expr::{type_coercion, Signature}; #[test] - fn test_count_arragg_approx_expr() -> Result<()> { - let funcs = vec![AggregateFunction::Count, AggregateFunction::ArrayAgg]; + fn test_approx_expr() -> Result<()> { + let funcs = vec![AggregateFunction::ArrayAgg]; let data_types = vec![ DataType::UInt32, DataType::Int32, @@ -352,29 +267,18 @@ mod tests { &input_schema, "c1", )?; - match fun { - AggregateFunction::Count => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", DataType::Int64, true), - result_agg_phy_exprs.field().unwrap() - ); - } - AggregateFunction::ArrayAgg => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new_list( - "c1", - Field::new("item", data_type.clone(), true), - true, - ), - result_agg_phy_exprs.field().unwrap() - ); - } - _ => {} - }; + if fun == AggregateFunction::ArrayAgg { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new_list( + "c1", + Field::new("item", data_type.clone(), true), + true, + ), + result_agg_phy_exprs.field().unwrap() + ); + } let result_distinct = create_physical_agg_expr_for_test( &fun, @@ -383,29 +287,18 @@ mod tests { &input_schema, "c1", )?; - match fun { - AggregateFunction::Count => { - assert!(result_distinct.as_any().is::()); - assert_eq!("c1", result_distinct.name()); - assert_eq!( - Field::new("c1", DataType::Int64, true), - result_distinct.field().unwrap() - ); - } - AggregateFunction::ArrayAgg => { - assert!(result_distinct.as_any().is::()); - assert_eq!("c1", result_distinct.name()); - assert_eq!( - Field::new_list( - "c1", - Field::new("item", data_type.clone(), true), - true, - ), - result_agg_phy_exprs.field().unwrap() - ); - } - _ => {} - }; + if fun == AggregateFunction::ArrayAgg { + assert!(result_distinct.as_any().is::()); + assert_eq!("c1", result_distinct.name()); + assert_eq!( + Field::new_list( + "c1", + Field::new("item", data_type.clone(), true), + true, + ), + result_agg_phy_exprs.field().unwrap() + ); + } } } Ok(()) @@ -668,20 +561,6 @@ mod tests { Ok(()) } - #[test] - fn test_count_return_type() -> Result<()> { - let observed = AggregateFunction::Count.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Int64, observed); - - let observed = AggregateFunction::Count.return_type(&[DataType::Int8])?; - assert_eq!(DataType::Int64, observed); - - let observed = - AggregateFunction::Count.return_type(&[DataType::Decimal128(28, 13)])?; - assert_eq!(DataType::Int64, observed); - Ok(()) - } - #[test] fn test_avg_return_type() -> Result<()> { let observed = AggregateFunction::Avg.return_type(&[DataType::Float32])?; diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs deleted file mode 100644 index aad18a82ab878..0000000000000 --- a/datafusion/physical-expr/src/aggregate/count.rs +++ /dev/null @@ -1,348 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions that can evaluated at runtime during query execution - -use std::any::Any; -use std::fmt::Debug; -use std::ops::BitAnd; -use std::sync::Arc; - -use crate::aggregate::utils::down_cast_any_ref; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::array::{Array, Int64Array}; -use arrow::compute; -use arrow::datatypes::DataType; -use arrow::{array::ArrayRef, datatypes::Field}; -use arrow_array::cast::AsArray; -use arrow_array::types::Int64Type; -use arrow_array::PrimitiveArray; -use arrow_buffer::BooleanBuffer; -use datafusion_common::{downcast_value, ScalarValue}; -use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::{Accumulator, EmitTo, GroupsAccumulator}; - -use crate::expressions::format_state_name; - -use super::groups_accumulator::accumulate::accumulate_indices; - -/// COUNT aggregate expression -/// Returns the amount of non-null values of the given expression. -#[derive(Debug, Clone)] -pub struct Count { - name: String, - data_type: DataType, - nullable: bool, - /// Input exprs - /// - /// For `COUNT(c1)` this is `[c1]` - /// For `COUNT(c1, c2)` this is `[c1, c2]` - exprs: Vec>, -} - -impl Count { - /// Create a new COUNT aggregate function. - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - exprs: vec![expr], - data_type, - nullable: true, - } - } - - pub fn new_with_multiple_exprs( - exprs: Vec>, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - exprs, - data_type, - nullable: true, - } - } -} - -/// An accumulator to compute the counts of [`PrimitiveArray`]. -/// Stores values as native types, and does overflow checking -/// -/// Unlike most other accumulators, COUNT never produces NULLs. If no -/// non-null values are seen in any group the output is 0. Thus, this -/// accumulator has no additional null or seen filter tracking. -#[derive(Debug)] -struct CountGroupsAccumulator { - /// Count per group. - /// - /// Note this is an i64 and not a u64 (or usize) because the - /// output type of count is `DataType::Int64`. Thus by using `i64` - /// for the counts, the output [`Int64Array`] can be created - /// without copy. - counts: Vec, -} - -impl CountGroupsAccumulator { - pub fn new() -> Self { - Self { counts: vec![] } - } -} - -impl GroupsAccumulator for CountGroupsAccumulator { - fn update_batch( - &mut self, - values: &[ArrayRef], - group_indices: &[usize], - opt_filter: Option<&arrow_array::BooleanArray>, - total_num_groups: usize, - ) -> Result<()> { - assert_eq!(values.len(), 1, "single argument to update_batch"); - let values = &values[0]; - - // Add one to each group's counter for each non null, non - // filtered value - self.counts.resize(total_num_groups, 0); - accumulate_indices( - group_indices, - values.logical_nulls().as_ref(), - opt_filter, - |group_index| { - self.counts[group_index] += 1; - }, - ); - - Ok(()) - } - - fn merge_batch( - &mut self, - values: &[ArrayRef], - group_indices: &[usize], - opt_filter: Option<&arrow_array::BooleanArray>, - total_num_groups: usize, - ) -> Result<()> { - assert_eq!(values.len(), 1, "one argument to merge_batch"); - // first batch is counts, second is partial sums - let partial_counts = values[0].as_primitive::(); - - // intermediate counts are always created as non null - assert_eq!(partial_counts.null_count(), 0); - let partial_counts = partial_counts.values(); - - // Adds the counts with the partial counts - self.counts.resize(total_num_groups, 0); - match opt_filter { - Some(filter) => filter - .iter() - .zip(group_indices.iter()) - .zip(partial_counts.iter()) - .for_each(|((filter_value, &group_index), partial_count)| { - if let Some(true) = filter_value { - self.counts[group_index] += partial_count; - } - }), - None => group_indices.iter().zip(partial_counts.iter()).for_each( - |(&group_index, partial_count)| { - self.counts[group_index] += partial_count; - }, - ), - } - - Ok(()) - } - - fn evaluate(&mut self, emit_to: EmitTo) -> Result { - let counts = emit_to.take_needed(&mut self.counts); - - // Count is always non null (null inputs just don't contribute to the overall values) - let nulls = None; - let array = PrimitiveArray::::new(counts.into(), nulls); - - Ok(Arc::new(array)) - } - - // return arrays for counts - fn state(&mut self, emit_to: EmitTo) -> Result> { - let counts = emit_to.take_needed(&mut self.counts); - let counts: PrimitiveArray = Int64Array::from(counts); // zero copy, no nulls - Ok(vec![Arc::new(counts) as ArrayRef]) - } - - fn size(&self) -> usize { - self.counts.capacity() * std::mem::size_of::() - } -} - -/// count null values for multiple columns -/// for each row if one column value is null, then null_count + 1 -fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize { - if values.len() > 1 { - let result_bool_buf: Option = values - .iter() - .map(|a| a.logical_nulls()) - .fold(None, |acc, b| match (acc, b) { - (Some(acc), Some(b)) => Some(acc.bitand(b.inner())), - (Some(acc), None) => Some(acc), - (None, Some(b)) => Some(b.into_inner()), - _ => None, - }); - result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits()) - } else { - values[0] - .logical_nulls() - .map_or(0, |nulls| nulls.null_count()) - } -} - -impl AggregateExpr for Count { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Int64, self.nullable)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new( - format_state_name(&self.name, "count"), - DataType::Int64, - true, - )]) - } - - fn expressions(&self) -> Vec> { - self.exprs.clone() - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(CountAccumulator::new())) - } - - fn name(&self) -> &str { - &self.name - } - - fn groups_accumulator_supported(&self) -> bool { - // groups accumulator only supports `COUNT(c1)`, not - // `COUNT(c1, c2)`, etc - self.exprs.len() == 1 - } - - fn reverse_expr(&self) -> Option> { - Some(Arc::new(self.clone())) - } - - fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(CountAccumulator::new())) - } - - fn create_groups_accumulator(&self) -> Result> { - // instantiate specialized accumulator - Ok(Box::new(CountGroupsAccumulator::new())) - } - - fn with_new_expressions( - &self, - args: Vec>, - order_by_exprs: Vec>, - ) -> Option> { - debug_assert_eq!(self.exprs.len(), args.len()); - debug_assert!(order_by_exprs.is_empty()); - Some(Arc::new(Count { - name: self.name.clone(), - data_type: self.data_type.clone(), - nullable: self.nullable, - exprs: args, - })) - } -} - -impl PartialEq for Count { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.nullable == x.nullable - && self.exprs.len() == x.exprs.len() - && self - .exprs - .iter() - .zip(x.exprs.iter()) - .all(|(expr1, expr2)| expr1.eq(expr2)) - }) - .unwrap_or(false) - } -} - -#[derive(Debug)] -struct CountAccumulator { - count: i64, -} - -impl CountAccumulator { - /// new count accumulator - pub fn new() -> Self { - Self { count: 0 } - } -} - -impl Accumulator for CountAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![ScalarValue::Int64(Some(self.count))]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let array = &values[0]; - self.count += (array.len() - null_count_for_multiple_cols(values)) as i64; - Ok(()) - } - - fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let array = &values[0]; - self.count -= (array.len() - null_count_for_multiple_cols(values)) as i64; - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let counts = downcast_value!(states[0], Int64Array); - let delta = &compute::sum(counts); - if let Some(d) = delta { - self.count += *d; - } - Ok(()) - } - - fn evaluate(&mut self) -> Result { - Ok(ScalarValue::Int64(Some(self.count))) - } - - fn supports_retract_batch(&self) -> bool { - true - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } -} diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs deleted file mode 100644 index 52f1c5c0f9a0b..0000000000000 --- a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs +++ /dev/null @@ -1,718 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -mod bytes; -mod native; - -use std::any::Any; -use std::collections::HashSet; -use std::fmt::Debug; -use std::sync::Arc; - -use ahash::RandomState; -use arrow::array::{Array, ArrayRef}; -use arrow::datatypes::{DataType, Field, TimeUnit}; -use arrow_array::cast::AsArray; -use arrow_array::types::{ - Date32Type, Date64Type, Decimal128Type, Decimal256Type, Float16Type, Float32Type, - Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, Time32MillisecondType, - Time32SecondType, Time64MicrosecondType, Time64NanosecondType, - TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, - TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, -}; - -use datafusion_common::{internal_err, Result, ScalarValue}; -use datafusion_expr::Accumulator; - -use crate::aggregate::count_distinct::bytes::BytesDistinctCountAccumulator; -use crate::aggregate::count_distinct::native::{ - FloatDistinctCountAccumulator, PrimitiveDistinctCountAccumulator, -}; -use crate::aggregate::utils::down_cast_any_ref; -use crate::binary_map::OutputType; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; - -/// Expression for a `COUNT(DISTINCT)` aggregation. -#[derive(Debug)] -pub struct DistinctCount { - /// Column name - name: String, - /// The DataType used to hold the state for each input - state_data_type: DataType, - /// The input arguments - expr: Arc, -} - -impl DistinctCount { - /// Create a new COUNT(DISTINCT) aggregate function. - pub fn new( - input_data_type: DataType, - expr: Arc, - name: impl Into, - ) -> Self { - Self { - name: name.into(), - state_data_type: input_data_type, - expr, - } - } -} - -impl AggregateExpr for DistinctCount { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Int64, true)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new_list( - format_state_name(&self.name, "count distinct"), - Field::new("item", self.state_data_type.clone(), true), - false, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn create_accumulator(&self) -> Result> { - use DataType::*; - use TimeUnit::*; - - let data_type = &self.state_data_type; - Ok(match data_type { - // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator - Int8 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - Int16 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - Int32 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - Int64 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - UInt8 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - UInt16 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - UInt32 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - UInt64 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< - Decimal128Type, - >::new(data_type)), - Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< - Decimal256Type, - >::new(data_type)), - - Date32 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - Date64 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - Time32(Millisecond) => Box::new(PrimitiveDistinctCountAccumulator::< - Time32MillisecondType, - >::new(data_type)), - Time32(Second) => Box::new(PrimitiveDistinctCountAccumulator::< - Time32SecondType, - >::new(data_type)), - Time64(Microsecond) => Box::new(PrimitiveDistinctCountAccumulator::< - Time64MicrosecondType, - >::new(data_type)), - Time64(Nanosecond) => Box::new(PrimitiveDistinctCountAccumulator::< - Time64NanosecondType, - >::new(data_type)), - Timestamp(Microsecond, _) => Box::new(PrimitiveDistinctCountAccumulator::< - TimestampMicrosecondType, - >::new(data_type)), - Timestamp(Millisecond, _) => Box::new(PrimitiveDistinctCountAccumulator::< - TimestampMillisecondType, - >::new(data_type)), - Timestamp(Nanosecond, _) => Box::new(PrimitiveDistinctCountAccumulator::< - TimestampNanosecondType, - >::new(data_type)), - Timestamp(Second, _) => Box::new(PrimitiveDistinctCountAccumulator::< - TimestampSecondType, - >::new(data_type)), - - Float16 => Box::new(FloatDistinctCountAccumulator::::new()), - Float32 => Box::new(FloatDistinctCountAccumulator::::new()), - Float64 => Box::new(FloatDistinctCountAccumulator::::new()), - - Utf8 => Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)), - LargeUtf8 => { - Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) - } - Binary => Box::new(BytesDistinctCountAccumulator::::new( - OutputType::Binary, - )), - LargeBinary => Box::new(BytesDistinctCountAccumulator::::new( - OutputType::Binary, - )), - - // Use the generic accumulator based on `ScalarValue` for all other types - _ => Box::new(DistinctCountAccumulator { - values: HashSet::default(), - state_data_type: self.state_data_type.clone(), - }), - }) - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for DistinctCount { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.state_data_type == x.state_data_type - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -/// General purpose distinct accumulator that works for any DataType by using -/// [`ScalarValue`]. -/// -/// It stores intermediate results as a `ListArray` -/// -/// Note that many types have specialized accumulators that are (much) -/// more efficient such as [`PrimitiveDistinctCountAccumulator`] and -/// [`BytesDistinctCountAccumulator`] -#[derive(Debug)] -struct DistinctCountAccumulator { - values: HashSet, - state_data_type: DataType, -} - -impl DistinctCountAccumulator { - // calculating the size for fixed length values, taking first batch size * - // number of batches This method is faster than .full_size(), however it is - // not suitable for variable length values like strings or complex types - fn fixed_size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.values.capacity()) - + self - .values - .iter() - .next() - .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) - .unwrap_or(0) - + std::mem::size_of::() - } - - // calculates the size as accurately as possible. Note that calling this - // method is expensive - fn full_size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.values.capacity()) - + self - .values - .iter() - .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) - .sum::() - + std::mem::size_of::() - } -} - -impl Accumulator for DistinctCountAccumulator { - /// Returns the distinct values seen so far as (one element) ListArray. - fn state(&mut self) -> Result> { - let scalars = self.values.iter().cloned().collect::>(); - let arr = ScalarValue::new_list(scalars.as_slice(), &self.state_data_type); - Ok(vec![ScalarValue::List(arr)]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - } - - let arr = &values[0]; - if arr.data_type() == &DataType::Null { - return Ok(()); - } - - (0..arr.len()).try_for_each(|index| { - if !arr.is_null(index) { - let scalar = ScalarValue::try_from_array(arr, index)?; - self.values.insert(scalar); - } - Ok(()) - }) - } - - /// Merges multiple sets of distinct values into the current set. - /// - /// The input to this function is a `ListArray` with **multiple** rows, - /// where each row contains the values from a partial aggregate's phase (e.g. - /// the result of calling `Self::state` on multiple accumulators). - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - assert_eq!(states.len(), 1, "array_agg states must be singleton!"); - let array = &states[0]; - let list_array = array.as_list::(); - for inner_array in list_array.iter() { - let Some(inner_array) = inner_array else { - return internal_err!( - "Intermediate results of COUNT DISTINCT should always be non null" - ); - }; - self.update_batch(&[inner_array])?; - } - Ok(()) - } - - fn evaluate(&mut self) -> Result { - Ok(ScalarValue::Int64(Some(self.values.len() as i64))) - } - - fn size(&self) -> usize { - match &self.state_data_type { - DataType::Boolean | DataType::Null => self.fixed_size(), - d if d.is_primitive() => self.fixed_size(), - _ => self.full_size(), - } - } -} - -#[cfg(test)] -mod tests { - use arrow::array::{ - BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, - Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, - }; - use arrow_array::Decimal256Array; - use arrow_buffer::i256; - - use datafusion_common::cast::{as_boolean_array, as_list_array, as_primitive_array}; - use datafusion_common::internal_err; - use datafusion_common::DataFusionError; - - use crate::expressions::NoOp; - - use super::*; - - macro_rules! state_to_vec_primitive { - ($LIST:expr, $DATA_TYPE:ident) => {{ - let arr = ScalarValue::raw_data($LIST).unwrap(); - let list_arr = as_list_array(&arr).unwrap(); - let arr = list_arr.values(); - let arr = as_primitive_array::<$DATA_TYPE>(arr)?; - arr.values().iter().cloned().collect::>() - }}; - } - - macro_rules! test_count_distinct_update_batch_numeric { - ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{ - let values: Vec> = vec![ - Some(1), - Some(1), - None, - Some(3), - Some(2), - None, - Some(2), - Some(3), - Some(1), - ]; - - let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef]; - - let (states, result) = run_update_batch(&arrays)?; - - let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE); - state_vec.sort(); - - assert_eq!(states.len(), 1); - assert_eq!(state_vec, vec![1, 2, 3]); - assert_eq!(result, ScalarValue::Int64(Some(3))); - - Ok(()) - }}; - } - - fn state_to_vec_bool(sv: &ScalarValue) -> Result> { - let arr = ScalarValue::raw_data(sv)?; - let list_arr = as_list_array(&arr)?; - let arr = list_arr.values(); - let bool_arr = as_boolean_array(arr)?; - Ok(bool_arr.iter().flatten().collect()) - } - - fn run_update_batch(arrays: &[ArrayRef]) -> Result<(Vec, ScalarValue)> { - let agg = DistinctCount::new( - arrays[0].data_type().clone(), - Arc::new(NoOp::new()), - String::from("__col_name__"), - ); - - let mut accum = agg.create_accumulator()?; - accum.update_batch(arrays)?; - - Ok((accum.state()?, accum.evaluate()?)) - } - - fn run_update( - data_types: &[DataType], - rows: &[Vec], - ) -> Result<(Vec, ScalarValue)> { - let agg = DistinctCount::new( - data_types[0].clone(), - Arc::new(NoOp::new()), - String::from("__col_name__"), - ); - - let mut accum = agg.create_accumulator()?; - - let cols = (0..rows[0].len()) - .map(|i| { - rows.iter() - .map(|inner| inner[i].clone()) - .collect::>() - }) - .collect::>(); - - let arrays: Vec = cols - .iter() - .map(|c| ScalarValue::iter_to_array(c.clone())) - .collect::>>()?; - - accum.update_batch(&arrays)?; - - Ok((accum.state()?, accum.evaluate()?)) - } - - // Used trait to create associated constant for f32 and f64 - trait SubNormal: 'static { - const SUBNORMAL: Self; - } - - impl SubNormal for f64 { - const SUBNORMAL: Self = 1.0e-308_f64; - } - - impl SubNormal for f32 { - const SUBNORMAL: Self = 1.0e-38_f32; - } - - macro_rules! test_count_distinct_update_batch_floating_point { - ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{ - let values: Vec> = vec![ - Some(<$PRIM_TYPE>::INFINITY), - Some(<$PRIM_TYPE>::NAN), - Some(1.0), - Some(<$PRIM_TYPE as SubNormal>::SUBNORMAL), - Some(1.0), - Some(<$PRIM_TYPE>::INFINITY), - None, - Some(3.0), - Some(-4.5), - Some(2.0), - None, - Some(2.0), - Some(3.0), - Some(<$PRIM_TYPE>::NEG_INFINITY), - Some(1.0), - Some(<$PRIM_TYPE>::NAN), - Some(<$PRIM_TYPE>::NEG_INFINITY), - ]; - - let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef]; - - let (states, result) = run_update_batch(&arrays)?; - - let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE); - - dbg!(&state_vec); - state_vec.sort_by(|a, b| match (a, b) { - (lhs, rhs) => lhs.total_cmp(rhs), - }); - - let nan_idx = state_vec.len() - 1; - assert_eq!(states.len(), 1); - assert_eq!( - &state_vec[..nan_idx], - vec![ - <$PRIM_TYPE>::NEG_INFINITY, - -4.5, - <$PRIM_TYPE as SubNormal>::SUBNORMAL, - 1.0, - 2.0, - 3.0, - <$PRIM_TYPE>::INFINITY - ] - ); - assert!(state_vec[nan_idx].is_nan()); - assert_eq!(result, ScalarValue::Int64(Some(8))); - - Ok(()) - }}; - } - - macro_rules! test_count_distinct_update_batch_bigint { - ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{ - let values: Vec> = vec![ - Some(i256::from(1)), - Some(i256::from(1)), - None, - Some(i256::from(3)), - Some(i256::from(2)), - None, - Some(i256::from(2)), - Some(i256::from(3)), - Some(i256::from(1)), - ]; - - let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef]; - - let (states, result) = run_update_batch(&arrays)?; - - let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE); - state_vec.sort(); - - assert_eq!(states.len(), 1); - assert_eq!(state_vec, vec![i256::from(1), i256::from(2), i256::from(3)]); - assert_eq!(result, ScalarValue::Int64(Some(3))); - - Ok(()) - }}; - } - - #[test] - fn count_distinct_update_batch_i8() -> Result<()> { - test_count_distinct_update_batch_numeric!(Int8Array, Int8Type, i8) - } - - #[test] - fn count_distinct_update_batch_i16() -> Result<()> { - test_count_distinct_update_batch_numeric!(Int16Array, Int16Type, i16) - } - - #[test] - fn count_distinct_update_batch_i32() -> Result<()> { - test_count_distinct_update_batch_numeric!(Int32Array, Int32Type, i32) - } - - #[test] - fn count_distinct_update_batch_i64() -> Result<()> { - test_count_distinct_update_batch_numeric!(Int64Array, Int64Type, i64) - } - - #[test] - fn count_distinct_update_batch_u8() -> Result<()> { - test_count_distinct_update_batch_numeric!(UInt8Array, UInt8Type, u8) - } - - #[test] - fn count_distinct_update_batch_u16() -> Result<()> { - test_count_distinct_update_batch_numeric!(UInt16Array, UInt16Type, u16) - } - - #[test] - fn count_distinct_update_batch_u32() -> Result<()> { - test_count_distinct_update_batch_numeric!(UInt32Array, UInt32Type, u32) - } - - #[test] - fn count_distinct_update_batch_u64() -> Result<()> { - test_count_distinct_update_batch_numeric!(UInt64Array, UInt64Type, u64) - } - - #[test] - fn count_distinct_update_batch_f32() -> Result<()> { - test_count_distinct_update_batch_floating_point!(Float32Array, Float32Type, f32) - } - - #[test] - fn count_distinct_update_batch_f64() -> Result<()> { - test_count_distinct_update_batch_floating_point!(Float64Array, Float64Type, f64) - } - - #[test] - fn count_distinct_update_batch_i256() -> Result<()> { - test_count_distinct_update_batch_bigint!(Decimal256Array, Decimal256Type, i256) - } - - #[test] - fn count_distinct_update_batch_boolean() -> Result<()> { - let get_count = |data: BooleanArray| -> Result<(Vec, i64)> { - let arrays = vec![Arc::new(data) as ArrayRef]; - let (states, result) = run_update_batch(&arrays)?; - let mut state_vec = state_to_vec_bool(&states[0])?; - state_vec.sort(); - - let count = match result { - ScalarValue::Int64(c) => c.ok_or_else(|| { - DataFusionError::Internal("Found None count".to_string()) - }), - scalar => { - internal_err!("Found non int64 scalar value from count: {scalar}") - } - }?; - Ok((state_vec, count)) - }; - - let zero_count_values = BooleanArray::from(Vec::::new()); - - let one_count_values = BooleanArray::from(vec![false, false]); - let one_count_values_with_null = - BooleanArray::from(vec![Some(true), Some(true), None, None]); - - let two_count_values = BooleanArray::from(vec![true, false, true, false, true]); - let two_count_values_with_null = BooleanArray::from(vec![ - Some(true), - Some(false), - None, - None, - Some(true), - Some(false), - ]); - - assert_eq!(get_count(zero_count_values)?, (Vec::::new(), 0)); - assert_eq!(get_count(one_count_values)?, (vec![false], 1)); - assert_eq!(get_count(one_count_values_with_null)?, (vec![true], 1)); - assert_eq!(get_count(two_count_values)?, (vec![false, true], 2)); - assert_eq!( - get_count(two_count_values_with_null)?, - (vec![false, true], 2) - ); - Ok(()) - } - - #[test] - fn count_distinct_update_batch_all_nulls() -> Result<()> { - let arrays = vec![Arc::new(Int32Array::from( - vec![None, None, None, None] as Vec> - )) as ArrayRef]; - - let (states, result) = run_update_batch(&arrays)?; - let state_vec = state_to_vec_primitive!(&states[0], Int32Type); - assert_eq!(states.len(), 1); - assert!(state_vec.is_empty()); - assert_eq!(result, ScalarValue::Int64(Some(0))); - - Ok(()) - } - - #[test] - fn count_distinct_update_batch_empty() -> Result<()> { - let arrays = vec![Arc::new(Int32Array::from(vec![0_i32; 0])) as ArrayRef]; - - let (states, result) = run_update_batch(&arrays)?; - let state_vec = state_to_vec_primitive!(&states[0], Int32Type); - assert_eq!(states.len(), 1); - assert!(state_vec.is_empty()); - assert_eq!(result, ScalarValue::Int64(Some(0))); - - Ok(()) - } - - #[test] - fn count_distinct_update() -> Result<()> { - let (states, result) = run_update( - &[DataType::Int32], - &[ - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(Some(5))], - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(Some(5))], - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(Some(2))], - ], - )?; - assert_eq!(states.len(), 1); - assert_eq!(result, ScalarValue::Int64(Some(3))); - - let (states, result) = run_update( - &[DataType::UInt64], - &[ - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(Some(5))], - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(Some(5))], - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(Some(2))], - ], - )?; - assert_eq!(states.len(), 1); - assert_eq!(result, ScalarValue::Int64(Some(3))); - Ok(()) - } - - #[test] - fn count_distinct_update_with_nulls() -> Result<()> { - let (states, result) = run_update( - &[DataType::Int32], - &[ - // None of these updates contains a None, so these are accumulated. - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(Some(-2))], - // Each of these updates contains at least one None, so these - // won't be accumulated. - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(None)], - vec![ScalarValue::Int32(None)], - ], - )?; - assert_eq!(states.len(), 1); - assert_eq!(result, ScalarValue::Int64(Some(2))); - - let (states, result) = run_update( - &[DataType::UInt64], - &[ - // None of these updates contains a None, so these are accumulated. - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(Some(2))], - // Each of these updates contains at least one None, so these - // won't be accumulated. - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(None)], - vec![ScalarValue::UInt64(None)], - ], - )?; - assert_eq!(states.len(), 1); - assert_eq!(result, ScalarValue::Int64(Some(2))); - Ok(()) - } -} diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs index 65227b727be70..a6946e739c97d 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs @@ -20,7 +20,7 @@ pub use adapter::GroupsAccumulatorAdapter; // Backward compatibility pub(crate) mod accumulate { - pub use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::{accumulate_indices, NullState}; + pub use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::NullState; } pub use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::NullState; diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 7a6c5f9d0e247..9079a81e62418 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -26,8 +26,6 @@ pub(crate) mod average; pub(crate) mod bit_and_or_xor; pub(crate) mod bool_and_or; pub(crate) mod correlation; -pub(crate) mod count; -pub(crate) mod count_distinct; pub(crate) mod covariance; pub(crate) mod grouping; pub(crate) mod nth_value; @@ -35,7 +33,6 @@ pub(crate) mod string_agg; #[macro_use] pub(crate) mod min_max; pub(crate) mod groups_accumulator; -pub(crate) mod regr; pub(crate) mod stats; pub(crate) mod stddev; pub(crate) mod variance; diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index c56229e07a636..08d8cd4413347 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -26,7 +26,7 @@ use crate::PhysicalExpr; use arrow::array::*; use arrow::compute::kernels::cmp::eq; use arrow::compute::kernels::zip::zip; -use arrow::compute::{and, is_null, not, nullif, or, prep_null_mask_filter}; +use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter}; use arrow::datatypes::{DataType, Schema}; use datafusion_common::cast::as_boolean_array; use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue}; @@ -168,7 +168,7 @@ impl CaseExpr { } }; - remainder = and(&remainder, ¬(&when_match)?)?; + remainder = and_not(&remainder, &when_match)?; } if let Some(e) = &self.else_expr { @@ -241,7 +241,7 @@ impl CaseExpr { // Succeed tuples should be filtered out for short-circuit evaluation, // null values for the current when expr should be kept - remainder = and(&remainder, ¬(&when_value)?)?; + remainder = and_not(&remainder, &when_value)?; } if let Some(e) = &self.else_expr { diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index a96d021730180..beba25740501e 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -47,12 +47,9 @@ pub use crate::aggregate::bit_and_or_xor::{BitAnd, BitOr, BitXor, DistinctBitXor pub use crate::aggregate::bool_and_or::{BoolAnd, BoolOr}; pub use crate::aggregate::build_in::create_aggregate_expr; pub use crate::aggregate::correlation::Correlation; -pub use crate::aggregate::count::Count; -pub use crate::aggregate::count_distinct::DistinctCount; pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::min_max::{Max, MaxAccumulator, Min, MinAccumulator}; pub use crate::aggregate::nth_value::NthValueAgg; -pub use crate::aggregate::regr::{Regr, RegrType}; pub use crate::aggregate::stats::StatsType; pub use crate::aggregate::string_agg::StringAgg; pub use crate::window::cume_dist::{cume_dist, CumeDist}; diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 72f5f2d50cb89..b764e81a95d13 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -17,7 +17,9 @@ pub mod aggregate; pub mod analysis; -pub mod binary_map; +pub mod binary_map { + pub use datafusion_physical_expr_common::binary_map::{ArrowBytesSet, OutputType}; +} pub mod equivalence; pub mod expressions; pub mod functions; diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index 55d112e1f6e0a..4bd40066ff341 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -125,7 +125,6 @@ impl BuiltInWindowFunctionExpr for NthValue { fn create_evaluator(&self) -> Result> { let state = NthValueState { - range: Default::default(), finalized_result: None, kind: self.kind, }; diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 065371d9e43e0..3cf68379d72b8 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -559,7 +559,6 @@ pub enum NthValueKind { #[derive(Debug, Clone)] pub struct NthValueState { - pub range: Range, // In certain cases, we can finalize the result early. Consider this usage: // ``` // FIRST_VALUE(increasing_col) OVER window AS my_first_value diff --git a/datafusion/physical-plan/src/aggregates/group_values/bytes.rs b/datafusion/physical-plan/src/aggregates/group_values/bytes.rs index d073c8995a9bf..f789af8b8a024 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/bytes.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/bytes.rs @@ -18,7 +18,7 @@ use crate::aggregates::group_values::GroupValues; use arrow_array::{Array, ArrayRef, OffsetSizeTrait, RecordBatch}; use datafusion_expr::EmitTo; -use datafusion_physical_expr::binary_map::{ArrowBytesMap, OutputType}; +use datafusion_physical_expr_common::binary_map::{ArrowBytesMap, OutputType}; /// A [`GroupValues`] storing single column of Utf8/LargeUtf8/Binary/LargeBinary values /// diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 79abbdb52ca24..b6fc70be7cbc5 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1194,12 +1194,14 @@ mod tests { use datafusion_execution::memory_pool::FairSpillPool; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_expr::expr::Sort; + use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::median::median_udaf; use datafusion_physical_expr::expressions::{ - lit, Count, FirstValue, LastValue, OrderSensitiveArrayAgg, + lit, FirstValue, LastValue, OrderSensitiveArrayAgg, }; use datafusion_physical_expr::PhysicalSortExpr; + use datafusion_physical_expr_common::aggregate::create_aggregate_expr; use futures::{FutureExt, Stream}; // Generate a schema which consists of 5 columns (a, b, c, d, e) @@ -1334,11 +1336,16 @@ mod tests { ], }; - let aggregates: Vec> = vec![Arc::new(Count::new( - lit(1i8), - "COUNT(1)".to_string(), - DataType::Int64, - ))]; + let aggregates = vec![create_aggregate_expr( + &count_udaf(), + &[lit(1i8)], + &[], + &[], + &input_schema, + "COUNT(1)", + false, + false, + )?]; let task_ctx = if spill { new_spill_ctx(4, 1000) diff --git a/datafusion/physical-plan/src/insert.rs b/datafusion/physical-plan/src/insert.rs index fa30141a19341..30c3353d4b71d 100644 --- a/datafusion/physical-plan/src/insert.rs +++ b/datafusion/physical-plan/src/insert.rs @@ -175,11 +175,6 @@ impl DataSinkExec { &self.sort_order } - /// Returns the metrics of the underlying [DataSink] - pub fn metrics(&self) -> Option { - self.sink.metrics() - } - fn create_schema( input: &Arc, schema: SchemaRef, @@ -289,6 +284,11 @@ impl ExecutionPlan for DataSinkExec { stream, ))) } + + /// Returns the metrics of the underlying [DataSink] + fn metrics(&self) -> Option { + self.sink.metrics() + } } /// Create a output record batch with a count diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 784584f03f0f5..cd66ab093f881 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -21,7 +21,7 @@ use std::fmt; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::task::Poll; -use std::{any::Any, usize, vec}; +use std::{any::Any, vec}; use super::{ utils::{OnceAsync, OnceFut}, diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index 0a01d84141e7c..46d3ac5acf1eb 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -20,7 +20,6 @@ use std::collections::{HashMap, VecDeque}; use std::sync::Arc; -use std::usize; use crate::joins::utils::{JoinFilter, JoinHashMapType}; use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder}; diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 7b4d790479b14..e11e6dd2f627b 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -29,7 +29,7 @@ use std::any::Any; use std::fmt::{self, Debug}; use std::sync::Arc; use std::task::{Context, Poll}; -use std::{usize, vec}; +use std::vec; use crate::common::SharedMemoryReservation; use crate::handle_state; diff --git a/datafusion/physical-plan/src/joins/test_utils.rs b/datafusion/physical-plan/src/joins/test_utils.rs index 9598ed83aa580..7e05ded6f69dd 100644 --- a/datafusion/physical-plan/src/joins/test_utils.rs +++ b/datafusion/physical-plan/src/joins/test_utils.rs @@ -18,7 +18,6 @@ //! This file has test utils for hash joins use std::sync::Arc; -use std::usize; use crate::joins::utils::{JoinFilter, JoinOn}; use crate::joins::{ diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 0d99d7a163567..c08b0e3d091c8 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -23,7 +23,6 @@ use std::future::Future; use std::ops::{IndexMut, Range}; use std::sync::Arc; use std::task::{Context, Poll}; -use std::usize; use crate::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder}; use crate::{ColumnStatistics, ExecutionPlan, Partitioning, Statistics}; diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 48f1bee59bbfa..56d780e51394a 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -1194,9 +1194,9 @@ mod tests { RecordBatchStream, SendableRecordBatchStream, TaskContext, }; use datafusion_expr::{ - AggregateFunction, WindowFrame, WindowFrameBound, WindowFrameUnits, - WindowFunctionDefinition, + WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; + use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::expressions::{col, Column, NthValue}; use datafusion_physical_expr::window::{ BuiltInWindowExpr, BuiltInWindowFunctionExpr, @@ -1298,8 +1298,7 @@ mod tests { order_by: &str, ) -> Result> { let schema = input.schema(); - let window_fn = - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count); + let window_fn = WindowFunctionDefinition::AggregateUDF(count_udaf()); let col_expr = Arc::new(Column::new(schema.fields[0].name(), 0)) as Arc; let args = vec![col_expr]; diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 9b392d941ef45..63ce473fc57e6 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -597,7 +597,6 @@ pub fn get_window_mode( #[cfg(test)] mod tests { use super::*; - use crate::aggregates::AggregateFunction; use crate::collect; use crate::expressions::col; use crate::streaming::StreamingTableExec; @@ -607,6 +606,7 @@ mod tests { use arrow::compute::SortOptions; use datafusion_execution::TaskContext; + use datafusion_functions_aggregate::count::count_udaf; use futures::FutureExt; use InputOrderMode::{Linear, PartiallySorted, Sorted}; @@ -749,7 +749,7 @@ mod tests { let refs = blocking_exec.refs(); let window_agg_exec = Arc::new(WindowAggExec::try_new( vec![create_window_expr( - &WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + &WindowFunctionDefinition::AggregateUDF(count_udaf()), "count".to_owned(), &[col("a", &schema)?], &[], diff --git a/datafusion/proto-common/Cargo.toml b/datafusion/proto-common/Cargo.toml index 97568fb5f678b..66ce7cbd838f4 100644 --- a/datafusion/proto-common/Cargo.toml +++ b/datafusion/proto-common/Cargo.toml @@ -26,7 +26,7 @@ homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } authors = { workspace = true } -rust-version = "1.73" +rust-version = "1.75" # Exclude proto files so crates.io consumers don't need protoc exclude = ["*.proto"] diff --git a/datafusion/proto-common/gen/Cargo.toml b/datafusion/proto-common/gen/Cargo.toml index 49884c48b3cc0..9f8f03de6dc9e 100644 --- a/datafusion/proto-common/gen/Cargo.toml +++ b/datafusion/proto-common/gen/Cargo.toml @@ -20,7 +20,7 @@ name = "gen-common" description = "Code generation for proto" version = "0.1.0" edition = { workspace = true } -rust-version = "1.73" +rust-version = "1.75" authors = { workspace = true } homepage = { workspace = true } repository = { workspace = true } diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index 358ba7e3eb94f..aa8d0e55b68fc 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -27,7 +27,7 @@ repository = { workspace = true } license = { workspace = true } authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.73" +rust-version = "1.75" # Exclude proto files so crates.io consumers don't need protoc exclude = ["*.proto"] @@ -59,6 +59,7 @@ serde_json = { workspace = true, optional = true } [dev-dependencies] datafusion-functions = { workspace = true, default-features = true } +datafusion-functions-aggregate = { workspace = true } doc-comment = { workspace = true } strum = { version = "0.26.1", features = ["derive"] } tokio = { workspace = true, features = ["rt-multi-thread"] } diff --git a/datafusion/proto/gen/Cargo.toml b/datafusion/proto/gen/Cargo.toml index b6993f6c040bf..eabaf7ba8e14f 100644 --- a/datafusion/proto/gen/Cargo.toml +++ b/datafusion/proto/gen/Cargo.toml @@ -20,7 +20,7 @@ name = "gen" description = "Code generation for proto" version = "0.1.0" edition = { workspace = true } -rust-version = "1.73" +rust-version = "1.75" authors = { workspace = true } homepage = { workspace = true } repository = { workspace = true } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index b401ff8810db2..83223a04d0233 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -476,7 +476,7 @@ enum AggregateFunction { MAX = 1; // SUM = 2; AVG = 3; - COUNT = 4; + // COUNT = 4; // APPROX_DISTINCT = 5; ARRAY_AGG = 6; // VARIANCE = 7; @@ -496,15 +496,15 @@ enum AggregateFunction { BIT_XOR = 21; BOOL_AND = 22; BOOL_OR = 23; - REGR_SLOPE = 26; - REGR_INTERCEPT = 27; - REGR_COUNT = 28; - REGR_R2 = 29; - REGR_AVGX = 30; - REGR_AVGY = 31; - REGR_SXX = 32; - REGR_SYY = 33; - REGR_SXY = 34; + // REGR_SLOPE = 26; + // REGR_INTERCEPT = 27; + // REGR_COUNT = 28; + // REGR_R2 = 29; + // REGR_AVGX = 30; + // REGR_AVGY = 31; + // REGR_SXX = 32; + // REGR_SYY = 33; + // REGR_SXY = 34; STRING_AGG = 35; NTH_VALUE_AGG = 36; } @@ -520,6 +520,7 @@ message AggregateExprNode { message AggregateUDFExprNode { string fun_name = 1; repeated LogicalExprNode args = 2; + bool distinct = 5; LogicalExprNode filter = 3; repeated LogicalExprNode order_by = 4; } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index d6632c77d8da7..f298dd241abff 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -535,7 +535,6 @@ impl serde::Serialize for AggregateFunction { Self::Min => "MIN", Self::Max => "MAX", Self::Avg => "AVG", - Self::Count => "COUNT", Self::ArrayAgg => "ARRAY_AGG", Self::Correlation => "CORRELATION", Self::ApproxPercentileCont => "APPROX_PERCENTILE_CONT", @@ -546,15 +545,6 @@ impl serde::Serialize for AggregateFunction { Self::BitXor => "BIT_XOR", Self::BoolAnd => "BOOL_AND", Self::BoolOr => "BOOL_OR", - Self::RegrSlope => "REGR_SLOPE", - Self::RegrIntercept => "REGR_INTERCEPT", - Self::RegrCount => "REGR_COUNT", - Self::RegrR2 => "REGR_R2", - Self::RegrAvgx => "REGR_AVGX", - Self::RegrAvgy => "REGR_AVGY", - Self::RegrSxx => "REGR_SXX", - Self::RegrSyy => "REGR_SYY", - Self::RegrSxy => "REGR_SXY", Self::StringAgg => "STRING_AGG", Self::NthValueAgg => "NTH_VALUE_AGG", }; @@ -571,7 +561,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "MIN", "MAX", "AVG", - "COUNT", "ARRAY_AGG", "CORRELATION", "APPROX_PERCENTILE_CONT", @@ -582,15 +571,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "BIT_XOR", "BOOL_AND", "BOOL_OR", - "REGR_SLOPE", - "REGR_INTERCEPT", - "REGR_COUNT", - "REGR_R2", - "REGR_AVGX", - "REGR_AVGY", - "REGR_SXX", - "REGR_SYY", - "REGR_SXY", "STRING_AGG", "NTH_VALUE_AGG", ]; @@ -636,7 +616,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "MIN" => Ok(AggregateFunction::Min), "MAX" => Ok(AggregateFunction::Max), "AVG" => Ok(AggregateFunction::Avg), - "COUNT" => Ok(AggregateFunction::Count), "ARRAY_AGG" => Ok(AggregateFunction::ArrayAgg), "CORRELATION" => Ok(AggregateFunction::Correlation), "APPROX_PERCENTILE_CONT" => Ok(AggregateFunction::ApproxPercentileCont), @@ -647,15 +626,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "BIT_XOR" => Ok(AggregateFunction::BitXor), "BOOL_AND" => Ok(AggregateFunction::BoolAnd), "BOOL_OR" => Ok(AggregateFunction::BoolOr), - "REGR_SLOPE" => Ok(AggregateFunction::RegrSlope), - "REGR_INTERCEPT" => Ok(AggregateFunction::RegrIntercept), - "REGR_COUNT" => Ok(AggregateFunction::RegrCount), - "REGR_R2" => Ok(AggregateFunction::RegrR2), - "REGR_AVGX" => Ok(AggregateFunction::RegrAvgx), - "REGR_AVGY" => Ok(AggregateFunction::RegrAvgy), - "REGR_SXX" => Ok(AggregateFunction::RegrSxx), - "REGR_SYY" => Ok(AggregateFunction::RegrSyy), - "REGR_SXY" => Ok(AggregateFunction::RegrSxy), "STRING_AGG" => Ok(AggregateFunction::StringAgg), "NTH_VALUE_AGG" => Ok(AggregateFunction::NthValueAgg), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), @@ -886,6 +856,9 @@ impl serde::Serialize for AggregateUdfExprNode { if !self.args.is_empty() { len += 1; } + if self.distinct { + len += 1; + } if self.filter.is_some() { len += 1; } @@ -899,6 +872,9 @@ impl serde::Serialize for AggregateUdfExprNode { if !self.args.is_empty() { struct_ser.serialize_field("args", &self.args)?; } + if self.distinct { + struct_ser.serialize_field("distinct", &self.distinct)?; + } if let Some(v) = self.filter.as_ref() { struct_ser.serialize_field("filter", v)?; } @@ -918,6 +894,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { "fun_name", "funName", "args", + "distinct", "filter", "order_by", "orderBy", @@ -927,6 +904,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { enum GeneratedField { FunName, Args, + Distinct, Filter, OrderBy, } @@ -952,6 +930,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { match value { "funName" | "fun_name" => Ok(GeneratedField::FunName), "args" => Ok(GeneratedField::Args), + "distinct" => Ok(GeneratedField::Distinct), "filter" => Ok(GeneratedField::Filter), "orderBy" | "order_by" => Ok(GeneratedField::OrderBy), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), @@ -975,6 +954,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { { let mut fun_name__ = None; let mut args__ = None; + let mut distinct__ = None; let mut filter__ = None; let mut order_by__ = None; while let Some(k) = map_.next_key()? { @@ -991,6 +971,12 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { } args__ = Some(map_.next_value()?); } + GeneratedField::Distinct => { + if distinct__.is_some() { + return Err(serde::de::Error::duplicate_field("distinct")); + } + distinct__ = Some(map_.next_value()?); + } GeneratedField::Filter => { if filter__.is_some() { return Err(serde::de::Error::duplicate_field("filter")); @@ -1008,6 +994,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { Ok(AggregateUdfExprNode { fun_name: fun_name__.unwrap_or_default(), args: args__.unwrap_or_default(), + distinct: distinct__.unwrap_or_default(), filter: filter__, order_by: order_by__.unwrap_or_default(), }) diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 0aca5ef1ffb80..fa0217e9ef4f5 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -767,6 +767,8 @@ pub struct AggregateUdfExprNode { pub fun_name: ::prost::alloc::string::String, #[prost(message, repeated, tag = "2")] pub args: ::prost::alloc::vec::Vec, + #[prost(bool, tag = "5")] + pub distinct: bool, #[prost(message, optional, boxed, tag = "3")] pub filter: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, repeated, tag = "4")] @@ -1928,7 +1930,7 @@ pub enum AggregateFunction { Max = 1, /// SUM = 2; Avg = 3, - Count = 4, + /// COUNT = 4; /// APPROX_DISTINCT = 5; ArrayAgg = 6, /// VARIANCE = 7; @@ -1948,15 +1950,15 @@ pub enum AggregateFunction { BitXor = 21, BoolAnd = 22, BoolOr = 23, - RegrSlope = 26, - RegrIntercept = 27, - RegrCount = 28, - RegrR2 = 29, - RegrAvgx = 30, - RegrAvgy = 31, - RegrSxx = 32, - RegrSyy = 33, - RegrSxy = 34, + /// REGR_SLOPE = 26; + /// REGR_INTERCEPT = 27; + /// REGR_COUNT = 28; + /// REGR_R2 = 29; + /// REGR_AVGX = 30; + /// REGR_AVGY = 31; + /// REGR_SXX = 32; + /// REGR_SYY = 33; + /// REGR_SXY = 34; StringAgg = 35, NthValueAgg = 36, } @@ -1970,7 +1972,6 @@ impl AggregateFunction { AggregateFunction::Min => "MIN", AggregateFunction::Max => "MAX", AggregateFunction::Avg => "AVG", - AggregateFunction::Count => "COUNT", AggregateFunction::ArrayAgg => "ARRAY_AGG", AggregateFunction::Correlation => "CORRELATION", AggregateFunction::ApproxPercentileCont => "APPROX_PERCENTILE_CONT", @@ -1983,15 +1984,6 @@ impl AggregateFunction { AggregateFunction::BitXor => "BIT_XOR", AggregateFunction::BoolAnd => "BOOL_AND", AggregateFunction::BoolOr => "BOOL_OR", - AggregateFunction::RegrSlope => "REGR_SLOPE", - AggregateFunction::RegrIntercept => "REGR_INTERCEPT", - AggregateFunction::RegrCount => "REGR_COUNT", - AggregateFunction::RegrR2 => "REGR_R2", - AggregateFunction::RegrAvgx => "REGR_AVGX", - AggregateFunction::RegrAvgy => "REGR_AVGY", - AggregateFunction::RegrSxx => "REGR_SXX", - AggregateFunction::RegrSyy => "REGR_SYY", - AggregateFunction::RegrSxy => "REGR_SXY", AggregateFunction::StringAgg => "STRING_AGG", AggregateFunction::NthValueAgg => "NTH_VALUE_AGG", } @@ -2002,7 +1994,6 @@ impl AggregateFunction { "MIN" => Some(Self::Min), "MAX" => Some(Self::Max), "AVG" => Some(Self::Avg), - "COUNT" => Some(Self::Count), "ARRAY_AGG" => Some(Self::ArrayAgg), "CORRELATION" => Some(Self::Correlation), "APPROX_PERCENTILE_CONT" => Some(Self::ApproxPercentileCont), @@ -2015,15 +2006,6 @@ impl AggregateFunction { "BIT_XOR" => Some(Self::BitXor), "BOOL_AND" => Some(Self::BoolAnd), "BOOL_OR" => Some(Self::BoolOr), - "REGR_SLOPE" => Some(Self::RegrSlope), - "REGR_INTERCEPT" => Some(Self::RegrIntercept), - "REGR_COUNT" => Some(Self::RegrCount), - "REGR_R2" => Some(Self::RegrR2), - "REGR_AVGX" => Some(Self::RegrAvgx), - "REGR_AVGY" => Some(Self::RegrAvgy), - "REGR_SXX" => Some(Self::RegrSxx), - "REGR_SYY" => Some(Self::RegrSyy), - "REGR_SXY" => Some(Self::RegrSxy), "STRING_AGG" => Some(Self::StringAgg), "NTH_VALUE_AGG" => Some(Self::NthValueAgg), _ => None, diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 3ad5973380ede..ed7b0129cc48f 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -145,18 +145,8 @@ impl From for AggregateFunction { protobuf::AggregateFunction::BitXor => Self::BitXor, protobuf::AggregateFunction::BoolAnd => Self::BoolAnd, protobuf::AggregateFunction::BoolOr => Self::BoolOr, - protobuf::AggregateFunction::Count => Self::Count, protobuf::AggregateFunction::ArrayAgg => Self::ArrayAgg, protobuf::AggregateFunction::Correlation => Self::Correlation, - protobuf::AggregateFunction::RegrSlope => Self::RegrSlope, - protobuf::AggregateFunction::RegrIntercept => Self::RegrIntercept, - protobuf::AggregateFunction::RegrCount => Self::RegrCount, - protobuf::AggregateFunction::RegrR2 => Self::RegrR2, - protobuf::AggregateFunction::RegrAvgx => Self::RegrAvgx, - protobuf::AggregateFunction::RegrAvgy => Self::RegrAvgy, - protobuf::AggregateFunction::RegrSxx => Self::RegrSXX, - protobuf::AggregateFunction::RegrSyy => Self::RegrSYY, - protobuf::AggregateFunction::RegrSxy => Self::RegrSXY, protobuf::AggregateFunction::ApproxPercentileCont => { Self::ApproxPercentileCont } @@ -642,7 +632,7 @@ pub fn parse_expr( Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( agg_fn, parse_exprs(&pb.args, registry, codec)?, - false, + pb.distinct, parse_optional_expr(pb.filter.as_deref(), registry, codec)?.map(Box::new), parse_vec_expr(&pb.order_by, registry, codec)?, None, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index d42470f198e38..04f7b596fea80 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -116,18 +116,8 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::BitXor => Self::BitXor, AggregateFunction::BoolAnd => Self::BoolAnd, AggregateFunction::BoolOr => Self::BoolOr, - AggregateFunction::Count => Self::Count, AggregateFunction::ArrayAgg => Self::ArrayAgg, AggregateFunction::Correlation => Self::Correlation, - AggregateFunction::RegrSlope => Self::RegrSlope, - AggregateFunction::RegrIntercept => Self::RegrIntercept, - AggregateFunction::RegrCount => Self::RegrCount, - AggregateFunction::RegrR2 => Self::RegrR2, - AggregateFunction::RegrAvgx => Self::RegrAvgx, - AggregateFunction::RegrAvgy => Self::RegrAvgy, - AggregateFunction::RegrSXX => Self::RegrSxx, - AggregateFunction::RegrSYY => Self::RegrSyy, - AggregateFunction::RegrSXY => Self::RegrSxy, AggregateFunction::ApproxPercentileCont => Self::ApproxPercentileCont, AggregateFunction::ApproxPercentileContWithWeight => { Self::ApproxPercentileContWithWeight @@ -406,25 +396,9 @@ pub fn serialize_expr( AggregateFunction::BoolAnd => protobuf::AggregateFunction::BoolAnd, AggregateFunction::BoolOr => protobuf::AggregateFunction::BoolOr, AggregateFunction::Avg => protobuf::AggregateFunction::Avg, - AggregateFunction::Count => protobuf::AggregateFunction::Count, AggregateFunction::Correlation => { protobuf::AggregateFunction::Correlation } - AggregateFunction::RegrSlope => { - protobuf::AggregateFunction::RegrSlope - } - AggregateFunction::RegrIntercept => { - protobuf::AggregateFunction::RegrIntercept - } - AggregateFunction::RegrR2 => protobuf::AggregateFunction::RegrR2, - AggregateFunction::RegrAvgx => protobuf::AggregateFunction::RegrAvgx, - AggregateFunction::RegrAvgy => protobuf::AggregateFunction::RegrAvgy, - AggregateFunction::RegrCount => { - protobuf::AggregateFunction::RegrCount - } - AggregateFunction::RegrSXX => protobuf::AggregateFunction::RegrSxx, - AggregateFunction::RegrSYY => protobuf::AggregateFunction::RegrSyy, - AggregateFunction::RegrSXY => protobuf::AggregateFunction::RegrSxy, AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, AggregateFunction::NthValue => { protobuf::AggregateFunction::NthValueAgg @@ -456,6 +430,7 @@ pub fn serialize_expr( protobuf::AggregateUdfExprNode { fun_name: fun.name().to_string(), args: serialize_exprs(args, codec)?, + distinct: *distinct, filter: match filter { Some(e) => Some(Box::new(serialize_expr(e.as_ref(), codec)?)), None => None, diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 5258bdd11d865..ef462ac94b9a9 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -25,10 +25,10 @@ use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ ApproxPercentileCont, ApproxPercentileContWithWeight, ArrayAgg, Avg, BinaryExpr, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, CastExpr, Column, Correlation, - Count, CumeDist, DistinctArrayAgg, DistinctBitXor, DistinctCount, Grouping, - InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, - NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, RankType, Regr, RegrType, - RowNumber, StringAgg, TryCastExpr, WindowShift, + CumeDist, DistinctArrayAgg, DistinctBitXor, Grouping, InListExpr, IsNotNullExpr, + IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, + OrderSensitiveArrayAgg, Rank, RankType, RowNumber, StringAgg, TryCastExpr, + WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -240,12 +240,7 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { let aggr_expr = expr.as_any(); let mut distinct = false; - let inner = if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Count - } else if aggr_expr.downcast_ref::().is_some() { - distinct = true; - protobuf::AggregateFunction::Count - } else if aggr_expr.downcast_ref::().is_some() { + let inner = if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::Grouping } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::BitAnd @@ -275,18 +270,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { protobuf::AggregateFunction::Avg } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::Correlation - } else if let Some(regr_expr) = aggr_expr.downcast_ref::() { - match regr_expr.get_regr_type() { - RegrType::Slope => protobuf::AggregateFunction::RegrSlope, - RegrType::Intercept => protobuf::AggregateFunction::RegrIntercept, - RegrType::Count => protobuf::AggregateFunction::RegrCount, - RegrType::R2 => protobuf::AggregateFunction::RegrR2, - RegrType::AvgX => protobuf::AggregateFunction::RegrAvgx, - RegrType::AvgY => protobuf::AggregateFunction::RegrAvgy, - RegrType::SXX => protobuf::AggregateFunction::RegrSxx, - RegrType::SYY => protobuf::AggregateFunction::RegrSyy, - RegrType::SXY => protobuf::AggregateFunction::RegrSxy, - } } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::ApproxPercentileCont } else if aggr_expr diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 699697dd2f2ce..d0f1c4aade5e4 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -26,6 +26,7 @@ use arrow::datatypes::{ DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, }; +use datafusion_functions_aggregate::count::count_udaf; use prost::Message; use datafusion::datasource::provider::TableProviderFactory; @@ -35,8 +36,8 @@ use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::approx_median::approx_median; use datafusion::functions_aggregate::expr_fn::{ - covar_pop, covar_samp, first_value, median, stddev, stddev_pop, sum, var_pop, - var_sample, + count, count_distinct, covar_pop, covar_samp, first_value, median, stddev, + stddev_pop, sum, var_pop, var_sample, }; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; @@ -53,10 +54,10 @@ use datafusion_expr::expr::{ }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ - Accumulator, AggregateFunction, ColumnarValue, ExprSchemable, LogicalPlan, Operator, - PartitionEvaluator, ScalarUDF, ScalarUDFImpl, Signature, TryCast, Volatility, - WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, - WindowUDFImpl, + Accumulator, AggregateExt, AggregateFunction, ColumnarValue, ExprSchemable, + LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, ScalarUDFImpl, Signature, + TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, + WindowFunctionDefinition, WindowUDF, WindowUDFImpl, }; use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, @@ -649,6 +650,8 @@ async fn roundtrip_expr_api() -> Result<()> { lit(1), ), array_replace_all(make_array(vec![lit(1), lit(2), lit(3)]), lit(2), lit(4)), + count(lit(1)), + count_distinct(lit(1)), first_value(lit(1), None), first_value(lit(1), Some(vec![lit(2).sort(true, true)])), covar_samp(lit(1.5), lit(2.2)), @@ -1780,28 +1783,18 @@ fn roundtrip_similar_to() { #[test] fn roundtrip_count() { - let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("bananas")], - false, - None, - None, - None, - )); + let test_expr = count(col("bananas")); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); } #[test] fn roundtrip_count_distinct() { - let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("bananas")], - true, - None, - None, - None, - )); + let test_expr = count_udaf() + .call(vec![col("bananas")]) + .distinct() + .build() + .unwrap(); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); } diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 9cf686dbd3d68..e517482f1db02 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -38,7 +38,7 @@ use datafusion::datasource::physical_plan::{ }; use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility}; -use datafusion::physical_expr::expressions::{Count, Max, NthValueAgg}; +use datafusion::physical_expr::expressions::{Max, NthValueAgg}; use datafusion::physical_expr::window::SlidingAggregateWindowExpr; use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr}; use datafusion::physical_plan::aggregates::{ @@ -47,8 +47,8 @@ use datafusion::physical_plan::aggregates::{ use datafusion::physical_plan::analyze::AnalyzeExec; use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::expressions::{ - binary, cast, col, in_list, like, lit, Avg, BinaryExpr, Column, DistinctCount, - NotExpr, NthValue, PhysicalSortExpr, StringAgg, + binary, cast, col, in_list, like, lit, Avg, BinaryExpr, Column, NotExpr, NthValue, + PhysicalSortExpr, StringAgg, }; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::insert::DataSinkExec; @@ -806,7 +806,7 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { let aggregate = Arc::new(AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::new(vec![], vec![], vec![]), - vec![Arc::new(Count::new(udf_expr, "count", DataType::Int64))], + vec![Arc::new(Max::new(udf_expr, "max", DataType::Int64))], vec![None], window, schema.clone(), @@ -818,31 +818,6 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { Ok(()) } -#[test] -fn roundtrip_distinct_count() -> Result<()> { - let field_a = Field::new("a", DataType::Int64, false); - let field_b = Field::new("b", DataType::Int64, false); - let schema = Arc::new(Schema::new(vec![field_a, field_b])); - - let aggregates: Vec> = vec![Arc::new(DistinctCount::new( - DataType::Int64, - col("b", &schema)?, - "COUNT(DISTINCT b)".to_string(), - ))]; - - let groups: Vec<(Arc, String)> = - vec![(col("a", &schema)?, "unused".to_string())]; - - roundtrip_test(Arc::new(AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::new_single(groups), - aggregates.clone(), - vec![None], - Arc::new(EmptyExec::new(schema.clone())), - schema, - )?)) -} - #[test] fn roundtrip_like() -> Result<()> { let schema = Schema::new(vec![ diff --git a/datafusion/sql/examples/sql.rs b/datafusion/sql/examples/sql.rs index 893db018c8af6..aee4cf5a38ed3 100644 --- a/datafusion/sql/examples/sql.rs +++ b/datafusion/sql/examples/sql.rs @@ -18,11 +18,12 @@ use arrow_schema::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; use datafusion_common::{plan_err, Result}; -use datafusion_expr::test::function_stub::sum_udaf; use datafusion_expr::WindowUDF; use datafusion_expr::{ logical_plan::builder::LogicalTableSource, AggregateUDF, ScalarUDF, TableSource, }; +use datafusion_functions_aggregate::count::count_udaf; +use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_sql::{ planner::{ContextProvider, SqlToRel}, sqlparser::{dialect::GenericDialect, parser::Parser}, @@ -50,7 +51,9 @@ fn main() { let statement = &ast[0]; // create a logical query plan - let context_provider = MyContextProvider::new().with_udaf(sum_udaf()); + let context_provider = MyContextProvider::new() + .with_udaf(sum_udaf()) + .with_udaf(count_udaf()); let sql_to_rel = SqlToRel::new(&context_provider); let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); @@ -66,7 +69,8 @@ struct MyContextProvider { impl MyContextProvider { fn with_udaf(mut self, udaf: Arc) -> Self { - self.udafs.insert(udaf.name().to_string(), udaf); + // TODO: change to to_string() if all the function name is converted to lowercase + self.udafs.insert(udaf.name().to_lowercase(), udaf); self } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index dc25a6c33ecef..12c48054f1a73 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -960,13 +960,14 @@ mod tests { use arrow_schema::DataType::Int8; use datafusion_common::TableReference; + use datafusion_expr::AggregateExt; use datafusion_expr::{ - case, col, cube, exists, - expr::{AggregateFunction, AggregateFunctionDefinition}, - grouping_set, lit, not, not_exists, out_ref_col, placeholder, rollup, table_scan, - try_cast, when, wildcard, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, - Volatility, WindowFrame, WindowFunctionDefinition, + case, col, cube, exists, grouping_set, lit, not, not_exists, out_ref_col, + placeholder, rollup, table_scan, try_cast, when, wildcard, ColumnarValue, + ScalarUDF, ScalarUDFImpl, Signature, Volatility, WindowFrame, + WindowFunctionDefinition, }; + use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::expr_fn::sum; use crate::unparser::dialect::CustomDialect; @@ -1127,29 +1128,19 @@ mod tests { ), (sum(col("a")), r#"sum(a)"#), ( - Expr::AggregateFunction(AggregateFunction { - func_def: AggregateFunctionDefinition::BuiltIn( - datafusion_expr::AggregateFunction::Count, - ), - args: vec![Expr::Wildcard { qualifier: None }], - distinct: true, - filter: None, - order_by: None, - null_treatment: None, - }), + count_udaf() + .call(vec![Expr::Wildcard { qualifier: None }]) + .distinct() + .build() + .unwrap(), "COUNT(DISTINCT *)", ), ( - Expr::AggregateFunction(AggregateFunction { - func_def: AggregateFunctionDefinition::BuiltIn( - datafusion_expr::AggregateFunction::Count, - ), - args: vec![Expr::Wildcard { qualifier: None }], - distinct: false, - filter: Some(Box::new(lit(true))), - order_by: None, - null_treatment: None, - }), + count_udaf() + .call(vec![Expr::Wildcard { qualifier: None }]) + .filter(lit(true)) + .build() + .unwrap(), "COUNT(*) FILTER (WHERE true)", ), ( @@ -1167,9 +1158,7 @@ mod tests { ), ( Expr::WindowFunction(WindowFunction { - fun: WindowFunctionDefinition::AggregateFunction( - datafusion_expr::AggregateFunction::Count, - ), + fun: WindowFunctionDefinition::AggregateUDF(count_udaf()), args: vec![wildcard()], partition_by: vec![], order_by: vec![Expr::Sort(Sort::new( diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 51bacb5f702b0..bc27d25cf216e 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -350,7 +350,8 @@ mod tests { use arrow::datatypes::{DataType as ArrowDataType, Field, Schema}; use arrow_schema::Fields; use datafusion_common::{DFSchema, Result}; - use datafusion_expr::{col, count, lit, unnest, EmptyRelation, LogicalPlan}; + use datafusion_expr::{col, lit, unnest, EmptyRelation, LogicalPlan}; + use datafusion_functions_aggregate::expr_fn::count; use crate::utils::{recursive_transform_unnest, resolve_positions_to_exprs}; diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 72018371a5f1a..33e28e7056b9f 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -19,7 +19,7 @@ use std::vec; use arrow_schema::*; use datafusion_common::{DFSchema, Result, TableReference}; -use datafusion_expr::test::function_stub::sum_udaf; +use datafusion_expr::test::function_stub::{count_udaf, sum_udaf}; use datafusion_expr::{col, table_scan}; use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_sql::unparser::dialect::{ @@ -153,7 +153,9 @@ fn roundtrip_statement() -> Result<()> { .try_with_sql(query)? .parse_statement()?; - let context = MockContextProvider::default().with_udaf(sum_udaf()); + let context = MockContextProvider::default() + .with_udaf(sum_udaf()) + .with_udaf(count_udaf()); let sql_to_rel = SqlToRel::new(&context); let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index d91c09ae12875..893678d6b3742 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -46,7 +46,8 @@ impl MockContextProvider { } pub(crate) fn with_udaf(mut self, udaf: Arc) -> Self { - self.udafs.insert(udaf.name().to_string(), udaf); + // TODO: change to to_string() if all the function name is converted to lowercase + self.udafs.insert(udaf.name().to_lowercase(), udaf); self } } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 7b9d39a2b51e8..8eb2a2b609e73 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -37,7 +37,9 @@ use datafusion_sql::{ planner::{ParserOptions, SqlToRel}, }; -use datafusion_functions_aggregate::approx_median::approx_median_udaf; +use datafusion_functions_aggregate::{ + approx_median::approx_median_udaf, count::count_udaf, +}; use rstest::rstest; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; @@ -2702,7 +2704,8 @@ fn logical_plan_with_dialect_and_options( )) .with_udf(make_udf("sqrt", vec![DataType::Int64], DataType::Int64)) .with_udaf(sum_udaf()) - .with_udaf(approx_median_udaf()); + .with_udaf(approx_median_udaf()) + .with_udaf(count_udaf()); let planner = SqlToRel::new_with_options(&context, options); let result = DFParser::parse_sql_with_dialect(sql, dialect); diff --git a/datafusion/sqllogictest/test_files/errors.slt b/datafusion/sqllogictest/test_files/errors.slt index e930af107f772..d51c69496d46e 100644 --- a/datafusion/sqllogictest/test_files/errors.slt +++ b/datafusion/sqllogictest/test_files/errors.slt @@ -46,7 +46,7 @@ statement error DataFusion error: Arrow error: Cast error: Cannot cast string 'c SELECT CAST(c1 AS INT) FROM aggregate_test_100 # aggregation_with_bad_arguments -statement error DataFusion error: SQL error: ParserError\("Expected an expression:, found: \)"\) +query error SELECT COUNT(DISTINCT) FROM aggregate_test_100 # query_cte_incorrect @@ -104,7 +104,7 @@ SELECT power(1, 2, 3); # # AggregateFunction with wrong number of arguments -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'COUNT\(\)'\. You might need to add explicit type casts\.\n\tCandidate functions:\n\tCOUNT\(Any, \.\., Any\) +query error select count(); # AggregateFunction with wrong number of arguments @@ -112,11 +112,11 @@ statement error DataFusion error: Error during planning: No function matches the select avg(c1, c12) from aggregate_test_100; # AggregateFunction with wrong argument type -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'REGR_SLOPE\(Int64, Utf8\)'\. You might need to add explicit type casts\.\n\tCandidate functions:\n\tREGR_SLOPE\(Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64, Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64\) +statement error Coercion select regr_slope(1, '2'); # WindowFunction using AggregateFunction wrong signature -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'REGR_SLOPE\(Float32, Utf8\)'\. You might need to add explicit type casts\.\n\tCandidate functions:\n\tREGR_SLOPE\(Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64, Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64\) +statement error Coercion select c9, regr_slope(c11, '2') over () as min1 diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index f04d768221249..df6295d63b817 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -487,7 +487,7 @@ statement error Did you mean 'to_timestamp_seconds'? SELECT to_TIMESTAMPS_second(v2) from test; # Aggregate function -statement error Did you mean 'COUNT'? +query error DataFusion error: Error during planning: Invalid function 'counter' SELECT counter(*) from test; # Aggregate function diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index ee96ffa670441..d934dba4cfea3 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -26,7 +26,7 @@ repository = { workspace = true } license = { workspace = true } authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.73" +rust-version = "1.75" [lints] workspace = true diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 648a281832e10..93f197885c0ab 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -982,18 +982,16 @@ pub async fn from_substrait_agg_func( let function_name = substrait_fun_name((**function_name).as_str()); // try udaf first, then built-in aggr fn. if let Ok(fun) = ctx.udaf(function_name) { + // deal with situation that count(*) got no arguments + if fun.name() == "COUNT" && args.is_empty() { + args.push(Expr::Literal(ScalarValue::Int64(Some(1)))); + } + Ok(Arc::new(Expr::AggregateFunction( expr::AggregateFunction::new_udf(fun, args, distinct, filter, order_by, None), ))) } else if let Ok(fun) = aggregate_function::AggregateFunction::from_str(function_name) { - match &fun { - // deal with situation that count(*) got no arguments - aggregate_function::AggregateFunction::Count if args.is_empty() => { - args.push(Expr::Literal(ScalarValue::Int64(Some(1)))); - } - _ => {} - } Ok(Arc::new(Expr::AggregateFunction( expr::AggregateFunction::new(fun, args, distinct, filter, order_by, None), ))) @@ -1395,7 +1393,9 @@ fn from_substrait_type( })?; let field = Arc::new(Field::new_list_field( from_substrait_type(inner_type, dfs_names, name_idx)?, - is_substrait_type_nullable(inner_type)?, + // We ignore Substrait's nullability here to match to_substrait_literal + // which always creates nullable lists + true, )); match list.type_variation_reference { DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::List(field)), diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 88dc894eccd28..c0469d3331647 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -2309,14 +2309,12 @@ mod test { round_trip_type(DataType::Decimal128(10, 2))?; round_trip_type(DataType::Decimal256(30, 2))?; - for nullable in [true, false] { - round_trip_type(DataType::List( - Field::new_list_field(DataType::Int32, nullable).into(), - ))?; - round_trip_type(DataType::LargeList( - Field::new_list_field(DataType::Int32, nullable).into(), - ))?; - } + round_trip_type(DataType::List( + Field::new_list_field(DataType::Int32, true).into(), + ))?; + round_trip_type(DataType::LargeList( + Field::new_list_field(DataType::Int32, true).into(), + ))?; round_trip_type(DataType::Struct( vec![ diff --git a/datafusion/substrait/tests/cases/logical_plans.rs b/datafusion/substrait/tests/cases/logical_plans.rs index 994a932c30e0f..94572e098b2ca 100644 --- a/datafusion/substrait/tests/cases/logical_plans.rs +++ b/datafusion/substrait/tests/cases/logical_plans.rs @@ -20,6 +20,7 @@ #[cfg(test)] mod tests { use datafusion::common::Result; + use datafusion::dataframe::DataFrame; use datafusion::prelude::{CsvReadOptions, SessionContext}; use datafusion_substrait::logical_plan::consumer::from_substrait_plan; use std::fs::File; @@ -38,11 +39,7 @@ mod tests { // File generated with substrait-java's Isthmus: // ./isthmus-cli/build/graal/isthmus "select not d from data" -c "create table data (d boolean)" - let path = "tests/testdata/select_not_bool.substrait.json"; - let proto = serde_json::from_reader::<_, Plan>(BufReader::new( - File::open(path).expect("file not found"), - )) - .expect("failed to parse json"); + let proto = read_json("tests/testdata/select_not_bool.substrait.json"); let plan = from_substrait_plan(&ctx, &proto).await?; @@ -54,6 +51,31 @@ mod tests { Ok(()) } + #[tokio::test] + async fn non_nullable_lists() -> Result<()> { + // DataFusion's Substrait consumer treats all lists as nullable, even if the Substrait plan specifies them as non-nullable. + // That's because implementing the non-nullability consistently is non-trivial. + // This test confirms that reading a plan with non-nullable lists works as expected. + let ctx = create_context().await?; + let proto = read_json("tests/testdata/non_nullable_lists.substrait.json"); + + let plan = from_substrait_plan(&ctx, &proto).await?; + + assert_eq!(format!("{:?}", &plan), "Values: (List([1, 2]))"); + + // Need to trigger execution to ensure that Arrow has validated the plan + DataFrame::new(ctx.state(), plan).show().await?; + + Ok(()) + } + + fn read_json(path: &str) -> Plan { + serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json") + } + async fn create_context() -> datafusion::common::Result { let ctx = SessionContext::new(); ctx.register_csv("DATA", "tests/testdata/data.csv", CsvReadOptions::new()) diff --git a/datafusion/substrait/tests/testdata/non_nullable_lists.substrait.json b/datafusion/substrait/tests/testdata/non_nullable_lists.substrait.json new file mode 100644 index 0000000000000..e1c5574f8bec2 --- /dev/null +++ b/datafusion/substrait/tests/testdata/non_nullable_lists.substrait.json @@ -0,0 +1,71 @@ +{ + "extensionUris": [], + "extensions": [], + "relations": [ + { + "root": { + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "col" + ], + "struct": { + "types": [ + { + "list": { + "type": { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "virtualTable": { + "values": [ + { + "fields": [ + { + "list": { + "values": [ + { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + }, + { + "i32": 2, + "nullable": false, + "typeVariationReference": 0 + } + ] + }, + "nullable": false, + "typeVariationReference": 0 + } + ] + } + ] + } + } + }, + "names": [ + "col" + ] + } + } + ], + "expectedTypeUrls": [] +}