From 98ba11f3cda4fc1ac5cf9a18da94296fd8d4d061 Mon Sep 17 00:00:00 2001 From: comphead Date: Sat, 6 Apr 2024 16:19:58 -0700 Subject: [PATCH 01/14] MSRV 1.73 (#9977) --- Cargo.toml | 2 +- datafusion-cli/Cargo.toml | 2 +- datafusion-cli/Dockerfile | 2 +- datafusion/core/Cargo.toml | 2 +- datafusion/proto/Cargo.toml | 2 +- datafusion/proto/gen/Cargo.toml | 2 +- datafusion/substrait/Cargo.toml | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ca34ea9c2a240..c04f13b6c18b3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,7 +49,7 @@ homepage = "https://github.com/apache/arrow-datafusion" license = "Apache-2.0" readme = "README.md" repository = "https://github.com/apache/arrow-datafusion" -rust-version = "1.72" +rust-version = "1.73" version = "37.0.0" [workspace.dependencies] diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index c9241fcf10b4b..98588edcd18ed 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -26,7 +26,7 @@ license = "Apache-2.0" homepage = "https://github.com/apache/arrow-datafusion" repository = "https://github.com/apache/arrow-datafusion" # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.72" +rust-version = "1.73" readme = "README.md" [dependencies] diff --git a/datafusion-cli/Dockerfile b/datafusion-cli/Dockerfile index 5ddedad2a6f4c..9dbab5b1ed750 100644 --- a/datafusion-cli/Dockerfile +++ b/datafusion-cli/Dockerfile @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -FROM rust:1.72-bullseye as builder +FROM rust:1.73-bullseye as builder COPY . /usr/src/arrow-datafusion COPY ./datafusion /usr/src/arrow-datafusion/datafusion diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 610784f91dec0..0236e8587d660 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -30,7 +30,7 @@ authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version and fails with # "Unable to find key 'package.rust-version' (or 'package.metadata.msrv') in 'arrow-datafusion/Cargo.toml'" # https://github.com/foresterre/cargo-msrv/issues/590 -rust-version = "1.72" +rust-version = "1.73" [lib] name = "datafusion" diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index bec2b8c53a7a7..325cd8704ccff 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -27,7 +27,7 @@ repository = { workspace = true } license = { workspace = true } authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.72" +rust-version = "1.73" # Exclude proto files so crates.io consumers don't need protoc exclude = ["*.proto"] diff --git a/datafusion/proto/gen/Cargo.toml b/datafusion/proto/gen/Cargo.toml index 0feece4218d99..e843827a91ac2 100644 --- a/datafusion/proto/gen/Cargo.toml +++ b/datafusion/proto/gen/Cargo.toml @@ -20,7 +20,7 @@ name = "gen" description = "Code generation for proto" version = "0.1.0" edition = { workspace = true } -rust-version = "1.72" +rust-version = "1.73" authors = { workspace = true } homepage = { workspace = true } repository = { workspace = true } diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index f9523446980ea..d4800eca90f82 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -26,7 +26,7 @@ repository = { workspace = true } license = { workspace = true } authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.72" +rust-version = "1.73" [dependencies] async-recursion = "1.0" From 2f1c3ab679fd553ea9e18067ed894fc6cf1f567a Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 7 Apr 2024 10:25:09 +0800 Subject: [PATCH 02/14] Move First Value UDAF and builtin first / last function to `aggregate-functions` (#9960) * backup Signed-off-by: jayzhan211 * move PhysicalExpr Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * move physical sort Signed-off-by: jayzhan211 * cleanup dependencies Signed-off-by: jayzhan211 * add readme Signed-off-by: jayzhan211 * disable doc test Signed-off-by: jayzhan211 * move column Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * move aggregatexp Signed-off-by: jayzhan211 * move other two utils Signed-off-by: jayzhan211 * license Signed-off-by: jayzhan211 * switch to ignore Signed-off-by: jayzhan211 * move reverse order Signed-off-by: jayzhan211 * rename to common Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * backup Signed-off-by: jayzhan211 * move acc to first value Signed-off-by: jayzhan211 * move builtin expr too Signed-off-by: jayzhan211 * use macro Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * fix doc Signed-off-by: jayzhan211 * add todo Signed-off-by: jayzhan211 * rm comments Signed-off-by: jayzhan211 * rm unused Signed-off-by: jayzhan211 * rm unused code Signed-off-by: jayzhan211 * change to private Signed-off-by: jayzhan211 * fix lock Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * support roundtrip Signed-off-by: jayzhan211 * remmove old format state Signed-off-by: jayzhan211 * move aggregate related things to aggr crate Signed-off-by: jayzhan211 * move back to common Signed-off-by: jayzhan211 * taplo Signed-off-by: jayzhan211 * rm comment Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * lock Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- Cargo.toml | 2 + datafusion-cli/Cargo.lock | 17 + datafusion/core/Cargo.toml | 1 + datafusion/core/src/execution/context/mod.rs | 29 +- datafusion/core/src/lib.rs | 5 + datafusion/expr/src/expr_fn.rs | 84 --- datafusion/expr/src/udaf.rs | 7 +- datafusion/expr/src/utils.rs | 5 + datafusion/functions-aggregate/Cargo.toml | 44 ++ .../src}/first_last.rs | 511 ++++++++++-------- datafusion/functions-aggregate/src/lib.rs | 84 +++ datafusion/functions-aggregate/src/macros.rs | 53 ++ .../physical-expr-common/src/aggregate/mod.rs | 198 ++++++- datafusion/physical-expr/Cargo.toml | 1 + datafusion/physical-expr/src/aggregate/mod.rs | 1 - .../physical-expr/src/aggregate/utils.rs | 6 +- .../physical-expr/src/expressions/mod.rs | 10 +- datafusion/physical-expr/src/lib.rs | 2 - datafusion/physical-plan/Cargo.toml | 2 + .../physical-plan/src/aggregates/mod.rs | 2 +- datafusion/physical-plan/src/lib.rs | 6 +- datafusion/physical-plan/src/udaf.rs | 218 -------- .../tests/cases/roundtrip_logical_plan.rs | 2 + 23 files changed, 720 insertions(+), 570 deletions(-) create mode 100644 datafusion/functions-aggregate/Cargo.toml rename datafusion/{physical-expr/src/aggregate => functions-aggregate/src}/first_last.rs (85%) create mode 100644 datafusion/functions-aggregate/src/lib.rs create mode 100644 datafusion/functions-aggregate/src/macros.rs delete mode 100644 datafusion/physical-plan/src/udaf.rs diff --git a/Cargo.toml b/Cargo.toml index c04f13b6c18b3..e1e09d9893b1e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ members = [ "datafusion/core", "datafusion/expr", "datafusion/execution", + "datafusion/functions-aggregate", "datafusion/functions", "datafusion/functions-array", "datafusion/optimizer", @@ -78,6 +79,7 @@ datafusion-common-runtime = { path = "datafusion/common-runtime", version = "37. datafusion-execution = { path = "datafusion/execution", version = "37.0.0" } datafusion-expr = { path = "datafusion/expr", version = "37.0.0" } datafusion-functions = { path = "datafusion/functions", version = "37.0.0" } +datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "37.0.0" } datafusion-functions-array = { path = "datafusion/functions-array", version = "37.0.0" } datafusion-optimizer = { path = "datafusion/optimizer", version = "37.0.0", default-features = false } datafusion-physical-expr = { path = "datafusion/physical-expr", version = "37.0.0", default-features = false } diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 8a8aed249b7a3..447b69e414cf6 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1135,6 +1135,7 @@ dependencies = [ "datafusion-execution", "datafusion-expr", "datafusion-functions", + "datafusion-functions-aggregate", "datafusion-functions-array", "datafusion-optimizer", "datafusion-physical-expr", @@ -1278,6 +1279,19 @@ dependencies = [ "uuid", ] +[[package]] +name = "datafusion-functions-aggregate" +version = "37.0.0" +dependencies = [ + "arrow", + "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "datafusion-physical-expr-common", + "log", + "paste", +] + [[package]] name = "datafusion-functions-array" version = "37.0.0" @@ -1330,6 +1344,7 @@ dependencies = [ "datafusion-common", "datafusion-execution", "datafusion-expr", + "datafusion-functions-aggregate", "datafusion-physical-expr-common", "half", "hashbrown 0.14.3", @@ -1369,7 +1384,9 @@ dependencies = [ "datafusion-common-runtime", "datafusion-execution", "datafusion-expr", + "datafusion-functions-aggregate", "datafusion-physical-expr", + "datafusion-physical-expr-common", "futures", "half", "hashbrown 0.14.3", diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 0236e8587d660..018b5083e0311 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -98,6 +98,7 @@ datafusion-common-runtime = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-functions = { workspace = true } +datafusion-functions-aggregate = { workspace = true } datafusion-functions-array = { workspace = true, optional = true } datafusion-optimizer = { workspace = true } datafusion-physical-expr = { workspace = true } diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 31a474bd217c3..f15c1c218db6d 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -44,6 +44,7 @@ use crate::{ datasource::{provider_as_source, MemTable, TableProvider, ViewTable}, error::{DataFusionError, Result}, execution::{options::ArrowReadOptions, runtime_env::RuntimeEnv, FunctionRegistry}, + logical_expr::AggregateUDF, logical_expr::{ CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateFunction, CreateMemoryTable, CreateView, DropCatalogSchema, DropFunction, DropTable, @@ -53,10 +54,11 @@ use crate::{ optimizer::analyzer::{Analyzer, AnalyzerRule}, optimizer::optimizer::{Optimizer, OptimizerConfig, OptimizerRule}, physical_optimizer::optimizer::{PhysicalOptimizer, PhysicalOptimizerRule}, - physical_plan::{udaf::AggregateUDF, udf::ScalarUDF, ExecutionPlan}, + physical_plan::{udf::ScalarUDF, ExecutionPlan}, physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}, variable::{VarProvider, VarType}, }; +use crate::{functions, functions_aggregate, functions_array}; use arrow::datatypes::{DataType, SchemaRef}; use arrow::record_batch::RecordBatch; @@ -69,14 +71,11 @@ use datafusion_common::{ SchemaReference, TableReference, }; use datafusion_execution::registry::SerializerRegistry; -use datafusion_expr::type_coercion::aggregates::NUMERICS; -use datafusion_expr::{create_first_value, Signature, Volatility}; use datafusion_expr::{ logical_plan::{DdlStatement, Statement}, var_provider::is_system_variables, Expr, StringifiedPlan, UserDefinedLogicalNode, WindowUDF, }; -use datafusion_physical_expr::create_first_value_accumulator; use datafusion_sql::{ parser::{CopyToSource, CopyToStatement, DFParser}, planner::{object_name_to_table_reference, ContextProvider, ParserOptions, SqlToRel}, @@ -85,7 +84,6 @@ use datafusion_sql::{ use async_trait::async_trait; use chrono::{DateTime, Utc}; -use log::debug; use parking_lot::RwLock; use sqlparser::dialect::dialect_from_str; use url::Url; @@ -1452,29 +1450,16 @@ impl SessionState { }; // register built in functions - datafusion_functions::register_all(&mut new_self) + functions::register_all(&mut new_self) .expect("can not register built in functions"); // register crate of array expressions (if enabled) #[cfg(feature = "array_expressions")] - datafusion_functions_array::register_all(&mut new_self) + functions_array::register_all(&mut new_self) .expect("can not register array expressions"); - let first_value = create_first_value( - "FIRST_VALUE", - Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable), - Arc::new(create_first_value_accumulator), - ); - - match new_self.register_udaf(Arc::new(first_value)) { - Ok(Some(existing_udaf)) => { - debug!("Overwrite existing UDAF: {}", existing_udaf.name()); - } - Ok(None) => {} - Err(err) => { - panic!("Failed to register UDAF: {}", err); - } - } + functions_aggregate::register_all(&mut new_self) + .expect("can not register aggregate functions"); new_self } diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index feeace3b5cfd0..c213f4554fb8b 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -541,6 +541,11 @@ pub mod functions_array { pub use datafusion_functions_array::*; } +/// re-export of [`datafusion_functions_aggregate`] crate +pub mod functions_aggregate { + pub use datafusion_functions_aggregate::*; +} + #[cfg(test)] pub mod test; pub mod test_util; diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index a1235a093d760..f68685a87f13c 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -24,7 +24,6 @@ use crate::expr::{ use crate::function::{ AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory, }; -use crate::udaf::format_state_name; use crate::{ aggregate_function, built_in_function, conditional_expressions::CaseBuilder, logical_plan::Subquery, AggregateUDF, BuiltinScalarFunction, Expr, LogicalPlan, @@ -708,17 +707,6 @@ pub fn create_udaf( )) } -/// Creates a new UDAF with a specific signature, state type and return type. -/// The signature and state type must match the `Accumulator's implementation`. -/// TOOD: We plan to move aggregate function to its own crate. This function will be deprecated then. -pub fn create_first_value( - name: &str, - signature: Signature, - accumulator: AccumulatorFactoryFunction, -) -> AggregateUDF { - AggregateUDF::from(FirstValue::new(name, signature, accumulator)) -} - /// Implements [`AggregateUDFImpl`] for functions that have a single signature and /// return type. pub struct SimpleAggregateUDF { @@ -813,78 +801,6 @@ impl AggregateUDFImpl for SimpleAggregateUDF { } } -pub struct FirstValue { - name: String, - signature: Signature, - accumulator: AccumulatorFactoryFunction, -} - -impl Debug for FirstValue { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("FirstValue") - .field("name", &self.name) - .field("signature", &self.signature) - .field("accumulator", &"") - .finish() - } -} - -impl FirstValue { - pub fn new( - name: impl Into, - signature: Signature, - accumulator: AccumulatorFactoryFunction, - ) -> Self { - let name = name.into(); - Self { - name, - signature, - accumulator, - } - } -} - -impl AggregateUDFImpl for FirstValue { - fn as_any(&self) -> &dyn Any { - self - } - - fn name(&self) -> &str { - &self.name - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].clone()) - } - - fn accumulator( - &self, - acc_args: AccumulatorArgs, - ) -> Result> { - (self.accumulator)(acc_args) - } - - fn state_fields( - &self, - name: &str, - value_type: DataType, - ordering_fields: Vec, - ) -> Result> { - let mut fields = vec![Field::new( - format_state_name(name, "first_value"), - value_type, - true, - )]; - fields.extend(ordering_fields); - fields.push(Field::new("is_set", DataType::Boolean, true)); - Ok(fields) - } -} - /// Creates a new UDWF with a specific signature, state type and return type. /// /// The signature and state type must match the [`PartitionEvaluator`]'s implementation`. diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 3cf1845aacd69..856f0dc44246e 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -19,6 +19,7 @@ use crate::function::AccumulatorArgs; use crate::groups_accumulator::GroupsAccumulator; +use crate::utils::format_state_name; use crate::{Accumulator, Expr}; use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature}; use arrow::datatypes::{DataType, Field}; @@ -447,9 +448,3 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { (self.accumulator)(acc_args) } } - -/// returns the name of the state -/// TODO: Remove duplicated function in physical-expr -pub(crate) fn format_state_name(name: &str, state_name: &str) -> String { - format!("{name}[{state_name}]") -} diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 72d01da204482..a93282574e8a2 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1240,6 +1240,11 @@ pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema { } } +/// Build state name. State is the intermidiate state of the aggregate function. +pub fn format_state_name(name: &str, state_name: &str) -> String { + format!("{name}[{state_name}]") +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml new file mode 100644 index 0000000000000..d42932d8abddb --- /dev/null +++ b/datafusion/functions-aggregate/Cargo.toml @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-functions-aggregate" +description = "Aggregate function packages for the DataFusion query engine" +keywords = ["datafusion", "logical", "plan", "expressions"] +readme = "README.md" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +authors = { workspace = true } +rust-version = { workspace = true } + +[lib] +name = "datafusion_functions_aggregate" +path = "src/lib.rs" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +arrow = { workspace = true } +datafusion-common = { workspace = true } +datafusion-execution = { workspace = true } +datafusion-expr = { workspace = true } +datafusion-physical-expr-common = { workspace = true } +log = { workspace = true } +paste = "1.0.14" diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs similarity index 85% rename from datafusion/physical-expr/src/aggregate/first_last.rs rename to datafusion/functions-aggregate/src/first_last.rs index 26bd219f65f00..d5367ad34163e 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -17,209 +17,149 @@ //! Defines the FIRST_VALUE/LAST_VALUE aggregations. -use std::any::Any; -use std::sync::Arc; - -use crate::aggregate::utils::{down_cast_any_ref, get_sort_options, ordering_fields}; -use crate::expressions::{self, format_state_name}; -use crate::{ - reverse_order_bys, AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr, -}; - -use arrow::array::{Array, ArrayRef, AsArray, BooleanArray}; -use arrow::compute::{self, lexsort_to_indices, SortColumn}; +use arrow::array::{ArrayRef, AsArray, BooleanArray}; +use arrow::compute::{self, lexsort_to_indices, SortColumn, SortOptions}; use arrow::datatypes::{DataType, Field}; -use arrow_schema::SortOptions; use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, get_row_at_idx}; use datafusion_common::{ arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, }; use datafusion_expr::function::AccumulatorArgs; -use datafusion_expr::{Accumulator, Expr}; +use datafusion_expr::type_coercion::aggregates::NUMERICS; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{Accumulator, AggregateUDFImpl, Expr, Signature, Volatility}; +use datafusion_physical_expr_common::aggregate::utils::{ + down_cast_any_ref, get_sort_options, ordering_fields, +}; +use datafusion_physical_expr_common::aggregate::AggregateExpr; +use datafusion_physical_expr_common::expressions; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_expr_common::utils::reverse_order_bys; +use std::any::Any; +use std::fmt::Debug; +use std::sync::Arc; + +make_udaf_function!( + FirstValue, + first_value, + value, + "Returns the first value in a group of values.", + first_value_udaf +); -/// FIRST_VALUE aggregate expression -#[derive(Debug, Clone)] pub struct FirstValue { - name: String, - input_data_type: DataType, - order_by_data_types: Vec, - expr: Arc, - ordering_req: LexOrdering, - requirement_satisfied: bool, - ignore_nulls: bool, - state_fields: Vec, + signature: Signature, + aliases: Vec, +} + +impl Debug for FirstValue { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("FirstValue") + .field("name", &self.name()) + .field("signature", &self.signature) + .field("accumulator", &"") + .finish() + } +} + +impl Default for FirstValue { + fn default() -> Self { + Self::new() + } } impl FirstValue { - /// Creates a new FIRST_VALUE aggregation function. - pub fn new( - expr: Arc, - name: impl Into, - input_data_type: DataType, - ordering_req: LexOrdering, - order_by_data_types: Vec, - state_fields: Vec, - ) -> Self { - let requirement_satisfied = ordering_req.is_empty(); + pub fn new() -> Self { Self { - name: name.into(), - input_data_type, - order_by_data_types, - expr, - ordering_req, - requirement_satisfied, - ignore_nulls: false, - state_fields, + aliases: vec![String::from("FIRST_VALUE")], + signature: Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable), } } +} - pub fn with_ignore_nulls(mut self, ignore_nulls: bool) -> Self { - self.ignore_nulls = ignore_nulls; +impl AggregateUDFImpl for FirstValue { + fn as_any(&self) -> &dyn Any { self } - /// Returns the name of the aggregate expression. - pub fn name(&self) -> &str { - &self.name + fn name(&self) -> &str { + "FIRST_VALUE" } - /// Returns the input data type of the aggregate expression. - pub fn input_data_type(&self) -> &DataType { - &self.input_data_type + fn signature(&self) -> &Signature { + &self.signature } - /// Returns the data types of the order-by columns. - pub fn order_by_data_types(&self) -> &Vec { - &self.order_by_data_types + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) } - /// Returns the expression associated with the aggregate function. - pub fn expr(&self) -> &Arc { - &self.expr - } + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let mut all_sort_orders = vec![]; - /// Returns the lexical ordering requirements of the aggregate expression. - pub fn ordering_req(&self) -> &LexOrdering { - &self.ordering_req - } + // Construct PhysicalSortExpr objects from Expr objects: + let mut sort_exprs = vec![]; + for expr in acc_args.sort_exprs { + if let Expr::Sort(sort) = expr { + if let Expr::Column(col) = sort.expr.as_ref() { + let name = &col.name; + let e = expressions::column::col(name, acc_args.schema)?; + sort_exprs.push(PhysicalSortExpr { + expr: e, + options: SortOptions { + descending: !sort.asc, + nulls_first: sort.nulls_first, + }, + }); + } + } + } + if !sort_exprs.is_empty() { + all_sort_orders.extend(sort_exprs); + } - pub fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { - self.requirement_satisfied = requirement_satisfied; - self - } + let ordering_req = all_sort_orders; - pub fn convert_to_last(self) -> LastValue { - let name = if self.name.starts_with("FIRST") { - format!("LAST{}", &self.name[5..]) - } else { - format!("LAST_VALUE({})", self.expr) - }; - let FirstValue { - expr, - input_data_type, - ordering_req, - order_by_data_types, - .. - } = self; - LastValue::new( - expr, - name, - input_data_type, - reverse_order_bys(&ordering_req), - order_by_data_types, - ) - } -} - -impl AggregateExpr for FirstValue { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } + let ordering_dtypes = ordering_req + .iter() + .map(|e| e.expr.data_type(acc_args.schema)) + .collect::>>()?; - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.input_data_type.clone(), true)) - } + let requirement_satisfied = ordering_req.is_empty(); - fn create_accumulator(&self) -> Result> { FirstValueAccumulator::try_new( - &self.input_data_type, - &self.order_by_data_types, - self.ordering_req.clone(), - self.ignore_nulls, + acc_args.data_type, + &ordering_dtypes, + ordering_req, + acc_args.ignore_nulls, ) - .map(|acc| { - Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ - }) + .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _) } - fn state_fields(&self) -> Result> { - if !self.state_fields.is_empty() { - return Ok(self.state_fields.clone()); - } - + fn state_fields( + &self, + name: &str, + value_type: DataType, + ordering_fields: Vec, + ) -> Result> { let mut fields = vec![Field::new( - format_state_name(&self.name, "first_value"), - self.input_data_type.clone(), + format_state_name(name, "first_value"), + value_type, true, )]; - fields.extend(ordering_fields( - &self.ordering_req, - &self.order_by_data_types, - )); - fields.push(Field::new( - format_state_name(&self.name, "is_set"), - DataType::Boolean, - true, - )); + fields.extend(ordering_fields); + fields.push(Field::new("is_set", DataType::Boolean, true)); Ok(fields) } - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { - (!self.ordering_req.is_empty()).then_some(&self.ordering_req) - } - - fn name(&self) -> &str { - &self.name - } - - fn reverse_expr(&self) -> Option> { - Some(Arc::new(self.clone().convert_to_last())) - } - - fn create_sliding_accumulator(&self) -> Result> { - FirstValueAccumulator::try_new( - &self.input_data_type, - &self.order_by_data_types, - self.ordering_req.clone(), - self.ignore_nulls, - ) - .map(|acc| { - Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ - }) - } -} - -impl PartialEq for FirstValue { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.input_data_type == x.input_data_type - && self.order_by_data_types == x.order_by_data_types - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) + fn aliases(&self) -> &[String] { + &self.aliases } } #[derive(Debug)] -struct FirstValueAccumulator { +pub struct FirstValueAccumulator { first: ScalarValue, // At the beginning, `is_set` is false, which means `first` is not seen yet. // Once we see the first value, we set the `is_set` flag and do not update `first` anymore. @@ -258,6 +198,11 @@ impl FirstValueAccumulator { }) } + pub fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { + self.requirement_satisfied = requirement_satisfied; + self + } + // Updates state with the values in the given row. fn update_with_new_row(&mut self, row: &[ScalarValue]) { self.first = row[0].clone(); @@ -307,11 +252,6 @@ impl FirstValueAccumulator { Ok((!indices.is_empty()).then_some(indices.value(0) as _)) } } - - fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { - self.requirement_satisfied = requirement_satisfied; - self - } } impl Accumulator for FirstValueAccumulator { @@ -393,53 +333,190 @@ impl Accumulator for FirstValueAccumulator { } } -pub fn create_first_value_accumulator( - acc_args: AccumulatorArgs, -) -> Result> { - let mut all_sort_orders = vec![]; - - // Construct PhysicalSortExpr objects from Expr objects: - let mut sort_exprs = vec![]; - for expr in acc_args.sort_exprs { - if let Expr::Sort(sort) = expr { - if let Expr::Column(col) = sort.expr.as_ref() { - let name = &col.name; - let e = expressions::col(name, acc_args.schema)?; - sort_exprs.push(PhysicalSortExpr { - expr: e, - options: SortOptions { - descending: !sort.asc, - nulls_first: sort.nulls_first, - }, - }); - } +/// TO BE DEPRECATED: Builtin FIRST_VALUE physical aggregate expression will be replaced by udf in the future +#[derive(Debug, Clone)] +pub struct FirstValuePhysicalExpr { + name: String, + input_data_type: DataType, + order_by_data_types: Vec, + expr: Arc, + ordering_req: LexOrdering, + requirement_satisfied: bool, + ignore_nulls: bool, + state_fields: Vec, +} + +impl FirstValuePhysicalExpr { + /// Creates a new FIRST_VALUE aggregation function. + pub fn new( + expr: Arc, + name: impl Into, + input_data_type: DataType, + ordering_req: LexOrdering, + order_by_data_types: Vec, + state_fields: Vec, + ) -> Self { + let requirement_satisfied = ordering_req.is_empty(); + Self { + name: name.into(), + input_data_type, + order_by_data_types, + expr, + ordering_req, + requirement_satisfied, + ignore_nulls: false, + state_fields, } } - if !sort_exprs.is_empty() { - all_sort_orders.extend(sort_exprs); + + pub fn with_ignore_nulls(mut self, ignore_nulls: bool) -> Self { + self.ignore_nulls = ignore_nulls; + self } - let ordering_req = all_sort_orders; + /// Returns the name of the aggregate expression. + pub fn name(&self) -> &str { + &self.name + } - let ordering_dtypes = ordering_req - .iter() - .map(|e| e.expr.data_type(acc_args.schema)) - .collect::>>()?; - - let requirement_satisfied = ordering_req.is_empty(); - - FirstValueAccumulator::try_new( - acc_args.data_type, - &ordering_dtypes, - ordering_req, - acc_args.ignore_nulls, - ) - .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _) + /// Returns the input data type of the aggregate expression. + pub fn input_data_type(&self) -> &DataType { + &self.input_data_type + } + + /// Returns the data types of the order-by columns. + pub fn order_by_data_types(&self) -> &Vec { + &self.order_by_data_types + } + + /// Returns the expression associated with the aggregate function. + pub fn expr(&self) -> &Arc { + &self.expr + } + + /// Returns the lexical ordering requirements of the aggregate expression. + pub fn ordering_req(&self) -> &LexOrdering { + &self.ordering_req + } + + pub fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { + self.requirement_satisfied = requirement_satisfied; + self + } + + pub fn convert_to_last(self) -> LastValuePhysicalExpr { + let name = if self.name.starts_with("FIRST") { + format!("LAST{}", &self.name[5..]) + } else { + format!("LAST_VALUE({})", self.expr) + }; + let FirstValuePhysicalExpr { + expr, + input_data_type, + ordering_req, + order_by_data_types, + .. + } = self; + LastValuePhysicalExpr::new( + expr, + name, + input_data_type, + reverse_order_bys(&ordering_req), + order_by_data_types, + ) + } } -/// LAST_VALUE aggregate expression +impl AggregateExpr for FirstValuePhysicalExpr { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, self.input_data_type.clone(), true)) + } + + fn create_accumulator(&self) -> Result> { + FirstValueAccumulator::try_new( + &self.input_data_type, + &self.order_by_data_types, + self.ordering_req.clone(), + self.ignore_nulls, + ) + .map(|acc| { + Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ + }) + } + + fn state_fields(&self) -> Result> { + if !self.state_fields.is_empty() { + return Ok(self.state_fields.clone()); + } + + let mut fields = vec![Field::new( + format_state_name(&self.name, "first_value"), + self.input_data_type.clone(), + true, + )]; + fields.extend(ordering_fields( + &self.ordering_req, + &self.order_by_data_types, + )); + fields.push(Field::new( + format_state_name(&self.name, "is_set"), + DataType::Boolean, + true, + )); + Ok(fields) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { + (!self.ordering_req.is_empty()).then_some(&self.ordering_req) + } + + fn name(&self) -> &str { + &self.name + } + + fn reverse_expr(&self) -> Option> { + Some(Arc::new(self.clone().convert_to_last())) + } + + fn create_sliding_accumulator(&self) -> Result> { + FirstValueAccumulator::try_new( + &self.input_data_type, + &self.order_by_data_types, + self.ordering_req.clone(), + self.ignore_nulls, + ) + .map(|acc| { + Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ + }) + } +} + +impl PartialEq for FirstValuePhysicalExpr { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.name == x.name + && self.input_data_type == x.input_data_type + && self.order_by_data_types == x.order_by_data_types + && self.expr.eq(&x.expr) + }) + .unwrap_or(false) + } +} + +/// TO BE DEPRECATED: Builtin LAST_VALUE physical aggregate expression will be replaced by udf in the future #[derive(Debug, Clone)] -pub struct LastValue { +pub struct LastValuePhysicalExpr { name: String, input_data_type: DataType, order_by_data_types: Vec, @@ -449,7 +526,7 @@ pub struct LastValue { ignore_nulls: bool, } -impl LastValue { +impl LastValuePhysicalExpr { /// Creates a new LAST_VALUE aggregation function. pub fn new( expr: Arc, @@ -505,20 +582,20 @@ impl LastValue { self } - pub fn convert_to_first(self) -> FirstValue { + pub fn convert_to_first(self) -> FirstValuePhysicalExpr { let name = if self.name.starts_with("LAST") { format!("FIRST{}", &self.name[4..]) } else { format!("FIRST_VALUE({})", self.expr) }; - let LastValue { + let LastValuePhysicalExpr { expr, input_data_type, ordering_req, order_by_data_types, .. } = self; - FirstValue::new( + FirstValuePhysicalExpr::new( expr, name, input_data_type, @@ -529,7 +606,7 @@ impl LastValue { } } -impl AggregateExpr for LastValue { +impl AggregateExpr for LastValuePhysicalExpr { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self @@ -598,7 +675,7 @@ impl AggregateExpr for LastValue { } } -impl PartialEq for LastValue { +impl PartialEq for LastValuePhysicalExpr { fn eq(&self, other: &dyn Any) -> bool { down_cast_any_ref(other) .downcast_ref::() @@ -820,15 +897,9 @@ fn convert_to_sort_cols( #[cfg(test)] mod tests { - use std::sync::Arc; - - use crate::aggregate::first_last::{FirstValueAccumulator, LastValueAccumulator}; + use arrow::array::Int64Array; - use arrow::compute::concat; - use arrow_array::{ArrayRef, Int64Array}; - use arrow_schema::DataType; - use datafusion_common::{Result, ScalarValue}; - use datafusion_expr::Accumulator; + use super::*; #[test] fn test_first_last_value_value() -> Result<()> { @@ -888,7 +959,7 @@ mod tests { let mut states = vec![]; for idx in 0..state1.len() { - states.push(concat(&[ + states.push(arrow::compute::concat(&[ &state1[idx].to_array()?, &state2[idx].to_array()?, ])?); @@ -918,7 +989,7 @@ mod tests { let mut states = vec![]; for idx in 0..state1.len() { - states.push(concat(&[ + states.push(arrow::compute::concat(&[ &state1[idx].to_array()?, &state2[idx].to_array()?, ])?); diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs new file mode 100644 index 0000000000000..8016b76889f71 --- /dev/null +++ b/datafusion/functions-aggregate/src/lib.rs @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Aggregate Function packages for [DataFusion]. +//! +//! This crate contains a collection of various aggregate function packages for DataFusion, +//! implemented using the extension API. Users may wish to control which functions +//! are available to control the binary size of their application as well as +//! use dialect specific implementations of functions (e.g. Spark vs Postgres) +//! +//! Each package is implemented as a separate +//! module, activated by a feature flag. +//! +//! [DataFusion]: https://crates.io/crates/datafusion +//! +//! # Available Packages +//! See the list of [modules](#modules) in this crate for available packages. +//! +//! # Using A Package +//! You can register all functions in all packages using the [`register_all`] function. +//! +//! Each package also exports an `expr_fn` submodule to help create [`Expr`]s that invoke +//! functions using a fluent style. For example: +//! +//![`Expr`]: datafusion_expr::Expr +//! +//! # Implementing A New Package +//! +//! To add a new package to this crate, you should follow the model of existing +//! packages. The high level steps are: +//! +//! 1. Create a new module with the appropriate [AggregateUDF] implementations. +//! +//! 2. Use the macros in [`macros`] to create standard entry points. +//! +//! 3. Add a new feature to `Cargo.toml`, with any optional dependencies +//! +//! 4. Use the `make_package!` macro to expose the module when the +//! feature is enabled. + +#[macro_use] +pub mod macros; + +pub mod first_last; + +use datafusion_common::Result; +use datafusion_execution::FunctionRegistry; +use datafusion_expr::AggregateUDF; +use log::debug; +use std::sync::Arc; + +/// Fluent-style API for creating `Expr`s +pub mod expr_fn { + pub use super::first_last::first_value; +} + +/// Registers all enabled packages with a [`FunctionRegistry`] +pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { + let functions: Vec> = vec![first_last::first_value_udaf()]; + + functions.into_iter().try_for_each(|udf| { + let existing_udaf = registry.register_udaf(udf)?; + if let Some(existing_udaf) = existing_udaf { + debug!("Overwrite existing UDAF: {}", existing_udaf.name()); + } + Ok(()) as Result<()> + })?; + + Ok(()) +} diff --git a/datafusion/functions-aggregate/src/macros.rs b/datafusion/functions-aggregate/src/macros.rs new file mode 100644 index 0000000000000..d24c60f932701 --- /dev/null +++ b/datafusion/functions-aggregate/src/macros.rs @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +macro_rules! make_udaf_function { + ($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { + paste::paste! { + // "fluent expr_fn" style function + #[doc = $DOC] + pub fn $EXPR_FN($($arg: Expr),*) -> Expr { + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + $AGGREGATE_UDF_FN(), + vec![$($arg),*], + // TODO: Support arguments for `expr` API + false, + None, + None, + None, + )) + } + + /// Singleton instance of [$UDAF], ensures the UDAF is only created once + /// named STATIC_$(UDAF). For example `STATIC_FirstValue` + #[allow(non_upper_case_globals)] + static [< STATIC_ $UDAF >]: std::sync::OnceLock> = + std::sync::OnceLock::new(); + + /// AggregateFunction that returns a [AggregateUDF] for [$UDAF] + /// + /// [AggregateUDF]: datafusion_expr::AggregateUDF + pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc { + [< STATIC_ $UDAF >] + .get_or_init(|| { + std::sync::Arc::new(datafusion_expr::AggregateUDF::from(<$UDAF>::default())) + }) + .clone() + } + } + } +} diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 579f51815d849..33044fd9beee8 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -17,16 +17,54 @@ pub mod utils; -use std::any::Any; +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion_common::{not_impl_err, Result}; +use datafusion_expr::{ + function::AccumulatorArgs, Accumulator, AggregateUDF, Expr, GroupsAccumulator, +}; use std::fmt::Debug; -use std::sync::Arc; +use std::{any::Any, sync::Arc}; use crate::physical_expr::PhysicalExpr; -use crate::sort_expr::PhysicalSortExpr; +use crate::sort_expr::{LexOrdering, PhysicalSortExpr}; -use arrow::datatypes::Field; -use datafusion_common::{not_impl_err, Result}; -use datafusion_expr::{Accumulator, GroupsAccumulator}; +use self::utils::{down_cast_any_ref, ordering_fields}; + +/// Creates a physical expression of the UDAF, that includes all necessary type coercion. +/// This function errors when `args`' can't be coerced to a valid argument type of the UDAF. +pub fn create_aggregate_expr( + fun: &AggregateUDF, + input_phy_exprs: &[Arc], + sort_exprs: &[Expr], + ordering_req: &[PhysicalSortExpr], + schema: &Schema, + name: impl Into, + ignore_nulls: bool, +) -> Result> { + let input_exprs_types = input_phy_exprs + .iter() + .map(|arg| arg.data_type(schema)) + .collect::>>()?; + + let ordering_types = ordering_req + .iter() + .map(|e| e.expr.data_type(schema)) + .collect::>>()?; + + let ordering_fields = ordering_fields(ordering_req, &ordering_types); + + Ok(Arc::new(AggregateFunctionExpr { + fun: fun.clone(), + args: input_phy_exprs.to_vec(), + data_type: fun.return_type(&input_exprs_types)?, + name: name.into(), + schema: schema.clone(), + sort_exprs: sort_exprs.to_vec(), + ordering_req: ordering_req.to_vec(), + ignore_nulls, + ordering_fields, + })) +} /// An aggregate expression that: /// * knows its resulting field @@ -100,3 +138,151 @@ pub trait AggregateExpr: Send + Sync + Debug + PartialEq { not_impl_err!("Retractable Accumulator hasn't been implemented for {self:?} yet") } } + +/// Physical aggregate expression of a UDAF. +#[derive(Debug)] +pub struct AggregateFunctionExpr { + fun: AggregateUDF, + args: Vec>, + /// Output / return type of this aggregate + data_type: DataType, + name: String, + schema: Schema, + // The logical order by expressions + sort_exprs: Vec, + // The physical order by expressions + ordering_req: LexOrdering, + ignore_nulls: bool, + ordering_fields: Vec, +} + +impl AggregateFunctionExpr { + /// Return the `AggregateUDF` used by this `AggregateFunctionExpr` + pub fn fun(&self) -> &AggregateUDF { + &self.fun + } +} + +impl AggregateExpr for AggregateFunctionExpr { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn expressions(&self) -> Vec> { + self.args.clone() + } + + fn state_fields(&self) -> Result> { + self.fun.state_fields( + self.name(), + self.data_type.clone(), + self.ordering_fields.clone(), + ) + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, self.data_type.clone(), true)) + } + + fn create_accumulator(&self) -> Result> { + let acc_args = AccumulatorArgs::new( + &self.data_type, + &self.schema, + self.ignore_nulls, + &self.sort_exprs, + ); + + self.fun.accumulator(acc_args) + } + + fn create_sliding_accumulator(&self) -> Result> { + let accumulator = self.create_accumulator()?; + + // Accumulators that have window frame startings different + // than `UNBOUNDED PRECEDING`, such as `1 PRECEEDING`, need to + // implement retract_batch method in order to run correctly + // currently in DataFusion. + // + // If this `retract_batches` is not present, there is no way + // to calculate result correctly. For example, the query + // + // ```sql + // SELECT + // SUM(a) OVER(ORDER BY a ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS sum_a + // FROM + // t + // ``` + // + // 1. First sum value will be the sum of rows between `[0, 1)`, + // + // 2. Second sum value will be the sum of rows between `[0, 2)` + // + // 3. Third sum value will be the sum of rows between `[1, 3)`, etc. + // + // Since the accumulator keeps the running sum: + // + // 1. First sum we add to the state sum value between `[0, 1)` + // + // 2. Second sum we add to the state sum value between `[1, 2)` + // (`[0, 1)` is already in the state sum, hence running sum will + // cover `[0, 2)` range) + // + // 3. Third sum we add to the state sum value between `[2, 3)` + // (`[0, 2)` is already in the state sum). Also we need to + // retract values between `[0, 1)` by this way we can obtain sum + // between [1, 3) which is indeed the apropriate range. + // + // When we use `UNBOUNDED PRECEDING` in the query starting + // index will always be 0 for the desired range, and hence the + // `retract_batch` method will not be called. In this case + // having retract_batch is not a requirement. + // + // This approach is a a bit different than window function + // approach. In window function (when they use a window frame) + // they get all the desired range during evaluation. + if !accumulator.supports_retract_batch() { + return not_impl_err!( + "Aggregate can not be used as a sliding accumulator because \ + `retract_batch` is not implemented: {}", + self.name + ); + } + Ok(accumulator) + } + + fn name(&self) -> &str { + &self.name + } + + fn groups_accumulator_supported(&self) -> bool { + self.fun.groups_accumulator_supported() + } + + fn create_groups_accumulator(&self) -> Result> { + self.fun.create_groups_accumulator() + } + + fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { + (!self.ordering_req.is_empty()).then_some(&self.ordering_req) + } +} + +impl PartialEq for AggregateFunctionExpr { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.name == x.name + && self.data_type == x.data_type + && self.fun == x.fun + && self.args.len() == x.args.len() + && self + .args + .iter() + .zip(x.args.iter()) + .all(|(this_arg, other_arg)| this_arg.eq(other_arg)) + }) + .unwrap_or(false) + } +} diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 87d73183d0ddf..72fac5370ae0d 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -59,6 +59,7 @@ chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } +datafusion-functions-aggregate = { workspace = true } datafusion-physical-expr-common = { workspace = true } half = { workspace = true } hashbrown = { version = "0.14", features = ["raw"] } diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index e176084ae6ec2..eff008e8f8256 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -38,7 +38,6 @@ pub(crate) mod correlation; pub(crate) mod count; pub(crate) mod count_distinct; pub(crate) mod covariance; -pub(crate) mod first_last; pub(crate) mod grouping; pub(crate) mod median; pub(crate) mod nth_value; diff --git a/datafusion/physical-expr/src/aggregate/utils.rs b/datafusion/physical-expr/src/aggregate/utils.rs index d14a52f5752d0..6d97ad3da6de5 100644 --- a/datafusion/physical-expr/src/aggregate/utils.rs +++ b/datafusion/physical-expr/src/aggregate/utils.rs @@ -20,9 +20,9 @@ use std::sync::Arc; // For backwards compatibility -pub use datafusion_physical_expr_common::aggregate::utils::down_cast_any_ref; -pub use datafusion_physical_expr_common::aggregate::utils::get_sort_options; -pub use datafusion_physical_expr_common::aggregate::utils::ordering_fields; +pub use datafusion_physical_expr_common::aggregate::utils::{ + down_cast_any_ref, get_sort_options, ordering_fields, +}; use arrow::array::{ArrayRef, ArrowNativeTypeOp}; use arrow_array::cast::AsArray; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index f0cc4b175ea58..688d5ce6eabf2 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -53,7 +53,6 @@ pub use crate::aggregate::correlation::Correlation; pub use crate::aggregate::count::Count; pub use crate::aggregate::count_distinct::DistinctCount; pub use crate::aggregate::covariance::{Covariance, CovariancePop}; -pub use crate::aggregate::first_last::{FirstValue, LastValue}; pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::median::Median; pub use crate::aggregate::min_max::{Max, Min}; @@ -76,11 +75,15 @@ pub use crate::window::rank::{dense_rank, percent_rank, rank}; pub use crate::window::rank::{Rank, RankType}; pub use crate::window::row_number::RowNumber; pub use crate::PhysicalSortExpr; +pub use datafusion_functions_aggregate::first_last::{ + FirstValuePhysicalExpr as FirstValue, LastValuePhysicalExpr as LastValue, +}; pub use binary::{binary, BinaryExpr}; pub use case::{case, CaseExpr}; pub use cast::{cast, cast_with_options, CastExpr}; pub use column::UnKnownColumn; +pub use datafusion_expr::utils::format_state_name; pub use datafusion_physical_expr_common::expressions::column::{col, Column}; pub use in_list::{in_list, InListExpr}; pub use is_not_null::{is_not_null, IsNotNullExpr}; @@ -92,11 +95,6 @@ pub use no_op::NoOp; pub use not::{not, NotExpr}; pub use try_cast::{try_cast, TryCastExpr}; -/// returns the name of the state -pub fn format_state_name(name: &str, state_name: &str) -> String { - format!("{name}[{state_name}]") -} - #[cfg(test)] pub(crate) mod tests { use std::sync::Arc; diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index c88f1b32bbc6c..7b81e8f8a5c47 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -61,8 +61,6 @@ pub use scalar_function::ScalarFunctionExpr; pub use datafusion_physical_expr_common::utils::reverse_order_bys; pub use utils::split_conjunction; -pub use aggregate::first_last::create_first_value_accumulator; - // For backwards compatibility pub mod sort_properties { pub use datafusion_physical_expr_common::sort_properties::{ diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index 1ba32bff746e1..6a78bd596a46e 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -46,7 +46,9 @@ datafusion-common = { workspace = true, default-features = true } datafusion-common-runtime = { workspace = true, default-features = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } +datafusion-functions-aggregate = { workspace = true } datafusion-physical-expr = { workspace = true, default-features = true } +datafusion-physical-expr-common = { workspace = true } futures = { workspace = true } half = { workspace = true } hashbrown = { version = "0.14", features = ["raw"] } diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index f8ad03bf6d977..98c44e23c6c77 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1235,7 +1235,7 @@ mod tests { use datafusion_execution::memory_pool::FairSpillPool; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_physical_expr::expressions::{ - lit, ApproxDistinct, Count, FirstValue, LastValue, Median, OrderSensitiveArrayAgg, + lit, ApproxDistinct, Count, LastValue, Median, OrderSensitiveArrayAgg, }; use datafusion_physical_expr::{ reverse_order_bys, AggregateExpr, EquivalenceProperties, PhysicalExpr, diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 3decf2e34015b..e1c8489655bf5 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -66,7 +66,6 @@ pub mod sorts; pub mod stream; pub mod streaming; pub mod tree_node; -pub mod udaf; pub mod union; pub mod unnest; pub mod values; @@ -91,6 +90,11 @@ pub use datafusion_physical_expr::{ // Backwards compatibility pub use crate::stream::EmptyRecordBatchStream; pub use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; +pub mod udaf { + pub use datafusion_physical_expr_common::aggregate::{ + create_aggregate_expr, AggregateFunctionExpr, + }; +} /// Represent nodes in the DataFusion Physical Plan. /// diff --git a/datafusion/physical-plan/src/udaf.rs b/datafusion/physical-plan/src/udaf.rs deleted file mode 100644 index 74a5603c0c817..0000000000000 --- a/datafusion/physical-plan/src/udaf.rs +++ /dev/null @@ -1,218 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! This module contains functions and structs supporting user-defined aggregate functions. - -use datafusion_expr::function::AccumulatorArgs; -use datafusion_expr::{Expr, GroupsAccumulator}; -use fmt::Debug; -use std::any::Any; -use std::fmt; - -use arrow::datatypes::{DataType, Field, Schema}; - -use super::{Accumulator, AggregateExpr}; -use datafusion_common::{not_impl_err, Result}; -pub use datafusion_expr::AggregateUDF; -use datafusion_physical_expr::{LexOrdering, PhysicalExpr, PhysicalSortExpr}; - -use datafusion_physical_expr::aggregate::utils::{down_cast_any_ref, ordering_fields}; -use std::sync::Arc; - -/// Creates a physical expression of the UDAF, that includes all necessary type coercion. -/// This function errors when `args`' can't be coerced to a valid argument type of the UDAF. -pub fn create_aggregate_expr( - fun: &AggregateUDF, - input_phy_exprs: &[Arc], - sort_exprs: &[Expr], - ordering_req: &[PhysicalSortExpr], - schema: &Schema, - name: impl Into, - ignore_nulls: bool, -) -> Result> { - let input_exprs_types = input_phy_exprs - .iter() - .map(|arg| arg.data_type(schema)) - .collect::>>()?; - - let ordering_types = ordering_req - .iter() - .map(|e| e.expr.data_type(schema)) - .collect::>>()?; - - let ordering_fields = ordering_fields(ordering_req, &ordering_types); - - Ok(Arc::new(AggregateFunctionExpr { - fun: fun.clone(), - args: input_phy_exprs.to_vec(), - data_type: fun.return_type(&input_exprs_types)?, - name: name.into(), - schema: schema.clone(), - sort_exprs: sort_exprs.to_vec(), - ordering_req: ordering_req.to_vec(), - ignore_nulls, - ordering_fields, - })) -} - -/// Physical aggregate expression of a UDAF. -#[derive(Debug)] -pub struct AggregateFunctionExpr { - fun: AggregateUDF, - args: Vec>, - /// Output / return type of this aggregate - data_type: DataType, - name: String, - schema: Schema, - // The logical order by expressions - sort_exprs: Vec, - // The physical order by expressions - ordering_req: LexOrdering, - ignore_nulls: bool, - ordering_fields: Vec, -} - -impl AggregateFunctionExpr { - /// Return the `AggregateUDF` used by this `AggregateFunctionExpr` - pub fn fun(&self) -> &AggregateUDF { - &self.fun - } -} - -impl AggregateExpr for AggregateFunctionExpr { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn expressions(&self) -> Vec> { - self.args.clone() - } - - fn state_fields(&self) -> Result> { - self.fun.state_fields( - self.name(), - self.data_type.clone(), - self.ordering_fields.clone(), - ) - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.data_type.clone(), true)) - } - - fn create_accumulator(&self) -> Result> { - let acc_args = AccumulatorArgs::new( - &self.data_type, - &self.schema, - self.ignore_nulls, - &self.sort_exprs, - ); - - self.fun.accumulator(acc_args) - } - - fn create_sliding_accumulator(&self) -> Result> { - let accumulator = self.create_accumulator()?; - - // Accumulators that have window frame startings different - // than `UNBOUNDED PRECEDING`, such as `1 PRECEEDING`, need to - // implement retract_batch method in order to run correctly - // currently in DataFusion. - // - // If this `retract_batches` is not present, there is no way - // to calculate result correctly. For example, the query - // - // ```sql - // SELECT - // SUM(a) OVER(ORDER BY a ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS sum_a - // FROM - // t - // ``` - // - // 1. First sum value will be the sum of rows between `[0, 1)`, - // - // 2. Second sum value will be the sum of rows between `[0, 2)` - // - // 3. Third sum value will be the sum of rows between `[1, 3)`, etc. - // - // Since the accumulator keeps the running sum: - // - // 1. First sum we add to the state sum value between `[0, 1)` - // - // 2. Second sum we add to the state sum value between `[1, 2)` - // (`[0, 1)` is already in the state sum, hence running sum will - // cover `[0, 2)` range) - // - // 3. Third sum we add to the state sum value between `[2, 3)` - // (`[0, 2)` is already in the state sum). Also we need to - // retract values between `[0, 1)` by this way we can obtain sum - // between [1, 3) which is indeed the apropriate range. - // - // When we use `UNBOUNDED PRECEDING` in the query starting - // index will always be 0 for the desired range, and hence the - // `retract_batch` method will not be called. In this case - // having retract_batch is not a requirement. - // - // This approach is a a bit different than window function - // approach. In window function (when they use a window frame) - // they get all the desired range during evaluation. - if !accumulator.supports_retract_batch() { - return not_impl_err!( - "Aggregate can not be used as a sliding accumulator because \ - `retract_batch` is not implemented: {}", - self.name - ); - } - Ok(accumulator) - } - - fn name(&self) -> &str { - &self.name - } - - fn groups_accumulator_supported(&self) -> bool { - self.fun.groups_accumulator_supported() - } - - fn create_groups_accumulator(&self) -> Result> { - self.fun.create_groups_accumulator() - } - - fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { - (!self.ordering_req.is_empty()).then_some(&self.ordering_req) - } -} - -impl PartialEq for AggregateFunctionExpr { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.fun == x.fun - && self.args.len() == x.args.len() - && self - .args - .iter() - .zip(x.args.iter()) - .all(|(this_arg, other_arg)| this_arg.eq(other_arg)) - }) - .unwrap_or(false) - } -} diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index f136e314559b0..e680a1b2ff1e4 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -30,6 +30,7 @@ use datafusion::datasource::provider::TableProviderFactory; use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::functions_aggregate::expr_fn::first_value; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::config::{FormatOptions, TableOptions}; @@ -612,6 +613,7 @@ async fn roundtrip_expr_api() -> Result<()> { lit(1), ), array_replace_all(make_array(vec![lit(1), lit(2), lit(3)]), lit(2), lit(4)), + first_value(lit(1)), ]; // ensure expressions created with the expr api can be round tripped From f7b4ed0ae1382fc10498c053597c202974753514 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 7 Apr 2024 00:55:27 -0400 Subject: [PATCH 03/14] Minor: Avoid copying all expressions in check_plan (#9974) --- datafusion/optimizer/src/analyzer/mod.rs | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index ae61aea997b70..c7eb6e895d577 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -26,7 +26,6 @@ use datafusion_common::{DataFusionError, Result}; use datafusion_expr::expr::Exists; use datafusion_expr::expr::InSubquery; use datafusion_expr::expr_rewriter::FunctionRewrite; -use datafusion_expr::utils::inspect_expr_pre; use datafusion_expr::{Expr, LogicalPlan}; use crate::analyzer::count_wildcard_rule::CountWildcardRule; @@ -156,18 +155,21 @@ impl Analyzer { /// Do necessary check and fail the invalid plan fn check_plan(plan: &LogicalPlan) -> Result<()> { plan.apply(&mut |plan: &LogicalPlan| { - for expr in plan.expressions().iter() { + plan.inspect_expressions(|expr| { // recursively look for subqueries - inspect_expr_pre(expr, |expr| match expr { - Expr::Exists(Exists { subquery, .. }) - | Expr::InSubquery(InSubquery { subquery, .. }) - | Expr::ScalarSubquery(subquery) => { - check_subquery_expr(plan, &subquery.subquery, expr) - } - _ => Ok(()), + expr.apply(&mut |expr| { + match expr { + Expr::Exists(Exists { subquery, .. }) + | Expr::InSubquery(InSubquery { subquery, .. }) + | Expr::ScalarSubquery(subquery) => { + check_subquery_expr(plan, &subquery.subquery, expr)?; + } + _ => {} + }; + Ok(TreeNodeRecursion::Continue) })?; - } - + Ok::<(), DataFusionError>(()) + })?; Ok(TreeNodeRecursion::Continue) })?; From 1a002bccd420ff91ec149ee1ba9c42061510f906 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 7 Apr 2024 03:39:49 -0400 Subject: [PATCH 04/14] Minor: Improve documentation about optimizer (#9967) * Minor: Improve documentation about optimizer * fix unused commit --- datafusion/optimizer/src/analyzer/mod.rs | 1 + datafusion/optimizer/src/decorrelate.rs | 13 ++++++++++--- .../src/decorrelate_predicate_subquery.rs | 1 + datafusion/optimizer/src/eliminate_cross_join.rs | 2 +- .../optimizer/src/eliminate_duplicated_expr.rs | 2 ++ datafusion/optimizer/src/eliminate_filter.rs | 11 +++++++---- datafusion/optimizer/src/eliminate_join.rs | 3 ++- datafusion/optimizer/src/eliminate_limit.rs | 15 ++++++++------- .../optimizer/src/eliminate_nested_union.rs | 2 +- datafusion/optimizer/src/eliminate_one_union.rs | 2 +- datafusion/optimizer/src/eliminate_outer_join.rs | 2 +- .../optimizer/src/extract_equijoin_predicate.rs | 2 +- .../optimizer/src/filter_null_join_keys.rs | 5 +---- datafusion/optimizer/src/lib.rs | 16 +++++++++++++++- datafusion/optimizer/src/optimize_projections.rs | 16 ++++++++-------- datafusion/optimizer/src/optimizer.rs | 4 ++-- .../optimizer/src/propagate_empty_relation.rs | 1 + datafusion/optimizer/src/push_down_filter.rs | 3 +-- datafusion/optimizer/src/push_down_limit.rs | 7 ++++--- datafusion/optimizer/src/push_down_projection.rs | 3 --- .../optimizer/src/replace_distinct_aggregate.rs | 1 + .../src/rewrite_disjunctive_predicate.rs | 2 ++ .../optimizer/src/scalar_subquery_to_join.rs | 2 ++ .../optimizer/src/simplify_expressions/mod.rs | 3 +++ .../optimizer/src/single_distinct_to_groupby.rs | 2 +- .../optimizer/src/unwrap_cast_in_comparison.rs | 4 +--- datafusion/optimizer/src/utils.rs | 2 +- 27 files changed, 79 insertions(+), 48 deletions(-) diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index c7eb6e895d577..b446fe2f320ec 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +//! [`Analyzer`] and [`AnalyzerRule`] use std::sync::Arc; use log::debug; diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index 12e84a63ea150..dbcf02b26ba66 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! [`PullUpCorrelatedExpr`] converts correlated subqueries to `Joins` + use std::collections::{BTreeSet, HashMap}; use std::ops::Deref; @@ -31,8 +33,11 @@ use datafusion_expr::utils::{conjunction, find_join_exprs, split_conjunction}; use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; use datafusion_physical_expr::execution_props::ExecutionProps; -/// This struct rewrite the sub query plan by pull up the correlated expressions(contains outer reference columns) from the inner subquery's 'Filter'. -/// It adds the inner reference columns to the 'Projection' or 'Aggregate' of the subquery if they are missing, so that they can be evaluated by the parent operator as the join condition. +/// This struct rewrite the sub query plan by pull up the correlated +/// expressions(contains outer reference columns) from the inner subquery's +/// 'Filter'. It adds the inner reference columns to the 'Projection' or +/// 'Aggregate' of the subquery if they are missing, so that they can be +/// evaluated by the parent operator as the join condition. pub struct PullUpCorrelatedExpr { pub join_filters: Vec, // mapping from the plan to its holding correlated columns @@ -54,7 +59,9 @@ pub struct PullUpCorrelatedExpr { /// This is used to handle the Count bug pub const UN_MATCHED_ROW_INDICATOR: &str = "__always_true"; -/// Mapping from expr display name to its evaluation result on empty record batch (for example: 'count(*)' is 'ScalarValue(0)', 'count(*) + 2' is 'ScalarValue(2)') +/// Mapping from expr display name to its evaluation result on empty record +/// batch (for example: 'count(*)' is 'ScalarValue(0)', 'count(*) + 2' is +/// 'ScalarValue(2)') pub type ExprResultMap = HashMap; impl TreeNodeRewriter for PullUpCorrelatedExpr { diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index b94cf37c5c12b..019e7507b1228 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +//! [`DecorrelatePredicateSubquery`] converts `IN`/`EXISTS` subquery predicates to `SEMI`/`ANTI` joins use std::collections::BTreeSet; use std::ops::Deref; use std::sync::Arc; diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 7f65690a4a7cb..18a9c05b9dc65 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Optimizer rule to eliminate cross join to inner join if join predicates are available in filters. +//! [`EliminateCrossJoin`] converts `CROSS JOIN` to `INNER JOIN` if join predicates are available. use std::collections::HashSet; use std::sync::Arc; diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs b/datafusion/optimizer/src/eliminate_duplicated_expr.rs index de05717a72e27..349d4d8878e02 100644 --- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs +++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! [`EliminateDuplicatedExpr`] Removes redundant expressions + use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; diff --git a/datafusion/optimizer/src/eliminate_filter.rs b/datafusion/optimizer/src/eliminate_filter.rs index fea14342ca774..9411dc192bebf 100644 --- a/datafusion/optimizer/src/eliminate_filter.rs +++ b/datafusion/optimizer/src/eliminate_filter.rs @@ -15,9 +15,8 @@ // specific language governing permissions and limitations // under the License. -//! Optimizer rule to replace `where false or null` on a plan with an empty relation. -//! This saves time in planning and executing the query. -//! Note that this rule should be applied after simplify expressions optimizer rule. +//! [`EliminateFilter`] replaces `where false` or `where null` with an empty relation. + use crate::optimizer::ApplyOrder; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ @@ -27,7 +26,11 @@ use datafusion_expr::{ use crate::{OptimizerConfig, OptimizerRule}; -/// Optimization rule that eliminate the scalar value (true/false/null) filter with an [LogicalPlan::EmptyRelation] +/// Optimization rule that eliminate the scalar value (true/false/null) filter +/// with an [LogicalPlan::EmptyRelation] +/// +/// This saves time in planning and executing the query. +/// Note that this rule should be applied after simplify expressions optimizer rule. #[derive(Default)] pub struct EliminateFilter; diff --git a/datafusion/optimizer/src/eliminate_join.rs b/datafusion/optimizer/src/eliminate_join.rs index 0dbebcc8a0519..e685229c61b26 100644 --- a/datafusion/optimizer/src/eliminate_join.rs +++ b/datafusion/optimizer/src/eliminate_join.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +//! [`EliminateJoin`] rewrites `INNER JOIN` with `true`/`null` use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{Result, ScalarValue}; @@ -24,7 +25,7 @@ use datafusion_expr::{ CrossJoin, Expr, }; -/// Eliminates joins when inner join condition is false. +/// Eliminates joins when join condition is false. /// Replaces joins when inner join condition is true with a cross join. #[derive(Default)] pub struct EliminateJoin; diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index 4386253740aaa..fb5d0d17b839a 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -15,18 +15,19 @@ // specific language governing permissions and limitations // under the License. -//! Optimizer rule to replace `LIMIT 0` or -//! `LIMIT whose ancestor LIMIT's skip is greater than or equal to current's fetch` -//! on a plan with an empty relation. -//! This rule also removes OFFSET 0 from the [LogicalPlan] -//! This saves time in planning and executing the query. +//! [`EliminateLimit`] eliminates `LIMIT` when possible use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; use datafusion_expr::logical_plan::{EmptyRelation, LogicalPlan}; -/// Optimization rule that eliminate LIMIT 0 or useless LIMIT(skip:0, fetch:None). -/// It can cooperate with `propagate_empty_relation` and `limit_push_down`. +/// Optimizer rule to replace `LIMIT 0` or `LIMIT` whose ancestor LIMIT's skip is +/// greater than or equal to current's fetch +/// +/// It can cooperate with `propagate_empty_relation` and `limit_push_down`. on a +/// plan with an empty relation. +/// +/// This rule also removes OFFSET 0 from the [LogicalPlan] #[derive(Default)] pub struct EliminateLimit; diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs b/datafusion/optimizer/src/eliminate_nested_union.rs index 5771ea2e19a29..924a0853418cd 100644 --- a/datafusion/optimizer/src/eliminate_nested_union.rs +++ b/datafusion/optimizer/src/eliminate_nested_union.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Optimizer rule to replace nested unions to single union. +//! [`EliminateNestedUnion`]: flattens nested `Union` to a single `Union` use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; diff --git a/datafusion/optimizer/src/eliminate_one_union.rs b/datafusion/optimizer/src/eliminate_one_union.rs index 70ee490346ffb..63c3e789daa67 100644 --- a/datafusion/optimizer/src/eliminate_one_union.rs +++ b/datafusion/optimizer/src/eliminate_one_union.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Optimizer rule to eliminate one union. +//! [`EliminateOneUnion`] eliminates single element `Union` use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; use datafusion_expr::logical_plan::{LogicalPlan, Union}; diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index 56a4a76987f75..a004da2bff19b 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Optimizer rule to eliminate left/right/full join to inner join if possible. +//! [`EliminateOuterJoin`] converts `LEFT/RIGHT/FULL` joins to `INNER` joins use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{Column, DFSchema, Result}; use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan}; diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index 24664d57c38d8..4cfcd07b47d93 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! [`ExtractEquijoinPredicate`] rule that extracts equijoin predicates +//! [`ExtractEquijoinPredicate`] identifies equality join (equijoin) predicates use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::DFSchema; diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs b/datafusion/optimizer/src/filter_null_join_keys.rs index 95cd8a9fd36ca..16039b182bb20 100644 --- a/datafusion/optimizer/src/filter_null_join_keys.rs +++ b/datafusion/optimizer/src/filter_null_join_keys.rs @@ -15,10 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! The FilterNullJoinKeys rule will identify inner joins with equi-join conditions -//! where the join key is nullable on one side and non-nullable on the other side -//! and then insert an `IsNotNull` filter on the nullable side since null values -//! can never match. +//! [`FilterNullJoinKeys`] adds filters to join inputs when input isn't nullable use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index b54facc5d6825..f1f49727c39c2 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -15,6 +15,19 @@ // specific language governing permissions and limitations // under the License. +//! # DataFusion Optimizer +//! +//! Contains rules for rewriting [`LogicalPlan`]s +//! +//! 1. [`Analyzer`] applies [`AnalyzerRule`]s to transform `LogicalPlan`s +//! to make the plan valid prior to the rest of the DataFusion optimization +//! process (for example, [`TypeCoercion`]). +//! +//! 2. [`Optimizer`] applies [`OptimizerRule`]s to transform `LogicalPlan`s +//! into equivalent, but more efficient plans. +//! +//! [`LogicalPlan`]: datafusion_expr::LogicalPlan +//! [`TypeCoercion`]: analyzer::type_coercion::TypeCoercion pub mod analyzer; pub mod common_subexpr_eliminate; pub mod decorrelate; @@ -46,7 +59,8 @@ pub mod utils; #[cfg(test)] pub mod test; -pub use optimizer::{OptimizerConfig, OptimizerContext, OptimizerRule}; +pub use analyzer::{Analyzer, AnalyzerRule}; +pub use optimizer::{Optimizer, OptimizerConfig, OptimizerContext, OptimizerRule}; pub use utils::optimize_children; mod plan_signature; diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs index c40a9bb704ebf..147702cc04411 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -15,13 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Optimizer rule to prune unnecessary columns from intermediate schemas -//! inside the [`LogicalPlan`]. This rule: -//! - Removes unnecessary columns that do not appear at the output and/or are -//! not used during any computation step. -//! - Adds projections to decrease table column size before operators that -//! benefit from a smaller memory footprint at its input. -//! - Removes unnecessary [`LogicalPlan::Projection`]s from the [`LogicalPlan`]. +//! [`OptimizeProjections`] identifies and eliminates unused columns use std::collections::HashSet; use std::sync::Arc; @@ -44,7 +38,13 @@ use datafusion_expr::utils::inspect_expr_pre; use hashbrown::HashMap; use itertools::{izip, Itertools}; -/// A rule for optimizing logical plans by removing unused columns/fields. +/// Optimizer rule to prune unnecessary columns from intermediate schemas +/// inside the [`LogicalPlan`]. This rule: +/// - Removes unnecessary columns that do not appear at the output and/or are +/// not used during any computation step. +/// - Adds projections to decrease table column size before operators that +/// benefit from a smaller memory footprint at its input. +/// - Removes unnecessary [`LogicalPlan::Projection`]s from the [`LogicalPlan`]. /// /// `OptimizeProjections` is an optimizer rule that identifies and eliminates /// columns from a logical plan that are not used by downstream operations. diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 3153f72d7ee70..03ff402c3e3f2 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Query optimizer traits +//! [`Optimizer`] and [`OptimizerRule`] use std::collections::HashSet; use std::sync::Arc; @@ -54,7 +54,7 @@ use datafusion_expr::logical_plan::LogicalPlan; use chrono::{DateTime, Utc}; use log::{debug, warn}; -/// `OptimizerRule` transforms one [`LogicalPlan`] into another which +/// `OptimizerRule`s transforms one [`LogicalPlan`] into another which /// computes the same results, but in a potentially more efficient /// way. If there are no suitable transformations for the input plan, /// the optimizer should simply return it unmodified. diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs index 55fb982d2a875..2aca6f93254ad 100644 --- a/datafusion/optimizer/src/propagate_empty_relation.rs +++ b/datafusion/optimizer/src/propagate_empty_relation.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +//! [`PropagateEmptyRelation`] eliminates nodes fed by `EmptyRelation` use datafusion_common::{plan_err, Result}; use datafusion_expr::logical_plan::LogicalPlan; use datafusion_expr::{EmptyRelation, JoinType, Projection, Union}; diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 83db4b0640a49..ff24df259adfd 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -12,8 +12,7 @@ // specific language governing permissions and limitations // under the License. -//! [`PushDownFilter`] Moves filters so they are applied as early as possible in -//! the plan. +//! [`PushDownFilter`] applies filters as early as possible use std::collections::{HashMap, HashSet}; use std::sync::Arc; diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 33d02d5c5628e..cca6c3fd9bd17 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -15,8 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Optimizer rule to push down LIMIT in the query plan -//! It will push down through projection, limits (taking the smaller limit) +//! [`PushDownLimit`] pushes `LIMIT` earlier in the query plan use std::sync::Arc; @@ -29,7 +28,9 @@ use datafusion_expr::logical_plan::{ }; use datafusion_expr::CrossJoin; -/// Optimization rule that tries to push down LIMIT. +/// Optimization rule that tries to push down `LIMIT`. +/// +//. It will push down through projection, limits (taking the smaller limit) #[derive(Default)] pub struct PushDownLimit {} diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index ccdcf2f65bc8f..ae57ed9e5a345 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -15,9 +15,6 @@ // specific language governing permissions and limitations // under the License. -//! Projection Push Down optimizer rule ensures that only referenced columns are -//! loaded into memory - #[cfg(test)] mod tests { use std::collections::HashMap; diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index 0055e329c29d9..752915be69c04 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +//! [`ReplaceDistinctWithAggregate`] replaces `DISTINCT ...` with `GROUP BY ...` use crate::optimizer::{ApplyOrder, ApplyOrder::BottomUp}; use crate::{OptimizerConfig, OptimizerRule}; diff --git a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs index 90c96b4b8b8cb..059b1452ff3dc 100644 --- a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs +++ b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! [`RewriteDisjunctivePredicate`] rewrites predicates to reduce redundancy + use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 8acc36e479cab..a2c4eabcaae6e 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! [`ScalarSubqueryToJoin`] rewriting scalar subquery filters to `JOIN`s + use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs b/datafusion/optimizer/src/simplify_expressions/mod.rs index 5244f9a5af881..d0399fef07e64 100644 --- a/datafusion/optimizer/src/simplify_expressions/mod.rs +++ b/datafusion/optimizer/src/simplify_expressions/mod.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +//! [`SimplifyExpressions`] simplifies expressions in the logical plan, +//! [`ExprSimplifier`] simplifies individual `Expr`s. + pub mod expr_simplifier; mod guarantees; mod inlist_simplifier; diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 5b47abb308d0d..076bf4e24296d 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! single distinct to group by optimizer rule +//! [`SingleDistinctToGroupBy`] replaces `AGG(DISTINCT ..)` with `AGG(..) GROUP BY ..` use std::sync::Arc; diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index f573ac69377ba..fda390f379610 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -15,9 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Unwrap-cast binary comparison rule can be used to the binary/inlist comparison expr now, and other type -//! of expr can be added if needed. -//! This rule can reduce adding the `Expr::Cast` the expr instead of adding the `Expr::Cast` to literal expr. +//! [`UnwrapCastInComparison`] rewrites `CAST(col) = lit` to `col = CAST(lit)` use std::cmp::Ordering; use std::sync::Arc; diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 0df79550f143a..560c63b18882a 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Collection of utility functions that are leveraged by the query optimizer rules +//! Utility functions leveraged by the query optimizer rules use std::collections::{BTreeSet, HashMap}; From 7acc8f16cf0776a4112a5e62214a44ad20c4c673 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Sun, 7 Apr 2024 18:03:27 +0200 Subject: [PATCH 05/14] use `Expr::apply()` instead of `inspect_expr_pre()` (#9984) --- datafusion/expr/src/utils.rs | 13 +++++++------ datafusion/optimizer/src/optimize_projections.rs | 6 +++--- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index a93282574e8a2..0d99d0b5028ee 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -264,7 +264,7 @@ pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result> { /// Recursively walk an expression tree, collecting the unique set of columns /// referenced in the expression pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { - inspect_expr_pre(expr, |expr| { + expr.apply(&mut |expr| { match expr { Expr::Column(qc) => { accum.insert(qc.clone()); @@ -307,8 +307,9 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::Placeholder(_) | Expr::OuterReferenceColumn { .. } => {} } - Ok(()) + Ok(TreeNodeRecursion::Continue) }) + .map(|_| ()) } /// Find excluded columns in the schema, if any @@ -838,11 +839,11 @@ pub fn find_column_exprs(exprs: &[Expr]) -> Vec { pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec { let mut exprs = vec![]; - inspect_expr_pre(e, |expr| { + e.apply(&mut |expr| { if let Expr::Column(c) = expr { exprs.push(c.clone()) } - Ok(()) as Result<()> + Ok(TreeNodeRecursion::Continue) }) // As the closure always returns Ok, this "can't" error .expect("Unexpected error"); @@ -867,7 +868,7 @@ pub(crate) fn find_column_indexes_referenced_by_expr( schema: &DFSchemaRef, ) -> Vec { let mut indexes = vec![]; - inspect_expr_pre(e, |expr| { + e.apply(&mut |expr| { match expr { Expr::Column(qc) => { if let Ok(idx) = schema.index_of_column(qc) { @@ -879,7 +880,7 @@ pub(crate) fn find_column_indexes_referenced_by_expr( } _ => {} } - Ok(()) as Result<()> + Ok(TreeNodeRecursion::Continue) }) .unwrap(); indexes diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs index 147702cc04411..69905c990a7f3 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -34,7 +34,7 @@ use datafusion_expr::{ Expr, Projection, TableScan, Window, }; -use datafusion_expr::utils::inspect_expr_pre; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use hashbrown::HashMap; use itertools::{izip, Itertools}; @@ -613,7 +613,7 @@ fn rewrite_expr(expr: &Expr, input: &Projection) -> Result> { /// columns are collected. fn outer_columns(expr: &Expr, columns: &mut HashSet) { // inspect_expr_pre doesn't handle subquery references, so find them explicitly - inspect_expr_pre(expr, |expr| { + expr.apply(&mut |expr| { match expr { Expr::OuterReferenceColumn(_, col) => { columns.insert(col.clone()); @@ -632,7 +632,7 @@ fn outer_columns(expr: &Expr, columns: &mut HashSet) { } _ => {} }; - Ok(()) as Result<()> + Ok(TreeNodeRecursion::Continue) }) // unwrap: closure above never returns Err, so can not be Err here .unwrap(); From 85b4e40df9e9a5a71c08760452c2059a271313d1 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 7 Apr 2024 13:39:44 -0400 Subject: [PATCH 06/14] Update documentation for COPY command (#9931) * Update documentation for COPY command * Fix example * prettier --- docs/source/user-guide/sql/dml.md | 39 +++++++++++++++--- docs/source/user-guide/sql/write_options.md | 45 ++++++++++----------- 2 files changed, 54 insertions(+), 30 deletions(-) diff --git a/docs/source/user-guide/sql/dml.md b/docs/source/user-guide/sql/dml.md index 79c36092fd3d3..666e86b460023 100644 --- a/docs/source/user-guide/sql/dml.md +++ b/docs/source/user-guide/sql/dml.md @@ -35,8 +35,22 @@ TO 'file_name' [ OPTIONS( option [, ... ] ) ] +`STORED AS` specifies the file format the `COPY` command will write. If this +clause is not specified, it will be inferred from the file extension if possible. + +`PARTITIONED BY` specifies the columns to use for partitioning the output files into +separate hive-style directories. + +The output format is determined by the first match of the following rules: + +1. Value of `STORED AS` +2. Value of the `OPTION (FORMAT ..)` +3. Filename extension (e.g. `foo.parquet` implies `PARQUET` format) + For a detailed list of valid OPTIONS, see [Write Options](write_options). +### Examples + Copy the contents of `source_table` to `file_name.json` in JSON format: ```sql @@ -72,6 +86,23 @@ of hive-style partitioned parquet files: +-------+ ``` +If the the data contains values of `x` and `y` in column1 and only `a` in +column2, output files will appear in the following directory structure: + +``` +dir_name/ + column1=x/ + column2=a/ + .parquet + .parquet + ... + column1=y/ + column2=a/ + .parquet + .parquet + ... +``` + Run the query `SELECT * from source ORDER BY time` and write the results (maintaining the order) to a parquet file named `output.parquet` with a maximum parquet row group size of 10MB: @@ -85,14 +116,10 @@ results (maintaining the order) to a parquet file named +-------+ ``` -The output format is determined by the first match of the following rules: - -1. Value of `STORED AS` -2. Value of the `OPTION (FORMAT ..)` -3. Filename extension (e.g. `foo.parquet` implies `PARQUET` format) - ## INSERT +### Examples + Insert values into a table.
diff --git a/docs/source/user-guide/sql/write_options.md b/docs/source/user-guide/sql/write_options.md
index ac0a41a97f07c..5c204d8fc0e61 100644
--- a/docs/source/user-guide/sql/write_options.md
+++ b/docs/source/user-guide/sql/write_options.md
@@ -35,44 +35,41 @@ If inserting to an external table, table specific write options can be specified
 
 ```sql
 CREATE EXTERNAL TABLE
-my_table(a bigint, b bigint)
-STORED AS csv
-COMPRESSION TYPE gzip
-WITH HEADER ROW
-DELIMITER ';'
-LOCATION '/test/location/my_csv_table/'
-OPTIONS(
-NULL_VALUE 'NAN'
-);
+  my_table(a bigint, b bigint)
+  STORED AS csv
+  COMPRESSION TYPE gzip
+  WITH HEADER ROW
+  DELIMITER ';'
+  LOCATION '/test/location/my_csv_table/'
+  OPTIONS(
+    NULL_VALUE 'NAN'
+  )
 ```
 
 When running `INSERT INTO my_table ...`, the options from the `CREATE TABLE` will be respected (gzip compression, special delimiter, and header row included). There will be a single output file if the output path doesn't have folder format, i.e. ending with a `\`. Note that compression, header, and delimiter settings can also be specified within the `OPTIONS` tuple list. Dedicated syntax within the SQL statement always takes precedence over arbitrary option tuples, so if both are specified the `OPTIONS` setting will be ignored. NULL_VALUE is a CSV format specific option that determines how null values should be encoded within the CSV file.
 
 Finally, options can be passed when running a `COPY` command.
 
+
+
 ```sql
 COPY source_table
-TO 'test/table_with_options'
-(format parquet,
-compression snappy,
-'compression::col1' 'zstd(5)',
-partition_by 'column3, column4'
-)
+  TO 'test/table_with_options'
+  PARTITIONED BY (column3, column4)
+  OPTIONS (
+    format parquet,
+    compression snappy,
+    'compression::column1' 'zstd(5)',
+  )
 ```
 
 In this example, we write the entirety of `source_table` out to a folder of parquet files. One parquet file will be written in parallel to the folder for each partition in the query. The next option `compression` set to `snappy` indicates that unless otherwise specified all columns should use the snappy compression codec. The option `compression::col1` sets an override, so that the column `col1` in the parquet file will use `ZSTD` compression codec with compression level `5`. In general, parquet options which support column specific settings can be specified with the syntax `OPTION::COLUMN.NESTED.PATH`.
 
 ## Available Options
 
-### COPY Specific Options
-
-The following special options are specific to the `COPY` command.
-
-| Option       | Description                                                                                                                                                                         | Default Value |
-| ------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------- |
-| FORMAT       | Specifies the file format COPY query will write out. If there're more than one output file or the format cannot be inferred from the file extension, then FORMAT must be specified. | N/A           |
-| PARTITION_BY | Specifies the columns that the output files should be partitioned by into separate hive-style directories. Value should be a comma separated string literal, e.g. 'col1,col2'       | N/A           |
-
 ### JSON Format Specific Options
 
 The following options are available when writing JSON files. Note: If any unsupported option is specified, an error will be raised and the query will fail.

From 9ce21e11ea8574ea2b650d80bf09327db343887f Mon Sep 17 00:00:00 2001
From: Andrew Lamb 
Date: Sun, 7 Apr 2024 21:53:41 -0400
Subject: [PATCH 07/14] Minor: fix bug in pruning predicate doc (#9986)

* Minor: fix bug in pruning predicate doc

* formatting
---
 datafusion/core/src/physical_optimizer/pruning.rs | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs
index 80bb5ad42e814..19e71a92a7066 100644
--- a/datafusion/core/src/physical_optimizer/pruning.rs
+++ b/datafusion/core/src/physical_optimizer/pruning.rs
@@ -330,7 +330,7 @@ pub trait PruningStatistics {
 /// `x = 5` | `CASE WHEN x_null_count = x_row_count THEN false ELSE x_min <= 5 AND 5 <= x_max END`
 /// `x < 5` | `CASE WHEN x_null_count = x_row_count THEN false ELSE x_max < 5 END`
 /// `x = 5 AND y = 10` | `CASE WHEN x_null_count = x_row_count THEN false ELSE x_min <= 5 AND 5 <= x_max END AND CASE WHEN y_null_count = y_row_count THEN false ELSE y_min <= 10 AND 10 <= y_max END`
-/// `x IS NULL`  | `CASE WHEN x_null_count = x_row_count THEN false ELSE x_null_count > 0 END`
+/// `x IS NULL`  | `x_null_count > 0`
 /// `CAST(x as int) = 5` | `CASE WHEN x_null_count = x_row_count THEN false ELSE CAST(x_min as int) <= 5 AND 5 <= CAST(x_max as int) END`
 ///
 /// ## Predicate Evaluation

From 215f30f74a12e91fd7dca0d30e37014c8c3493ed Mon Sep 17 00:00:00 2001
From: Jonah Gao 
Date: Mon, 8 Apr 2024 11:08:06 +0800
Subject: [PATCH 08/14] fix: improve `unnest_generic_list` handling of null
 list (#9975)

* fix: improve `unnest_generic_list` handling of null list

* fix clippy

* fix comment
---
 datafusion/physical-plan/src/unnest.rs | 139 +++++++++++++++++++++----
 1 file changed, 117 insertions(+), 22 deletions(-)

diff --git a/datafusion/physical-plan/src/unnest.rs b/datafusion/physical-plan/src/unnest.rs
index 324e2ea2d7733..6ea1b3c40c83d 100644
--- a/datafusion/physical-plan/src/unnest.rs
+++ b/datafusion/physical-plan/src/unnest.rs
@@ -364,32 +364,31 @@ fn unnest_generic_list>(
     options: &UnnestOptions,
 ) -> Result> {
     let values = list_array.values();
-    if list_array.null_count() == 0 || !options.preserve_nulls {
-        Ok(values.clone())
-    } else {
-        let mut take_indicies_builder =
-            PrimitiveArray::

::builder(values.len() + list_array.null_count()); - let mut take_offset = 0; + if list_array.null_count() == 0 { + return Ok(values.clone()); + } - list_array.iter().for_each(|elem| match elem { - Some(array) => { - for i in 0..array.len() { - // take_offset + i is always positive - let take_index = P::Native::from_usize(take_offset + i).unwrap(); - take_indicies_builder.append_value(take_index); - } - take_offset += array.len(); - } - None => { + let mut take_indicies_builder = + PrimitiveArray::

::builder(values.len() + list_array.null_count()); + let offsets = list_array.value_offsets(); + for row in 0..list_array.len() { + if list_array.is_null(row) { + if options.preserve_nulls { take_indicies_builder.append_null(); } - }); - Ok(kernels::take::take( - &values, - &take_indicies_builder.finish(), - None, - )?) + } else { + let start = offsets[row].as_usize(); + let end = offsets[row + 1].as_usize(); + for idx in start..end { + take_indicies_builder.append_value(P::Native::from_usize(idx).unwrap()); + } + } } + Ok(kernels::take::take( + &values, + &take_indicies_builder.finish(), + None, + )?) } fn build_batch_fixedsize_list( @@ -596,3 +595,99 @@ where Ok(RecordBatch::try_new(schema.clone(), arrays.to_vec())?) } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::{ + array::AsArray, + datatypes::{DataType, Field}, + }; + use arrow_array::StringArray; + use arrow_buffer::{BooleanBufferBuilder, NullBuffer, OffsetBuffer}; + + // Create a ListArray with the following list values: + // [A, B, C], [], NULL, [D], NULL, [NULL, F] + fn make_test_array() -> ListArray { + let mut values = vec![]; + let mut offsets = vec![0]; + let mut valid = BooleanBufferBuilder::new(2); + + // [A, B, C] + values.extend_from_slice(&[Some("A"), Some("B"), Some("C")]); + offsets.push(values.len() as i32); + valid.append(true); + + // [] + offsets.push(values.len() as i32); + valid.append(true); + + // NULL with non-zero value length + // Issue https://github.com/apache/arrow-datafusion/issues/9932 + values.push(Some("?")); + offsets.push(values.len() as i32); + valid.append(false); + + // [D] + values.push(Some("D")); + offsets.push(values.len() as i32); + valid.append(true); + + // Another NULL with zero value length + offsets.push(values.len() as i32); + valid.append(false); + + // [NULL, F] + values.extend_from_slice(&[None, Some("F")]); + offsets.push(values.len() as i32); + valid.append(true); + + let field = Arc::new(Field::new("item", DataType::Utf8, true)); + ListArray::new( + field, + OffsetBuffer::new(offsets.into()), + Arc::new(StringArray::from(values)), + Some(NullBuffer::new(valid.finish())), + ) + } + + #[test] + fn test_unnest_generic_list() -> datafusion_common::Result<()> { + let list_array = make_test_array(); + + // Test with preserve_nulls = false + let options = UnnestOptions { + preserve_nulls: false, + }; + let unnested_array = + unnest_generic_list::(&list_array, &options)?; + let strs = unnested_array.as_string::().iter().collect::>(); + assert_eq!( + strs, + vec![Some("A"), Some("B"), Some("C"), Some("D"), None, Some("F")] + ); + + // Test with preserve_nulls = true + let options = UnnestOptions { + preserve_nulls: true, + }; + let unnested_array = + unnest_generic_list::(&list_array, &options)?; + let strs = unnested_array.as_string::().iter().collect::>(); + assert_eq!( + strs, + vec![ + Some("A"), + Some("B"), + Some("C"), + None, + Some("D"), + None, + None, + Some("F") + ] + ); + + Ok(()) + } +} From 0a4d9a6c788c1e4ad340943492abb823bd31c4f9 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Mon, 8 Apr 2024 11:28:59 +0200 Subject: [PATCH 09/14] Consistent LogicalPlan subquery handling in TreeNode::apply and TreeNode::visit (#9913) * fix * clippy * remove accidental extra apply * minor fixes * fix `LogicalPlan::apply_expressions()` and `LogicalPlan::map_subqueries()` * fix `LogicalPlan::visit_with_subqueries()` * Add deprecated LogicalPlan::inspect_expressions --------- Co-authored-by: Andrew Lamb --- datafusion/common/src/tree_node.rs | 3 +- datafusion/core/src/execution/context/mod.rs | 4 +- datafusion/expr/src/logical_plan/plan.rs | 558 ++++++++++++++---- datafusion/expr/src/tree_node/expr.rs | 2 +- datafusion/expr/src/tree_node/plan.rs | 53 +- datafusion/optimizer/src/analyzer/mod.rs | 15 +- datafusion/optimizer/src/analyzer/subquery.rs | 2 +- datafusion/optimizer/src/plan_signature.rs | 4 +- 8 files changed, 475 insertions(+), 166 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 8e088e7a0b567..42514537e28d8 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -25,10 +25,9 @@ use crate::Result; /// These macros are used to determine continuation during transforming traversals. macro_rules! handle_transform_recursion { ($F_DOWN:expr, $F_CHILD:expr, $F_UP:expr) => {{ - #[allow(clippy::redundant_closure_call)] $F_DOWN? .transform_children(|n| n.map_children($F_CHILD))? - .transform_parent(|n| $F_UP(n)) + .transform_parent($F_UP) }}; } diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index f15c1c218db6d..9e48c7b8a6f24 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -67,7 +67,7 @@ use datafusion_common::{ alias::AliasGenerator, config::{ConfigExtension, TableOptions}, exec_err, not_impl_err, plan_datafusion_err, plan_err, - tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}, + tree_node::{TreeNodeRecursion, TreeNodeVisitor}, SchemaReference, TableReference, }; use datafusion_execution::registry::SerializerRegistry; @@ -2298,7 +2298,7 @@ impl SQLOptions { /// Return an error if the [`LogicalPlan`] has any nodes that are /// incompatible with this [`SQLOptions`]. pub fn verify_plan(&self, plan: &LogicalPlan) -> Result<()> { - plan.visit(&mut BadPlanVisitor::new(self))?; + plan.visit_with_subqueries(&mut BadPlanVisitor::new(self))?; Ok(()) } } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 3d40dcae0e4bd..4f55bbfe3f4d6 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -34,8 +34,7 @@ use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::logical_plan::{DmlStatement, Statement}; use crate::utils::{ enumerate_grouping_sets, exprlist_to_fields, find_out_reference_exprs, - grouping_set_expr_count, grouping_set_to_exprlist, inspect_expr_pre, - split_conjunction, + grouping_set_expr_count, grouping_set_to_exprlist, split_conjunction, }; use crate::{ build_join_schema, expr_vec_fmt, BinaryExpr, BuiltInWindowFunction, @@ -45,16 +44,19 @@ use crate::{ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor, + Transformed, TransformedResult, TreeNode, TreeNodeIterator, TreeNodeRecursion, + TreeNodeRewriter, TreeNodeVisitor, }; use datafusion_common::{ - aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, - DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependence, - FunctionalDependencies, ParamValues, Result, TableReference, UnnestOptions, + aggregate_functional_dependencies, internal_err, map_until_stop_and_collect, + plan_err, Column, Constraints, DFSchema, DFSchemaRef, DataFusionError, Dependency, + FunctionalDependence, FunctionalDependencies, ParamValues, Result, TableReference, + UnnestOptions, }; // backwards compatibility use crate::display::PgJsonVisitor; +use crate::tree_node::expr::transform_option_vec; pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; pub use datafusion_common::{JoinConstraint, JoinType}; @@ -248,9 +250,9 @@ impl LogicalPlan { /// DataFusion's optimizer attempts to optimize them away. pub fn expressions(self: &LogicalPlan) -> Vec { let mut exprs = vec![]; - self.inspect_expressions(|e| { + self.apply_expressions(|e| { exprs.push(e.clone()); - Ok(()) as Result<()> + Ok(TreeNodeRecursion::Continue) }) // closure always returns OK .unwrap(); @@ -261,13 +263,13 @@ impl LogicalPlan { /// logical plan nodes and all its descendant nodes. pub fn all_out_ref_exprs(self: &LogicalPlan) -> Vec { let mut exprs = vec![]; - self.inspect_expressions(|e| { + self.apply_expressions(|e| { find_out_reference_exprs(e).into_iter().for_each(|e| { if !exprs.contains(&e) { exprs.push(e) } }); - Ok(()) as Result<(), DataFusionError> + Ok(TreeNodeRecursion::Continue) }) // closure always returns OK .unwrap(); @@ -282,60 +284,81 @@ impl LogicalPlan { exprs } - /// Calls `f` on all expressions (non-recursively) in the current - /// logical plan node. This does not include expressions in any - /// children. + #[deprecated(since = "37.0.0", note = "Use `apply_expressions` instead")] pub fn inspect_expressions(self: &LogicalPlan, mut f: F) -> Result<(), E> where F: FnMut(&Expr) -> Result<(), E>, { + let mut err = Ok(()); + self.apply_expressions(|e| { + if let Err(e) = f(e) { + // save the error for later (it may not be a DataFusionError + err = Err(e); + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + }) + // The closure always returns OK, so this will always too + .expect("no way to return error during recursion"); + + err + } + + /// Calls `f` on all expressions (non-recursively) in the current + /// logical plan node. This does not include expressions in any + /// children. + pub fn apply_expressions Result>( + &self, + mut f: F, + ) -> Result { match self { LogicalPlan::Projection(Projection { expr, .. }) => { - expr.iter().try_for_each(f) - } - LogicalPlan::Values(Values { values, .. }) => { - values.iter().flatten().try_for_each(f) + expr.iter().apply_until_stop(f) } + LogicalPlan::Values(Values { values, .. }) => values + .iter() + .apply_until_stop(|value| value.iter().apply_until_stop(&mut f)), LogicalPlan::Filter(Filter { predicate, .. }) => f(predicate), LogicalPlan::Repartition(Repartition { partitioning_scheme, .. }) => match partitioning_scheme { - Partitioning::Hash(expr, _) => expr.iter().try_for_each(f), - Partitioning::DistributeBy(expr) => expr.iter().try_for_each(f), - Partitioning::RoundRobinBatch(_) => Ok(()), + Partitioning::Hash(expr, _) | Partitioning::DistributeBy(expr) => { + expr.iter().apply_until_stop(f) + } + Partitioning::RoundRobinBatch(_) => Ok(TreeNodeRecursion::Continue), }, LogicalPlan::Window(Window { window_expr, .. }) => { - window_expr.iter().try_for_each(f) + window_expr.iter().apply_until_stop(f) } LogicalPlan::Aggregate(Aggregate { group_expr, aggr_expr, .. - }) => group_expr.iter().chain(aggr_expr.iter()).try_for_each(f), + }) => group_expr + .iter() + .chain(aggr_expr.iter()) + .apply_until_stop(f), // There are two part of expression for join, equijoin(on) and non-equijoin(filter). // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. // 2. the second part is non-equijoin(filter). LogicalPlan::Join(Join { on, filter, .. }) => { on.iter() + // TODO: why we need to create an `Expr::eq`? Cloning `Expr` is costly... // it not ideal to create an expr here to analyze them, but could cache it on the Join itself .map(|(l, r)| Expr::eq(l.clone(), r.clone())) - .try_for_each(|e| f(&e))?; - - if let Some(filter) = filter.as_ref() { - f(filter) - } else { - Ok(()) - } + .apply_until_stop(|e| f(&e))? + .visit_sibling(|| filter.iter().apply_until_stop(f)) } - LogicalPlan::Sort(Sort { expr, .. }) => expr.iter().try_for_each(f), + LogicalPlan::Sort(Sort { expr, .. }) => expr.iter().apply_until_stop(f), LogicalPlan::Extension(extension) => { // would be nice to avoid this copy -- maybe can // update extension to just observer Exprs - extension.node.expressions().iter().try_for_each(f) + extension.node.expressions().iter().apply_until_stop(f) } LogicalPlan::TableScan(TableScan { filters, .. }) => { - filters.iter().try_for_each(f) + filters.iter().apply_until_stop(f) } LogicalPlan::Unnest(Unnest { column, .. }) => { f(&Expr::Column(column.clone())) @@ -348,8 +371,8 @@ impl LogicalPlan { })) => on_expr .iter() .chain(select_expr.iter()) - .chain(sort_expr.clone().unwrap_or(vec![]).iter()) - .try_for_each(f), + .chain(sort_expr.iter().flatten()) + .apply_until_stop(f), // plans without expressions LogicalPlan::EmptyRelation(_) | LogicalPlan::RecursiveQuery(_) @@ -366,10 +389,225 @@ impl LogicalPlan { | LogicalPlan::Ddl(_) | LogicalPlan::Copy(_) | LogicalPlan::DescribeTable(_) - | LogicalPlan::Prepare(_) => Ok(()), + | LogicalPlan::Prepare(_) => Ok(TreeNodeRecursion::Continue), } } + pub fn map_expressions Result>>( + self, + mut f: F, + ) -> Result> { + Ok(match self { + LogicalPlan::Projection(Projection { + expr, + input, + schema, + }) => expr + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|expr| { + LogicalPlan::Projection(Projection { + expr, + input, + schema, + }) + }), + LogicalPlan::Values(Values { schema, values }) => values + .into_iter() + .map_until_stop_and_collect(|value| { + value.into_iter().map_until_stop_and_collect(&mut f) + })? + .update_data(|values| LogicalPlan::Values(Values { schema, values })), + LogicalPlan::Filter(Filter { predicate, input }) => f(predicate)? + .update_data(|predicate| { + LogicalPlan::Filter(Filter { predicate, input }) + }), + LogicalPlan::Repartition(Repartition { + input, + partitioning_scheme, + }) => match partitioning_scheme { + Partitioning::Hash(expr, usize) => expr + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|expr| Partitioning::Hash(expr, usize)), + Partitioning::DistributeBy(expr) => expr + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(Partitioning::DistributeBy), + Partitioning::RoundRobinBatch(_) => Transformed::no(partitioning_scheme), + } + .update_data(|partitioning_scheme| { + LogicalPlan::Repartition(Repartition { + input, + partitioning_scheme, + }) + }), + LogicalPlan::Window(Window { + input, + window_expr, + schema, + }) => window_expr + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|window_expr| { + LogicalPlan::Window(Window { + input, + window_expr, + schema, + }) + }), + LogicalPlan::Aggregate(Aggregate { + input, + group_expr, + aggr_expr, + schema, + }) => map_until_stop_and_collect!( + group_expr.into_iter().map_until_stop_and_collect(&mut f), + aggr_expr, + aggr_expr.into_iter().map_until_stop_and_collect(&mut f) + )? + .update_data(|(group_expr, aggr_expr)| { + LogicalPlan::Aggregate(Aggregate { + input, + group_expr, + aggr_expr, + schema, + }) + }), + + // There are two part of expression for join, equijoin(on) and non-equijoin(filter). + // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. + // 2. the second part is non-equijoin(filter). + LogicalPlan::Join(Join { + left, + right, + on, + filter, + join_type, + join_constraint, + schema, + null_equals_null, + }) => map_until_stop_and_collect!( + on.into_iter().map_until_stop_and_collect( + |on| map_until_stop_and_collect!(f(on.0), on.1, f(on.1)) + ), + filter, + filter.map_or(Ok::<_, DataFusionError>(Transformed::no(None)), |e| { + Ok(f(e)?.update_data(Some)) + }) + )? + .update_data(|(on, filter)| { + LogicalPlan::Join(Join { + left, + right, + on, + filter, + join_type, + join_constraint, + schema, + null_equals_null, + }) + }), + LogicalPlan::Sort(Sort { expr, input, fetch }) => expr + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|expr| LogicalPlan::Sort(Sort { expr, input, fetch })), + LogicalPlan::Extension(Extension { node }) => { + // would be nice to avoid this copy -- maybe can + // update extension to just observer Exprs + node.expressions() + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|exprs| { + LogicalPlan::Extension(Extension { + node: UserDefinedLogicalNode::from_template( + node.as_ref(), + exprs.as_slice(), + node.inputs() + .into_iter() + .cloned() + .collect::>() + .as_slice(), + ), + }) + }) + } + LogicalPlan::TableScan(TableScan { + table_name, + source, + projection, + projected_schema, + filters, + fetch, + }) => filters + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|filters| { + LogicalPlan::TableScan(TableScan { + table_name, + source, + projection, + projected_schema, + filters, + fetch, + }) + }), + LogicalPlan::Unnest(Unnest { + input, + column, + schema, + options, + }) => f(Expr::Column(column))?.map_data(|column| match column { + Expr::Column(column) => Ok(LogicalPlan::Unnest(Unnest { + input, + column, + schema, + options, + })), + _ => internal_err!("Transformation should return Column"), + })?, + LogicalPlan::Distinct(Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + input, + schema, + })) => map_until_stop_and_collect!( + on_expr.into_iter().map_until_stop_and_collect(&mut f), + select_expr, + select_expr.into_iter().map_until_stop_and_collect(&mut f), + sort_expr, + transform_option_vec(sort_expr, &mut f) + )? + .update_data(|(on_expr, select_expr, sort_expr)| { + LogicalPlan::Distinct(Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + input, + schema, + })) + }), + // plans without expressions + LogicalPlan::EmptyRelation(_) + | LogicalPlan::RecursiveQuery(_) + | LogicalPlan::Subquery(_) + | LogicalPlan::SubqueryAlias(_) + | LogicalPlan::Limit(_) + | LogicalPlan::Statement(_) + | LogicalPlan::CrossJoin(_) + | LogicalPlan::Analyze(_) + | LogicalPlan::Explain(_) + | LogicalPlan::Union(_) + | LogicalPlan::Distinct(Distinct::All(_)) + | LogicalPlan::Dml(_) + | LogicalPlan::Ddl(_) + | LogicalPlan::Copy(_) + | LogicalPlan::DescribeTable(_) + | LogicalPlan::Prepare(_) => Transformed::no(self), + }) + } + /// returns all inputs of this `LogicalPlan` node. Does not /// include inputs to inputs, or subqueries. pub fn inputs(&self) -> Vec<&LogicalPlan> { @@ -417,7 +655,7 @@ impl LogicalPlan { pub fn using_columns(&self) -> Result>, DataFusionError> { let mut using_columns: Vec> = vec![]; - self.apply(&mut |plan| { + self.apply_with_subqueries(&mut |plan| { if let LogicalPlan::Join(Join { join_constraint: JoinConstraint::Using, on, @@ -1079,57 +1317,178 @@ impl LogicalPlan { } } +/// This macro is used to determine continuation during combined transforming +/// traversals. +macro_rules! handle_transform_recursion { + ($F_DOWN:expr, $F_CHILD:expr, $F_UP:expr) => {{ + $F_DOWN? + .transform_children(|n| n.map_subqueries($F_CHILD))? + .transform_sibling(|n| n.map_children($F_CHILD))? + .transform_parent($F_UP) + }}; +} + +macro_rules! handle_transform_recursion_down { + ($F_DOWN:expr, $F_CHILD:expr) => {{ + $F_DOWN? + .transform_children(|n| n.map_subqueries($F_CHILD))? + .transform_sibling(|n| n.map_children($F_CHILD)) + }}; +} + +macro_rules! handle_transform_recursion_up { + ($SELF:expr, $F_CHILD:expr, $F_UP:expr) => {{ + $SELF + .map_subqueries($F_CHILD)? + .transform_sibling(|n| n.map_children($F_CHILD))? + .transform_parent(|n| $F_UP(n)) + }}; +} + impl LogicalPlan { - /// applies `op` to any subqueries in the plan - pub(crate) fn apply_subqueries(&self, op: &mut F) -> Result<()> - where - F: FnMut(&Self) -> Result, - { - self.inspect_expressions(|expr| { - // recursively look for subqueries - inspect_expr_pre(expr, |expr| { - match expr { - Expr::Exists(Exists { subquery, .. }) - | Expr::InSubquery(InSubquery { subquery, .. }) - | Expr::ScalarSubquery(subquery) => { - // use a synthetic plan so the collector sees a - // LogicalPlan::Subquery (even though it is - // actually a Subquery alias) - let synthetic_plan = LogicalPlan::Subquery(subquery.clone()); - synthetic_plan.apply(op)?; - } - _ => {} + pub fn visit_with_subqueries>( + &self, + visitor: &mut V, + ) -> Result { + visitor + .f_down(self)? + .visit_children(|| { + self.apply_subqueries(|c| c.visit_with_subqueries(visitor)) + })? + .visit_sibling(|| self.apply_children(|c| c.visit_with_subqueries(visitor)))? + .visit_parent(|| visitor.f_up(self)) + } + + pub fn rewrite_with_subqueries>( + self, + rewriter: &mut R, + ) -> Result> { + handle_transform_recursion!( + rewriter.f_down(self), + |c| c.rewrite_with_subqueries(rewriter), + |n| rewriter.f_up(n) + ) + } + + pub fn apply_with_subqueries Result>( + &self, + f: &mut F, + ) -> Result { + f(self)? + .visit_children(|| self.apply_subqueries(|c| c.apply_with_subqueries(f)))? + .visit_sibling(|| self.apply_children(|c| c.apply_with_subqueries(f))) + } + + pub fn transform_with_subqueries Result>>( + self, + f: &F, + ) -> Result> { + self.transform_up_with_subqueries(f) + } + + pub fn transform_down_with_subqueries Result>>( + self, + f: &F, + ) -> Result> { + handle_transform_recursion_down!(f(self), |c| c.transform_down_with_subqueries(f)) + } + + pub fn transform_down_mut_with_subqueries< + F: FnMut(Self) -> Result>, + >( + self, + f: &mut F, + ) -> Result> { + handle_transform_recursion_down!(f(self), |c| c + .transform_down_mut_with_subqueries(f)) + } + + pub fn transform_up_with_subqueries Result>>( + self, + f: &F, + ) -> Result> { + handle_transform_recursion_up!(self, |c| c.transform_up_with_subqueries(f), f) + } + + pub fn transform_up_mut_with_subqueries< + F: FnMut(Self) -> Result>, + >( + self, + f: &mut F, + ) -> Result> { + handle_transform_recursion_up!(self, |c| c.transform_up_mut_with_subqueries(f), f) + } + + pub fn transform_down_up_with_subqueries< + FD: FnMut(Self) -> Result>, + FU: FnMut(Self) -> Result>, + >( + self, + f_down: &mut FD, + f_up: &mut FU, + ) -> Result> { + handle_transform_recursion!( + f_down(self), + |c| c.transform_down_up_with_subqueries(f_down, f_up), + f_up + ) + } + + fn apply_subqueries Result>( + &self, + mut f: F, + ) -> Result { + self.apply_expressions(|expr| { + expr.apply(&mut |expr| match expr { + Expr::Exists(Exists { subquery, .. }) + | Expr::InSubquery(InSubquery { subquery, .. }) + | Expr::ScalarSubquery(subquery) => { + // use a synthetic plan so the collector sees a + // LogicalPlan::Subquery (even though it is + // actually a Subquery alias) + f(&LogicalPlan::Subquery(subquery.clone())) } - Ok::<(), DataFusionError>(()) + _ => Ok(TreeNodeRecursion::Continue), }) - })?; - Ok(()) + }) } - /// applies visitor to any subqueries in the plan - pub(crate) fn visit_subqueries(&self, v: &mut V) -> Result<()> - where - V: TreeNodeVisitor, - { - self.inspect_expressions(|expr| { - // recursively look for subqueries - inspect_expr_pre(expr, |expr| { - match expr { - Expr::Exists(Exists { subquery, .. }) - | Expr::InSubquery(InSubquery { subquery, .. }) - | Expr::ScalarSubquery(subquery) => { - // use a synthetic plan so the visitor sees a - // LogicalPlan::Subquery (even though it is - // actually a Subquery alias) - let synthetic_plan = LogicalPlan::Subquery(subquery.clone()); - synthetic_plan.visit(v)?; - } - _ => {} + fn map_subqueries Result>>( + self, + mut f: F, + ) -> Result> { + self.map_expressions(|expr| { + expr.transform_down_mut(&mut |expr| match expr { + Expr::Exists(Exists { subquery, negated }) => { + f(LogicalPlan::Subquery(subquery))?.map_data(|s| match s { + LogicalPlan::Subquery(subquery) => { + Ok(Expr::Exists(Exists { subquery, negated })) + } + _ => internal_err!("Transformation should return Subquery"), + }) } - Ok::<(), DataFusionError>(()) + Expr::InSubquery(InSubquery { + expr, + subquery, + negated, + }) => f(LogicalPlan::Subquery(subquery))?.map_data(|s| match s { + LogicalPlan::Subquery(subquery) => Ok(Expr::InSubquery(InSubquery { + expr, + subquery, + negated, + })), + _ => internal_err!("Transformation should return Subquery"), + }), + Expr::ScalarSubquery(subquery) => f(LogicalPlan::Subquery(subquery))? + .map_data(|s| match s { + LogicalPlan::Subquery(subquery) => { + Ok(Expr::ScalarSubquery(subquery)) + } + _ => internal_err!("Transformation should return Subquery"), + }), + _ => Ok(Transformed::no(expr)), }) - })?; - Ok(()) + }) } /// Return a `LogicalPlan` with all placeholders (e.g $1 $2, @@ -1165,8 +1524,8 @@ impl LogicalPlan { ) -> Result>, DataFusionError> { let mut param_types: HashMap> = HashMap::new(); - self.apply(&mut |plan| { - plan.inspect_expressions(|expr| { + self.apply_with_subqueries(&mut |plan| { + plan.apply_expressions(|expr| { expr.apply(&mut |expr| { if let Expr::Placeholder(Placeholder { id, data_type }) = expr { let prev = param_types.get(id); @@ -1183,13 +1542,10 @@ impl LogicalPlan { } } Ok(TreeNodeRecursion::Continue) - })?; - Ok::<(), DataFusionError>(()) - })?; - Ok(TreeNodeRecursion::Continue) - })?; - - Ok(param_types) + }) + }) + }) + .map(|_| param_types) } /// Return an Expr with all placeholders replaced with their @@ -1257,7 +1613,7 @@ impl LogicalPlan { fn fmt(&self, f: &mut Formatter) -> fmt::Result { let with_schema = false; let mut visitor = IndentVisitor::new(f, with_schema); - match self.0.visit(&mut visitor) { + match self.0.visit_with_subqueries(&mut visitor) { Ok(_) => Ok(()), Err(_) => Err(fmt::Error), } @@ -1300,7 +1656,7 @@ impl LogicalPlan { fn fmt(&self, f: &mut Formatter) -> fmt::Result { let with_schema = true; let mut visitor = IndentVisitor::new(f, with_schema); - match self.0.visit(&mut visitor) { + match self.0.visit_with_subqueries(&mut visitor) { Ok(_) => Ok(()), Err(_) => Err(fmt::Error), } @@ -1320,7 +1676,7 @@ impl LogicalPlan { fn fmt(&self, f: &mut Formatter) -> fmt::Result { let mut visitor = PgJsonVisitor::new(f); visitor.with_schema(true); - match self.0.visit(&mut visitor) { + match self.0.visit_with_subqueries(&mut visitor) { Ok(_) => Ok(()), Err(_) => Err(fmt::Error), } @@ -1369,12 +1725,16 @@ impl LogicalPlan { visitor.start_graph()?; visitor.pre_visit_plan("LogicalPlan")?; - self.0.visit(&mut visitor).map_err(|_| fmt::Error)?; + self.0 + .visit_with_subqueries(&mut visitor) + .map_err(|_| fmt::Error)?; visitor.post_visit_plan()?; visitor.set_with_schema(true); visitor.pre_visit_plan("Detailed LogicalPlan")?; - self.0.visit(&mut visitor).map_err(|_| fmt::Error)?; + self.0 + .visit_with_subqueries(&mut visitor) + .map_err(|_| fmt::Error)?; visitor.post_visit_plan()?; visitor.end_graph()?; @@ -2908,7 +3268,7 @@ digraph { fn visit_order() { let mut visitor = OkVisitor::default(); let plan = test_plan(); - let res = plan.visit(&mut visitor); + let res = plan.visit_with_subqueries(&mut visitor); assert!(res.is_ok()); assert_eq!( @@ -2984,7 +3344,7 @@ digraph { ..Default::default() }; let plan = test_plan(); - let res = plan.visit(&mut visitor); + let res = plan.visit_with_subqueries(&mut visitor); assert!(res.is_ok()); assert_eq!( @@ -3000,7 +3360,7 @@ digraph { ..Default::default() }; let plan = test_plan(); - let res = plan.visit(&mut visitor); + let res = plan.visit_with_subqueries(&mut visitor); assert!(res.is_ok()); assert_eq!( @@ -3051,7 +3411,7 @@ digraph { ..Default::default() }; let plan = test_plan(); - let res = plan.visit(&mut visitor).unwrap_err(); + let res = plan.visit_with_subqueries(&mut visitor).unwrap_err(); assert_eq!( "This feature is not implemented: Error in pre_visit", res.strip_backtrace() @@ -3069,7 +3429,7 @@ digraph { ..Default::default() }; let plan = test_plan(); - let res = plan.visit(&mut visitor).unwrap_err(); + let res = plan.visit_with_subqueries(&mut visitor).unwrap_err(); assert_eq!( "This feature is not implemented: Error in post_visit", res.strip_backtrace() diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 97331720ce7d0..85097f6249e13 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -412,7 +412,7 @@ where } /// &mut transform a Option<`Vec` of `Expr`s> -fn transform_option_vec( +pub fn transform_option_vec( ove: Option>, f: &mut F, ) -> Result>>> diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs index 7a6b1005fedec..482fc96b519b1 100644 --- a/datafusion/expr/src/tree_node/plan.rs +++ b/datafusion/expr/src/tree_node/plan.rs @@ -20,58 +20,11 @@ use crate::LogicalPlan; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, TreeNodeVisitor, + Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, }; use datafusion_common::Result; impl TreeNode for LogicalPlan { - fn apply Result>( - &self, - f: &mut F, - ) -> Result { - // Compared to the default implementation, we need to invoke - // [`Self::apply_subqueries`] before visiting its children - f(self)?.visit_children(|| { - self.apply_subqueries(f)?; - self.apply_children(|n| n.apply(f)) - }) - } - - /// To use, define a struct that implements the trait [`TreeNodeVisitor`] and then invoke - /// [`LogicalPlan::visit`]. - /// - /// For example, for a logical plan like: - /// - /// ```text - /// Projection: id - /// Filter: state Eq Utf8(\"CO\")\ - /// CsvScan: employee.csv projection=Some([0, 3])"; - /// ``` - /// - /// The sequence of visit operations would be: - /// ```text - /// visitor.pre_visit(Projection) - /// visitor.pre_visit(Filter) - /// visitor.pre_visit(CsvScan) - /// visitor.post_visit(CsvScan) - /// visitor.post_visit(Filter) - /// visitor.post_visit(Projection) - /// ``` - fn visit>( - &self, - visitor: &mut V, - ) -> Result { - // Compared to the default implementation, we need to invoke - // [`Self::visit_subqueries`] before visiting its children - visitor - .f_down(self)? - .visit_children(|| { - self.visit_subqueries(visitor)?; - self.apply_children(|n| n.visit(visitor)) - })? - .visit_parent(|| visitor.f_up(self)) - } - fn apply_children Result>( &self, f: F, @@ -85,8 +38,8 @@ impl TreeNode for LogicalPlan { ) -> Result> { let new_children = self .inputs() - .iter() - .map(|&c| c.clone()) + .into_iter() + .cloned() .map_until_stop_and_collect(f)?; // Propagate up `new_children.transformed` and `new_children.tnr` // along with the node containing transformed children. diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index b446fe2f320ec..d0b83d24299b1 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -155,8 +155,8 @@ impl Analyzer { /// Do necessary check and fail the invalid plan fn check_plan(plan: &LogicalPlan) -> Result<()> { - plan.apply(&mut |plan: &LogicalPlan| { - plan.inspect_expressions(|expr| { + plan.apply_with_subqueries(&mut |plan: &LogicalPlan| { + plan.apply_expressions(|expr| { // recursively look for subqueries expr.apply(&mut |expr| { match expr { @@ -168,11 +168,8 @@ fn check_plan(plan: &LogicalPlan) -> Result<()> { _ => {} }; Ok(TreeNodeRecursion::Continue) - })?; - Ok::<(), DataFusionError>(()) - })?; - Ok(TreeNodeRecursion::Continue) - })?; - - Ok(()) + }) + }) + }) + .map(|_| ()) } diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index 038361c3ee8c3..79375e52da1f9 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -283,7 +283,7 @@ fn strip_inner_query(inner_plan: &LogicalPlan) -> &LogicalPlan { fn get_correlated_expressions(inner_plan: &LogicalPlan) -> Result> { let mut exprs = vec![]; - inner_plan.apply(&mut |plan| { + inner_plan.apply_with_subqueries(&mut |plan| { if let LogicalPlan::Filter(Filter { predicate, .. }) = plan { let (correlated, _): (Vec<_>, Vec<_>) = split_conjunction(predicate) .into_iter() diff --git a/datafusion/optimizer/src/plan_signature.rs b/datafusion/optimizer/src/plan_signature.rs index 4143d52a053eb..a8e323ff429f2 100644 --- a/datafusion/optimizer/src/plan_signature.rs +++ b/datafusion/optimizer/src/plan_signature.rs @@ -21,7 +21,7 @@ use std::{ num::NonZeroUsize, }; -use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; +use datafusion_common::tree_node::TreeNodeRecursion; use datafusion_expr::LogicalPlan; /// Non-unique identifier of a [`LogicalPlan`]. @@ -73,7 +73,7 @@ impl LogicalPlanSignature { /// Get total number of [`LogicalPlan`]s in the plan. fn get_node_number(plan: &LogicalPlan) -> NonZeroUsize { let mut node_number = 0; - plan.apply(&mut |_plan| { + plan.apply_with_subqueries(&mut |_plan| { node_number += 1; Ok(TreeNodeRecursion::Continue) }) From fc29c3e67d43e82e4d1d49b44d150ff710ad7004 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Mon, 8 Apr 2024 18:29:32 +0800 Subject: [PATCH 10/14] Remove unnecessary result (#9990) --- datafusion/common/src/dfschema.rs | 29 +++++++++++++---------------- datafusion/expr/src/utils.rs | 2 +- datafusion/sql/src/statement.rs | 2 +- 3 files changed, 15 insertions(+), 18 deletions(-) diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 9f167fd1f6d94..83e53b3cc6ff1 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -319,7 +319,7 @@ impl DFSchema { &self, qualifier: Option<&TableReference>, name: &str, - ) -> Result> { + ) -> Option { let mut matches = self .iter() .enumerate() @@ -345,19 +345,19 @@ impl DFSchema { (None, Some(_)) | (None, None) => f.name() == name, }) .map(|(idx, _)| idx); - Ok(matches.next()) + matches.next() } /// Find the index of the column with the given qualifier and name pub fn index_of_column(&self, col: &Column) -> Result { - self.index_of_column_by_name(col.relation.as_ref(), &col.name)? + self.index_of_column_by_name(col.relation.as_ref(), &col.name) .ok_or_else(|| field_not_found(col.relation.clone(), &col.name, self)) } /// Check if the column is in the current schema - pub fn is_column_from_schema(&self, col: &Column) -> Result { + pub fn is_column_from_schema(&self, col: &Column) -> bool { self.index_of_column_by_name(col.relation.as_ref(), &col.name) - .map(|idx| idx.is_some()) + .is_some() } /// Find the field with the given name @@ -381,7 +381,7 @@ impl DFSchema { ) -> Result<(Option<&TableReference>, &Field)> { if let Some(qualifier) = qualifier { let idx = self - .index_of_column_by_name(Some(qualifier), name)? + .index_of_column_by_name(Some(qualifier), name) .ok_or_else(|| field_not_found(Some(qualifier.clone()), name, self))?; Ok((self.field_qualifiers[idx].as_ref(), self.field(idx))) } else { @@ -519,7 +519,7 @@ impl DFSchema { name: &str, ) -> Result<&Field> { let idx = self - .index_of_column_by_name(Some(qualifier), name)? + .index_of_column_by_name(Some(qualifier), name) .ok_or_else(|| field_not_found(Some(qualifier.clone()), name, self))?; Ok(self.field(idx)) @@ -1190,11 +1190,8 @@ mod tests { .to_string(), expected_help ); - assert!(schema.index_of_column_by_name(None, "y").unwrap().is_none()); - assert!(schema - .index_of_column_by_name(None, "t1.c0") - .unwrap() - .is_none()); + assert!(schema.index_of_column_by_name(None, "y").is_none()); + assert!(schema.index_of_column_by_name(None, "t1.c0").is_none()); Ok(()) } @@ -1284,28 +1281,28 @@ mod tests { { let col = Column::from_qualified_name("t1.c0"); let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; - assert!(schema.is_column_from_schema(&col)?); + assert!(schema.is_column_from_schema(&col)); } // qualified not exists { let col = Column::from_qualified_name("t1.c2"); let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; - assert!(!schema.is_column_from_schema(&col)?); + assert!(!schema.is_column_from_schema(&col)); } // unqualified exists { let col = Column::from_name("c0"); let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; - assert!(schema.is_column_from_schema(&col)?); + assert!(schema.is_column_from_schema(&col)); } // unqualified not exists { let col = Column::from_name("c2"); let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; - assert!(!schema.is_column_from_schema(&col)?); + assert!(!schema.is_column_from_schema(&col)); } Ok(()) diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 0d99d0b5028ee..8c6b98f17933d 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -933,7 +933,7 @@ pub fn check_all_columns_from_schema( schema: DFSchemaRef, ) -> Result { for col in columns.iter() { - let exist = schema.is_column_from_schema(col)?; + let exist = schema.is_column_from_schema(col); if !exist { return Ok(false); } diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index b8c9172621c34..6b89f89aaccf8 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -1350,7 +1350,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .enumerate() .map(|(i, c)| { let column_index = table_schema - .index_of_column_by_name(None, &c)? + .index_of_column_by_name(None, &c) .ok_or_else(|| unqualified_field_not_found(&c, &table_schema))?; if value_indices[column_index].is_some() { return schema_err!(SchemaError::DuplicateUnqualifiedField { From 820843ff597161c9cdacd0e79cecf20d05755081 Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Mon, 8 Apr 2024 06:31:10 -0400 Subject: [PATCH 11/14] Removes Bloom filter for Int8/Int16/Uint8/Uint16 (#9969) * Removing broken tests * Simplifying tests / removing support for failed tests * Revert "Simplifying tests / removing support for failed tests" This reverts commit 6e50a8064436943d9f42d313cef2c2b017d196f1. * Fixing tests for real * Apply suggestions from code review Thanks @alamb ! Co-authored-by: Andrew Lamb --------- Co-authored-by: Andrew Lamb --- .../physical_plan/parquet/row_groups.rs | 4 -- .../core/tests/parquet/row_group_pruning.rs | 54 +++++++++---------- 2 files changed, 26 insertions(+), 32 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index 8df4925fc5667..6600dd07d7fd4 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -232,12 +232,8 @@ impl PruningStatistics for BloomFilterStatistics { ScalarValue::Float32(Some(v)) => sbbf.check(v), ScalarValue::Int64(Some(v)) => sbbf.check(v), ScalarValue::Int32(Some(v)) => sbbf.check(v), - ScalarValue::Int16(Some(v)) => sbbf.check(v), - ScalarValue::Int8(Some(v)) => sbbf.check(v), ScalarValue::UInt64(Some(v)) => sbbf.check(v), ScalarValue::UInt32(Some(v)) => sbbf.check(v), - ScalarValue::UInt16(Some(v)) => sbbf.check(v), - ScalarValue::UInt8(Some(v)) => sbbf.check(v), ScalarValue::Decimal128(Some(v), p, s) => match parquet_type { Type::INT32 => { //https://github.com/apache/parquet-format/blob/eb4b31c1d64a01088d02a2f9aefc6c17c54cc6fc/Encodings.md?plain=1#L35-L42 diff --git a/datafusion/core/tests/parquet/row_group_pruning.rs b/datafusion/core/tests/parquet/row_group_pruning.rs index b7b434d1c3d3b..8fc7936552af3 100644 --- a/datafusion/core/tests/parquet/row_group_pruning.rs +++ b/datafusion/core/tests/parquet/row_group_pruning.rs @@ -290,7 +290,7 @@ async fn prune_disabled() { // https://github.com/apache/arrow-datafusion/issues/9779 bug so that tests pass // if and only if Bloom filters on Int8 and Int16 columns are still buggy. macro_rules! int_tests { - ($bits:expr, correct_bloom_filters: $correct_bloom_filters:expr) => { + ($bits:expr) => { paste::item! { #[tokio::test] async fn []() { @@ -329,9 +329,9 @@ macro_rules! int_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) - .with_matched_by_bloom_filter(Some(if $correct_bloom_filters { 1 } else { 0 })) - .with_pruned_by_bloom_filter(Some(if $correct_bloom_filters { 0 } else { 1 })) - .with_expected_rows(if $correct_bloom_filters { 1 } else { 0 }) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) .test_row_group_prune() .await; } @@ -343,9 +343,9 @@ macro_rules! int_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) - .with_matched_by_bloom_filter(Some(if $correct_bloom_filters { 1 } else { 0 })) - .with_pruned_by_bloom_filter(Some(if $correct_bloom_filters { 0 } else { 1 })) - .with_expected_rows(if $correct_bloom_filters { 1 } else { 0 }) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) .test_row_group_prune() .await; } @@ -404,9 +404,9 @@ macro_rules! int_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) - .with_matched_by_bloom_filter(Some(if $correct_bloom_filters { 1 } else { 0 })) - .with_pruned_by_bloom_filter(Some(if $correct_bloom_filters { 0 } else { 1 })) - .with_expected_rows(if $correct_bloom_filters { 1 } else { 0 }) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) .test_row_group_prune() .await; } @@ -447,17 +447,16 @@ macro_rules! int_tests { }; } -int_tests!(8, correct_bloom_filters: false); -int_tests!(16, correct_bloom_filters: false); -int_tests!(32, correct_bloom_filters: true); -int_tests!(64, correct_bloom_filters: true); +// int8/int16 are incorrect: https://github.com/apache/arrow-datafusion/issues/9779 +int_tests!(32); +int_tests!(64); // $bits: number of bits of the integer to test (8, 16, 32, 64) // $correct_bloom_filters: if false, replicates the // https://github.com/apache/arrow-datafusion/issues/9779 bug so that tests pass // if and only if Bloom filters on UInt8 and UInt16 columns are still buggy. macro_rules! uint_tests { - ($bits:expr, correct_bloom_filters: $correct_bloom_filters:expr) => { + ($bits:expr) => { paste::item! { #[tokio::test] async fn []() { @@ -482,9 +481,9 @@ macro_rules! uint_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) - .with_matched_by_bloom_filter(Some(if $correct_bloom_filters { 1 } else { 0 })) - .with_pruned_by_bloom_filter(Some(if $correct_bloom_filters { 0 } else { 1 })) - .with_expected_rows(if $correct_bloom_filters { 1 } else { 0 }) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) .test_row_group_prune() .await; } @@ -496,9 +495,9 @@ macro_rules! uint_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) - .with_matched_by_bloom_filter(Some(if $correct_bloom_filters { 1 } else { 0 })) - .with_pruned_by_bloom_filter(Some(if $correct_bloom_filters { 0 } else { 1 })) - .with_expected_rows(if $correct_bloom_filters { 1 } else { 0 }) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) .test_row_group_prune() .await; } @@ -542,9 +541,9 @@ macro_rules! uint_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) - .with_matched_by_bloom_filter(Some(if $correct_bloom_filters { 1 } else { 0 })) - .with_pruned_by_bloom_filter(Some(if $correct_bloom_filters { 0 } else { 1 })) - .with_expected_rows(if $correct_bloom_filters { 1 } else { 0 }) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) .test_row_group_prune() .await; } @@ -585,10 +584,9 @@ macro_rules! uint_tests { }; } -uint_tests!(8, correct_bloom_filters: false); -uint_tests!(16, correct_bloom_filters: false); -uint_tests!(32, correct_bloom_filters: true); -uint_tests!(64, correct_bloom_filters: true); +// uint8/uint16 are incorrect: https://github.com/apache/arrow-datafusion/issues/9779 +uint_tests!(32); +uint_tests!(64); #[tokio::test] async fn prune_int32_eq_large_in_list() { From 86ad8a580863218f9fb123f09e1d058094ec3ef8 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 8 Apr 2024 09:46:04 -0400 Subject: [PATCH 12/14] Move LogicalPlan tree_node modul (#9995) --- datafusion/expr/src/logical_plan/mod.rs | 1 + datafusion/expr/src/logical_plan/plan.rs | 2 +- .../plan.rs => logical_plan/tree_node.rs} | 0 .../src/{tree_node/expr.rs => tree_node.rs} | 0 datafusion/expr/src/tree_node/mod.rs | 21 ------------------- 5 files changed, 2 insertions(+), 22 deletions(-) rename datafusion/expr/src/{tree_node/plan.rs => logical_plan/tree_node.rs} (100%) rename datafusion/expr/src/{tree_node/expr.rs => tree_node.rs} (100%) delete mode 100644 datafusion/expr/src/tree_node/mod.rs diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index 84781cb2e9ec5..a1fe7a6f0a51e 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -22,6 +22,7 @@ pub mod dml; mod extension; mod plan; mod statement; +mod tree_node; pub use builder::{ build_join_schema, table_scan, union, wrap_projection_for_join_if_necessary, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 4f55bbfe3f4d6..860fd7daafbf5 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -56,7 +56,7 @@ use datafusion_common::{ // backwards compatibility use crate::display::PgJsonVisitor; -use crate::tree_node::expr::transform_option_vec; +use crate::tree_node::transform_option_vec; pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; pub use datafusion_common::{JoinConstraint, JoinType}; diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/logical_plan/tree_node.rs similarity index 100% rename from datafusion/expr/src/tree_node/plan.rs rename to datafusion/expr/src/logical_plan/tree_node.rs diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node.rs similarity index 100% rename from datafusion/expr/src/tree_node/expr.rs rename to datafusion/expr/src/tree_node.rs diff --git a/datafusion/expr/src/tree_node/mod.rs b/datafusion/expr/src/tree_node/mod.rs deleted file mode 100644 index 3f8bb6d3257e6..0000000000000 --- a/datafusion/expr/src/tree_node/mod.rs +++ /dev/null @@ -1,21 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Tree node implementation for logical expr and logical plan - -pub mod expr; -pub mod plan; From 8c9e5678228557aff370b137e9029462230df68a Mon Sep 17 00:00:00 2001 From: Kevin Mingtarja <69668484+kevinmingtarja@users.noreply.github.com> Date: Tue, 9 Apr 2024 00:01:26 +0800 Subject: [PATCH 13/14] Optimize performance of substr_index and add tests (#9973) * Optimize performance of substr_index --- .../functions/src/unicode/substrindex.rs | 153 +++++++++++++++--- .../sqllogictest/test_files/functions.slt | 11 +- 2 files changed, 143 insertions(+), 21 deletions(-) diff --git a/datafusion/functions/src/unicode/substrindex.rs b/datafusion/functions/src/unicode/substrindex.rs index d00108a68fc99..da4ff55828e9b 100644 --- a/datafusion/functions/src/unicode/substrindex.rs +++ b/datafusion/functions/src/unicode/substrindex.rs @@ -18,7 +18,7 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::array::{ArrayRef, OffsetSizeTrait, StringBuilder}; use arrow::datatypes::DataType; use datafusion_common::cast::{as_generic_string_array, as_int64_array}; @@ -101,38 +101,151 @@ pub fn substr_index(args: &[ArrayRef]) -> Result { let delimiter_array = as_generic_string_array::(&args[1])?; let count_array = as_int64_array(&args[2])?; - let result = string_array + let mut builder = StringBuilder::new(); + string_array .iter() .zip(delimiter_array.iter()) .zip(count_array.iter()) - .map(|((string, delimiter), n)| match (string, delimiter, n) { + .for_each(|((string, delimiter), n)| match (string, delimiter, n) { (Some(string), Some(delimiter), Some(n)) => { // In MySQL, these cases will return an empty string. if n == 0 || string.is_empty() || delimiter.is_empty() { - return Some(String::new()); + builder.append_value(""); + return; } - let splitted: Box> = if n > 0 { - Box::new(string.split(delimiter)) + let occurrences = usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX); + let length = if n > 0 { + let splitted = string.split(delimiter); + splitted + .take(occurrences) + .map(|s| s.len() + delimiter.len()) + .sum::() + - delimiter.len() } else { - Box::new(string.rsplit(delimiter)) + let splitted = string.rsplit(delimiter); + splitted + .take(occurrences) + .map(|s| s.len() + delimiter.len()) + .sum::() + - delimiter.len() }; - let occurrences = usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX); - // The length of the substring covered by substr_index. - let length = splitted - .take(occurrences) // at least 1 element, since n != 0 - .map(|s| s.len() + delimiter.len()) - .sum::() - - delimiter.len(); if n > 0 { - Some(string[..length].to_owned()) + match string.get(..length) { + Some(substring) => builder.append_value(substring), + None => builder.append_null(), + } } else { - Some(string[string.len() - length..].to_owned()) + match string.get(string.len().saturating_sub(length)..) { + Some(substring) => builder.append_value(substring), + None => builder.append_null(), + } } } - _ => None, - }) - .collect::>(); + _ => builder.append_null(), + }); + + Ok(Arc::new(builder.finish()) as ArrayRef) +} - Ok(Arc::new(result) as ArrayRef) +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::substrindex::SubstrIndexFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + test_function!( + SubstrIndexFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), + ColumnarValue::Scalar(ScalarValue::from(".")), + ColumnarValue::Scalar(ScalarValue::from(1i64)), + ], + Ok(Some("www")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrIndexFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), + ColumnarValue::Scalar(ScalarValue::from(".")), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + Ok(Some("www.apache")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrIndexFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), + ColumnarValue::Scalar(ScalarValue::from(".")), + ColumnarValue::Scalar(ScalarValue::from(-2i64)), + ], + Ok(Some("apache.org")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrIndexFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), + ColumnarValue::Scalar(ScalarValue::from(".")), + ColumnarValue::Scalar(ScalarValue::from(-1i64)), + ], + Ok(Some("org")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrIndexFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), + ColumnarValue::Scalar(ScalarValue::from(".")), + ColumnarValue::Scalar(ScalarValue::from(0i64)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrIndexFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("")), + ColumnarValue::Scalar(ScalarValue::from(".")), + ColumnarValue::Scalar(ScalarValue::from(1i64)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrIndexFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), + ColumnarValue::Scalar(ScalarValue::from("")), + ColumnarValue::Scalar(ScalarValue::from(1i64)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + + Ok(()) + } } diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 21433ba16810f..38ebedf5654a5 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -940,7 +940,8 @@ SELECT str, n, substring_index(str, '.', n) AS c FROM (VALUES ROW('arrow.apache.org'), ROW('.'), - ROW('...') + ROW('...'), + ROW(NULL) ) AS strings(str), (VALUES ROW(1), @@ -954,6 +955,14 @@ SELECT str, n, substring_index(str, '.', n) AS c FROM ) AS occurrences(n) ORDER BY str DESC, n; ---- +NULL -100 NULL +NULL -3 NULL +NULL -2 NULL +NULL -1 NULL +NULL 1 NULL +NULL 2 NULL +NULL 3 NULL +NULL 100 NULL arrow.apache.org -100 arrow.apache.org arrow.apache.org -3 arrow.apache.org arrow.apache.org -2 apache.org From bece785174c199f4fde4343a27c2213fae11bfb8 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Mon, 8 Apr 2024 12:03:00 -0400 Subject: [PATCH 14/14] move Floor, Gcd, Lcm, Pi to datafusion-functions (#9976) * move Floor, Gcd, Lcm, Pi to datafusion-functions --- datafusion/expr/src/built_in_function.rs | 29 +--- datafusion/expr/src/expr_fn.rs | 17 +- datafusion/functions/src/math/gcd.rs | 145 ++++++++++++++++++ datafusion/functions/src/math/lcm.rs | 126 +++++++++++++++ datafusion/functions/src/math/mod.rs | 14 +- datafusion/functions/src/math/pi.rs | 76 +++++++++ .../optimizer/src/analyzer/type_coercion.rs | 30 ++-- .../physical-expr/src/equivalence/ordering.rs | 44 ++++-- .../src/equivalence/projection.rs | 41 +++-- .../src/equivalence/properties.rs | 41 ++--- datafusion/physical-expr/src/functions.rs | 10 +- .../physical-expr/src/math_expressions.rs | 118 -------------- datafusion/physical-expr/src/udf.rs | 55 ++----- datafusion/physical-expr/src/utils/mod.rs | 99 +++++++++++- datafusion/proto/proto/datafusion.proto | 8 +- datafusion/proto/src/generated/pbjson.rs | 12 -- datafusion/proto/src/generated/prost.rs | 16 +- .../proto/src/logical_plan/from_proto.rs | 20 +-- datafusion/proto/src/logical_plan/to_proto.rs | 4 - datafusion/sql/src/expr/function.rs | 20 ++- datafusion/sql/src/expr/mod.rs | 7 +- 21 files changed, 588 insertions(+), 344 deletions(-) create mode 100644 datafusion/functions/src/math/gcd.rs create mode 100644 datafusion/functions/src/math/lcm.rs create mode 100644 datafusion/functions/src/math/pi.rs diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index dc1fc98a5c02d..7426ccd938e77 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -45,20 +45,12 @@ pub enum BuiltinScalarFunction { Exp, /// factorial Factorial, - /// floor - Floor, - /// gcd, Greatest common divisor - Gcd, - /// lcm, Least common multiple - Lcm, /// iszero Iszero, /// log, same as log10 Log, /// nanvl Nanvl, - /// pi - Pi, /// power Power, /// round @@ -135,13 +127,9 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Coalesce => Volatility::Immutable, BuiltinScalarFunction::Exp => Volatility::Immutable, BuiltinScalarFunction::Factorial => Volatility::Immutable, - BuiltinScalarFunction::Floor => Volatility::Immutable, - BuiltinScalarFunction::Gcd => Volatility::Immutable, BuiltinScalarFunction::Iszero => Volatility::Immutable, - BuiltinScalarFunction::Lcm => Volatility::Immutable, BuiltinScalarFunction::Log => Volatility::Immutable, BuiltinScalarFunction::Nanvl => Volatility::Immutable, - BuiltinScalarFunction::Pi => Volatility::Immutable, BuiltinScalarFunction::Power => Volatility::Immutable, BuiltinScalarFunction::Round => Volatility::Immutable, BuiltinScalarFunction::Cot => Volatility::Immutable, @@ -183,13 +171,10 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::InitCap => { utf8_to_str_type(&input_expr_types[0], "initcap") } - BuiltinScalarFunction::Pi => Ok(Float64), BuiltinScalarFunction::Random => Ok(Float64), BuiltinScalarFunction::EndsWith => Ok(Boolean), - BuiltinScalarFunction::Factorial - | BuiltinScalarFunction::Gcd - | BuiltinScalarFunction::Lcm => Ok(Int64), + BuiltinScalarFunction::Factorial => Ok(Int64), BuiltinScalarFunction::Power => match &input_expr_types[0] { Int64 => Ok(Int64), @@ -210,7 +195,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Ceil | BuiltinScalarFunction::Exp - | BuiltinScalarFunction::Floor | BuiltinScalarFunction::Round | BuiltinScalarFunction::Trunc | BuiltinScalarFunction::Cot => match input_expr_types[0] { @@ -248,7 +232,6 @@ impl BuiltinScalarFunction { ], self.volatility(), ), - BuiltinScalarFunction::Pi => Signature::exact(vec![], self.volatility()), BuiltinScalarFunction::Random => Signature::exact(vec![], self.volatility()), BuiltinScalarFunction::Power => Signature::one_of( vec![Exact(vec![Int64, Int64]), Exact(vec![Float64, Float64])], @@ -289,12 +272,8 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Factorial => { Signature::uniform(1, vec![Int64], self.volatility()) } - BuiltinScalarFunction::Gcd | BuiltinScalarFunction::Lcm => { - Signature::uniform(2, vec![Int64], self.volatility()) - } BuiltinScalarFunction::Ceil | BuiltinScalarFunction::Exp - | BuiltinScalarFunction::Floor | BuiltinScalarFunction::Cot => { // math expressions expect 1 argument of type f64 or f32 // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we @@ -319,10 +298,8 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Ceil | BuiltinScalarFunction::Exp | BuiltinScalarFunction::Factorial - | BuiltinScalarFunction::Floor | BuiltinScalarFunction::Round | BuiltinScalarFunction::Trunc - | BuiltinScalarFunction::Pi ) { Some(vec![Some(true)]) } else if *self == BuiltinScalarFunction::Log { @@ -339,13 +316,9 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Cot => &["cot"], BuiltinScalarFunction::Exp => &["exp"], BuiltinScalarFunction::Factorial => &["factorial"], - BuiltinScalarFunction::Floor => &["floor"], - BuiltinScalarFunction::Gcd => &["gcd"], BuiltinScalarFunction::Iszero => &["iszero"], - BuiltinScalarFunction::Lcm => &["lcm"], BuiltinScalarFunction::Log => &["log"], BuiltinScalarFunction::Nanvl => &["nanvl"], - BuiltinScalarFunction::Pi => &["pi"], BuiltinScalarFunction::Power => &["power", "pow"], BuiltinScalarFunction::Random => &["random"], BuiltinScalarFunction::Round => &["round"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index f68685a87f13c..6c811ff064185 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -297,11 +297,6 @@ pub fn concat_ws(sep: Expr, values: Vec) -> Expr { )) } -/// Returns an approximate value of π -pub fn pi() -> Expr { - Expr::ScalarFunction(ScalarFunction::new(BuiltinScalarFunction::Pi, vec![])) -} - /// Returns a random value in the range 0.0 <= x < 1.0 pub fn random() -> Expr { Expr::ScalarFunction(ScalarFunction::new(BuiltinScalarFunction::Random, vec![])) @@ -537,12 +532,6 @@ macro_rules! nary_scalar_expr { // math functions scalar_expr!(Cot, cot, num, "cotangent of a number"); scalar_expr!(Factorial, factorial, num, "factorial"); -scalar_expr!( - Floor, - floor, - num, - "nearest integer less than or equal to argument" -); scalar_expr!( Ceil, ceil, @@ -556,8 +545,7 @@ nary_scalar_expr!( "truncate toward zero, with optional precision" ); scalar_expr!(Exp, exp, num, "exponential"); -scalar_expr!(Gcd, gcd, arg_1 arg_2, "greatest common divisor"); -scalar_expr!(Lcm, lcm, arg_1 arg_2, "least common multiple"); + scalar_expr!(Power, power, base exponent, "`base` raised to the power of `exponent`"); scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`"); @@ -974,7 +962,6 @@ mod test { fn scalar_function_definitions() { test_unary_scalar_expr!(Cot, cot); test_unary_scalar_expr!(Factorial, factorial); - test_unary_scalar_expr!(Floor, floor); test_unary_scalar_expr!(Ceil, ceil); test_nary_scalar_expr!(Round, round, input); test_nary_scalar_expr!(Round, round, input, decimal_places); @@ -984,8 +971,6 @@ mod test { test_scalar_expr!(Nanvl, nanvl, x, y); test_scalar_expr!(Iszero, iszero, input); - test_scalar_expr!(Gcd, gcd, arg_1, arg_2); - test_scalar_expr!(Lcm, lcm, arg_1, arg_2); test_scalar_expr!(InitCap, initcap, string); test_scalar_expr!(EndsWith, ends_with, string, characters); } diff --git a/datafusion/functions/src/math/gcd.rs b/datafusion/functions/src/math/gcd.rs new file mode 100644 index 0000000000000..41c9e4e233147 --- /dev/null +++ b/datafusion/functions/src/math/gcd.rs @@ -0,0 +1,145 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, Int64Array}; +use std::any::Any; +use std::mem::swap; +use std::sync::Arc; + +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::Int64; + +use crate::utils::make_scalar_function; +use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_expr::ColumnarValue; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; + +#[derive(Debug)] +pub struct GcdFunc { + signature: Signature, +} + +impl Default for GcdFunc { + fn default() -> Self { + Self::new() + } +} + +impl GcdFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform(2, vec![Int64], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for GcdFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "gcd" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Int64) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(gcd, vec![])(args) + } +} + +/// Gcd SQL function +fn gcd(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + Int64 => Ok(Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "x", + "y", + Int64Array, + Int64Array, + { compute_gcd } + )) as ArrayRef), + other => exec_err!("Unsupported data type {other:?} for function gcd"), + } +} + +/// Computes greatest common divisor using Binary GCD algorithm. +pub fn compute_gcd(x: i64, y: i64) -> i64 { + let mut a = x.wrapping_abs(); + let mut b = y.wrapping_abs(); + + if a == 0 { + return b; + } + if b == 0 { + return a; + } + + let shift = (a | b).trailing_zeros(); + a >>= shift; + b >>= shift; + a >>= a.trailing_zeros(); + + loop { + b >>= b.trailing_zeros(); + if a > b { + swap(&mut a, &mut b); + } + + b -= a; + + if b == 0 { + return a << shift; + } + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::array::{ArrayRef, Int64Array}; + + use crate::math::gcd::gcd; + use datafusion_common::cast::as_int64_array; + + #[test] + fn test_gcd_i64() { + let args: Vec = vec![ + Arc::new(Int64Array::from(vec![0, 3, 25, -16])), // x + Arc::new(Int64Array::from(vec![0, -2, 15, 8])), // y + ]; + + let result = gcd(&args).expect("failed to initialize function gcd"); + let ints = as_int64_array(&result).expect("failed to initialize function gcd"); + + assert_eq!(ints.len(), 4); + assert_eq!(ints.value(0), 0); + assert_eq!(ints.value(1), 1); + assert_eq!(ints.value(2), 5); + assert_eq!(ints.value(3), 8); + } +} diff --git a/datafusion/functions/src/math/lcm.rs b/datafusion/functions/src/math/lcm.rs new file mode 100644 index 0000000000000..3674f7371de2f --- /dev/null +++ b/datafusion/functions/src/math/lcm.rs @@ -0,0 +1,126 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int64Array}; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::Int64; + +use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_expr::ColumnarValue; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; + +use crate::math::gcd::compute_gcd; +use crate::utils::make_scalar_function; + +#[derive(Debug)] +pub struct LcmFunc { + signature: Signature, +} + +impl Default for LcmFunc { + fn default() -> Self { + LcmFunc::new() + } +} + +impl LcmFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform(2, vec![Int64], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for LcmFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "lcm" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Int64) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(lcm, vec![])(args) + } +} + +/// Lcm SQL function +fn lcm(args: &[ArrayRef]) -> Result { + let compute_lcm = |x: i64, y: i64| { + let a = x.wrapping_abs(); + let b = y.wrapping_abs(); + + if a == 0 || b == 0 { + return 0; + } + a / compute_gcd(a, b) * b + }; + + match args[0].data_type() { + Int64 => Ok(Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "x", + "y", + Int64Array, + Int64Array, + { compute_lcm } + )) as ArrayRef), + other => exec_err!("Unsupported data type {other:?} for function lcm"), + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::array::{ArrayRef, Int64Array}; + + use datafusion_common::cast::as_int64_array; + + use crate::math::lcm::lcm; + + #[test] + fn test_lcm_i64() { + let args: Vec = vec![ + Arc::new(Int64Array::from(vec![0, 3, 25, -16])), // x + Arc::new(Int64Array::from(vec![0, -2, 15, 8])), // y + ]; + + let result = lcm(&args).expect("failed to initialize function lcm"); + let ints = as_int64_array(&result).expect("failed to initialize function lcm"); + + assert_eq!(ints.len(), 4); + assert_eq!(ints.value(0), 0); + assert_eq!(ints.value(1), 6); + assert_eq!(ints.value(2), 75); + assert_eq!(ints.value(3), 16); + } +} diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index f241c8b3250bd..3a1f7cc13bb7c 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -18,11 +18,17 @@ //! "math" DataFusion functions pub mod abs; +pub mod gcd; +pub mod lcm; pub mod nans; +pub mod pi; // Create UDFs make_udf_function!(nans::IsNanFunc, ISNAN, isnan); make_udf_function!(abs::AbsFunc, ABS, abs); +make_udf_function!(gcd::GcdFunc, GCD, gcd); +make_udf_function!(lcm::LcmFunc, LCM, lcm); +make_udf_function!(pi::PiFunc, PI, pi); make_math_unary_udf!(Log2Func, LOG2, log2, log2, Some(vec![Some(true)])); make_math_unary_udf!(Log10Func, LOG10, log10, log10, Some(vec![Some(true)])); @@ -50,6 +56,8 @@ make_math_unary_udf!(CosFunc, COS, cos, cos, None); make_math_unary_udf!(CoshFunc, COSH, cosh, cosh, None); make_math_unary_udf!(DegreesFunc, DEGREES, degrees, to_degrees, None); +make_math_unary_udf!(FloorFunc, FLOOR, floor, floor, Some(vec![Some(true)])); + // Export the functions out of this package, both as expr_fn as well as a list of functions export_functions!( ( @@ -86,5 +94,9 @@ export_functions!( (cbrt, num, "cube root of a number"), (cos, num, "cosine"), (cosh, num, "hyperbolic cosine"), - (degrees, num, "converts radians to degrees") + (degrees, num, "converts radians to degrees"), + (gcd, x y, "greatest common divisor"), + (lcm, x y, "least common multiple"), + (floor, num, "nearest integer less than or equal to argument"), + (pi, , "Returns an approximate value of π") ); diff --git a/datafusion/functions/src/math/pi.rs b/datafusion/functions/src/math/pi.rs new file mode 100644 index 0000000000000..0801e797511b5 --- /dev/null +++ b/datafusion/functions/src/math/pi.rs @@ -0,0 +1,76 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::Float64Array; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::Float64; + +use datafusion_common::{exec_err, Result}; +use datafusion_expr::{ColumnarValue, FuncMonotonicity, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; + +#[derive(Debug)] +pub struct PiFunc { + signature: Signature, +} + +impl Default for PiFunc { + fn default() -> Self { + PiFunc::new() + } +} + +impl PiFunc { + pub fn new() -> Self { + Self { + signature: Signature::exact(vec![], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for PiFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "pi" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Float64) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + if !matches!(&args[0], ColumnarValue::Array(_)) { + return exec_err!("Expect pi function to take no param"); + } + let array = Float64Array::from_value(std::f64::consts::PI, 1); + Ok(ColumnarValue::Array(Arc::new(array))) + } + + fn monotonicity(&self) -> Result> { + Ok(Some(vec![Some(true)])) + } +} diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 04de243fba07a..1ea8b9534e808 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -19,9 +19,8 @@ use std::sync::Arc; -use crate::analyzer::AnalyzerRule; - use arrow::datatypes::{DataType, IntervalUnit}; + use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::{ @@ -51,6 +50,8 @@ use datafusion_expr::{ WindowFrameUnits, }; +use crate::analyzer::AnalyzerRule; + #[derive(Default)] pub struct TypeCoercion {} @@ -758,25 +759,25 @@ mod test { use std::any::Any; use std::sync::{Arc, OnceLock}; - use crate::analyzer::type_coercion::{ - coerce_case_expression, TypeCoercion, TypeCoercionRewriter, - }; - use crate::test::assert_analyzed_plan_eq; - use arrow::datatypes::{DataType, Field, TimeUnit}; + use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction}; use datafusion_expr::logical_plan::{EmptyRelation, Projection}; use datafusion_expr::{ cast, col, concat, concat_ws, create_udaf, is_true, lit, - AccumulatorFactoryFunction, AggregateFunction, AggregateUDF, BinaryExpr, - BuiltinScalarFunction, Case, ColumnarValue, Expr, ExprSchemable, Filter, - LogicalPlan, Operator, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, - Subquery, Volatility, + AccumulatorFactoryFunction, AggregateFunction, AggregateUDF, BinaryExpr, Case, + ColumnarValue, Expr, ExprSchemable, Filter, LogicalPlan, Operator, ScalarUDF, + ScalarUDFImpl, Signature, SimpleAggregateUDF, Subquery, Volatility, }; use datafusion_physical_expr::expressions::AvgAccumulator; + use crate::analyzer::type_coercion::{ + coerce_case_expression, TypeCoercion, TypeCoercionRewriter, + }; + use crate::test::assert_analyzed_plan_eq; + fn empty() -> Arc { Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, @@ -875,14 +876,15 @@ mod test { // test that automatic argument type coercion for scalar functions work let empty = empty(); let lit_expr = lit(10i64); - let fun: BuiltinScalarFunction = BuiltinScalarFunction::Floor; + let fun = ScalarUDF::new_from_impl(TestScalarUDF {}); let scalar_function_expr = - Expr::ScalarFunction(ScalarFunction::new(fun, vec![lit_expr])); + Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![lit_expr])); let plan = LogicalPlan::Projection(Projection::try_new( vec![scalar_function_expr], empty, )?); - let expected = "Projection: floor(CAST(Int64(10) AS Float64))\n EmptyRelation"; + let expected = + "Projection: TestScalarUDF(CAST(Int64(10) AS Float32))\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected) } diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs index 1364d3a8c0285..688cdf798bdd2 100644 --- a/datafusion/physical-expr/src/equivalence/ordering.rs +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -15,10 +15,11 @@ // specific language governing permissions and limitations // under the License. -use arrow_schema::SortOptions; use std::hash::Hash; use std::sync::Arc; +use arrow_schema::SortOptions; + use crate::equivalence::add_offset_to_expr; use crate::{LexOrdering, PhysicalExpr, PhysicalSortExpr}; @@ -220,6 +221,16 @@ fn resolve_overlap(orderings: &mut [LexOrdering], idx: usize, pre_idx: usize) -> #[cfg(test)] mod tests { + use std::sync::Arc; + + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::SortOptions; + use itertools::Itertools; + + use datafusion_common::{DFSchema, Result}; + use datafusion_expr::execution_props::ExecutionProps; + use datafusion_expr::{BuiltinScalarFunction, Operator, ScalarUDF}; + use crate::equivalence::tests::{ convert_to_orderings, convert_to_sort_exprs, create_random_schema, create_test_params, generate_table_for_eq_properties, is_table_same_after_sort, @@ -231,14 +242,8 @@ mod tests { use crate::expressions::Column; use crate::expressions::{col, BinaryExpr}; use crate::functions::create_physical_expr; + use crate::utils::tests::TestScalarUDF; use crate::{PhysicalExpr, PhysicalSortExpr}; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow_schema::SortOptions; - use datafusion_common::Result; - use datafusion_expr::execution_props::ExecutionProps; - use datafusion_expr::{BuiltinScalarFunction, Operator}; - use itertools::Itertools; - use std::sync::Arc; #[test] fn test_ordering_satisfy() -> Result<()> { @@ -281,17 +286,20 @@ mod tests { let col_d = &col("d", &test_schema)?; let col_e = &col("e", &test_schema)?; let col_f = &col("f", &test_schema)?; - let floor_a = &create_physical_expr( - &BuiltinScalarFunction::Floor, + let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); + let floor_a = &crate::udf::create_physical_expr( + &test_fun, &[col("a", &test_schema)?], &test_schema, - &ExecutionProps::default(), + &[], + &DFSchema::empty(), )?; - let floor_f = &create_physical_expr( - &BuiltinScalarFunction::Floor, + let floor_f = &crate::udf::create_physical_expr( + &test_fun, &[col("f", &test_schema)?], &test_schema, - &ExecutionProps::default(), + &[], + &DFSchema::empty(), )?; let exp_a = &create_physical_expr( &BuiltinScalarFunction::Exp, @@ -804,11 +812,13 @@ mod tests { let table_data_with_properties = generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - let floor_a = create_physical_expr( - &BuiltinScalarFunction::Floor, + let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); + let floor_a = crate::udf::create_physical_expr( + &test_fun, &[col("a", &test_schema)?], &test_schema, - &ExecutionProps::default(), + &[], + &DFSchema::empty(), )?; let a_plus_b = Arc::new(BinaryExpr::new( col("a", &test_schema)?, diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs index b8231a74c2714..5efcf5942c396 100644 --- a/datafusion/physical-expr/src/equivalence/projection.rs +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -17,13 +17,14 @@ use std::sync::Arc; -use crate::expressions::Column; -use crate::PhysicalExpr; - use arrow::datatypes::SchemaRef; + use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; +use crate::expressions::Column; +use crate::PhysicalExpr; + /// Stores the mapping between source expressions and target expressions for a /// projection. #[derive(Debug, Clone)] @@ -111,7 +112,14 @@ impl ProjectionMapping { mod tests { use std::sync::Arc; - use super::*; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::{SortOptions, TimeUnit}; + use itertools::Itertools; + + use datafusion_common::{DFSchema, Result}; + use datafusion_expr::execution_props::ExecutionProps; + use datafusion_expr::{BuiltinScalarFunction, Operator, ScalarUDF}; + use crate::equivalence::tests::{ apply_projection, convert_to_orderings, convert_to_orderings_owned, create_random_schema, generate_table_for_eq_properties, is_table_same_after_sort, @@ -119,16 +127,11 @@ mod tests { }; use crate::equivalence::EquivalenceProperties; use crate::expressions::{col, BinaryExpr}; - use crate::functions::create_physical_expr; + use crate::udf::create_physical_expr; + use crate::utils::tests::TestScalarUDF; use crate::PhysicalSortExpr; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow_schema::{SortOptions, TimeUnit}; - use datafusion_common::Result; - use datafusion_expr::execution_props::ExecutionProps; - use datafusion_expr::{BuiltinScalarFunction, Operator}; - - use itertools::Itertools; + use super::*; #[test] fn project_orderings() -> Result<()> { @@ -646,7 +649,7 @@ mod tests { col_b.clone(), )) as Arc; - let round_c = &create_physical_expr( + let round_c = &crate::functions::create_physical_expr( &BuiltinScalarFunction::Round, &[col_c.clone()], &schema, @@ -973,11 +976,13 @@ mod tests { let table_data_with_properties = generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; // Floor(a) + let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); let floor_a = create_physical_expr( - &BuiltinScalarFunction::Floor, + &test_fun, &[col("a", &test_schema)?], &test_schema, - &ExecutionProps::default(), + &[], + &DFSchema::empty(), )?; // a + b let a_plus_b = Arc::new(BinaryExpr::new( @@ -1049,11 +1054,13 @@ mod tests { let table_data_with_properties = generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; // Floor(a) + let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); let floor_a = create_physical_expr( - &BuiltinScalarFunction::Floor, + &test_fun, &[col("a", &test_schema)?], &test_schema, - &ExecutionProps::default(), + &[], + &DFSchema::empty(), )?; // a + b let a_plus_b = Arc::new(BinaryExpr::new( diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index 7ce540b267b26..c14c88d6c69bf 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -18,7 +18,13 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; -use super::ordering::collapse_lex_ordering; +use arrow_schema::{SchemaRef, SortOptions}; +use indexmap::{IndexMap, IndexSet}; +use itertools::Itertools; + +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::{JoinSide, JoinType, Result}; + use crate::equivalence::{ collapse_lex_req, EquivalenceGroup, OrderingEquivalenceClass, ProjectionMapping, }; @@ -30,12 +36,7 @@ use crate::{ PhysicalSortRequirement, }; -use arrow_schema::{SchemaRef, SortOptions}; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{JoinSide, JoinType, Result}; - -use indexmap::{IndexMap, IndexSet}; -use itertools::Itertools; +use super::ordering::collapse_lex_ordering; /// A `EquivalenceProperties` object stores useful information related to a schema. /// Currently, it keeps track of: @@ -1296,7 +1297,13 @@ mod tests { use std::ops::Not; use std::sync::Arc; - use super::*; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::{Fields, SortOptions, TimeUnit}; + use itertools::Itertools; + + use datafusion_common::{DFSchema, Result}; + use datafusion_expr::{Operator, ScalarUDF}; + use crate::equivalence::add_offset_to_expr; use crate::equivalence::tests::{ convert_to_orderings, convert_to_sort_exprs, convert_to_sort_reqs, @@ -1304,16 +1311,10 @@ mod tests { generate_table_for_eq_properties, is_table_same_after_sort, output_schema, }; use crate::expressions::{col, BinaryExpr, Column}; - use crate::functions::create_physical_expr; + use crate::utils::tests::TestScalarUDF; use crate::PhysicalSortExpr; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow_schema::{Fields, SortOptions, TimeUnit}; - use datafusion_common::Result; - use datafusion_expr::execution_props::ExecutionProps; - use datafusion_expr::{BuiltinScalarFunction, Operator}; - - use itertools::Itertools; + use super::*; #[test] fn project_equivalence_properties_test() -> Result<()> { @@ -1792,11 +1793,13 @@ mod tests { let table_data_with_properties = generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - let floor_a = create_physical_expr( - &BuiltinScalarFunction::Floor, + let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); + let floor_a = crate::udf::create_physical_expr( + &test_fun, &[col("a", &test_schema)?], &test_schema, - &ExecutionProps::default(), + &[], + &DFSchema::empty(), )?; let a_plus_b = Arc::new(BinaryExpr::new( col("a", &test_schema)?, diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 770d9184325a5..79d69b273d2c3 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -184,16 +184,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::Factorial => { Arc::new(|args| make_scalar_function_inner(math_expressions::factorial)(args)) } - BuiltinScalarFunction::Floor => Arc::new(math_expressions::floor), - BuiltinScalarFunction::Gcd => { - Arc::new(|args| make_scalar_function_inner(math_expressions::gcd)(args)) - } BuiltinScalarFunction::Iszero => { Arc::new(|args| make_scalar_function_inner(math_expressions::iszero)(args)) } - BuiltinScalarFunction::Lcm => { - Arc::new(|args| make_scalar_function_inner(math_expressions::lcm)(args)) - } BuiltinScalarFunction::Nanvl => { Arc::new(|args| make_scalar_function_inner(math_expressions::nanvl)(args)) } @@ -204,7 +197,6 @@ pub fn create_physical_fun( BuiltinScalarFunction::Trunc => { Arc::new(|args| make_scalar_function_inner(math_expressions::trunc)(args)) } - BuiltinScalarFunction::Pi => Arc::new(math_expressions::pi), BuiltinScalarFunction::Power => { Arc::new(|args| make_scalar_function_inner(math_expressions::power)(args)) } @@ -573,7 +565,7 @@ mod tests { let execution_props = ExecutionProps::new(); let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let funs = [BuiltinScalarFunction::Pi, BuiltinScalarFunction::Random]; + let funs = [BuiltinScalarFunction::Random]; for fun in funs.iter() { create_physical_expr_with_type_coercion(fun, &[], &schema, &execution_props)?; diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index f8244ad9525f1..384f8d87eb96b 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -19,7 +19,6 @@ use std::any::type_name; use std::iter; -use std::mem::swap; use std::sync::Arc; use arrow::array::ArrayRef; @@ -161,7 +160,6 @@ math_unary_function!("atan", atan); math_unary_function!("asinh", asinh); math_unary_function!("acosh", acosh); math_unary_function!("atanh", atanh); -math_unary_function!("floor", floor); math_unary_function!("ceil", ceil); math_unary_function!("exp", exp); math_unary_function!("ln", ln); @@ -181,79 +179,6 @@ pub fn factorial(args: &[ArrayRef]) -> Result { } } -/// Computes greatest common divisor using Binary GCD algorithm. -fn compute_gcd(x: i64, y: i64) -> i64 { - let mut a = x.wrapping_abs(); - let mut b = y.wrapping_abs(); - - if a == 0 { - return b; - } - if b == 0 { - return a; - } - - let shift = (a | b).trailing_zeros(); - a >>= shift; - b >>= shift; - a >>= a.trailing_zeros(); - - loop { - b >>= b.trailing_zeros(); - if a > b { - swap(&mut a, &mut b); - } - - b -= a; - - if b == 0 { - return a << shift; - } - } -} - -/// Gcd SQL function -pub fn gcd(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - DataType::Int64 => Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "x", - "y", - Int64Array, - Int64Array, - { compute_gcd } - )) as ArrayRef), - other => exec_err!("Unsupported data type {other:?} for function gcd"), - } -} - -/// Lcm SQL function -pub fn lcm(args: &[ArrayRef]) -> Result { - let compute_lcm = |x: i64, y: i64| { - let a = x.wrapping_abs(); - let b = y.wrapping_abs(); - - if a == 0 || b == 0 { - return 0; - } - a / compute_gcd(a, b) * b - }; - - match args[0].data_type() { - DataType::Int64 => Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "x", - "y", - Int64Array, - Int64Array, - { compute_lcm } - )) as ArrayRef), - other => exec_err!("Unsupported data type {other:?} for function lcm"), - } -} - /// Nanvl SQL function pub fn nanvl(args: &[ArrayRef]) -> Result { match args[0].data_type() { @@ -345,15 +270,6 @@ pub fn iszero(args: &[ArrayRef]) -> Result { } } -/// Pi SQL function -pub fn pi(args: &[ColumnarValue]) -> Result { - if !matches!(&args[0], ColumnarValue::Array(_)) { - return exec_err!("Expect pi function to take no param"); - } - let array = Float64Array::from_value(std::f64::consts::PI, 1); - Ok(ColumnarValue::Array(Arc::new(array))) -} - /// Random SQL function pub fn random(args: &[ColumnarValue]) -> Result { let len: usize = match &args[0] { @@ -808,40 +724,6 @@ mod tests { assert_eq!(ints, &expected); } - #[test] - fn test_gcd_i64() { - let args: Vec = vec![ - Arc::new(Int64Array::from(vec![0, 3, 25, -16])), // x - Arc::new(Int64Array::from(vec![0, -2, 15, 8])), // y - ]; - - let result = gcd(&args).expect("failed to initialize function gcd"); - let ints = as_int64_array(&result).expect("failed to initialize function gcd"); - - assert_eq!(ints.len(), 4); - assert_eq!(ints.value(0), 0); - assert_eq!(ints.value(1), 1); - assert_eq!(ints.value(2), 5); - assert_eq!(ints.value(3), 8); - } - - #[test] - fn test_lcm_i64() { - let args: Vec = vec![ - Arc::new(Int64Array::from(vec![0, 3, 25, -16])), // x - Arc::new(Int64Array::from(vec![0, -2, 15, 8])), // y - ]; - - let result = lcm(&args).expect("failed to initialize function lcm"); - let ints = as_int64_array(&result).expect("failed to initialize function lcm"); - - assert_eq!(ints.len(), 4); - assert_eq!(ints.value(0), 0); - assert_eq!(ints.value(1), 6); - assert_eq!(ints.value(2), 75); - assert_eq!(ints.value(3), 16); - } - #[test] fn test_cot_f32() { let args: Vec = diff --git a/datafusion/physical-expr/src/udf.rs b/datafusion/physical-expr/src/udf.rs index 4fc94bfa15eca..368dfdf92f454 100644 --- a/datafusion/physical-expr/src/udf.rs +++ b/datafusion/physical-expr/src/udf.rs @@ -16,14 +16,17 @@ // under the License. //! UDF support -use crate::{PhysicalExpr, ScalarFunctionExpr}; +use std::sync::Arc; + use arrow_schema::Schema; + use datafusion_common::{DFSchema, Result}; pub use datafusion_expr::ScalarUDF; use datafusion_expr::{ type_coercion::functions::data_types, Expr, ScalarFunctionDefinition, }; -use std::sync::Arc; + +use crate::{PhysicalExpr, ScalarFunctionExpr}; /// Create a physical expression of the UDF. /// @@ -60,58 +63,18 @@ pub fn create_physical_expr( #[cfg(test)] mod tests { - use arrow_schema::{DataType, Schema}; + use arrow_schema::Schema; + use datafusion_common::{DFSchema, Result}; - use datafusion_expr::{ - ColumnarValue, FuncMonotonicity, ScalarUDF, ScalarUDFImpl, Signature, Volatility, - }; + use datafusion_expr::ScalarUDF; + use crate::utils::tests::TestScalarUDF; use crate::ScalarFunctionExpr; use super::create_physical_expr; #[test] fn test_functions() -> Result<()> { - #[derive(Debug, Clone)] - struct TestScalarUDF { - signature: Signature, - } - - impl TestScalarUDF { - fn new() -> Self { - let signature = - Signature::exact(vec![DataType::Float64], Volatility::Immutable); - - Self { signature } - } - } - - impl ScalarUDFImpl for TestScalarUDF { - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn name(&self) -> &str { - "my_fn" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Float64) - } - - fn invoke(&self, _args: &[ColumnarValue]) -> Result { - unimplemented!("my_fn is not implemented") - } - - fn monotonicity(&self) -> Result> { - Ok(Some(vec![Some(true)])) - } - } - // create and register the udf let udf = ScalarUDF::from(TestScalarUDF::new()); diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index e55bc3d156659..d7bebbff891c4 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -256,7 +256,9 @@ pub fn merge_vectors( } #[cfg(test)] -mod tests { +pub(crate) mod tests { + use arrow_array::{ArrayRef, Float32Array, Float64Array}; + use std::any::Any; use std::fmt::{Display, Formatter}; use std::sync::Arc; @@ -265,10 +267,103 @@ mod tests { use crate::PhysicalSortExpr; use arrow_schema::{DataType, Field, Schema}; - use datafusion_common::{Result, ScalarValue}; + use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; + use datafusion_expr::{ + ColumnarValue, FuncMonotonicity, ScalarUDFImpl, Signature, Volatility, + }; use petgraph::visit::Bfs; + #[derive(Debug, Clone)] + pub struct TestScalarUDF { + signature: Signature, + } + + impl TestScalarUDF { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![Float64, Float32], + Volatility::Immutable, + ), + } + } + } + + impl ScalarUDFImpl for TestScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "test-scalar-udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let arg_type = &arg_types[0]; + + match arg_type { + DataType::Float32 => Ok(DataType::Float32), + _ => Ok(DataType::Float64), + } + } + + fn monotonicity(&self) -> Result> { + Ok(Some(vec![Some(true)])) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let args = ColumnarValue::values_to_arrays(args)?; + + let arr: ArrayRef = match args[0].data_type() { + DataType::Float64 => Arc::new({ + let arg = &args[0] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast {} to {}", + self.name(), + std::any::type_name::() + )) + })?; + + arg.iter() + .map(|a| a.map(f64::floor)) + .collect::() + }), + DataType::Float32 => Arc::new({ + let arg = &args[0] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast {} to {}", + self.name(), + std::any::type_name::() + )) + })?; + + arg.iter() + .map(|a| a.map(f32::floor)) + .collect::() + }), + other => { + return exec_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ); + } + }; + Ok(ColumnarValue::Array(arr)) + } + } + #[derive(Clone)] struct DummyProperty { expr_type: String, diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 7f967657f573f..b656bededc07e 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -550,7 +550,7 @@ enum ScalarFunction { // 6 was Cos // 7 was Digest Exp = 8; - Floor = 9; + // 9 was Floor // 10 was Ln Log = 11; // 12 was Log10 @@ -621,12 +621,12 @@ enum ScalarFunction { // 77 was Sinh // 78 was Cosh // Tanh = 79 - Pi = 80; + // 80 was Pi // 81 was Degrees // 82 was Radians Factorial = 83; - Lcm = 84; - Gcd = 85; + // 84 was Lcm + // 85 was Gcd // 86 was ArrayAppend // 87 was ArrayConcat // 88 was ArrayDims diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 966d7f7f7487f..c13ae045bdb51 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22794,7 +22794,6 @@ impl serde::Serialize for ScalarFunction { Self::Unknown => "unknown", Self::Ceil => "Ceil", Self::Exp => "Exp", - Self::Floor => "Floor", Self::Log => "Log", Self::Round => "Round", Self::Trunc => "Trunc", @@ -22804,10 +22803,7 @@ impl serde::Serialize for ScalarFunction { Self::Random => "Random", Self::Coalesce => "Coalesce", Self::Power => "Power", - Self::Pi => "Pi", Self::Factorial => "Factorial", - Self::Lcm => "Lcm", - Self::Gcd => "Gcd", Self::Cot => "Cot", Self::Nanvl => "Nanvl", Self::Iszero => "Iszero", @@ -22826,7 +22822,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "unknown", "Ceil", "Exp", - "Floor", "Log", "Round", "Trunc", @@ -22836,10 +22831,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Random", "Coalesce", "Power", - "Pi", "Factorial", - "Lcm", - "Gcd", "Cot", "Nanvl", "Iszero", @@ -22887,7 +22879,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "unknown" => Ok(ScalarFunction::Unknown), "Ceil" => Ok(ScalarFunction::Ceil), "Exp" => Ok(ScalarFunction::Exp), - "Floor" => Ok(ScalarFunction::Floor), "Log" => Ok(ScalarFunction::Log), "Round" => Ok(ScalarFunction::Round), "Trunc" => Ok(ScalarFunction::Trunc), @@ -22897,10 +22888,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Random" => Ok(ScalarFunction::Random), "Coalesce" => Ok(ScalarFunction::Coalesce), "Power" => Ok(ScalarFunction::Power), - "Pi" => Ok(ScalarFunction::Pi), "Factorial" => Ok(ScalarFunction::Factorial), - "Lcm" => Ok(ScalarFunction::Lcm), - "Gcd" => Ok(ScalarFunction::Gcd), "Cot" => Ok(ScalarFunction::Cot), "Nanvl" => Ok(ScalarFunction::Nanvl), "Iszero" => Ok(ScalarFunction::Iszero), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index c94aa1f4ed934..092d5c59d081b 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2849,7 +2849,7 @@ pub enum ScalarFunction { /// 6 was Cos /// 7 was Digest Exp = 8, - Floor = 9, + /// 9 was Floor /// 10 was Ln Log = 11, /// 12 was Log10 @@ -2920,12 +2920,12 @@ pub enum ScalarFunction { /// 77 was Sinh /// 78 was Cosh /// Tanh = 79 - Pi = 80, + /// 80 was Pi /// 81 was Degrees /// 82 was Radians Factorial = 83, - Lcm = 84, - Gcd = 85, + /// 84 was Lcm + /// 85 was Gcd /// 86 was ArrayAppend /// 87 was ArrayConcat /// 88 was ArrayDims @@ -2989,7 +2989,6 @@ impl ScalarFunction { ScalarFunction::Unknown => "unknown", ScalarFunction::Ceil => "Ceil", ScalarFunction::Exp => "Exp", - ScalarFunction::Floor => "Floor", ScalarFunction::Log => "Log", ScalarFunction::Round => "Round", ScalarFunction::Trunc => "Trunc", @@ -2999,10 +2998,7 @@ impl ScalarFunction { ScalarFunction::Random => "Random", ScalarFunction::Coalesce => "Coalesce", ScalarFunction::Power => "Power", - ScalarFunction::Pi => "Pi", ScalarFunction::Factorial => "Factorial", - ScalarFunction::Lcm => "Lcm", - ScalarFunction::Gcd => "Gcd", ScalarFunction::Cot => "Cot", ScalarFunction::Nanvl => "Nanvl", ScalarFunction::Iszero => "Iszero", @@ -3015,7 +3011,6 @@ impl ScalarFunction { "unknown" => Some(Self::Unknown), "Ceil" => Some(Self::Ceil), "Exp" => Some(Self::Exp), - "Floor" => Some(Self::Floor), "Log" => Some(Self::Log), "Round" => Some(Self::Round), "Trunc" => Some(Self::Trunc), @@ -3025,10 +3020,7 @@ impl ScalarFunction { "Random" => Some(Self::Random), "Coalesce" => Some(Self::Coalesce), "Power" => Some(Self::Power), - "Pi" => Some(Self::Pi), "Factorial" => Some(Self::Factorial), - "Lcm" => Some(Self::Lcm), - "Gcd" => Some(Self::Gcd), "Cot" => Some(Self::Cot), "Nanvl" => Some(Self::Nanvl), "Iszero" => Some(Self::Iszero), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 96b3b5942ec36..9c24a39418957 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -39,9 +39,9 @@ use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_ use datafusion_expr::{ ceil, coalesce, concat_expr, concat_ws_expr, cot, ends_with, exp, expr::{self, InList, Sort, WindowFunction}, - factorial, floor, gcd, initcap, iszero, lcm, log, + factorial, initcap, iszero, log, logical_plan::{PlanType, StringifiedPlan}, - nanvl, pi, power, random, round, trunc, AggregateFunction, Between, BinaryExpr, + nanvl, power, random, round, trunc, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, GroupingSet::GroupingSets, @@ -423,9 +423,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Exp => Self::Exp, ScalarFunction::Log => Self::Log, ScalarFunction::Factorial => Self::Factorial, - ScalarFunction::Gcd => Self::Gcd, - ScalarFunction::Lcm => Self::Lcm, - ScalarFunction::Floor => Self::Floor, ScalarFunction::Ceil => Self::Ceil, ScalarFunction::Round => Self::Round, ScalarFunction::Trunc => Self::Trunc, @@ -435,7 +432,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::InitCap => Self::InitCap, ScalarFunction::Random => Self::Random, ScalarFunction::Coalesce => Self::Coalesce, - ScalarFunction::Pi => Self::Pi, ScalarFunction::Power => Self::Power, ScalarFunction::Nanvl => Self::Nanvl, ScalarFunction::Iszero => Self::Iszero, @@ -1301,9 +1297,6 @@ pub fn parse_expr( match scalar_function { ScalarFunction::Unknown => Err(proto_error("Unknown scalar function")), ScalarFunction::Exp => Ok(exp(parse_expr(&args[0], registry, codec)?)), - ScalarFunction::Floor => { - Ok(floor(parse_expr(&args[0], registry, codec)?)) - } ScalarFunction::Factorial => { Ok(factorial(parse_expr(&args[0], registry, codec)?)) } @@ -1313,14 +1306,6 @@ pub fn parse_expr( ScalarFunction::InitCap => { Ok(initcap(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::Gcd => Ok(gcd( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), - ScalarFunction::Lcm => Ok(lcm( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), ScalarFunction::Random => Ok(random()), ScalarFunction::Concat => { Ok(concat_expr(parse_exprs(args, registry, codec)?)) @@ -1335,7 +1320,6 @@ pub fn parse_expr( ScalarFunction::Coalesce => { Ok(coalesce(parse_exprs(args, registry, codec)?)) } - ScalarFunction::Pi => Ok(pi()), ScalarFunction::Power => Ok(power( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index a10edb393241e..bd964b43d4189 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1410,10 +1410,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Cot => Self::Cot, BuiltinScalarFunction::Exp => Self::Exp, BuiltinScalarFunction::Factorial => Self::Factorial, - BuiltinScalarFunction::Gcd => Self::Gcd, - BuiltinScalarFunction::Lcm => Self::Lcm, BuiltinScalarFunction::Log => Self::Log, - BuiltinScalarFunction::Floor => Self::Floor, BuiltinScalarFunction::Ceil => Self::Ceil, BuiltinScalarFunction::Round => Self::Round, BuiltinScalarFunction::Trunc => Self::Trunc, @@ -1423,7 +1420,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::InitCap => Self::InitCap, BuiltinScalarFunction::Random => Self::Random, BuiltinScalarFunction::Coalesce => Self::Coalesce, - BuiltinScalarFunction::Pi => Self::Pi, BuiltinScalarFunction::Power => Self::Power, BuiltinScalarFunction::Nanvl => Self::Nanvl, BuiltinScalarFunction::Iszero => Self::Iszero, diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index e97eb1a32b121..4bf0906685cae 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -18,7 +18,8 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow_schema::DataType; use datafusion_common::{ - not_impl_err, plan_datafusion_err, plan_err, DFSchema, Dependency, Result, + internal_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema, + Dependency, Result, }; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ @@ -264,6 +265,23 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan_err!("Invalid function '{name}'.\nDid you mean '{suggested_func_name}'?") } + pub(super) fn sql_fn_name_to_expr( + &self, + expr: SQLExpr, + fn_name: &str, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + let fun = self + .context_provider + .get_function_meta(fn_name) + .ok_or_else(|| { + internal_datafusion_err!("Unable to find expected '{fn_name}' function") + })?; + let args = vec![self.sql_expr_to_logical_expr(expr, schema, planner_context)?]; + Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) + } + pub(super) fn sql_named_function_to_expr( &self, expr: SQLExpr, diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index c2f72720afcbe..7763fa2d8dab1 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -518,12 +518,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Floor { expr, field: _field, - } => self.sql_named_function_to_expr( - *expr, - BuiltinScalarFunction::Floor, - schema, - planner_context, - ), + } => self.sql_fn_name_to_expr(*expr, "floor", schema, planner_context), SQLExpr::Ceil { expr, field: _field,