Skip to content

Commit

Permalink
Temp: add LogicalPlan::transform_expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Apr 6, 2024
1 parent f0d4986 commit 0bccb70
Show file tree
Hide file tree
Showing 2 changed files with 233 additions and 8 deletions.
235 changes: 230 additions & 5 deletions datafusion/expr/src/logical_plan/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,18 @@

//! Methods for rewriting logical plans
use crate::tree_node::expr::transform_option_vec;
use crate::{
Aggregate, CrossJoin, Distinct, DistinctOn, EmptyRelation, Filter, Join, Limit,
LogicalPlan, Prepare, Projection, RecursiveQuery, Repartition, Sort, Subquery,
SubqueryAlias, Union, Unnest, UserDefinedLogicalNode, Window,
Aggregate, CrossJoin, Distinct, DistinctOn, EmptyRelation, Expr, Extension, Filter,
Join, Limit, LogicalPlan, Partitioning, Prepare, Projection, RecursiveQuery,
Repartition, Sort, Subquery, SubqueryAlias, TableScan, Union, Unnest,
UserDefinedLogicalNode, Values, Window,
};
use datafusion_common::tree_node::{Transformed, TreeNodeIterator, TreeNodeRecursion};
use datafusion_common::{
internal_err, map_until_stop_and_collect, DFSchema, DFSchemaRef, DataFusionError,
Result,
};
use datafusion_common::tree_node::{Transformed, TreeNodeIterator};
use datafusion_common::{DFSchema, DFSchemaRef, Result};
use std::sync::{Arc, OnceLock};

/// A temporary node that is left in place while rewriting the children of a
Expand Down Expand Up @@ -226,3 +231,223 @@ impl LogicalPlan {
Ok(children_result)
}
}

impl LogicalPlan {
/// Transforms the expressions in the logical plan using the provided closure.
/// from <https://github.com/apache/arrow-datafusion/pull/9913/files>
/// TODO: use when available
pub fn transform_expressions<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
self,
mut f: F,
) -> Result<Transformed<Self>> {
Ok(match self {
LogicalPlan::Projection(Projection {
expr,
input,
schema,
}) => expr
.into_iter()
.map_until_stop_and_collect(f)?
.update_data(|expr| {
LogicalPlan::Projection(Projection {
expr,
input,
schema,
})
}),
LogicalPlan::Values(Values { schema, values }) => values
.into_iter()
.map_until_stop_and_collect(|value| {
value.into_iter().map_until_stop_and_collect(&mut f)
})?
.update_data(|values| LogicalPlan::Values(Values { schema, values })),
LogicalPlan::Filter(Filter { predicate, input }) => f(predicate)?
.update_data(|predicate| {
LogicalPlan::Filter(Filter { predicate, input })
}),
LogicalPlan::Repartition(Repartition {
input,
partitioning_scheme,
}) => match partitioning_scheme {
Partitioning::Hash(expr, usize) => expr
.into_iter()
.map_until_stop_and_collect(f)?
.update_data(|expr| Partitioning::Hash(expr, usize)),
Partitioning::DistributeBy(expr) => expr
.into_iter()
.map_until_stop_and_collect(f)?
.update_data(Partitioning::DistributeBy),
Partitioning::RoundRobinBatch(_) => Transformed::no(partitioning_scheme),
}
.update_data(|partitioning_scheme| {
LogicalPlan::Repartition(Repartition {
input,
partitioning_scheme,
})
}),
LogicalPlan::Window(Window {
input,
window_expr,
schema,
}) => window_expr
.into_iter()
.map_until_stop_and_collect(f)?
.update_data(|window_expr| {
LogicalPlan::Window(Window {
input,
window_expr,
schema,
})
}),
LogicalPlan::Aggregate(Aggregate {
input,
group_expr,
aggr_expr,
schema,
}) => map_until_stop_and_collect!(
group_expr.into_iter().map_until_stop_and_collect(&mut f),
aggr_expr,
aggr_expr.into_iter().map_until_stop_and_collect(&mut f)
)?
.update_data(|(group_expr, aggr_expr)| {
LogicalPlan::Aggregate(Aggregate {
input,
group_expr,
aggr_expr,
schema,
})
}),

// There are two part of expression for join, equijoin(on) and non-equijoin(filter).
// 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`.
// 2. the second part is non-equijoin(filter).
LogicalPlan::Join(Join {
left,
right,
on,
filter,
join_type,
join_constraint,
schema,
null_equals_null,
}) => map_until_stop_and_collect!(
on.into_iter().map_until_stop_and_collect(
|on| map_until_stop_and_collect!(f(on.0), on.1, f(on.1))
),
filter,
filter.map_or(Ok::<_, DataFusionError>(Transformed::no(None)), |e| {
Ok(f(e)?.update_data(Some))
})
)?
.update_data(|(on, filter)| {
LogicalPlan::Join(Join {
left,
right,
on,
filter,
join_type,
join_constraint,
schema,
null_equals_null,
})
}),
LogicalPlan::Sort(Sort { expr, input, fetch }) => expr
.into_iter()
.map_until_stop_and_collect(f)?
.update_data(|expr| LogicalPlan::Sort(Sort { expr, input, fetch })),
LogicalPlan::Extension(Extension { node }) => {
// would be nice to avoid this copy -- maybe can
// update extension to just observer Exprs
node.expressions()
.into_iter()
.map_until_stop_and_collect(f)?
.update_data(|exprs| {
LogicalPlan::Extension(Extension {
node: UserDefinedLogicalNode::from_template(
node.as_ref(),
exprs.as_slice(),
node.inputs()
.into_iter()
.cloned()
.collect::<Vec<_>>()
.as_slice(),
),
})
})
}
LogicalPlan::TableScan(TableScan {
table_name,
source,
projection,
projected_schema,
filters,
fetch,
}) => filters
.into_iter()
.map_until_stop_and_collect(f)?
.update_data(|filters| {
LogicalPlan::TableScan(TableScan {
table_name,
source,
projection,
projected_schema,
filters,
fetch,
})
}),
LogicalPlan::Unnest(Unnest {
input,
column,
schema,
options,
}) => f(Expr::Column(column))?.map_data(|column| match column {
Expr::Column(column) => Ok(LogicalPlan::Unnest(Unnest {
input,
column,
schema,
options,
})),
_ => internal_err!("Transformation should return Column"),
})?,
LogicalPlan::Distinct(Distinct::On(DistinctOn {
on_expr,
select_expr,
sort_expr,
input,
schema,
})) => map_until_stop_and_collect!(
on_expr.into_iter().map_until_stop_and_collect(&mut f),
select_expr,
select_expr.into_iter().map_until_stop_and_collect(&mut f),
sort_expr,
transform_option_vec(sort_expr, &mut f)
)?
.update_data(|(on_expr, select_expr, sort_expr)| {
LogicalPlan::Distinct(Distinct::On(DistinctOn {
on_expr,
select_expr,
sort_expr,
input,
schema,
}))
}),
// plans without expressions
LogicalPlan::EmptyRelation(_)
| LogicalPlan::RecursiveQuery(_)
| LogicalPlan::Subquery(_)
| LogicalPlan::SubqueryAlias(_)
| LogicalPlan::Limit(_)
| LogicalPlan::Statement(_)
| LogicalPlan::CrossJoin(_)
| LogicalPlan::Analyze(_)
| LogicalPlan::Explain(_)
| LogicalPlan::Union(_)
| LogicalPlan::Distinct(Distinct::All(_))
| LogicalPlan::Dml(_)
| LogicalPlan::Ddl(_)
| LogicalPlan::Copy(_)
| LogicalPlan::DescribeTable(_)
| LogicalPlan::Prepare(_) => Transformed::no(self),
})
}
}
6 changes: 3 additions & 3 deletions datafusion/expr/src/tree_node/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -392,14 +392,14 @@ impl TreeNode for Expr {
}
}

fn transform_box<F>(be: Box<Expr>, f: &mut F) -> Result<Transformed<Box<Expr>>>
pub fn transform_box<F>(be: Box<Expr>, f: &mut F) -> Result<Transformed<Box<Expr>>>
where
F: FnMut(Expr) -> Result<Transformed<Expr>>,
{
Ok(f(*be)?.update_data(Box::new))
}

fn transform_option_box<F>(
pub fn transform_option_box<F>(
obe: Option<Box<Expr>>,
f: &mut F,
) -> Result<Transformed<Option<Box<Expr>>>>
Expand All @@ -412,7 +412,7 @@ where
}

/// &mut transform a Option<`Vec` of `Expr`s>
fn transform_option_vec<F>(
pub fn transform_option_vec<F>(
ove: Option<Vec<Expr>>,
f: &mut F,
) -> Result<Transformed<Option<Vec<Expr>>>>
Expand Down

0 comments on commit 0bccb70

Please sign in to comment.