diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index 3760328934bc..20dbc69abce9 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -35,8 +35,11 @@ use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl, }; -use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use datafusion_functions_window_common::{ + expr::ExpressionArgs, field::WindowUDFFieldArgs, +}; +use datafusion_physical_expr::expressions::lit; /// A query with a window function evaluated over the entire partition const UNBOUNDED_WINDOW_QUERY: &str = "SELECT x, y, val, \ @@ -645,3 +648,71 @@ fn odd_count_arr(arr: &Int64Array, num_rows: usize) -> arrow_array::ArrayRef { let array: Int64Array = std::iter::repeat(odd_count(arr)).take(num_rows).collect(); Arc::new(array) } + +#[derive(Debug)] +struct ThreeArgWindowUDF { + signature: Signature, +} + +impl ThreeArgWindowUDF { + fn new() -> Self { + Self { + signature: Signature::uniform( + 3, + vec![DataType::Int32, DataType::Boolean, DataType::Float32], + Volatility::Immutable, + ), + } + } +} + +impl WindowUDFImpl for ThreeArgWindowUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "three_arg_window_udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn partition_evaluator( + &self, + _: PartitionEvaluatorArgs, + ) -> Result> { + todo!() + } + + fn field(&self, _: WindowUDFFieldArgs) -> Result { + todo!() + } +} + +#[test] +fn test_input_expressions() -> Result<()> { + let udwf = WindowUDF::from(ThreeArgWindowUDF::new()); + + let input_exprs = vec![lit(1), lit(false), lit(0.5)]; // Vec> + let input_types = [DataType::Int32, DataType::Boolean, DataType::Float32]; // Vec + let actual = udwf.expressions(ExpressionArgs::new(&input_exprs, &input_types)); + + assert_eq!(actual.len(), 3); + + assert_eq!( + format!("{:?}", actual.first().unwrap()), + format!("{:?}", input_exprs.first().unwrap()), + ); + assert_eq!( + format!("{:?}", actual.get(1).unwrap()), + format!("{:?}", input_exprs.get(1).unwrap()) + ); + assert_eq!( + format!("{:?}", actual.get(2).unwrap()), + format!("{:?}", input_exprs.get(2).unwrap()) + ); + + Ok(()) +}