Skip to content

Commit

Permalink
fix unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiayu Liu committed Jun 10, 2021
1 parent 4b29444 commit 6ff2361
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 19 deletions.
52 changes: 45 additions & 7 deletions datafusion/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use super::{
functions, hash_join::PartitionMode, udaf, union::UnionExec, windows,
};
use crate::execution::context::ExecutionContextState;
use crate::logical_plan::window_frames::WindowFrame;
use crate::logical_plan::{
DFSchema, Expr, LogicalPlan, Operator, Partitioning as LogicalPartitioning, PlanType,
StringifiedPlan, UserDefinedLogicalNode,
Expand Down Expand Up @@ -740,19 +741,56 @@ impl DefaultPhysicalPlanner {
ctx_state: &ExecutionContextState,
) -> Result<Arc<dyn WindowExpr>> {
match e {
Expr::WindowFunction { fun, args, .. } => {
Expr::WindowFunction {
fun,
args,
partition_by,
order_by,
window_frame,
} => {
let args = args
.iter()
.map(|e| {
self.create_physical_expr(e, physical_input_schema, ctx_state)
})
.collect::<Result<Vec<_>>>()?;
// if !order_by.is_empty() {
// return Err(DataFusionError::NotImplemented(
// "Window function with order by is not yet implemented".to_owned(),
// ));
// }
windows::create_window_expr(fun, &args, physical_input_schema, name)
let partition_by = partition_by
.iter()
.map(|e| {
self.create_physical_expr(e, physical_input_schema, ctx_state)
})
.collect::<Result<Vec<_>>>()?;
let order_by = order_by
.iter()
.map(|e| match e {
Expr::Sort {
expr,
asc,
nulls_first,
} => self.create_physical_sort_expr(
expr,
&physical_input_schema,
SortOptions {
descending: !*asc,
nulls_first: *nulls_first,
},
&ctx_state,
),
_ => Err(DataFusionError::Plan(
"Sort only accepts sort expressions".to_string(),
)),
})
.collect::<Result<Vec<_>>>()?;
let window_frame = window_frame.unwrap_or_else(WindowFrame::default);
windows::create_window_expr(
fun,
name,
&args,
&partition_by,
&order_by,
window_frame,
physical_input_schema,
)
}
other => Err(DataFusionError::Internal(format!(
"Invalid window expression '{:?}'",
Expand Down
41 changes: 29 additions & 12 deletions datafusion/src/physical_plan/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
//! Execution plan for window functions
use crate::error::{DataFusionError, Result};

use crate::logical_plan::window_frames::WindowFrame;
use crate::physical_plan::{
aggregates, common,
expressions::{Literal, NthValue, RowNumber},
expressions::{Literal, NthValue, PhysicalSortExpr, RowNumber},
type_coercion::coerce,
window_functions::signature_for_built_in,
window_functions::BuiltInWindowFunctionExpr,
Expand Down Expand Up @@ -61,24 +63,27 @@ pub struct WindowAggExec {
/// Create a physical expression for window function
pub fn create_window_expr(
fun: &WindowFunction,
name: String,
args: &[Arc<dyn PhysicalExpr>],
_partition_by: &[Arc<dyn PhysicalExpr>],
_order_by: &[PhysicalSortExpr],
_window_frame: WindowFrame,
input_schema: &Schema,
name: String,
) -> Result<Arc<dyn WindowExpr>> {
match fun {
WindowFunction::AggregateFunction(fun) => Ok(Arc::new(AggregateWindowExpr {
Ok(match fun {
WindowFunction::AggregateFunction(fun) => Arc::new(AggregateWindowExpr {
aggregate: aggregates::create_aggregate_expr(
fun,
false,
args,
input_schema,
name,
)?,
})),
WindowFunction::BuiltInWindowFunction(fun) => Ok(Arc::new(BuiltInWindowExpr {
}),
WindowFunction::BuiltInWindowFunction(fun) => Arc::new(BuiltInWindowExpr {
window: create_built_in_window_expr(fun, args, input_schema, name)?,
})),
}
}),
})
}

fn create_built_in_window_expr(
Expand Down Expand Up @@ -538,9 +543,12 @@ mod tests {
let window_exec = Arc::new(WindowAggExec::try_new(
vec![create_window_expr(
&WindowFunction::AggregateFunction(AggregateFunction::Count),
"count".to_owned(),
&[col("c3")],
&[],
&[],
WindowFrame::default(),
schema.as_ref(),
"count".to_owned(),
)?],
input,
schema.clone(),
Expand Down Expand Up @@ -568,21 +576,30 @@ mod tests {
vec![
create_window_expr(
&WindowFunction::AggregateFunction(AggregateFunction::Count),
"count".to_owned(),
&[col("c3")],
&[],
&[],
WindowFrame::default(),
schema.as_ref(),
"count".to_owned(),
)?,
create_window_expr(
&WindowFunction::AggregateFunction(AggregateFunction::Max),
"max".to_owned(),
&[col("c3")],
&[],
&[],
WindowFrame::default(),
schema.as_ref(),
"max".to_owned(),
)?,
create_window_expr(
&WindowFunction::AggregateFunction(AggregateFunction::Min),
"min".to_owned(),
&[col("c3")],
&[],
&[],
WindowFrame::default(),
schema.as_ref(),
"min".to_owned(),
)?,
],
input,
Expand Down

0 comments on commit 6ff2361

Please sign in to comment.