diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index 20dbc69abce9..83368f5921b0 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -29,17 +29,20 @@ use std::{ use arrow::array::AsArray; use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray}; -use arrow_schema::{DataType, Field}; +use arrow_schema::{DataType, Field, Schema}; use datafusion::{assert_batches_eq, prelude::SessionContext}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ - PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl, + PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDF, WindowUDFImpl, }; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use datafusion_functions_window_common::{ expr::ExpressionArgs, field::WindowUDFFieldArgs, }; -use datafusion_physical_expr::expressions::lit; +use datafusion_physical_expr::{ + expressions::{col, lit}, + PhysicalExpr, +}; /// A query with a window function evaluated over the entire partition const UNBOUNDED_WINDOW_QUERY: &str = "SELECT x, y, val, \ @@ -650,29 +653,33 @@ fn odd_count_arr(arr: &Int64Array, num_rows: usize) -> arrow_array::ArrayRef { } #[derive(Debug)] -struct ThreeArgWindowUDF { +struct VariadicWindowUDF { signature: Signature, } -impl ThreeArgWindowUDF { +impl VariadicWindowUDF { fn new() -> Self { Self { - signature: Signature::uniform( - 3, - vec![DataType::Int32, DataType::Boolean, DataType::Float32], + signature: Signature::one_of( + vec![ + TypeSignature::Any(0), + TypeSignature::Any(1), + TypeSignature::Any(2), + TypeSignature::Any(3), + ], Volatility::Immutable, ), } } } -impl WindowUDFImpl for ThreeArgWindowUDF { +impl WindowUDFImpl for VariadicWindowUDF { fn as_any(&self) -> &dyn Any { self } fn name(&self) -> &str { - "three_arg_window_udf" + "variadic_window_udf" } fn signature(&self) -> &Signature { @@ -683,36 +690,83 @@ impl WindowUDFImpl for ThreeArgWindowUDF { &self, _: PartitionEvaluatorArgs, ) -> Result<Box<dyn PartitionEvaluator>> { - todo!() + unimplemented!("unnecessary for testing"); } fn field(&self, _: WindowUDFFieldArgs) -> Result<Field> { - todo!() + unimplemented!("unnecessary for testing"); } } #[test] -fn test_input_expressions() -> Result<()> { - let udwf = WindowUDF::from(ThreeArgWindowUDF::new()); - - let input_exprs = vec![lit(1), lit(false), lit(0.5)]; // Vec<Arc<dyn PhysicalExpr>> - let input_types = [DataType::Int32, DataType::Boolean, DataType::Float32]; // Vec<DataType> - let actual = udwf.expressions(ExpressionArgs::new(&input_exprs, &input_types)); - - assert_eq!(actual.len(), 3); +// Fixes: default implementation of `WindowUDFImpl::expressions` +// returns all input expressions to the user-defined window +// function unmodified. +// +// See: https://github.com/apache/datafusion/pull/13169 +fn test_default_expressions() -> Result<()> { + let udwf = WindowUDF::from(VariadicWindowUDF::new()); + + let field_a = Field::new("a", DataType::Int32, false); + let field_b = Field::new("b", DataType::Float32, false); + let field_c = Field::new("c", DataType::Boolean, false); + let schema = Schema::new(vec![field_a, field_b, field_c]); + + let test_cases = vec![ + // + // Zero arguments + // + vec![], + // + // Single argument + // + vec![col("a", &schema)?], + vec![lit(1)], + // + // Two arguments + // + vec![col("a", &schema)?, col("b", &schema)?], + vec![col("a", &schema)?, lit(2)], + vec![lit(false), col("a", &schema)?], + // + // Three arguments + // + vec![col("a", &schema)?, col("b", &schema)?, col("c", &schema)?], + vec![col("a", &schema)?, col("b", &schema)?, lit(false)], + vec![col("a", &schema)?, lit(0.5), col("c", &schema)?], + vec![lit(3), col("b", &schema)?, col("c", &schema)?], + ]; - assert_eq!( - format!("{:?}", actual.first().unwrap()), - format!("{:?}", input_exprs.first().unwrap()), - ); - assert_eq!( - format!("{:?}", actual.get(1).unwrap()), - format!("{:?}", input_exprs.get(1).unwrap()) - ); - assert_eq!( - format!("{:?}", actual.get(2).unwrap()), - format!("{:?}", input_exprs.get(2).unwrap()) - ); + for input_exprs in &test_cases { + let input_types = input_exprs + .iter() + .map(|expr: &std::sync::Arc<dyn PhysicalExpr>| { + expr.data_type(&schema).unwrap() + }) + .collect::<Vec<_>>(); + let expr_args = ExpressionArgs::new(input_exprs, &input_types); + + let ret_exprs = udwf.expressions(expr_args); + + // Verify same number of input expressions are returned + assert_eq!( + input_exprs.len(), + ret_exprs.len(), + "\nInput expressions: {:?}\nReturned expressions: {:?}", + input_exprs, + ret_exprs + ); + // Compares each returned expression with original input expressions + for (expected, actual) in input_exprs.iter().zip(&ret_exprs) { + assert_eq!( + format!("{expected:?}"), + format!("{actual:?}"), + "\nInput expressions: {:?}\nReturned expressions: {:?}", + input_exprs, + ret_exprs + ); + } + } Ok(()) }