Skip to content

Commit

Permalink
improve dos
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Apr 1, 2024
1 parent 177ceea commit ffab86a
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 61 deletions.
68 changes: 45 additions & 23 deletions datafusion/common/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ pub trait TreeNode: Sized {
/// Visit the tree node using the given [`TreeNodeVisitor`], performing a
/// depth-first walk of the node and its children.
///
/// See also:
/// * [`Self::mutate`] to rewrite `TreeNode`s in place
/// * [`Self::rewrite`] to rewrite owned `TreeNode`s
///
/// Consider the following tree structure:
/// ```text
/// ParentNode
Expand Down Expand Up @@ -144,6 +148,10 @@ pub trait TreeNode: Sized {
/// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for
/// recursively transforming [`TreeNode`]s.
///
/// See also:
/// * [`Self::mutate`] to rewrite `TreeNode`s in place
/// * [`Self::visit`] for inspecting (without modification) `TreeNode`s
///
/// Consider the following tree structure:
/// ```text
/// ParentNode
Expand Down Expand Up @@ -175,7 +183,11 @@ pub trait TreeNode: Sized {
}

/// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for
/// recursively mutating / rewriting [`TreeNode`]s in place
/// recursively mutating / rewriting [`TreeNode`]s in place.
///
/// See also:
/// * [`Self::rewrite`] to rewrite owned `TreeNode`s
/// * [`Self::visit`] for inspecting (without modification) `TreeNode`s
///
/// Consider the following tree structure:
/// ```text
Expand All @@ -184,7 +196,7 @@ pub trait TreeNode: Sized {
/// right: ChildNode2
/// ```
///
/// Here, the nodes would be mutated using the following order:
/// Here, the nodes would be mutataed in the following order:
/// ```text
/// TreeNodeMutator::f_down(ParentNode)
/// TreeNodeMutator::f_down(ChildNode1)
Expand Down Expand Up @@ -422,13 +434,17 @@ pub trait TreeNode: Sized {

/// Rewrite the node's children in place using `F`.
///
/// Using [`Self::map_children`], the owned API, is more ideomatic and
/// has clearer semantics on error (the node is consumed). However, it requires
/// copying the interior fields of the tree node during rewrite
/// On error, `self` is left partially rewritten.
///
/// # Notes
///
/// Using [`Self::map_children`], the owned API, has clearer semantics on
/// error (the node is consumed). However, it requires copying the interior
/// fields of the tree node during rewrite.
///
/// This API writes the nodes in place, which can be faster as it avoids
/// copying. However, one downside is that the tree node can be left in an
/// partially rewritten state when an error occurs.
/// copying, but leaves the tree node in an partially rewritten state when
/// an error occurs.
fn mutate_children<F: FnMut(&mut Self) -> Result<Transformed<()>>>(
&mut self,
_f: F,
Expand Down Expand Up @@ -492,30 +508,35 @@ pub trait TreeNodeRewriter: Sized {
}
}

/// Trait for potentially rewriting tree of [`TreeNode`]s in place
/// Trait for mutating (rewriting in place) [`TreeNode`]s in place
///
/// See [`TreeNodeRewriter`] for rewriting owned tree ndoes
/// See [`TreeNodeVisitor`] for visiting, but not changing, tree nodes
/// # See Also:
/// * [`TreeNodeRewriter`] for rewriting owned `TreeNode`e
/// * [`TreeNodeVisitor`] for visiting, but not changing, `TreeNode`s
pub trait TreeNodeMutator: Sized {
/// The node type to rewrite.
/// The node type to mutating.
type Node: TreeNode;

/// Invoked while traversing down the tree before any children are rewritten.
/// Default implementation returns the node as is and continues recursion.
/// Invoked while traversing down the tree before any children are mutated.
/// Default implementation does nothing to the node and continues recursion.
///
/// # Notes
///
/// Since this mutates the nodes in place, the returned Transformed object
/// As the node maybe mutated in place, the returned [`Transformed`] object
/// returns `()` (no data).
///
/// If the node's children are changed by `f_down`, the *new* children are
/// visited, not the original.
/// visited, not the original children.
fn f_down(&mut self, _node: &mut Self::Node) -> Result<Transformed<()>> {
Ok(Transformed::no(()))
}

/// Invoked while traversing up the tree after all children have been rewritten.
/// Default implementation returns the node as is and continues recursion.
/// Invoked while traversing up the tree after all children have been mutated.
/// Default implementation does nothing to the node and continues recursion.
///
/// # Notes
///
/// Since this mutates the nodes in place, the returned Transformed object
/// As the node maybe mutated in place, the returned [`Transformed`] object
/// returns `()` (no data).
fn f_up(&mut self, _node: &mut Self::Node) -> Result<Transformed<()>> {
Ok(Transformed::no(()))
Expand Down Expand Up @@ -603,7 +624,7 @@ impl<T> Transformed<T> {
/// Invokes f(), depending on the value of self.tnr.
///
/// This is used to conditionally apply a function during a f_up tree
/// traversal, if the result of children traversal was `Continue`.
/// traversal, if the result of children traversal was `[`TreeNodeRecursion::Continue`].
///
/// Handling [`TreeNodeRecursion::Continue`] and [`TreeNodeRecursion::Stop`]
/// is straightforward, but [`TreeNodeRecursion::Jump`] can behave differently
Expand Down Expand Up @@ -650,11 +671,12 @@ impl<T> Transformed<T> {

impl Transformed<()> {
/// Invoke the given function `f` and combine the transformed state with
/// the current state,
/// the current state:
///
/// * if `f` returns an Err, returns that err
///
/// if f() returns an Err, returns that err
/// If f() returns Ok, returns a true transformed flag if either self or
/// the result of f() was transformed
/// * If `f` returns Ok, sets `self.transformed` to `true` if either self or
/// the result of `f` were transformed.
pub fn and_then<F>(self, f: F) -> Result<Transformed<()>>
where
F: FnOnce() -> Result<Transformed<()>>,
Expand Down
70 changes: 32 additions & 38 deletions datafusion/expr/src/logical_plan/mutate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ use datafusion_common::{Column, DFSchema, DFSchemaRef};
use std::sync::{Arc, OnceLock};

impl LogicalPlan {
/// applies the closure `f` to each expression of this node, potentially
/// rewriting it in place
/// applies `f` to each expression of this node, potentially rewriting it in
/// place
///
/// If the closure returns an error, the error is returned and the expressions
/// are left in a partially modified state
/// If `f` returns an error, the error is returned and the expressions are
/// left in a partially modified state
pub fn rewrite_exprs<F>(&mut self, mut f: F) -> Result<Transformed<()>>
where
F: FnMut(&mut Expr) -> Result<Transformed<()>>,
Expand Down Expand Up @@ -66,7 +66,6 @@ impl LogicalPlan {
// 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`.
// 2. the second part is non-equijoin(filter).
LogicalPlan::Join(Join { on, filter, .. }) => {
// don't look at the equijoin expressions as a whole
let exprs = on
.iter_mut()
.flat_map(|(e1, e2)| std::iter::once(e1).chain(std::iter::once(e2)));
Expand All @@ -88,10 +87,7 @@ impl LogicalPlan {
LogicalPlan::TableScan(TableScan { filters, .. }) => {
rewrite_expr_iter_mut(filters.iter_mut(), f)
}
LogicalPlan::Unnest(Unnest { column, .. }) => {
// it would be really nice to avoid this and instead have unnest have Expr
rewrite_column(column, f)
}
LogicalPlan::Unnest(Unnest { column, .. }) => rewrite_column(column, f),
LogicalPlan::Distinct(Distinct::On(DistinctOn {
on_expr,
select_expr,
Expand Down Expand Up @@ -125,12 +121,14 @@ impl LogicalPlan {
}
}

/// applies the closure `f` to each input of this node, in place.
/// applies `f` to each input of this node, rewriting them in place.
///
/// Inputs include both direct children as well as any embedded
/// `LogicalPlan`s embedded in `Expr::Exists`, etc.
/// # Notes
/// Inputs include both direct children as well as any embedded subquery
/// `LogicalPlan`s, for example such as are in [`Expr::Exists`].
///
/// If Err is returned, the inputs may be in a partially modified state
/// If `f` returns an `Err`, that Err is returned, and the inputs are left
/// in a partially modified state
pub fn rewrite_inputs<F>(&mut self, mut f: F) -> Result<Transformed<()>>
where
F: FnMut(&mut LogicalPlan) -> Result<Transformed<()>>,
Expand Down Expand Up @@ -202,7 +200,7 @@ impl LogicalPlan {
children_result.and_then(|| self.rewrite_subqueries(&mut f))
}

/// applies the closure `f` to LogicalPlans in any subquery expressions
/// applies `f` to LogicalPlans in any subquery expressions
///
/// If Err is returned, the plan may be left in a partially modified state
fn rewrite_subqueries<F>(&mut self, mut f: F) -> Result<Transformed<()>>
Expand All @@ -220,7 +218,7 @@ impl LogicalPlan {
}
}

/// writes each element in the iterator using `f`
/// writes each `&mut Expr` in the iterator using `f`
fn rewrite_expr_iter_mut<'a, F>(
i: impl IntoIterator<Item = &'a mut Expr>,
mut f: F,
Expand All @@ -233,17 +231,18 @@ where
}

/// A temporary node that is left in place while rewriting the children of a
/// logical plan to ensure that the logical plan is always in a valid state
/// [`LogicalPlan`]. This is necessary to ensure that the `LogicalPlan` is
/// always in a valid state (from the Rust perspective)
static PLACEHOLDER: OnceLock<Arc<LogicalPlan>> = OnceLock::new();

/// Applies f() to rewrite each child of the logical plan, replacing the existing
/// node with the result, trying to avoid clone'ing.
/// Applies `f` to rewrite the existing node, while avoiding `clone`'ing as much
/// as possiblw.
///
/// TODO we would remove the Arc<LogicalPlan> nonsense entirely from LogicalPlan
/// and have it own its inputs however, for now do a horrible hack and swap out the value
/// of the Arc to avoid cloning the entire plan
/// TODO eventually remove `Arc<LogicalPlan>` from `LogicalPlan` and have it own
/// its inputs, so this code would not be needed. However, for now we try and
/// unwrap the `Arc` which avoids `clone`ing in most cases.
///
/// On error, the node will be partially rewritten (left with a placeholder logical plan)
/// On error, node be left with a placeholder logical plan
fn rewrite_arc<F>(node: &mut Arc<LogicalPlan>, mut f: F) -> Result<Transformed<()>>
where
F: FnMut(&mut LogicalPlan) -> Result<Transformed<()>>,
Expand All @@ -263,16 +262,11 @@ where
std::mem::swap(node, &mut new_node);

// try to update existing node, if it isn't shared with others
let mut new_node = match Arc::try_unwrap(new_node) {
Ok(node) => {
//println!("Unwrapped arc yay");
node
}
Err(node) => {
//println!("Failed to unwrap arc boo");
node.as_ref().clone()
}
};
let mut new_node = Arc::try_unwrap(new_node)
// if None is returned, there is another reference to this
// LogicalPlan, so we must clone instead
.unwrap_or_else(|node| node.as_ref().clone());

// apply the actual transform
let result = f(&mut new_node)?;

Expand All @@ -283,24 +277,23 @@ where
Ok(result)
}

/// Rewrites a `Column` in place using the provided closure
/// Rewrites a [`Column`] in place using the provided closure
fn rewrite_column<F>(column: &mut Column, mut f: F) -> Result<Transformed<()>>
where
F: FnMut(&mut Expr) -> Result<Transformed<()>>,
{
// Column's isn't an Expr to visit, but the closure is to rewrite Exprs.
// So we need to make a temporary Expr to rewrite and then put it bac,
// Since `Column`'s isn't an `Expr`, but the closure in terms of Exprs,
// we make a temporary Expr to rewrite and then put it back

let mut swap_column = Column::new_unqualified("TEMP_unnest_column");
std::mem::swap(column, &mut swap_column);

let mut expr = Expr::Column(swap_column);
let result = f(&mut expr)?;
// put the column back
// Get the rewritten column
let Expr::Column(mut swap_column) = expr else {
return internal_err!(
"Rewrite of Unnest expr must return Column, returned {:?}",
expr
"Rewrite of Column Expr must return Column, returned {expr:?}"
);
};
// put the rewritten column back
Expand All @@ -310,6 +303,7 @@ where

/// Rewrites all expressions for an Extension node "in place"
/// (it currently has to copy values because there are no APIs for in place modification)
/// TODO file ticket for inplace modificiation of Extension nodes
///
/// Should be removed when we have an API for in place modifications of the
/// extension to avoid these copies
Expand Down

0 comments on commit ffab86a

Please sign in to comment.