Skip to content

Commit

Permalink
update find_window_fn to search built-ins first
Browse files Browse the repository at this point in the history
The behavior of `first_value` and `last_value` UDAFs currently does not match the built-in behavior.
This allowed me to remove `marks=pytest.xfail` from the window tests.
  • Loading branch information
Michael-J-Ward committed Jul 25, 2024
1 parent b6eee28 commit 39893fc
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
2 changes: 0 additions & 2 deletions python/datafusion/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,13 +295,11 @@ def test_distinct():
order_by=[f.order_by(column("b"))]
),
[1, 1, 1],
marks=pytest.mark.xfail,
),
pytest.param(
"last_value",
f.window("last_value", [column("b")], order_by=[f.order_by(column("b"))]),
[4, 5, 6],
marks=pytest.mark.xfail,
),
pytest.param(
"2nd_value",
Expand Down
22 changes: 12 additions & 10 deletions src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -613,13 +613,21 @@ fn case(expr: PyExpr) -> PyResult<PyCaseBuilder> {
/// Helper function to find the appropriate window function.
///
/// Search procedure:
/// 1) Search built in window functions, which are being deprecated.
/// 1) If a session context is provided:
/// 1) search User Defined Aggregate Functions (UDAFs)
/// 2) search registered window functions
/// 3) search registered aggregate functions
/// 2) If no function has been found, search default aggregate functions.
/// 3) Lastly, as a fall back attempt, search built in window functions, which are being deprecated.
/// 1) search registered window functions
/// 1) search registered aggregate functions
/// 1) If no function has been found, search default aggregate functions.
///
/// NOTE: we search the built-ins first because the `UDAF` versions currently do not have the same behavior.
fn find_window_fn(name: &str, ctx: Option<PySessionContext>) -> PyResult<WindowFunctionDefinition> {
// search built in window functions (soon to be deprecated)
let df_window_func = find_df_window_func(name);
if let Some(df_window_func) = df_window_func {
return Ok(df_window_func);
}

if let Some(ctx) = ctx {
// search UDAFs
let udaf = ctx
Expand Down Expand Up @@ -665,12 +673,6 @@ fn find_window_fn(name: &str, ctx: Option<PySessionContext>) -> PyResult<WindowF
return Ok(agg_fn);
}

// search built in window functions (soon to be deprecated)
let df_window_func = find_df_window_func(name);
if let Some(df_window_func) = df_window_func {
return Ok(df_window_func);
}

Err(DataFusionError::Common(format!("window function `{name}` not found")).into())
}

Expand Down

0 comments on commit 39893fc

Please sign in to comment.