Skip to content

Commit

Permalink
Sketch out TreeNodeMutator API
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Mar 29, 2024
1 parent 127fe5e commit 233d8e9
Show file tree
Hide file tree
Showing 5 changed files with 466 additions and 55 deletions.
140 changes: 137 additions & 3 deletions datafusion/common/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
use std::sync::Arc;

use crate::Result;
use crate::{error::_not_impl_err, Result};

/// This macro is used to control continuation behaviors during tree traversals
/// based on the specified direction. Depending on `$DIRECTION` and the value of
Expand Down Expand Up @@ -174,6 +174,66 @@ pub trait TreeNode: Sized {
})
}

/// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for
/// recursively mutating / rewriting [`TreeNode`]s in place
///
/// Consider the following tree structure:
/// ```text
/// ParentNode
/// left: ChildNode1
/// right: ChildNode2
/// ```
///
/// Here, the nodes would be mutated using the following order:
/// ```text
/// TreeNodeMutator::f_down(ParentNode)
/// TreeNodeMutator::f_down(ChildNode1)
/// TreeNodeMutator::f_up(ChildNode1)
/// TreeNodeMutator::f_down(ChildNode2)
/// TreeNodeMutator::f_up(ChildNode2)
/// TreeNodeMutator::f_up(ParentNode)
/// ```
///
/// See [`TreeNodeRecursion`] for more details on controlling the traversal.
///
/// # Error Handling
///
/// If [`TreeNodeVisitor::f_down()`] or [`TreeNodeVisitor::f_up()`] returns [`Err`],
/// the recursion stops immediately and the tree may be left partially changed
///
/// # Changing Children During Traversal
///
/// If `f_down` changes the nodes children, the new children are visited
/// (not the old children prior to rewrite)
fn mutate<M: TreeNodeMutator<Node = Self>>(
&mut self,
mutator: &mut M,
) -> Result<Transformed<()>> {
// Note this is an inlined version of handle_transform_recursion!
let pre_visited = mutator.f_down(self)?;

// Traverse children and then call f_up on self if necessary
match pre_visited.tnr {
TreeNodeRecursion::Continue => {
// rewrite children recursively with mutator
self.mutate_children(|c| c.mutate(mutator))?
.try_transform_node_with(
|_: ()| mutator.f_up(self),
TreeNodeRecursion::Jump,
)
}
TreeNodeRecursion::Jump => {
// skip other children and start back up
mutator.f_up(self)
}
TreeNodeRecursion::Stop => return Ok(pre_visited),
}
.map(|mut post_visited| {
post_visited.transformed |= pre_visited.transformed;
post_visited
})
}

/// Applies `f` to the node and its children. `f` is applied in a pre-order
/// way, and it is controlled by [`TreeNodeRecursion`], which means result
/// of the `f` on a node can cause an early return.
Expand Down Expand Up @@ -353,13 +413,34 @@ pub trait TreeNode: Sized {
}

/// Apply the closure `F` to the node's children.
///
/// See `mutate_children` for rewriting in place
fn apply_children<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
f: &mut F,
) -> Result<TreeNodeRecursion>;

/// Apply transform `F` to the node's children. Note that the transform `F`
/// might have a direction (pre-order or post-order).
/// Rewrite the node's children in place using `F`.
///
/// Using [`Self::map_children`], the owned API, is more ideomatic and
/// has clearer semantics on error (the node is consumed). However, it requires
/// copying the interior fields of the tree node during rewrite
///
/// This API writes the nodes in place, which can be faster as it avoids
/// copying. However, one downside is that the tree node can be left in an
/// partially rewritten state when an error occurs.
fn mutate_children<F: FnMut(&mut Self) -> Result<Transformed<()>>>(
&mut self,
_f: F,
) -> Result<Transformed<()>> {
_not_impl_err!(
"mutate_children not implemented for {} yet",
std::any::type_name::<Self>()
)
}

/// Apply transform `F` to potentially rewrite the node's children. Note
/// that the transform `F` might have a direction (pre-order or post-order).
fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
self,
f: F,
Expand Down Expand Up @@ -411,6 +492,36 @@ pub trait TreeNodeRewriter: Sized {
}
}

/// Trait for potentially rewriting tree of [`TreeNode`]s in place
///
/// See [`TreeNodeRewriter`] for rewriting owned tree ndoes
/// See [`TreeNodeVisitor`] for visiting, but not changing, tree nodes
pub trait TreeNodeMutator: Sized {
/// The node type to rewrite.
type Node: TreeNode;

/// Invoked while traversing down the tree before any children are rewritten.
/// Default implementation returns the node as is and continues recursion.
///
/// Since this mutates the nodes in place, the returned Transformed object
/// returns `()` (no data).
///
/// If the node's children are changed by `f_down`, the *new* children are
/// visited, not the original.
fn f_down(&mut self, _node: &mut Self::Node) -> Result<Transformed<()>> {
Ok(Transformed::no(()))
}

/// Invoked while traversing up the tree after all children have been rewritten.
/// Default implementation returns the node as is and continues recursion.
///
/// Since this mutates the nodes in place, the returned Transformed object
/// returns `()` (no data).
fn f_up(&mut self, _node: &mut Self::Node) -> Result<Transformed<()>> {
Ok(Transformed::no(()))
}
}

/// Controls how [`TreeNode`] recursions should proceed.
#[derive(Debug, PartialEq, Clone, Copy)]
pub enum TreeNodeRecursion {
Expand Down Expand Up @@ -489,6 +600,11 @@ impl<T> Transformed<T> {
f(self.data).map(|data| Transformed::new(data, self.transformed, self.tnr))
}

/// Invokes f(), depending on the value of self.tnr.
///
/// This is used to conditionally apply a function during a f_up tree
/// traversal, if the result of children traversal was `Continue`.
///
/// Handling [`TreeNodeRecursion::Continue`] and [`TreeNodeRecursion::Stop`]
/// is straightforward, but [`TreeNodeRecursion::Jump`] can behave differently
/// when we are traversing down or up on a tree. If [`TreeNodeRecursion`] of
Expand Down Expand Up @@ -532,6 +648,24 @@ impl<T> Transformed<T> {
}
}

impl Transformed<()> {
/// Invoke the given function `f` and combine the transformed state with
/// the current state,
///
/// if f() returns an Err, returns that err
/// If f() returns Ok, returns a true transformed flag if either self or
/// the result of f() was transformed
pub fn and_then<F>(self, f: F) -> Result<Transformed<()>>
where
F: FnOnce() -> Result<Transformed<()>>,
{
f().map(|mut t| {
t.transformed |= self.transformed;
t
})
}
}

/// Transformation helper to process tree nodes that are siblings.
pub trait TransformedIterator: Iterator {
fn map_until_stop_and_collect<
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ mod ddl;
pub mod display;
pub mod dml;
mod extension;
mod mutate;
mod plan;
mod statement;

Expand Down
Loading

0 comments on commit 233d8e9

Please sign in to comment.