Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for AggregateExpr, WindowExpr rewrite. #10742

Merged
merged 4 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions datafusion/physical-expr-common/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,40 @@ pub trait AggregateExpr: Send + Sync + Debug + PartialEq<dyn Any> {
fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
not_impl_err!("Retractable Accumulator hasn't been implemented for {self:?} yet")
}

/// Returns all expressions used in the [`AggregateExpr`].
/// These expressions are (1)function arguments, (2) order by expressions.
fn all_expressions(&self) -> AggregatePhysicalExpressions {
let args = self.expressions();
let order_bys = self.order_bys().unwrap_or(&[]);
let order_by_exprs = order_bys
.iter()
.map(|sort_expr| sort_expr.expr.clone())
.collect::<Vec<_>>();
AggregatePhysicalExpressions {
args,
order_by_exprs,
}
}

/// Rewrites [`AggregateExpr`], with new expressions given. The argument should be consistent
/// with the return value of the [`AggregateExpr::all_expressions`] method.
/// Returns `Some(Arc<dyn AggregateExpr>)` if re-write is supported, otherwise returns `None`.
fn with_new_expressions(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only thing that might be worth considering is that this API forces cloning. However, since everything is Arcd that seems relatively minor. Maybe something to keep in mind for the future.

&self,
_args: Vec<Arc<dyn PhysicalExpr>>,
_order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
) -> Option<Arc<dyn AggregateExpr>> {
None
}
}

/// Stores the physical expressions used inside the `AggregateExpr`.
pub struct AggregatePhysicalExpressions {
/// Aggregate function arguments
pub args: Vec<Arc<dyn PhysicalExpr>>,
/// Order by expressions
pub order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
}

/// Physical aggregate expression of a UDAF.
Expand Down
15 changes: 15 additions & 0 deletions datafusion/physical-expr/src/aggregate/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,21 @@ impl AggregateExpr for Count {
// instantiate specialized accumulator
Ok(Box::new(CountGroupsAccumulator::new()))
}

fn with_new_expressions(
&self,
args: Vec<Arc<dyn PhysicalExpr>>,
order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
) -> Option<Arc<dyn AggregateExpr>> {
debug_assert_eq!(self.exprs.len(), args.len());
debug_assert!(order_by_exprs.is_empty());
Some(Arc::new(Count {
name: self.name.clone(),
data_type: self.data_type.clone(),
nullable: self.nullable,
exprs: args,
}))
}
}

impl PartialEq<dyn Any> for Count {
Expand Down
4 changes: 3 additions & 1 deletion datafusion/physical-expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ pub mod execution_props {

pub use aggregate::groups_accumulator::{GroupsAccumulatorAdapter, NullState};
pub use analysis::{analyze, AnalysisContext, ExprBoundaries};
pub use datafusion_physical_expr_common::aggregate::AggregateExpr;
pub use datafusion_physical_expr_common::aggregate::{
AggregateExpr, AggregatePhysicalExpressions,
};
pub use equivalence::EquivalenceProperties;
pub use partitioning::{Distribution, Partitioning};
pub use physical_expr::{
Expand Down
25 changes: 25 additions & 0 deletions datafusion/physical-expr/src/window/sliding_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,31 @@ impl WindowExpr for SlidingAggregateWindowExpr {
fn uses_bounded_memory(&self) -> bool {
!self.window_frame.end_bound.is_unbounded()
}

fn with_new_expressions(
&self,
args: Vec<Arc<dyn PhysicalExpr>>,
partition_bys: Vec<Arc<dyn PhysicalExpr>>,
order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
) -> Option<Arc<dyn WindowExpr>> {
debug_assert_eq!(self.order_by.len(), order_by_exprs.len());

let new_order_by = self
.order_by
.iter()
.zip(order_by_exprs)
.map(|(req, new_expr)| PhysicalSortExpr {
expr: new_expr,
options: req.options,
})
.collect::<Vec<_>>();
Some(Arc::new(SlidingAggregateWindowExpr {
aggregate: self.aggregate.with_new_expressions(args, vec![])?,
partition_by: partition_bys,
order_by: new_order_by,
window_frame: self.window_frame.clone(),
}))
}
}

impl AggregateWindowExpr for SlidingAggregateWindowExpr {
Expand Down
39 changes: 39 additions & 0 deletions datafusion/physical-expr/src/window/window_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,45 @@ pub trait WindowExpr: Send + Sync + Debug {

/// Get the reverse expression of this [WindowExpr].
fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>>;

/// Returns all expressions used in the [`WindowExpr`].
/// These expressions are (1) function arguments, (2) partition by expressions, (3) order by expressions.
fn all_expressions(&self) -> WindowPhysicalExpressions {
let args = self.expressions();
let partition_by_exprs = self.partition_by().to_vec();
let order_by_exprs = self
.order_by()
.iter()
.map(|sort_expr| sort_expr.expr.clone())
.collect::<Vec<_>>();
WindowPhysicalExpressions {
args,
partition_by_exprs,
order_by_exprs,
}
}

/// Rewrites [`WindowExpr`], with new expressions given. The argument should be consistent
/// with the return value of the [`WindowExpr::all_expressions`] method.
/// Returns `Some(Arc<dyn WindowExpr>)` if re-write is supported, otherwise returns `None`.
fn with_new_expressions(
&self,
_args: Vec<Arc<dyn PhysicalExpr>>,
_partition_bys: Vec<Arc<dyn PhysicalExpr>>,
_order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
) -> Option<Arc<dyn WindowExpr>> {
None
}
}

/// Stores the physical expressions used inside the `WindowExpr`.
pub struct WindowPhysicalExpressions {
/// Window function arguments
pub args: Vec<Arc<dyn PhysicalExpr>>,
/// PARTITION BY expressions
pub partition_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
/// ORDER BY expressions
pub order_by_exprs: Vec<Arc<dyn PhysicalExpr>>,
}

/// Extension trait that adds common functionality to [`AggregateWindowExpr`]s
Expand Down