From d2b3d1c7538b9fb7ab9cfc0c4c6a238b0dcd91e6 Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Mon, 1 Jan 2024 14:09:41 -0500 Subject: [PATCH] Rename `expr::window_function::WindowFunction` to `WindowFunctionDefinition`, make structure consistent with ScalarFunction (#8382) * Refactoring WindowFunction into coherent structure with AggregateFunction * One more cargo fmt --------- Co-authored-by: Andrew Lamb --- datafusion/core/src/dataframe/mod.rs | 6 +- .../core/src/physical_optimizer/test_utils.rs | 4 +- datafusion/core/tests/dataframe/mod.rs | 4 +- .../core/tests/fuzz_cases/window_fuzz.rs | 46 +- .../expr/src/built_in_window_function.rs | 207 ++++++++ datafusion/expr/src/expr.rs | 291 ++++++++++- datafusion/expr/src/lib.rs | 6 +- datafusion/expr/src/udwf.rs | 2 +- datafusion/expr/src/utils.rs | 22 +- datafusion/expr/src/window_function.rs | 483 ------------------ .../src/analyzer/count_wildcard_rule.rs | 10 +- .../optimizer/src/analyzer/type_coercion.rs | 8 +- .../optimizer/src/push_down_projection.rs | 6 +- datafusion/physical-plan/src/windows/mod.rs | 28 +- .../proto/src/logical_plan/from_proto.rs | 8 +- datafusion/proto/src/logical_plan/to_proto.rs | 10 +- .../proto/src/physical_plan/from_proto.rs | 10 +- .../tests/cases/roundtrip_logical_plan.rs | 20 +- datafusion/sql/src/expr/function.rs | 19 +- .../substrait/src/logical_plan/consumer.rs | 4 +- 20 files changed, 613 insertions(+), 581 deletions(-) create mode 100644 datafusion/expr/src/built_in_window_function.rs delete mode 100644 datafusion/expr/src/window_function.rs diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 3c3bcd497b7f..5a8c706e32cd 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1360,7 +1360,7 @@ mod tests { use datafusion_expr::{ avg, cast, count, count_distinct, create_udf, expr, lit, max, min, sum, BuiltInWindowFunction, ScalarFunctionImplementation, Volatility, WindowFrame, - WindowFunction, + WindowFunctionDefinition, }; use datafusion_physical_expr::expressions::Column; use datafusion_physical_plan::get_plan_string; @@ -1525,7 +1525,9 @@ mod tests { // build plan using Table API let t = test_table().await?; let first_row = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::FirstValue, + ), vec![col("aggregate_test_100.c1")], vec![col("aggregate_test_100.c2")], vec![], diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 6e14cca21fed..debafefe39ab 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -41,7 +41,7 @@ use crate::prelude::{CsvReadOptions, SessionContext}; use arrow_schema::{Schema, SchemaRef, SortOptions}; use datafusion_common::{JoinType, Statistics}; use datafusion_execution::object_store::ObjectStoreUrl; -use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunction}; +use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunctionDefinition}; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; @@ -234,7 +234,7 @@ pub fn bounded_window_exec( Arc::new( crate::physical_plan::windows::BoundedWindowAggExec::try_new( vec![create_window_expr( - &WindowFunction::AggregateFunction(AggregateFunction::Count), + &WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), "count".to_owned(), &[col(col_name, &schema).unwrap()], &[], diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index ba661aa2445c..cca23ac6847c 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -45,7 +45,7 @@ use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::{ array_agg, avg, col, count, exists, expr, in_subquery, lit, max, out_ref_col, scalar_subquery, sum, wildcard, AggregateFunction, Expr, ExprSchemable, WindowFrame, - WindowFrameBound, WindowFrameUnits, WindowFunction, + WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; use datafusion_physical_expr::var_provider::{VarProvider, VarType}; @@ -170,7 +170,7 @@ async fn test_count_wildcard_on_window() -> Result<()> { .table("t1") .await? .select(vec![Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), vec![wildcard()], vec![], vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 44ff71d02392..3037b4857a3b 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -33,7 +33,7 @@ use datafusion_common::{Result, ScalarValue}; use datafusion_expr::type_coercion::aggregates::coerce_types; use datafusion_expr::{ AggregateFunction, BuiltInWindowFunction, WindowFrame, WindowFrameBound, - WindowFrameUnits, WindowFunction, + WindowFrameUnits, WindowFunctionDefinition, }; use datafusion_physical_expr::expressions::{cast, col, lit}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; @@ -143,7 +143,7 @@ fn get_random_function( schema: &SchemaRef, rng: &mut StdRng, is_linear: bool, -) -> (WindowFunction, Vec>, String) { +) -> (WindowFunctionDefinition, Vec>, String) { let mut args = if is_linear { // In linear test for the test version with WindowAggExec we use insert SortExecs to the plan to be able to generate // same result with BoundedWindowAggExec which doesn't use any SortExec. To make result @@ -159,28 +159,28 @@ fn get_random_function( window_fn_map.insert( "sum", ( - WindowFunction::AggregateFunction(AggregateFunction::Sum), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum), vec![], ), ); window_fn_map.insert( "count", ( - WindowFunction::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), vec![], ), ); window_fn_map.insert( "min", ( - WindowFunction::AggregateFunction(AggregateFunction::Min), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), vec![], ), ); window_fn_map.insert( "max", ( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![], ), ); @@ -191,28 +191,36 @@ fn get_random_function( window_fn_map.insert( "row_number", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::RowNumber), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::RowNumber, + ), vec![], ), ); window_fn_map.insert( "rank", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Rank), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::Rank, + ), vec![], ), ); window_fn_map.insert( "dense_rank", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::DenseRank), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::DenseRank, + ), vec![], ), ); window_fn_map.insert( "lead", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lead), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::Lead, + ), vec![ lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), lit(ScalarValue::Int64(Some(rng.gen_range(1..1000)))), @@ -222,7 +230,9 @@ fn get_random_function( window_fn_map.insert( "lag", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lag), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::Lag, + ), vec![ lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), lit(ScalarValue::Int64(Some(rng.gen_range(1..1000)))), @@ -233,21 +243,27 @@ fn get_random_function( window_fn_map.insert( "first_value", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::FirstValue, + ), vec![], ), ); window_fn_map.insert( "last_value", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::LastValue), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::LastValue, + ), vec![], ), ); window_fn_map.insert( "nth_value", ( - WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::NthValue), + WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::NthValue, + ), vec![lit(ScalarValue::Int64(Some(rng.gen_range(1..10))))], ), ); @@ -255,7 +271,7 @@ fn get_random_function( let rand_fn_idx = rng.gen_range(0..window_fn_map.len()); let fn_name = window_fn_map.keys().collect::>()[rand_fn_idx]; let (window_fn, new_args) = window_fn_map.values().collect::>()[rand_fn_idx]; - if let WindowFunction::AggregateFunction(f) = window_fn { + if let WindowFunctionDefinition::AggregateFunction(f) = window_fn { let a = args[0].clone(); let dt = a.data_type(schema.as_ref()).unwrap(); let sig = f.signature(); diff --git a/datafusion/expr/src/built_in_window_function.rs b/datafusion/expr/src/built_in_window_function.rs new file mode 100644 index 000000000000..a03e3d2d24a9 --- /dev/null +++ b/datafusion/expr/src/built_in_window_function.rs @@ -0,0 +1,207 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Built-in functions module contains all the built-in functions definitions. + +use std::fmt; +use std::str::FromStr; + +use crate::type_coercion::functions::data_types; +use crate::utils; +use crate::{Signature, TypeSignature, Volatility}; +use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; + +use arrow::datatypes::DataType; + +use strum_macros::EnumIter; + +impl fmt::Display for BuiltInWindowFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.name()) + } +} + +/// A [window function] built in to DataFusion +/// +/// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL) +#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter)] +pub enum BuiltInWindowFunction { + /// number of the current row within its partition, counting from 1 + RowNumber, + /// rank of the current row with gaps; same as row_number of its first peer + Rank, + /// rank of the current row without gaps; this function counts peer groups + DenseRank, + /// relative rank of the current row: (rank - 1) / (total rows - 1) + PercentRank, + /// relative rank of the current row: (number of rows preceding or peer with current row) / (total rows) + CumeDist, + /// integer ranging from 1 to the argument value, dividing the partition as equally as possible + Ntile, + /// returns value evaluated at the row that is offset rows before the current row within the partition; + /// if there is no such row, instead return default (which must be of the same type as value). + /// Both offset and default are evaluated with respect to the current row. + /// If omitted, offset defaults to 1 and default to null + Lag, + /// returns value evaluated at the row that is offset rows after the current row within the partition; + /// if there is no such row, instead return default (which must be of the same type as value). + /// Both offset and default are evaluated with respect to the current row. + /// If omitted, offset defaults to 1 and default to null + Lead, + /// returns value evaluated at the row that is the first row of the window frame + FirstValue, + /// returns value evaluated at the row that is the last row of the window frame + LastValue, + /// returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row + NthValue, +} + +impl BuiltInWindowFunction { + fn name(&self) -> &str { + use BuiltInWindowFunction::*; + match self { + RowNumber => "ROW_NUMBER", + Rank => "RANK", + DenseRank => "DENSE_RANK", + PercentRank => "PERCENT_RANK", + CumeDist => "CUME_DIST", + Ntile => "NTILE", + Lag => "LAG", + Lead => "LEAD", + FirstValue => "FIRST_VALUE", + LastValue => "LAST_VALUE", + NthValue => "NTH_VALUE", + } + } +} + +impl FromStr for BuiltInWindowFunction { + type Err = DataFusionError; + fn from_str(name: &str) -> Result { + Ok(match name.to_uppercase().as_str() { + "ROW_NUMBER" => BuiltInWindowFunction::RowNumber, + "RANK" => BuiltInWindowFunction::Rank, + "DENSE_RANK" => BuiltInWindowFunction::DenseRank, + "PERCENT_RANK" => BuiltInWindowFunction::PercentRank, + "CUME_DIST" => BuiltInWindowFunction::CumeDist, + "NTILE" => BuiltInWindowFunction::Ntile, + "LAG" => BuiltInWindowFunction::Lag, + "LEAD" => BuiltInWindowFunction::Lead, + "FIRST_VALUE" => BuiltInWindowFunction::FirstValue, + "LAST_VALUE" => BuiltInWindowFunction::LastValue, + "NTH_VALUE" => BuiltInWindowFunction::NthValue, + _ => return plan_err!("There is no built-in window function named {name}"), + }) + } +} + +/// Returns the datatype of the built-in window function +impl BuiltInWindowFunction { + pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { + // Note that this function *must* return the same type that the respective physical expression returns + // or the execution panics. + + // verify that this is a valid set of data types for this function + data_types(input_expr_types, &self.signature()) + // original errors are all related to wrong function signature + // aggregate them for better error message + .map_err(|_| { + plan_datafusion_err!( + "{}", + utils::generate_signature_error_msg( + &format!("{self}"), + self.signature(), + input_expr_types, + ) + ) + })?; + + match self { + BuiltInWindowFunction::RowNumber + | BuiltInWindowFunction::Rank + | BuiltInWindowFunction::DenseRank => Ok(DataType::UInt64), + BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => { + Ok(DataType::Float64) + } + BuiltInWindowFunction::Ntile => Ok(DataType::UInt64), + BuiltInWindowFunction::Lag + | BuiltInWindowFunction::Lead + | BuiltInWindowFunction::FirstValue + | BuiltInWindowFunction::LastValue + | BuiltInWindowFunction::NthValue => Ok(input_expr_types[0].clone()), + } + } + + /// the signatures supported by the built-in window function `fun`. + pub fn signature(&self) -> Signature { + // note: the physical expression must accept the type returned by this function or the execution panics. + match self { + BuiltInWindowFunction::RowNumber + | BuiltInWindowFunction::Rank + | BuiltInWindowFunction::DenseRank + | BuiltInWindowFunction::PercentRank + | BuiltInWindowFunction::CumeDist => Signature::any(0, Volatility::Immutable), + BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => { + Signature::one_of( + vec![ + TypeSignature::Any(1), + TypeSignature::Any(2), + TypeSignature::Any(3), + ], + Volatility::Immutable, + ) + } + BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => { + Signature::any(1, Volatility::Immutable) + } + BuiltInWindowFunction::Ntile => Signature::uniform( + 1, + vec![ + DataType::UInt64, + DataType::UInt32, + DataType::UInt16, + DataType::UInt8, + DataType::Int64, + DataType::Int32, + DataType::Int16, + DataType::Int8, + ], + Volatility::Immutable, + ), + BuiltInWindowFunction::NthValue => Signature::any(2, Volatility::Immutable), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use strum::IntoEnumIterator; + #[test] + // Test for BuiltInWindowFunction's Display and from_str() implementations. + // For each variant in BuiltInWindowFunction, it converts the variant to a string + // and then back to a variant. The test asserts that the original variant and + // the reconstructed variant are the same. This assertion is also necessary for + // function suggestion. See https://github.com/apache/arrow-datafusion/issues/8082 + fn test_display_and_from_str() { + for func_original in BuiltInWindowFunction::iter() { + let func_name = func_original.to_string(); + let func_from_str = BuiltInWindowFunction::from_str(&func_name).unwrap(); + assert_eq!(func_from_str, func_original); + } + } +} diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 0ec19bcadbf6..ebf4d3143c12 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -19,13 +19,13 @@ use crate::expr_fn::binary_expr; use crate::logical_plan::Subquery; -use crate::udaf; use crate::utils::{expr_to_columns, find_out_reference_exprs}; use crate::window_frame; -use crate::window_function; + use crate::Operator; use crate::{aggregate_function, ExprSchemable}; use crate::{built_in_function, BuiltinScalarFunction}; +use crate::{built_in_window_function, udaf}; use arrow::datatypes::DataType; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{internal_err, DFSchema, OwnedTableReference}; @@ -34,8 +34,11 @@ use std::collections::HashSet; use std::fmt; use std::fmt::{Display, Formatter, Write}; use std::hash::{BuildHasher, Hash, Hasher}; +use std::str::FromStr; use std::sync::Arc; +use crate::Signature; + /// `Expr` is a central struct of DataFusion's query API, and /// represent logical expressions such as `A + 1`, or `CAST(c1 AS /// int)`. @@ -566,11 +569,64 @@ impl AggregateFunction { } } +/// WindowFunction +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +/// Defines which implementation of an aggregate function DataFusion should call. +pub enum WindowFunctionDefinition { + /// A built in aggregate function that leverages an aggregate function + AggregateFunction(aggregate_function::AggregateFunction), + /// A a built-in window function + BuiltInWindowFunction(built_in_window_function::BuiltInWindowFunction), + /// A user defined aggregate function + AggregateUDF(Arc), + /// A user defined aggregate function + WindowUDF(Arc), +} + +impl WindowFunctionDefinition { + /// Returns the datatype of the window function + pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { + match self { + WindowFunctionDefinition::AggregateFunction(fun) => { + fun.return_type(input_expr_types) + } + WindowFunctionDefinition::BuiltInWindowFunction(fun) => { + fun.return_type(input_expr_types) + } + WindowFunctionDefinition::AggregateUDF(fun) => { + fun.return_type(input_expr_types) + } + WindowFunctionDefinition::WindowUDF(fun) => fun.return_type(input_expr_types), + } + } + + /// the signatures supported by the function `fun`. + pub fn signature(&self) -> Signature { + match self { + WindowFunctionDefinition::AggregateFunction(fun) => fun.signature(), + WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.signature(), + WindowFunctionDefinition::AggregateUDF(fun) => fun.signature().clone(), + WindowFunctionDefinition::WindowUDF(fun) => fun.signature().clone(), + } + } +} + +impl fmt::Display for WindowFunctionDefinition { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + WindowFunctionDefinition::AggregateFunction(fun) => fun.fmt(f), + WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.fmt(f), + WindowFunctionDefinition::AggregateUDF(fun) => std::fmt::Debug::fmt(fun, f), + WindowFunctionDefinition::WindowUDF(fun) => fun.fmt(f), + } + } +} + /// Window function #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct WindowFunction { /// Name of the function - pub fun: window_function::WindowFunction, + pub fun: WindowFunctionDefinition, /// List of expressions to feed to the functions as arguments pub args: Vec, /// List of partition by expressions @@ -584,7 +640,7 @@ pub struct WindowFunction { impl WindowFunction { /// Create a new Window expression pub fn new( - fun: window_function::WindowFunction, + fun: WindowFunctionDefinition, args: Vec, partition_by: Vec, order_by: Vec, @@ -600,6 +656,50 @@ impl WindowFunction { } } +/// Find DataFusion's built-in window function by name. +pub fn find_df_window_func(name: &str) -> Option { + let name = name.to_lowercase(); + // Code paths for window functions leveraging ordinary aggregators and + // built-in window functions are quite different, and the same function + // may have different implementations for these cases. If the sought + // function is not found among built-in window functions, we search for + // it among aggregate functions. + if let Ok(built_in_function) = + built_in_window_function::BuiltInWindowFunction::from_str(name.as_str()) + { + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_function, + )) + } else if let Ok(aggregate) = + aggregate_function::AggregateFunction::from_str(name.as_str()) + { + Some(WindowFunctionDefinition::AggregateFunction(aggregate)) + } else { + None + } +} + +/// Returns the datatype of the window function +#[deprecated( + since = "27.0.0", + note = "please use `WindowFunction::return_type` instead" +)] +pub fn return_type( + fun: &WindowFunctionDefinition, + input_expr_types: &[DataType], +) -> Result { + fun.return_type(input_expr_types) +} + +/// the signatures supported by the function `fun`. +#[deprecated( + since = "27.0.0", + note = "please use `WindowFunction::signature` instead" +)] +pub fn signature(fun: &WindowFunctionDefinition) -> Signature { + fun.signature() +} + // Exists expression. #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct Exists { @@ -1890,4 +1990,187 @@ mod test { .is_volatile() .expect_err("Shouldn't determine volatility of unresolved function"); } + + use super::*; + + #[test] + fn test_count_return_type() -> Result<()> { + let fun = find_df_window_func("count").unwrap(); + let observed = fun.return_type(&[DataType::Utf8])?; + assert_eq!(DataType::Int64, observed); + + let observed = fun.return_type(&[DataType::UInt64])?; + assert_eq!(DataType::Int64, observed); + + Ok(()) + } + + #[test] + fn test_first_value_return_type() -> Result<()> { + let fun = find_df_window_func("first_value").unwrap(); + let observed = fun.return_type(&[DataType::Utf8])?; + assert_eq!(DataType::Utf8, observed); + + let observed = fun.return_type(&[DataType::UInt64])?; + assert_eq!(DataType::UInt64, observed); + + Ok(()) + } + + #[test] + fn test_last_value_return_type() -> Result<()> { + let fun = find_df_window_func("last_value").unwrap(); + let observed = fun.return_type(&[DataType::Utf8])?; + assert_eq!(DataType::Utf8, observed); + + let observed = fun.return_type(&[DataType::Float64])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_lead_return_type() -> Result<()> { + let fun = find_df_window_func("lead").unwrap(); + let observed = fun.return_type(&[DataType::Utf8])?; + assert_eq!(DataType::Utf8, observed); + + let observed = fun.return_type(&[DataType::Float64])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_lag_return_type() -> Result<()> { + let fun = find_df_window_func("lag").unwrap(); + let observed = fun.return_type(&[DataType::Utf8])?; + assert_eq!(DataType::Utf8, observed); + + let observed = fun.return_type(&[DataType::Float64])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_nth_value_return_type() -> Result<()> { + let fun = find_df_window_func("nth_value").unwrap(); + let observed = fun.return_type(&[DataType::Utf8, DataType::UInt64])?; + assert_eq!(DataType::Utf8, observed); + + let observed = fun.return_type(&[DataType::Float64, DataType::UInt64])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_percent_rank_return_type() -> Result<()> { + let fun = find_df_window_func("percent_rank").unwrap(); + let observed = fun.return_type(&[])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_cume_dist_return_type() -> Result<()> { + let fun = find_df_window_func("cume_dist").unwrap(); + let observed = fun.return_type(&[])?; + assert_eq!(DataType::Float64, observed); + + Ok(()) + } + + #[test] + fn test_ntile_return_type() -> Result<()> { + let fun = find_df_window_func("ntile").unwrap(); + let observed = fun.return_type(&[DataType::Int16])?; + assert_eq!(DataType::UInt64, observed); + + Ok(()) + } + + #[test] + fn test_window_function_case_insensitive() -> Result<()> { + let names = vec![ + "row_number", + "rank", + "dense_rank", + "percent_rank", + "cume_dist", + "ntile", + "lag", + "lead", + "first_value", + "last_value", + "nth_value", + "min", + "max", + "count", + "avg", + "sum", + ]; + for name in names { + let fun = find_df_window_func(name).unwrap(); + let fun2 = find_df_window_func(name.to_uppercase().as_str()).unwrap(); + assert_eq!(fun, fun2); + assert_eq!(fun.to_string(), name.to_uppercase()); + } + Ok(()) + } + + #[test] + fn test_find_df_window_function() { + assert_eq!( + find_df_window_func("max"), + Some(WindowFunctionDefinition::AggregateFunction( + aggregate_function::AggregateFunction::Max + )) + ); + assert_eq!( + find_df_window_func("min"), + Some(WindowFunctionDefinition::AggregateFunction( + aggregate_function::AggregateFunction::Min + )) + ); + assert_eq!( + find_df_window_func("avg"), + Some(WindowFunctionDefinition::AggregateFunction( + aggregate_function::AggregateFunction::Avg + )) + ); + assert_eq!( + find_df_window_func("cume_dist"), + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_window_function::BuiltInWindowFunction::CumeDist + )) + ); + assert_eq!( + find_df_window_func("first_value"), + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_window_function::BuiltInWindowFunction::FirstValue + )) + ); + assert_eq!( + find_df_window_func("LAST_value"), + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_window_function::BuiltInWindowFunction::LastValue + )) + ); + assert_eq!( + find_df_window_func("LAG"), + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_window_function::BuiltInWindowFunction::Lag + )) + ); + assert_eq!( + find_df_window_func("LEAD"), + Some(WindowFunctionDefinition::BuiltInWindowFunction( + built_in_window_function::BuiltInWindowFunction::Lead + )) + ); + assert_eq!(find_df_window_func("not_exist"), None) + } } diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index bf8e9e2954f4..ab213a19a352 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -27,6 +27,7 @@ mod accumulator; mod built_in_function; +mod built_in_window_function; mod columnar_value; mod literal; mod nullif; @@ -53,16 +54,16 @@ pub mod tree_node; pub mod type_coercion; pub mod utils; pub mod window_frame; -pub mod window_function; pub mod window_state; pub use accumulator::Accumulator; pub use aggregate_function::AggregateFunction; pub use built_in_function::BuiltinScalarFunction; +pub use built_in_window_function::BuiltInWindowFunction; pub use columnar_value::ColumnarValue; pub use expr::{ Between, BinaryExpr, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, - Like, ScalarFunctionDefinition, TryCast, + Like, ScalarFunctionDefinition, TryCast, WindowFunctionDefinition, }; pub use expr_fn::*; pub use expr_schema::ExprSchemable; @@ -83,7 +84,6 @@ pub use udaf::AggregateUDF; pub use udf::{ScalarUDF, ScalarUDFImpl}; pub use udwf::WindowUDF; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; -pub use window_function::{BuiltInWindowFunction, WindowFunction}; #[cfg(test)] #[ctor::ctor] diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index c233ee84b32d..a97a68341f5c 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -107,7 +107,7 @@ impl WindowUDF { order_by: Vec, window_frame: WindowFrame, ) -> Expr { - let fun = crate::WindowFunction::WindowUDF(Arc::new(self.clone())); + let fun = crate::WindowFunctionDefinition::WindowUDF(Arc::new(self.clone())); Expr::WindowFunction(crate::expr::WindowFunction { fun, diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 09f4842c9e64..e3ecdf154e61 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1234,7 +1234,7 @@ mod tests { use super::*; use crate::{ col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, AggregateFunction, - WindowFrame, WindowFunction, + WindowFrame, WindowFunctionDefinition, }; #[test] @@ -1248,28 +1248,28 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> { let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], vec![], vec![], WindowFrame::new(false), )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], vec![], vec![], WindowFrame::new(false), )); let min3 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Min), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), vec![col("name")], vec![], vec![], WindowFrame::new(false), )); let sum4 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Sum), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum), vec![col("age")], vec![], vec![], @@ -1291,28 +1291,28 @@ mod tests { let created_at_desc = Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)); let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], vec![], vec![age_asc.clone(), name_desc.clone()], WindowFrame::new(true), )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], vec![], vec![], WindowFrame::new(false), )); let min3 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Min), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), vec![col("name")], vec![], vec![age_asc.clone(), name_desc.clone()], WindowFrame::new(true), )); let sum4 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Sum), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum), vec![col("age")], vec![], vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()], @@ -1343,7 +1343,7 @@ mod tests { fn test_find_sort_exprs() -> Result<()> { let exprs = &[ Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], vec![], vec![ @@ -1353,7 +1353,7 @@ mod tests { WindowFrame::new(true), )), Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Sum), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum), vec![col("age")], vec![], vec![ diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs deleted file mode 100644 index 610f1ecaeae9..000000000000 --- a/datafusion/expr/src/window_function.rs +++ /dev/null @@ -1,483 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Window functions provide the ability to perform calculations across -//! sets of rows that are related to the current query row. -//! -//! see also - -use crate::aggregate_function::AggregateFunction; -use crate::type_coercion::functions::data_types; -use crate::utils; -use crate::{AggregateUDF, Signature, TypeSignature, Volatility, WindowUDF}; -use arrow::datatypes::DataType; -use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; -use std::sync::Arc; -use std::{fmt, str::FromStr}; -use strum_macros::EnumIter; - -/// WindowFunction -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum WindowFunction { - /// A built in aggregate function that leverages an aggregate function - AggregateFunction(AggregateFunction), - /// A a built-in window function - BuiltInWindowFunction(BuiltInWindowFunction), - /// A user defined aggregate function - AggregateUDF(Arc), - /// A user defined aggregate function - WindowUDF(Arc), -} - -/// Find DataFusion's built-in window function by name. -pub fn find_df_window_func(name: &str) -> Option { - let name = name.to_lowercase(); - // Code paths for window functions leveraging ordinary aggregators and - // built-in window functions are quite different, and the same function - // may have different implementations for these cases. If the sought - // function is not found among built-in window functions, we search for - // it among aggregate functions. - if let Ok(built_in_function) = BuiltInWindowFunction::from_str(name.as_str()) { - Some(WindowFunction::BuiltInWindowFunction(built_in_function)) - } else if let Ok(aggregate) = AggregateFunction::from_str(name.as_str()) { - Some(WindowFunction::AggregateFunction(aggregate)) - } else { - None - } -} - -impl fmt::Display for BuiltInWindowFunction { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.name()) - } -} - -impl fmt::Display for WindowFunction { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - WindowFunction::AggregateFunction(fun) => fun.fmt(f), - WindowFunction::BuiltInWindowFunction(fun) => fun.fmt(f), - WindowFunction::AggregateUDF(fun) => std::fmt::Debug::fmt(fun, f), - WindowFunction::WindowUDF(fun) => fun.fmt(f), - } - } -} - -/// A [window function] built in to DataFusion -/// -/// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL) -#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter)] -pub enum BuiltInWindowFunction { - /// number of the current row within its partition, counting from 1 - RowNumber, - /// rank of the current row with gaps; same as row_number of its first peer - Rank, - /// rank of the current row without gaps; this function counts peer groups - DenseRank, - /// relative rank of the current row: (rank - 1) / (total rows - 1) - PercentRank, - /// relative rank of the current row: (number of rows preceding or peer with current row) / (total rows) - CumeDist, - /// integer ranging from 1 to the argument value, dividing the partition as equally as possible - Ntile, - /// returns value evaluated at the row that is offset rows before the current row within the partition; - /// if there is no such row, instead return default (which must be of the same type as value). - /// Both offset and default are evaluated with respect to the current row. - /// If omitted, offset defaults to 1 and default to null - Lag, - /// returns value evaluated at the row that is offset rows after the current row within the partition; - /// if there is no such row, instead return default (which must be of the same type as value). - /// Both offset and default are evaluated with respect to the current row. - /// If omitted, offset defaults to 1 and default to null - Lead, - /// returns value evaluated at the row that is the first row of the window frame - FirstValue, - /// returns value evaluated at the row that is the last row of the window frame - LastValue, - /// returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row - NthValue, -} - -impl BuiltInWindowFunction { - fn name(&self) -> &str { - use BuiltInWindowFunction::*; - match self { - RowNumber => "ROW_NUMBER", - Rank => "RANK", - DenseRank => "DENSE_RANK", - PercentRank => "PERCENT_RANK", - CumeDist => "CUME_DIST", - Ntile => "NTILE", - Lag => "LAG", - Lead => "LEAD", - FirstValue => "FIRST_VALUE", - LastValue => "LAST_VALUE", - NthValue => "NTH_VALUE", - } - } -} - -impl FromStr for BuiltInWindowFunction { - type Err = DataFusionError; - fn from_str(name: &str) -> Result { - Ok(match name.to_uppercase().as_str() { - "ROW_NUMBER" => BuiltInWindowFunction::RowNumber, - "RANK" => BuiltInWindowFunction::Rank, - "DENSE_RANK" => BuiltInWindowFunction::DenseRank, - "PERCENT_RANK" => BuiltInWindowFunction::PercentRank, - "CUME_DIST" => BuiltInWindowFunction::CumeDist, - "NTILE" => BuiltInWindowFunction::Ntile, - "LAG" => BuiltInWindowFunction::Lag, - "LEAD" => BuiltInWindowFunction::Lead, - "FIRST_VALUE" => BuiltInWindowFunction::FirstValue, - "LAST_VALUE" => BuiltInWindowFunction::LastValue, - "NTH_VALUE" => BuiltInWindowFunction::NthValue, - _ => return plan_err!("There is no built-in window function named {name}"), - }) - } -} - -/// Returns the datatype of the window function -#[deprecated( - since = "27.0.0", - note = "please use `WindowFunction::return_type` instead" -)] -pub fn return_type( - fun: &WindowFunction, - input_expr_types: &[DataType], -) -> Result { - fun.return_type(input_expr_types) -} - -impl WindowFunction { - /// Returns the datatype of the window function - pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { - match self { - WindowFunction::AggregateFunction(fun) => fun.return_type(input_expr_types), - WindowFunction::BuiltInWindowFunction(fun) => { - fun.return_type(input_expr_types) - } - WindowFunction::AggregateUDF(fun) => fun.return_type(input_expr_types), - WindowFunction::WindowUDF(fun) => fun.return_type(input_expr_types), - } - } -} - -/// Returns the datatype of the built-in window function -impl BuiltInWindowFunction { - pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { - // Note that this function *must* return the same type that the respective physical expression returns - // or the execution panics. - - // verify that this is a valid set of data types for this function - data_types(input_expr_types, &self.signature()) - // original errors are all related to wrong function signature - // aggregate them for better error message - .map_err(|_| { - plan_datafusion_err!( - "{}", - utils::generate_signature_error_msg( - &format!("{self}"), - self.signature(), - input_expr_types, - ) - ) - })?; - - match self { - BuiltInWindowFunction::RowNumber - | BuiltInWindowFunction::Rank - | BuiltInWindowFunction::DenseRank => Ok(DataType::UInt64), - BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => { - Ok(DataType::Float64) - } - BuiltInWindowFunction::Ntile => Ok(DataType::UInt64), - BuiltInWindowFunction::Lag - | BuiltInWindowFunction::Lead - | BuiltInWindowFunction::FirstValue - | BuiltInWindowFunction::LastValue - | BuiltInWindowFunction::NthValue => Ok(input_expr_types[0].clone()), - } - } -} - -/// the signatures supported by the function `fun`. -#[deprecated( - since = "27.0.0", - note = "please use `WindowFunction::signature` instead" -)] -pub fn signature(fun: &WindowFunction) -> Signature { - fun.signature() -} - -impl WindowFunction { - /// the signatures supported by the function `fun`. - pub fn signature(&self) -> Signature { - match self { - WindowFunction::AggregateFunction(fun) => fun.signature(), - WindowFunction::BuiltInWindowFunction(fun) => fun.signature(), - WindowFunction::AggregateUDF(fun) => fun.signature().clone(), - WindowFunction::WindowUDF(fun) => fun.signature().clone(), - } - } -} - -/// the signatures supported by the built-in window function `fun`. -#[deprecated( - since = "27.0.0", - note = "please use `BuiltInWindowFunction::signature` instead" -)] -pub fn signature_for_built_in(fun: &BuiltInWindowFunction) -> Signature { - fun.signature() -} - -impl BuiltInWindowFunction { - /// the signatures supported by the built-in window function `fun`. - pub fn signature(&self) -> Signature { - // note: the physical expression must accept the type returned by this function or the execution panics. - match self { - BuiltInWindowFunction::RowNumber - | BuiltInWindowFunction::Rank - | BuiltInWindowFunction::DenseRank - | BuiltInWindowFunction::PercentRank - | BuiltInWindowFunction::CumeDist => Signature::any(0, Volatility::Immutable), - BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => { - Signature::one_of( - vec![ - TypeSignature::Any(1), - TypeSignature::Any(2), - TypeSignature::Any(3), - ], - Volatility::Immutable, - ) - } - BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => { - Signature::any(1, Volatility::Immutable) - } - BuiltInWindowFunction::Ntile => Signature::uniform( - 1, - vec![ - DataType::UInt64, - DataType::UInt32, - DataType::UInt16, - DataType::UInt8, - DataType::Int64, - DataType::Int32, - DataType::Int16, - DataType::Int8, - ], - Volatility::Immutable, - ), - BuiltInWindowFunction::NthValue => Signature::any(2, Volatility::Immutable), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use strum::IntoEnumIterator; - - #[test] - fn test_count_return_type() -> Result<()> { - let fun = find_df_window_func("count").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Int64, observed); - - let observed = fun.return_type(&[DataType::UInt64])?; - assert_eq!(DataType::Int64, observed); - - Ok(()) - } - - #[test] - fn test_first_value_return_type() -> Result<()> { - let fun = find_df_window_func("first_value").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::UInt64])?; - assert_eq!(DataType::UInt64, observed); - - Ok(()) - } - - #[test] - fn test_last_value_return_type() -> Result<()> { - let fun = find_df_window_func("last_value").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_lead_return_type() -> Result<()> { - let fun = find_df_window_func("lead").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_lag_return_type() -> Result<()> { - let fun = find_df_window_func("lag").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_nth_value_return_type() -> Result<()> { - let fun = find_df_window_func("nth_value").unwrap(); - let observed = fun.return_type(&[DataType::Utf8, DataType::UInt64])?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64, DataType::UInt64])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_percent_rank_return_type() -> Result<()> { - let fun = find_df_window_func("percent_rank").unwrap(); - let observed = fun.return_type(&[])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_cume_dist_return_type() -> Result<()> { - let fun = find_df_window_func("cume_dist").unwrap(); - let observed = fun.return_type(&[])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_ntile_return_type() -> Result<()> { - let fun = find_df_window_func("ntile").unwrap(); - let observed = fun.return_type(&[DataType::Int16])?; - assert_eq!(DataType::UInt64, observed); - - Ok(()) - } - - #[test] - fn test_window_function_case_insensitive() -> Result<()> { - let names = vec![ - "row_number", - "rank", - "dense_rank", - "percent_rank", - "cume_dist", - "ntile", - "lag", - "lead", - "first_value", - "last_value", - "nth_value", - "min", - "max", - "count", - "avg", - "sum", - ]; - for name in names { - let fun = find_df_window_func(name).unwrap(); - let fun2 = find_df_window_func(name.to_uppercase().as_str()).unwrap(); - assert_eq!(fun, fun2); - assert_eq!(fun.to_string(), name.to_uppercase()); - } - Ok(()) - } - - #[test] - fn test_find_df_window_function() { - assert_eq!( - find_df_window_func("max"), - Some(WindowFunction::AggregateFunction(AggregateFunction::Max)) - ); - assert_eq!( - find_df_window_func("min"), - Some(WindowFunction::AggregateFunction(AggregateFunction::Min)) - ); - assert_eq!( - find_df_window_func("avg"), - Some(WindowFunction::AggregateFunction(AggregateFunction::Avg)) - ); - assert_eq!( - find_df_window_func("cume_dist"), - Some(WindowFunction::BuiltInWindowFunction( - BuiltInWindowFunction::CumeDist - )) - ); - assert_eq!( - find_df_window_func("first_value"), - Some(WindowFunction::BuiltInWindowFunction( - BuiltInWindowFunction::FirstValue - )) - ); - assert_eq!( - find_df_window_func("LAST_value"), - Some(WindowFunction::BuiltInWindowFunction( - BuiltInWindowFunction::LastValue - )) - ); - assert_eq!( - find_df_window_func("LAG"), - Some(WindowFunction::BuiltInWindowFunction( - BuiltInWindowFunction::Lag - )) - ); - assert_eq!( - find_df_window_func("LEAD"), - Some(WindowFunction::BuiltInWindowFunction( - BuiltInWindowFunction::Lead - )) - ); - assert_eq!(find_df_window_func("not_exist"), None) - } - - #[test] - // Test for BuiltInWindowFunction's Display and from_str() implementations. - // For each variant in BuiltInWindowFunction, it converts the variant to a string - // and then back to a variant. The test asserts that the original variant and - // the reconstructed variant are the same. This assertion is also necessary for - // function suggestion. See https://github.com/apache/arrow-datafusion/issues/8082 - fn test_display_and_from_str() { - for func_original in BuiltInWindowFunction::iter() { - let func_name = func_original.to_string(); - let func_from_str = BuiltInWindowFunction::from_str(&func_name).unwrap(); - assert_eq!(func_from_str, func_original); - } - } -} diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index fd84bb80160b..953716713e41 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -24,7 +24,7 @@ use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::Expr::ScalarSubquery; use datafusion_expr::{ - aggregate_function, expr, lit, window_function, Aggregate, Expr, Filter, LogicalPlan, + aggregate_function, expr, lit, Aggregate, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Projection, Sort, Subquery, }; use std::sync::Arc; @@ -121,7 +121,7 @@ impl TreeNodeRewriter for CountWildcardRewriter { let new_expr = match old_expr.clone() { Expr::WindowFunction(expr::WindowFunction { fun: - window_function::WindowFunction::AggregateFunction( + expr::WindowFunctionDefinition::AggregateFunction( aggregate_function::AggregateFunction::Count, ), args, @@ -131,7 +131,7 @@ impl TreeNodeRewriter for CountWildcardRewriter { }) if args.len() == 1 => match args[0] { Expr::Wildcard { qualifier: None } => { Expr::WindowFunction(expr::WindowFunction { - fun: window_function::WindowFunction::AggregateFunction( + fun: expr::WindowFunctionDefinition::AggregateFunction( aggregate_function::AggregateFunction::Count, ), args: vec![lit(COUNT_STAR_EXPANSION)], @@ -229,7 +229,7 @@ mod tests { use datafusion_expr::{ col, count, exists, expr, in_subquery, lit, logical_plan::LogicalPlanBuilder, max, out_ref_col, scalar_subquery, wildcard, AggregateFunction, Expr, - WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, + WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; fn assert_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { @@ -342,7 +342,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .window(vec![Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), vec![wildcard()], vec![], vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index b6298f5b552f..4d54dad99670 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -45,9 +45,9 @@ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8}; use datafusion_expr::utils::merge_schema; use datafusion_expr::{ is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, - type_coercion, window_function, AggregateFunction, BuiltinScalarFunction, Expr, - ExprSchemable, LogicalPlan, Operator, Projection, ScalarFunctionDefinition, - Signature, WindowFrame, WindowFrameBound, WindowFrameUnits, + type_coercion, AggregateFunction, BuiltinScalarFunction, Expr, ExprSchemable, + LogicalPlan, Operator, Projection, ScalarFunctionDefinition, Signature, WindowFrame, + WindowFrameBound, WindowFrameUnits, }; use crate::analyzer::AnalyzerRule; @@ -390,7 +390,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { coerce_window_frame(window_frame, &self.schema, &order_by)?; let args = match &fun { - window_function::WindowFunction::AggregateFunction(fun) => { + expr::WindowFunctionDefinition::AggregateFunction(fun) => { coerce_agg_exprs_for_signature( fun, &args, diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index 10cc1879aeeb..4ee4f7e417a6 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -37,7 +37,7 @@ mod tests { }; use datafusion_expr::{ col, count, lit, max, min, AggregateFunction, Expr, LogicalPlan, Projection, - WindowFrame, WindowFunction, + WindowFrame, WindowFunctionDefinition, }; #[test] @@ -582,7 +582,7 @@ mod tests { let table_scan = test_table_scan()?; let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("test.a")], vec![col("test.b")], vec![], @@ -590,7 +590,7 @@ mod tests { )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("test.b")], vec![], vec![], diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 3187e6b0fbd3..fec168fabf48 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -34,8 +34,8 @@ use arrow::datatypes::Schema; use arrow_schema::{DataType, Field, SchemaRef}; use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::{ - window_function::{BuiltInWindowFunction, WindowFunction}, - PartitionEvaluator, WindowFrame, WindowUDF, + BuiltInWindowFunction, PartitionEvaluator, WindowFrame, WindowFunctionDefinition, + WindowUDF, }; use datafusion_physical_expr::equivalence::collapse_lex_req; use datafusion_physical_expr::{ @@ -56,7 +56,7 @@ pub use datafusion_physical_expr::window::{ /// Create a physical expression for window function pub fn create_window_expr( - fun: &WindowFunction, + fun: &WindowFunctionDefinition, name: String, args: &[Arc], partition_by: &[Arc], @@ -65,7 +65,7 @@ pub fn create_window_expr( input_schema: &Schema, ) -> Result> { Ok(match fun { - WindowFunction::AggregateFunction(fun) => { + WindowFunctionDefinition::AggregateFunction(fun) => { let aggregate = aggregates::create_aggregate_expr( fun, false, @@ -81,13 +81,15 @@ pub fn create_window_expr( aggregate, ) } - WindowFunction::BuiltInWindowFunction(fun) => Arc::new(BuiltInWindowExpr::new( - create_built_in_window_expr(fun, args, input_schema, name)?, - partition_by, - order_by, - window_frame, - )), - WindowFunction::AggregateUDF(fun) => { + WindowFunctionDefinition::BuiltInWindowFunction(fun) => { + Arc::new(BuiltInWindowExpr::new( + create_built_in_window_expr(fun, args, input_schema, name)?, + partition_by, + order_by, + window_frame, + )) + } + WindowFunctionDefinition::AggregateUDF(fun) => { let aggregate = udaf::create_aggregate_expr(fun.as_ref(), args, input_schema, name)?; window_expr_from_aggregate_expr( @@ -97,7 +99,7 @@ pub fn create_window_expr( aggregate, ) } - WindowFunction::WindowUDF(fun) => Arc::new(BuiltInWindowExpr::new( + WindowFunctionDefinition::WindowUDF(fun) => Arc::new(BuiltInWindowExpr::new( create_udwf_window_expr(fun, args, input_schema, name)?, partition_by, order_by, @@ -647,7 +649,7 @@ mod tests { let refs = blocking_exec.refs(); let window_agg_exec = Arc::new(WindowAggExec::try_new( vec![create_window_expr( - &WindowFunction::AggregateFunction(AggregateFunction::Count), + &WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), "count".to_owned(), &[col("a", &schema)?], &[], diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index c582e92dc11c..36c5b44f00b9 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -1112,7 +1112,7 @@ pub fn parse_expr( let aggr_function = parse_i32_to_aggregate_function(i)?; Ok(Expr::WindowFunction(WindowFunction::new( - datafusion_expr::window_function::WindowFunction::AggregateFunction( + datafusion_expr::expr::WindowFunctionDefinition::AggregateFunction( aggr_function, ), vec![parse_required_expr(expr.expr.as_deref(), registry, "expr")?], @@ -1131,7 +1131,7 @@ pub fn parse_expr( .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( - datafusion_expr::window_function::WindowFunction::BuiltInWindowFunction( + datafusion_expr::expr::WindowFunctionDefinition::BuiltInWindowFunction( built_in_function, ), args, @@ -1146,7 +1146,7 @@ pub fn parse_expr( .map(|e| vec![e]) .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( - datafusion_expr::window_function::WindowFunction::AggregateUDF( + datafusion_expr::expr::WindowFunctionDefinition::AggregateUDF( udaf_function, ), args, @@ -1161,7 +1161,7 @@ pub fn parse_expr( .map(|e| vec![e]) .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( - datafusion_expr::window_function::WindowFunction::WindowUDF( + datafusion_expr::expr::WindowFunctionDefinition::WindowUDF( udwf_function, ), args, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index b9987ff6c727..a162b2389cd1 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -51,7 +51,7 @@ use datafusion_expr::expr::{ use datafusion_expr::{ logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction, BuiltInWindowFunction, BuiltinScalarFunction, Expr, JoinConstraint, JoinType, - TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, + TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; #[derive(Debug)] @@ -605,22 +605,22 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { ref window_frame, }) => { let window_function = match fun { - WindowFunction::AggregateFunction(fun) => { + WindowFunctionDefinition::AggregateFunction(fun) => { protobuf::window_expr_node::WindowFunction::AggrFunction( protobuf::AggregateFunction::from(fun).into(), ) } - WindowFunction::BuiltInWindowFunction(fun) => { + WindowFunctionDefinition::BuiltInWindowFunction(fun) => { protobuf::window_expr_node::WindowFunction::BuiltInFunction( protobuf::BuiltInWindowFunction::from(fun).into(), ) } - WindowFunction::AggregateUDF(aggr_udf) => { + WindowFunctionDefinition::AggregateUDF(aggr_udf) => { protobuf::window_expr_node::WindowFunction::Udaf( aggr_udf.name().to_string(), ) } - WindowFunction::WindowUDF(window_udf) => { + WindowFunctionDefinition::WindowUDF(window_udf) => { protobuf::window_expr_node::WindowFunction::Udwf( window_udf.name().to_string(), ) diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 8ad6d679df4d..23ab813ca739 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -31,7 +31,7 @@ use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{FileScanConfig, FileSinkConfig}; use datafusion::execution::context::ExecutionProps; use datafusion::execution::FunctionRegistry; -use datafusion::logical_expr::window_function::WindowFunction; +use datafusion::logical_expr::WindowFunctionDefinition; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ in_list, BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, LikeExpr, @@ -414,7 +414,9 @@ fn parse_required_physical_expr( }) } -impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> for WindowFunction { +impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> + for WindowFunctionDefinition +{ type Error = DataFusionError; fn try_from( @@ -428,7 +430,7 @@ impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> for WindowFun )) })?; - Ok(WindowFunction::AggregateFunction(f.into())) + Ok(WindowFunctionDefinition::AggregateFunction(f.into())) } protobuf::physical_window_expr_node::WindowFunction::BuiltInFunction(n) => { let f = protobuf::BuiltInWindowFunction::try_from(*n).map_err(|_| { @@ -437,7 +439,7 @@ impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> for WindowFun )) })?; - Ok(WindowFunction::BuiltInWindowFunction(f.into())) + Ok(WindowFunctionDefinition::BuiltInWindowFunction(f.into())) } } } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 2d7d85abda96..dea99f91e392 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -53,7 +53,7 @@ use datafusion_expr::{ col, create_udaf, lit, Accumulator, AggregateFunction, BuiltinScalarFunction::{Sqrt, Substr}, Expr, LogicalPlan, Operator, PartitionEvaluator, Signature, TryCast, Volatility, - WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, WindowUDF, + WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, }; use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, @@ -1663,8 +1663,8 @@ fn roundtrip_window() { // 1. without window_frame let test_expr1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::BuiltInWindowFunction( - datafusion_expr::window_function::BuiltInWindowFunction::Rank, + WindowFunctionDefinition::BuiltInWindowFunction( + datafusion_expr::BuiltInWindowFunction::Rank, ), vec![], vec![col("col1")], @@ -1674,8 +1674,8 @@ fn roundtrip_window() { // 2. with default window_frame let test_expr2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::BuiltInWindowFunction( - datafusion_expr::window_function::BuiltInWindowFunction::Rank, + WindowFunctionDefinition::BuiltInWindowFunction( + datafusion_expr::BuiltInWindowFunction::Rank, ), vec![], vec![col("col1")], @@ -1691,8 +1691,8 @@ fn roundtrip_window() { }; let test_expr3 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::BuiltInWindowFunction( - datafusion_expr::window_function::BuiltInWindowFunction::Rank, + WindowFunctionDefinition::BuiltInWindowFunction( + datafusion_expr::BuiltInWindowFunction::Rank, ), vec![], vec![col("col1")], @@ -1708,7 +1708,7 @@ fn roundtrip_window() { }; let test_expr4 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("col1")], vec![col("col1")], vec![col("col2")], @@ -1759,7 +1759,7 @@ fn roundtrip_window() { ); let test_expr5 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateUDF(Arc::new(dummy_agg.clone())), + WindowFunctionDefinition::AggregateUDF(Arc::new(dummy_agg.clone())), vec![col("col1")], vec![col("col1")], vec![col("col2")], @@ -1808,7 +1808,7 @@ fn roundtrip_window() { ); let test_expr6 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::WindowUDF(Arc::new(dummy_window_udf.clone())), + WindowFunctionDefinition::WindowUDF(Arc::new(dummy_window_udf.clone())), vec![col("col1")], vec![col("col1")], vec![col("col2")], diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 3934d6701c63..395f10b6f783 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -23,8 +23,8 @@ use datafusion_expr::expr::ScalarFunction; use datafusion_expr::function::suggest_valid_function; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - expr, window_function, AggregateFunction, BuiltinScalarFunction, Expr, WindowFrame, - WindowFunction, + expr, AggregateFunction, BuiltinScalarFunction, Expr, WindowFrame, + WindowFunctionDefinition, }; use sqlparser::ast::{ Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, WindowType, @@ -121,12 +121,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if let Ok(fun) = self.find_window_func(&name) { let expr = match fun { - WindowFunction::AggregateFunction(aggregate_fun) => { + WindowFunctionDefinition::AggregateFunction(aggregate_fun) => { let args = self.function_args_to_expr(args, schema, planner_context)?; Expr::WindowFunction(expr::WindowFunction::new( - WindowFunction::AggregateFunction(aggregate_fun), + WindowFunctionDefinition::AggregateFunction(aggregate_fun), args, partition_by, order_by, @@ -191,19 +191,22 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))) } - pub(super) fn find_window_func(&self, name: &str) -> Result { - window_function::find_df_window_func(name) + pub(super) fn find_window_func( + &self, + name: &str, + ) -> Result { + expr::find_df_window_func(name) // next check user defined aggregates .or_else(|| { self.context_provider .get_aggregate_meta(name) - .map(WindowFunction::AggregateUDF) + .map(WindowFunctionDefinition::AggregateUDF) }) // next check user defined window functions .or_else(|| { self.context_provider .get_window_meta(name) - .map(WindowFunction::WindowUDF) + .map(WindowFunctionDefinition::WindowUDF) }) .ok_or_else(|| { plan_datafusion_err!("There is no window function named {name}") diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 9931dd15aec8..a4ec3e7722a2 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -23,8 +23,8 @@ use datafusion::common::{ use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{ - aggregate_function, window_function::find_df_window_func, BinaryExpr, - BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator, + aggregate_function, expr::find_df_window_func, BinaryExpr, BuiltinScalarFunction, + Case, Expr, LogicalPlan, Operator, }; use datafusion::logical_expr::{ expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning,