Skip to content

Commit

Permalink
Add unit test to check for window function inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
timsaucer committed Oct 30, 2024
1 parent b085024 commit 3541c34
Showing 1 changed file with 72 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,11 @@ use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{
PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl,
};
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
use datafusion_functions_window_common::{
expr::ExpressionArgs, field::WindowUDFFieldArgs,
};
use datafusion_physical_expr::expressions::lit;

/// A query with a window function evaluated over the entire partition
const UNBOUNDED_WINDOW_QUERY: &str = "SELECT x, y, val, \
Expand Down Expand Up @@ -645,3 +648,71 @@ fn odd_count_arr(arr: &Int64Array, num_rows: usize) -> arrow_array::ArrayRef {
let array: Int64Array = std::iter::repeat(odd_count(arr)).take(num_rows).collect();
Arc::new(array)
}

#[derive(Debug)]
struct ThreeArgWindowUDF {
signature: Signature,
}

impl ThreeArgWindowUDF {
fn new() -> Self {
Self {
signature: Signature::uniform(
3,
vec![DataType::Int32, DataType::Boolean, DataType::Float32],
Volatility::Immutable,
),
}
}
}

impl WindowUDFImpl for ThreeArgWindowUDF {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"three_arg_window_udf"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn partition_evaluator(
&self,
_: PartitionEvaluatorArgs,
) -> Result<Box<dyn PartitionEvaluator>> {
todo!()
}

fn field(&self, _: WindowUDFFieldArgs) -> Result<Field> {
todo!()
}
}

#[test]
fn test_input_expressions() -> Result<()> {
let udwf = WindowUDF::from(ThreeArgWindowUDF::new());

let input_exprs = vec![lit(1), lit(false), lit(0.5)]; // Vec<Arc<dyn PhysicalExpr>>
let input_types = [DataType::Int32, DataType::Boolean, DataType::Float32]; // Vec<DataType>
let actual = udwf.expressions(ExpressionArgs::new(&input_exprs, &input_types));

assert_eq!(actual.len(), 3);

assert_eq!(
format!("{:?}", actual.first().unwrap()),
format!("{:?}", input_exprs.first().unwrap()),
);
assert_eq!(
format!("{:?}", actual.get(1).unwrap()),
format!("{:?}", input_exprs.get(1).unwrap())
);
assert_eq!(
format!("{:?}", actual.get(2).unwrap()),
format!("{:?}", input_exprs.get(2).unwrap())
);

Ok(())
}

0 comments on commit 3541c34

Please sign in to comment.