Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert nth_value builtIn function to User Defined Window Function #13201

Merged
merged 23 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1946,12 +1946,12 @@ mod tests {
use datafusion_common_runtime::SpawnedTask;
use datafusion_expr::expr::WindowFunction;
use datafusion_expr::{
cast, create_udf, lit, BuiltInWindowFunction, ExprFunctionExt,
ScalarFunctionImplementation, Volatility, WindowFrame, WindowFrameBound,
WindowFrameUnits, WindowFunctionDefinition,
cast, create_udf, lit, ExprFunctionExt, ScalarFunctionImplementation, Volatility,
WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::expr_fn::{array_agg, count_distinct};
use datafusion_functions_window::expr_fn::row_number;
use datafusion_functions_window::nth_value::first_value_udwf;
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties};
use sqlparser::ast::NullTreatment;
Expand Down Expand Up @@ -2177,9 +2177,7 @@ mod tests {
// build plan using Table API
let t = test_table().await?;
let first_row = Expr::WindowFunction(WindowFunction::new(
WindowFunctionDefinition::BuiltInWindowFunction(
BuiltInWindowFunction::FirstValue,
),
buraksenn marked this conversation as resolved.
Show resolved Hide resolved
WindowFunctionDefinition::WindowUDF(first_value_udwf()),
vec![col("aggregate_test_100.c1")],
))
.partition_by(vec![col("aggregate_test_100.c2")])
Expand Down
18 changes: 7 additions & 11 deletions datafusion/core/tests/fuzz_cases/window_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ use datafusion_common::{Result, ScalarValue};
use datafusion_common_runtime::SpawnedTask;
use datafusion_expr::type_coercion::functions::data_types_with_aggregate_udf;
use datafusion_expr::{
BuiltInWindowFunction, WindowFrame, WindowFrameBound, WindowFrameUnits,
WindowFunctionDefinition,
WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_functions_aggregate::min_max::{max_udaf, min_udaf};
Expand All @@ -47,6 +46,9 @@ use test_utils::add_empty_batches;
use datafusion::functions_window::row_number::row_number_udwf;
use datafusion_common::HashMap;
use datafusion_functions_window::lead_lag::{lag_udwf, lead_udwf};
use datafusion_functions_window::nth_value::{
first_value_udwf, last_value_udwf, nth_value_udwf,
};
use datafusion_functions_window::rank::{dense_rank_udwf, rank_udwf};
use datafusion_physical_expr_common::sort_expr::LexOrdering;
use rand::distributions::Alphanumeric;
Expand Down Expand Up @@ -418,27 +420,21 @@ fn get_random_function(
window_fn_map.insert(
"first_value",
(
WindowFunctionDefinition::BuiltInWindowFunction(
BuiltInWindowFunction::FirstValue,
),
WindowFunctionDefinition::WindowUDF(first_value_udwf()),
vec![arg.clone()],
),
);
window_fn_map.insert(
"last_value",
(
WindowFunctionDefinition::BuiltInWindowFunction(
BuiltInWindowFunction::LastValue,
),
WindowFunctionDefinition::WindowUDF(last_value_udwf()),
vec![arg.clone()],
),
);
window_fn_map.insert(
"nth_value",
(
WindowFunctionDefinition::BuiltInWindowFunction(
BuiltInWindowFunction::NthValue,
),
WindowFunctionDefinition::WindowUDF(nth_value_udwf()),
vec![
arg.clone(),
lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))),
Expand Down
89 changes: 0 additions & 89 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ use std::collections::HashSet;
use std::fmt::{self, Display, Formatter, Write};
use std::hash::{Hash, Hasher};
use std::mem;
use std::str::FromStr;
use std::sync::Arc;

use crate::expr_fn::binary_expr;
Expand Down Expand Up @@ -832,23 +831,6 @@ impl WindowFunction {
}
}

/// Find DataFusion's built-in window function by name.
pub fn find_df_window_func(name: &str) -> Option<WindowFunctionDefinition> {
let name = name.to_lowercase();
// Code paths for window functions leveraging ordinary aggregators and
// built-in window functions are quite different, and the same function
// may have different implementations for these cases. If the sought
// function is not found among built-in window functions, we search for
// it among aggregate functions.
if let Ok(built_in_function) = BuiltInWindowFunction::from_str(name.as_str()) {
Some(WindowFunctionDefinition::BuiltInWindowFunction(
built_in_function,
))
} else {
None
}
}

/// EXISTS expression
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub struct Exists {
Expand Down Expand Up @@ -2540,77 +2522,6 @@ mod test {

use super::*;

#[test]
fn test_first_value_return_type() -> Result<()> {
let fun = find_df_window_func("first_value").unwrap();
let observed = fun.return_type(&[DataType::Utf8], &[true], "")?;
assert_eq!(DataType::Utf8, observed);

let observed = fun.return_type(&[DataType::UInt64], &[true], "")?;
assert_eq!(DataType::UInt64, observed);

Ok(())
}

#[test]
fn test_last_value_return_type() -> Result<()> {
let fun = find_df_window_func("last_value").unwrap();
let observed = fun.return_type(&[DataType::Utf8], &[true], "")?;
assert_eq!(DataType::Utf8, observed);

let observed = fun.return_type(&[DataType::Float64], &[true], "")?;
assert_eq!(DataType::Float64, observed);

Ok(())
}

#[test]
fn test_nth_value_return_type() -> Result<()> {
let fun = find_df_window_func("nth_value").unwrap();
let observed =
fun.return_type(&[DataType::Utf8, DataType::UInt64], &[true, true], "")?;
assert_eq!(DataType::Utf8, observed);

let observed =
fun.return_type(&[DataType::Float64, DataType::UInt64], &[true, true], "")?;
assert_eq!(DataType::Float64, observed);

Ok(())
}

#[test]
fn test_window_function_case_insensitive() -> Result<()> {
let names = vec!["first_value", "last_value", "nth_value"];
for name in names {
let fun = find_df_window_func(name).unwrap();
let fun2 = find_df_window_func(name.to_uppercase().as_str()).unwrap();
assert_eq!(fun, fun2);
if fun.to_string() == "first_value" || fun.to_string() == "last_value" {
assert_eq!(fun.to_string(), name);
} else {
assert_eq!(fun.to_string(), name.to_uppercase());
}
}
Ok(())
}

#[test]
fn test_find_df_window_function() {
assert_eq!(
find_df_window_func("first_value"),
Some(WindowFunctionDefinition::BuiltInWindowFunction(
BuiltInWindowFunction::FirstValue
))
);
assert_eq!(
find_df_window_func("LAST_value"),
Some(WindowFunctionDefinition::BuiltInWindowFunction(
BuiltInWindowFunction::LastValue
))
);
assert_eq!(find_df_window_func("not_exist"), None)
}

#[test]
fn test_display_wildcard() {
assert_eq!(format!("{}", wildcard()), "*");
Expand Down
1 change: 0 additions & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ pub mod type_coercion;
pub mod utils;
pub mod var_provider;
pub mod window_frame;
pub mod window_function;
pub mod window_state;

pub use built_in_window_function::BuiltInWindowFunction;
Expand Down
26 changes: 0 additions & 26 deletions datafusion/expr/src/window_function.rs

This file was deleted.

6 changes: 6 additions & 0 deletions datafusion/functions-window/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
//!
//! [DataFusion]: https://crates.io/crates/datafusion
//!

use std::sync::Arc;

use log::debug;
Expand All @@ -34,6 +35,7 @@ pub mod macros;

pub mod cume_dist;
pub mod lead_lag;
pub mod nth_value;
pub mod ntile;
pub mod rank;
pub mod row_number;
Expand All @@ -44,6 +46,7 @@ pub mod expr_fn {
pub use super::cume_dist::cume_dist;
pub use super::lead_lag::lag;
pub use super::lead_lag::lead;
pub use super::nth_value::{first_value, last_value, nth_value};
pub use super::ntile::ntile;
pub use super::rank::{dense_rank, percent_rank, rank};
pub use super::row_number::row_number;
Expand All @@ -60,6 +63,9 @@ pub fn all_default_window_functions() -> Vec<Arc<WindowUDF>> {
rank::dense_rank_udwf(),
rank::percent_rank_udwf(),
ntile::ntile_udwf(),
nth_value::first_value_udwf(),
nth_value::last_value_udwf(),
nth_value::nth_value_udwf(),
]
}
/// Registers all enabled packages with a [`FunctionRegistry`]
Expand Down
Loading
Loading