Skip to content

Commit

Permalink
refactor TreeNode::rewrite()
Browse files Browse the repository at this point in the history
  • Loading branch information
peter-toth committed Jan 28, 2024
1 parent ff7dfc3 commit 729c9d2
Show file tree
Hide file tree
Showing 37 changed files with 355 additions and 351 deletions.
6 changes: 3 additions & 3 deletions datafusion-examples/examples/rewrite_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl AnalyzerRule for MyAnalyzerRule {

impl MyAnalyzerRule {
fn analyze_plan(plan: LogicalPlan) -> Result<LogicalPlan> {
plan.transform(&|plan| {
plan.transform_up(&|plan| {
Ok(match plan {
LogicalPlan::Filter(filter) => {
let predicate = Self::analyze_expr(filter.predicate.clone())?;
Expand All @@ -106,7 +106,7 @@ impl MyAnalyzerRule {
}

fn analyze_expr(expr: Expr) -> Result<Expr> {
expr.transform(&|expr| {
expr.transform_up(&|expr| {
// closure is invoked for all sub expressions
Ok(match expr {
Expr::Literal(ScalarValue::Int64(i)) => {
Expand Down Expand Up @@ -161,7 +161,7 @@ impl OptimizerRule for MyOptimizerRule {

/// use rewrite_expr to modify the expression tree.
fn my_rewrite(expr: Expr) -> Result<Expr> {
expr.transform(&|expr| {
expr.transform_up(&|expr| {
// closure is invoked for all sub expressions
Ok(match expr {
Expr::Between(Between {
Expand Down
192 changes: 104 additions & 88 deletions datafusion/common/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ use crate::Result;
macro_rules! handle_tree_recursion {
($EXPR:expr) => {
match $EXPR {
VisitRecursion::Continue => {}
TreeNodeRecursion::Continue => {}
// If the recursion should skip, do not apply to its children, let
// the recursion continue:
VisitRecursion::Skip => return Ok(VisitRecursion::Continue),
TreeNodeRecursion::Skip => return Ok(TreeNodeRecursion::Continue),
// If the recursion should stop, do not apply to its children:
VisitRecursion::Stop => return Ok(VisitRecursion::Stop),
TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop),
}
};
}
Expand All @@ -58,10 +58,10 @@ pub trait TreeNode: Sized {
///
/// The `op` closure can be used to collect some info from the
/// tree node or do some checking for the tree node.
fn apply<F: FnMut(&Self) -> Result<VisitRecursion>>(
fn apply<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
op: &mut F,
) -> Result<VisitRecursion> {
) -> Result<TreeNodeRecursion> {
handle_tree_recursion!(op(self)?);
self.apply_children(&mut |node| node.apply(op))
}
Expand All @@ -88,7 +88,7 @@ pub trait TreeNode: Sized {
///
/// If an Err result is returned, recursion is stopped immediately
///
/// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no
/// If [`TreeNodeRecursion::Stop`] is returned on a call to pre_visit, no
/// children of that node will be visited, nor is post_visit
/// called on that node. Details see [`TreeNodeVisitor`]
///
Expand All @@ -97,20 +97,53 @@ pub trait TreeNode: Sized {
fn visit<V: TreeNodeVisitor<N = Self>>(
&self,
visitor: &mut V,
) -> Result<VisitRecursion> {
) -> Result<TreeNodeRecursion> {
handle_tree_recursion!(visitor.pre_visit(self)?);
handle_tree_recursion!(self.apply_children(&mut |node| node.visit(visitor))?);
visitor.post_visit(self)
}

/// Convenience utils for writing optimizers rule: recursively apply the given `op` to the node tree.
/// When `op` does not apply to a given node, it is left unchanged.
/// The default tree traversal direction is transform_up(Postorder Traversal).
fn transform<F>(self, op: &F) -> Result<Self>
/// Transforms the tree using `f_down` while traversing the tree top-down
/// (pre-preorder) and using `f_up` while traversing the tree bottom-up (post-order).
///
/// E.g. for an tree such as:
/// ```text
/// ParentNode
/// left: ChildNode1
/// right: ChildNode2
/// ```
///
/// The nodes are visited using the following order:
/// ```text
/// f_down(ParentNode)
/// f_down(ChildNode1)
/// f_up(ChildNode1)
/// f_down(ChildNode2)
/// f_up(ChildNode2)
/// f_up(ParentNode)
/// ```
///
/// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled.
///
/// If `f_down` or `f_up` returns [`Err`], recursion is stopped immediately.
fn transform<FD, FU>(self, f_down: &mut FD, f_up: &mut FU) -> Result<Self>
where
F: Fn(Self) -> Result<Transformed<Self>>,
FD: FnMut(Self) -> Result<(Transformed<Self>, TreeNodeRecursion)>,
FU: FnMut(Self) -> Result<Self>,
{
self.transform_up(op)
let (new_node, tnr) = f_down(self).map(|(t, tnr)| (t.into(), tnr))?;
match tnr {
TreeNodeRecursion::Continue => {}
// If the recursion should skip, do not apply to its children. And let the recursion continue
TreeNodeRecursion::Skip => return Ok(new_node),
// If the recursion should stop, do not apply to its children
TreeNodeRecursion::Stop => {
panic!("Stop can't be used in TreeNode::transform()")
}
}
let node_with_new_children =
new_node.map_children(|node| node.transform(f_down, f_up))?;
f_up(node_with_new_children)
}

/// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its
Expand Down Expand Up @@ -159,56 +192,50 @@ pub trait TreeNode: Sized {
Ok(new_node)
}

/// Transform the tree node using the given [TreeNodeRewriter]
/// It performs a depth first walk of an node and its children.
/// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for
/// recursively transforming [`TreeNode`]s.
///
/// For an node tree such as
/// E.g. for an tree such as:
/// ```text
/// ParentNode
/// left: ChildNode1
/// right: ChildNode2
/// ```
///
/// The nodes are visited using the following order
/// The nodes are visited using the following order:
/// ```text
/// pre_visit(ParentNode)
/// pre_visit(ChildNode1)
/// mutate(ChildNode1)
/// pre_visit(ChildNode2)
/// mutate(ChildNode2)
/// mutate(ParentNode)
/// TreeNodeRewriter::f_down(ParentNode)
/// TreeNodeRewriter::f_down(ChildNode1)
/// TreeNodeRewriter::f_up(ChildNode1)
/// TreeNodeRewriter::f_down(ChildNode2)
/// TreeNodeRewriter::f_up(ChildNode2)
/// TreeNodeRewriter::f_up(ParentNode)
/// ```
///
/// If an Err result is returned, recursion is stopped immediately
///
/// If [`false`] is returned on a call to pre_visit, no
/// children of that node will be visited, nor is mutate
/// called on that node
/// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled.
///
/// If using the default [`TreeNodeRewriter::pre_visit`] which
/// returns `true`, [`Self::transform`] should be preferred.
fn rewrite<R: TreeNodeRewriter<N = Self>>(self, rewriter: &mut R) -> Result<Self> {
let need_mutate = match rewriter.pre_visit(&self)? {
RewriteRecursion::Mutate => return rewriter.mutate(self),
RewriteRecursion::Stop => return Ok(self),
RewriteRecursion::Continue => true,
RewriteRecursion::Skip => false,
};

let after_op_children = self.map_children(|node| node.rewrite(rewriter))?;

// now rewrite this node itself
if need_mutate {
rewriter.mutate(after_op_children)
} else {
Ok(after_op_children)
/// If [`TreeNodeRewriter::f_down()`] or [`TreeNodeRewriter::f_up()`] returns [`Err`],
/// recursion is stopped immediately.
fn rewrite<R: TreeNodeRewriter<Node = Self>>(self, rewriter: &mut R) -> Result<Self> {
let (new_node, tnr) = rewriter.f_down(self)?;
match tnr {
TreeNodeRecursion::Continue => {}
// If the recursion should skip, do not apply to its children. And let the recursion continue
TreeNodeRecursion::Skip => return Ok(new_node),
// If the recursion should stop, do not apply to its children
TreeNodeRecursion::Stop => {
panic!("Stop can't be used in TreeNode::rewrite()")
}
}
let node_with_new_children =
new_node.map_children(|node| node.rewrite(rewriter))?;
rewriter.f_up(node_with_new_children)
}

/// Apply the closure `F` to the node's children
fn apply_children<F>(&self, op: &mut F) -> Result<VisitRecursion>
fn apply_children<F>(&self, op: &mut F) -> Result<TreeNodeRecursion>
where
F: FnMut(&Self) -> Result<VisitRecursion>;
F: FnMut(&Self) -> Result<TreeNodeRecursion>;

/// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder)
fn map_children<F>(self, transform: F) -> Result<Self>
Expand All @@ -231,69 +258,58 @@ pub trait TreeNode: Sized {
/// If an [`Err`] result is returned, recursion is stopped
/// immediately.
///
/// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no
/// If [`TreeNodeRecursion::Stop`] is returned on a call to pre_visit, no
/// children of that tree node are visited, nor is post_visit
/// called on that tree node
///
/// If [`VisitRecursion::Stop`] is returned on a call to post_visit, no
/// If [`TreeNodeRecursion::Stop`] is returned on a call to post_visit, no
/// siblings of that tree node are visited, nor is post_visit
/// called on its parent tree node
///
/// If [`VisitRecursion::Skip`] is returned on a call to pre_visit, no
/// If [`TreeNodeRecursion::Skip`] is returned on a call to pre_visit, no
/// children of that tree node are visited.
pub trait TreeNodeVisitor: Sized {
/// The node type which is visitable.
type N: TreeNode;

/// Invoked before any children of `node` are visited.
fn pre_visit(&mut self, node: &Self::N) -> Result<VisitRecursion>;
fn pre_visit(&mut self, node: &Self::N) -> Result<TreeNodeRecursion>;

/// Invoked after all children of `node` are visited. Default
/// implementation does nothing.
fn post_visit(&mut self, _node: &Self::N) -> Result<VisitRecursion> {
Ok(VisitRecursion::Continue)
fn post_visit(&mut self, _node: &Self::N) -> Result<TreeNodeRecursion> {
Ok(TreeNodeRecursion::Continue)
}
}

/// Trait for potentially recursively transform an [`TreeNode`] node
/// tree. When passed to `TreeNode::rewrite`, `TreeNodeRewriter::mutate` is
/// invoked recursively on all nodes of a tree.
/// Trait for potentially recursively transform a [`TreeNode`] node tree.
pub trait TreeNodeRewriter: Sized {
/// The node type which is rewritable.
type N: TreeNode;
type Node: TreeNode;

/// Invoked before (Preorder) any children of `node` are rewritten /
/// visited. Default implementation returns `Ok(Recursion::Continue)`
fn pre_visit(&mut self, _node: &Self::N) -> Result<RewriteRecursion> {
Ok(RewriteRecursion::Continue)
/// Invoked while traversing down the tree before any children are rewritten /
/// visited.
/// Default implementation returns the node unmodified and continues recursion.
fn f_down(&mut self, node: Self::Node) -> Result<(Self::Node, TreeNodeRecursion)> {
Ok((node, TreeNodeRecursion::Continue))
}

/// Invoked after (Postorder) all children of `node` have been mutated and
/// returns a potentially modified node.
fn mutate(&mut self, node: Self::N) -> Result<Self::N>;
}

/// Controls how the [`TreeNode`] recursion should proceed for [`TreeNode::rewrite`].
#[derive(Debug)]
pub enum RewriteRecursion {
/// Continue rewrite this node tree.
Continue,
/// Call 'op' immediately and return.
Mutate,
/// Do not rewrite the children of this node.
Stop,
/// Keep recursive but skip apply op on this node
Skip,
/// Invoked while traversing up the tree after all children have been rewritten /
/// visited.
/// Default implementation returns the node unmodified.
fn f_up(&mut self, node: Self::Node) -> Result<Self::Node> {
Ok(node)
}
}

/// Controls how the [`TreeNode`] recursion should proceed for [`TreeNode::visit`].
/// Controls how [`TreeNode`] recursions should proceed.
#[derive(Debug)]
pub enum VisitRecursion {
/// Continue the visit to this node tree.
pub enum TreeNodeRecursion {
/// Continue recursion with the next node.
Continue,
/// Keep recursive but skip applying op on the children
/// Skip the current subtree.
Skip,
/// Stop the visit to this node tree.
/// Stop recursion.
Stop,
}

Expand Down Expand Up @@ -340,14 +356,14 @@ pub trait DynTreeNode {
/// [`DynTreeNode`] (such as [`Arc<dyn PhysicalExpr>`])
impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
/// Apply the closure `F` to the node's children
fn apply_children<F>(&self, op: &mut F) -> Result<VisitRecursion>
fn apply_children<F>(&self, op: &mut F) -> Result<TreeNodeRecursion>
where
F: FnMut(&Self) -> Result<VisitRecursion>,
F: FnMut(&Self) -> Result<TreeNodeRecursion>,
{
for child in self.arc_children() {
handle_tree_recursion!(op(&child)?)
}
Ok(VisitRecursion::Continue)
Ok(TreeNodeRecursion::Continue)
}

fn map_children<F>(self, transform: F) -> Result<Self>
Expand Down Expand Up @@ -382,14 +398,14 @@ pub trait ConcreteTreeNode: Sized {

impl<T: ConcreteTreeNode> TreeNode for T {
/// Apply the closure `F` to the node's children
fn apply_children<F>(&self, op: &mut F) -> Result<VisitRecursion>
fn apply_children<F>(&self, op: &mut F) -> Result<TreeNodeRecursion>
where
F: FnMut(&Self) -> Result<VisitRecursion>,
F: FnMut(&Self) -> Result<TreeNodeRecursion>,
{
for child in self.children() {
handle_tree_recursion!(op(child)?)
}
Ok(VisitRecursion::Continue)
Ok(TreeNodeRecursion::Continue)
}

fn map_children<F>(self, transform: F) -> Result<Self>
Expand Down
Loading

0 comments on commit 729c9d2

Please sign in to comment.