Skip to content

Commit

Permalink
Add unit test to catch errors in udwf with multiple column arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
timsaucer committed Oct 30, 2024
1 parent 3541c34 commit f8fa38c
Showing 1 changed file with 86 additions and 32 deletions.
118 changes: 86 additions & 32 deletions datafusion/core/tests/user_defined/user_defined_window_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,20 @@ use std::{

use arrow::array::AsArray;
use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray};
use arrow_schema::{DataType, Field};
use arrow_schema::{DataType, Field, Schema};
use datafusion::{assert_batches_eq, prelude::SessionContext};
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{
PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl,
PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDF, WindowUDFImpl,
};
use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
use datafusion_functions_window_common::{
expr::ExpressionArgs, field::WindowUDFFieldArgs,
};
use datafusion_physical_expr::expressions::lit;
use datafusion_physical_expr::{
expressions::{col, lit},
PhysicalExpr,
};

/// A query with a window function evaluated over the entire partition
const UNBOUNDED_WINDOW_QUERY: &str = "SELECT x, y, val, \
Expand Down Expand Up @@ -650,29 +653,33 @@ fn odd_count_arr(arr: &Int64Array, num_rows: usize) -> arrow_array::ArrayRef {
}

#[derive(Debug)]
struct ThreeArgWindowUDF {
struct VariadicWindowUDF {
signature: Signature,
}

impl ThreeArgWindowUDF {
impl VariadicWindowUDF {
fn new() -> Self {
Self {
signature: Signature::uniform(
3,
vec![DataType::Int32, DataType::Boolean, DataType::Float32],
signature: Signature::one_of(
vec![
TypeSignature::Any(0),
TypeSignature::Any(1),
TypeSignature::Any(2),
TypeSignature::Any(3),
],
Volatility::Immutable,
),
}
}
}

impl WindowUDFImpl for ThreeArgWindowUDF {
impl WindowUDFImpl for VariadicWindowUDF {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"three_arg_window_udf"
"variadic_window_udf"
}

fn signature(&self) -> &Signature {
Expand All @@ -683,36 +690,83 @@ impl WindowUDFImpl for ThreeArgWindowUDF {
&self,
_: PartitionEvaluatorArgs,
) -> Result<Box<dyn PartitionEvaluator>> {
todo!()
unimplemented!("unnecessary for testing");
}

fn field(&self, _: WindowUDFFieldArgs) -> Result<Field> {
todo!()
unimplemented!("unnecessary for testing");
}
}

#[test]
fn test_input_expressions() -> Result<()> {
let udwf = WindowUDF::from(ThreeArgWindowUDF::new());

let input_exprs = vec![lit(1), lit(false), lit(0.5)]; // Vec<Arc<dyn PhysicalExpr>>
let input_types = [DataType::Int32, DataType::Boolean, DataType::Float32]; // Vec<DataType>
let actual = udwf.expressions(ExpressionArgs::new(&input_exprs, &input_types));

assert_eq!(actual.len(), 3);
// Fixes: default implementation of `WindowUDFImpl::expressions`
// returns all input expressions to the user-defined window
// function unmodified.
//
// See: https://github.com/apache/datafusion/pull/13169
fn test_default_expressions() -> Result<()> {
let udwf = WindowUDF::from(VariadicWindowUDF::new());

let field_a = Field::new("a", DataType::Int32, false);
let field_b = Field::new("b", DataType::Float32, false);
let field_c = Field::new("c", DataType::Boolean, false);
let schema = Schema::new(vec![field_a, field_b, field_c]);

let test_cases = vec![
//
// Zero arguments
//
vec![],
//
// Single argument
//
vec![col("a", &schema)?],
vec![lit(1)],
//
// Two arguments
//
vec![col("a", &schema)?, col("b", &schema)?],
vec![col("a", &schema)?, lit(2)],
vec![lit(false), col("a", &schema)?],
//
// Three arguments
//
vec![col("a", &schema)?, col("b", &schema)?, col("c", &schema)?],
vec![col("a", &schema)?, col("b", &schema)?, lit(false)],
vec![col("a", &schema)?, lit(0.5), col("c", &schema)?],
vec![lit(3), col("b", &schema)?, col("c", &schema)?],
];

assert_eq!(
format!("{:?}", actual.first().unwrap()),
format!("{:?}", input_exprs.first().unwrap()),
);
assert_eq!(
format!("{:?}", actual.get(1).unwrap()),
format!("{:?}", input_exprs.get(1).unwrap())
);
assert_eq!(
format!("{:?}", actual.get(2).unwrap()),
format!("{:?}", input_exprs.get(2).unwrap())
);
for input_exprs in &test_cases {
let input_types = input_exprs
.iter()
.map(|expr: &std::sync::Arc<dyn PhysicalExpr>| {
expr.data_type(&schema).unwrap()
})
.collect::<Vec<_>>();
let expr_args = ExpressionArgs::new(input_exprs, &input_types);

let ret_exprs = udwf.expressions(expr_args);

// Verify same number of input expressions are returned
assert_eq!(
input_exprs.len(),
ret_exprs.len(),
"\nInput expressions: {:?}\nReturned expressions: {:?}",
input_exprs,
ret_exprs
);

// Compares each returned expression with original input expressions
for (expected, actual) in input_exprs.iter().zip(&ret_exprs) {
assert_eq!(
format!("{expected:?}"),
format!("{actual:?}"),
"\nInput expressions: {:?}\nReturned expressions: {:?}",
input_exprs,
ret_exprs
);
}
}
Ok(())
}

0 comments on commit f8fa38c

Please sign in to comment.