Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reference visitor TreeNode APIs, change ExecutionPlan::children() and PhysicalExpr::children() return references #10543

Merged
merged 6 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion-examples/examples/custom_datasource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ impl ExecutionPlan for CustomExec {
&self.cache
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![]
}

Expand Down
135 changes: 111 additions & 24 deletions datafusion/common/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ pub trait TreeNode: Sized {
/// TreeNodeVisitor::f_up(ChildNode2)
/// TreeNodeVisitor::f_up(ParentNode)
/// ```
fn visit<V: TreeNodeVisitor<Node = Self>>(
&self,
fn visit<'n, V: TreeNodeVisitor<'n, Node = Self>>(
&'n self,
visitor: &mut V,
) -> Result<TreeNodeRecursion> {
visitor
Expand Down Expand Up @@ -190,12 +190,12 @@ pub trait TreeNode: Sized {
/// # See Also
/// * [`Self::transform_down`] for the equivalent transformation API.
/// * [`Self::visit`] for both top-down and bottom up traversal.
fn apply<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
fn apply<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
&'n self,
mut f: F,
) -> Result<TreeNodeRecursion> {
fn apply_impl<N: TreeNode, F: FnMut(&N) -> Result<TreeNodeRecursion>>(
node: &N,
fn apply_impl<'n, N: TreeNode, F: FnMut(&'n N) -> Result<TreeNodeRecursion>>(
node: &'n N,
f: &mut F,
) -> Result<TreeNodeRecursion> {
f(node)?.visit_children(|| node.apply_children(|c| apply_impl(c, f)))
Expand Down Expand Up @@ -427,8 +427,8 @@ pub trait TreeNode: Sized {
///
/// Description: Apply `f` to inspect node's children (but not the node
/// itself).
fn apply_children<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
&'n self,
f: F,
) -> Result<TreeNodeRecursion>;

Expand Down Expand Up @@ -466,19 +466,19 @@ pub trait TreeNode: Sized {
///
/// # See Also:
/// * [`TreeNode::rewrite`] to rewrite owned `TreeNode`s
pub trait TreeNodeVisitor: Sized {
pub trait TreeNodeVisitor<'n>: Sized {
/// The node type which is visitable.
type Node: TreeNode;

/// Invoked while traversing down the tree, before any children are visited.
/// Default implementation continues the recursion.
fn f_down(&mut self, _node: &Self::Node) -> Result<TreeNodeRecursion> {
fn f_down(&mut self, _node: &'n Self::Node) -> Result<TreeNodeRecursion> {
Ok(TreeNodeRecursion::Continue)
}

/// Invoked while traversing up the tree after children are visited. Default
/// implementation continues the recursion.
fn f_up(&mut self, _node: &Self::Node) -> Result<TreeNodeRecursion> {
fn f_up(&mut self, _node: &'n Self::Node) -> Result<TreeNodeRecursion> {
Ok(TreeNodeRecursion::Continue)
}
}
Expand Down Expand Up @@ -855,7 +855,7 @@ impl<T> TransformedResult<T> for Result<Transformed<T>> {
/// its related `Arc<dyn T>` will automatically implement [`TreeNode`].
pub trait DynTreeNode {
/// Returns all children of the specified `TreeNode`.
fn arc_children(&self) -> Vec<Arc<Self>>;
fn arc_children(&self) -> Vec<&Arc<Self>>;

/// Constructs a new node with the specified children.
fn with_new_arc_children(
Expand All @@ -868,11 +868,11 @@ pub trait DynTreeNode {
/// Blanket implementation for any `Arc<T>` where `T` implements [`DynTreeNode`]
/// (such as [`Arc<dyn PhysicalExpr>`]).
impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
fn apply_children<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
&'n self,
f: F,
) -> Result<TreeNodeRecursion> {
self.arc_children().iter().apply_until_stop(f)
self.arc_children().into_iter().apply_until_stop(f)
}

fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
Expand All @@ -881,7 +881,10 @@ impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
) -> Result<Transformed<Self>> {
let children = self.arc_children();
if !children.is_empty() {
let new_children = children.into_iter().map_until_stop_and_collect(f)?;
let new_children = children
.into_iter()
.cloned()
.map_until_stop_and_collect(f)?;
// Propagate up `new_children.transformed` and `new_children.tnr`
// along with the node containing transformed children.
if new_children.transformed {
Expand Down Expand Up @@ -913,8 +916,8 @@ pub trait ConcreteTreeNode: Sized {
}

impl<T: ConcreteTreeNode> TreeNode for T {
fn apply_children<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
&'n self,
f: F,
) -> Result<TreeNodeRecursion> {
self.children().iter().apply_until_stop(f)
Expand All @@ -938,6 +941,7 @@ impl<T: ConcreteTreeNode> TreeNode for T {

#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::fmt::Display;

use crate::tree_node::{
Expand All @@ -946,7 +950,7 @@ mod tests {
};
use crate::Result;

#[derive(PartialEq, Debug)]
#[derive(Debug, Eq, Hash, PartialEq)]
struct TestTreeNode<T> {
children: Vec<TestTreeNode<T>>,
data: T,
Expand All @@ -959,8 +963,8 @@ mod tests {
}

impl<T> TreeNode for TestTreeNode<T> {
fn apply_children<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
&'n self,
f: F,
) -> Result<TreeNodeRecursion> {
self.children.iter().apply_until_stop(f)
Expand Down Expand Up @@ -1459,15 +1463,15 @@ mod tests {
}
}

impl<T: Display> TreeNodeVisitor for TestVisitor<T> {
impl<'n, T: Display> TreeNodeVisitor<'n> for TestVisitor<T> {
type Node = TestTreeNode<T>;

fn f_down(&mut self, node: &Self::Node) -> Result<TreeNodeRecursion> {
fn f_down(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
self.visits.push(format!("f_down({})", node.data));
(*self.f_down)(node)
}

fn f_up(&mut self, node: &Self::Node) -> Result<TreeNodeRecursion> {
fn f_up(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
self.visits.push(format!("f_up({})", node.data));
(*self.f_up)(node)
}
Expand Down Expand Up @@ -1912,4 +1916,87 @@ mod tests {
TreeNodeRecursion::Stop
)
);

// F
// / | \
// / | \
// E C A
// | / \
// C B D
// / \ |
// B D A
// |
// A
#[test]
fn test_apply_and_visit_references() -> Result<()> {
let node_a = TestTreeNode::new(vec![], "a".to_string());
let node_b = TestTreeNode::new(vec![], "b".to_string());
let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
let node_e = TestTreeNode::new(vec![node_c], "e".to_string());
let node_a_2 = TestTreeNode::new(vec![], "a".to_string());
let node_b_2 = TestTreeNode::new(vec![], "b".to_string());
let node_d_2 = TestTreeNode::new(vec![node_a_2], "d".to_string());
let node_c_2 = TestTreeNode::new(vec![node_b_2, node_d_2], "c".to_string());
let node_a_3 = TestTreeNode::new(vec![], "a".to_string());
let tree = TestTreeNode::new(vec![node_e, node_c_2, node_a_3], "f".to_string());

let node_f_ref = &tree;
let node_e_ref = &node_f_ref.children[0];
let node_c_ref = &node_e_ref.children[0];
let node_b_ref = &node_c_ref.children[0];
let node_d_ref = &node_c_ref.children[1];
let node_a_ref = &node_d_ref.children[0];

let mut m: HashMap<&TestTreeNode<String>, usize> = HashMap::new();
tree.apply(|e| {
*m.entry(e).or_insert(0) += 1;
Ok(TreeNodeRecursion::Continue)
})?;

let expected = HashMap::from([
(node_f_ref, 1),
(node_e_ref, 1),
(node_c_ref, 2),
(node_d_ref, 2),
(node_b_ref, 2),
(node_a_ref, 3),
]);
assert_eq!(m, expected);

struct TestVisitor<'n> {
m: HashMap<&'n TestTreeNode<String>, (usize, usize)>,
}

impl<'n> TreeNodeVisitor<'n> for TestVisitor<'n> {
type Node = TestTreeNode<String>;

fn f_down(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
let (down_count, _) = self.m.entry(node).or_insert((0, 0));
*down_count += 1;
Ok(TreeNodeRecursion::Continue)
}

fn f_up(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
let (_, up_count) = self.m.entry(node).or_insert((0, 0));
*up_count += 1;
Ok(TreeNodeRecursion::Continue)
}
}

let mut visitor = TestVisitor { m: HashMap::new() };
tree.visit(&mut visitor)?;

let expected = HashMap::from([
(node_f_ref, (1, 1)),
(node_e_ref, (1, 1)),
(node_c_ref, (2, 2)),
(node_d_ref, (2, 2)),
(node_b_ref, (2, 2)),
(node_a_ref, (3, 3)),
]);
assert_eq!(visitor.m, expected);

Ok(())
}
}
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/physical_plan/arrow_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ impl ExecutionPlan for ArrowExec {
&self.cache
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
Vec::new()
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/physical_plan/avro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ impl ExecutionPlan for AvroExec {
&self.cache
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
Vec::new()
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/physical_plan/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ impl ExecutionPlan for CsvExec {
&self.cache
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
// this is a leaf node and has no children
vec![]
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/physical_plan/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ impl ExecutionPlan for NdJsonExec {
&self.cache
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
Vec::new()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ impl ExecutionPlan for ParquetExec {
&self.cache
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
// this is a leaf node and has no children
vec![]
}
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2465,10 +2465,10 @@ impl<'a> BadPlanVisitor<'a> {
}
}

impl<'a> TreeNodeVisitor for BadPlanVisitor<'a> {
impl<'n, 'a> TreeNodeVisitor<'n> for BadPlanVisitor<'a> {
type Node = LogicalPlan;

fn f_down(&mut self, node: &Self::Node) -> Result<TreeNodeRecursion> {
fn f_down(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
match node {
LogicalPlan::Ddl(ddl) if !self.options.allow_ddl => {
plan_err!("DDL not supported: {}", ddl.name())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ fn take_optimizable(node: &dyn ExecutionPlan) -> Option<Arc<dyn ExecutionPlan>>
return Some(child);
}
}
if let [ref childrens_child] = child.children().as_slice() {
if let [childrens_child] = child.children().as_slice() {
child = Arc::clone(childrens_child);
} else {
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1375,8 +1375,8 @@ pub(crate) mod tests {
vec![false]
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
vec![self.input.clone()]
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.input]
}

// model that it requires the output ordering of its input
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/physical_optimizer/enforce_sorting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ fn remove_corresponding_sort_from_sub_plan(
// Replace with variants that do not preserve order.
if is_sort_preserving_merge(&node.plan) {
node.children = node.children.swap_remove(0).children;
node.plan = node.plan.children().swap_remove(0);
node.plan = node.plan.children().swap_remove(0).clone();
} else if let Some(repartition) =
node.plan.as_any().downcast_ref::<RepartitionExec>()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,15 @@ impl LimitedDistinctAggregation {
let mut is_global_limit = false;
if let Some(local_limit) = plan.as_any().downcast_ref::<LocalLimitExec>() {
limit = local_limit.fetch();
children = local_limit.children();
children = local_limit.children().into_iter().cloned().collect();
} else if let Some(global_limit) = plan.as_any().downcast_ref::<GlobalLimitExec>()
{
global_fetch = global_limit.fetch();
global_fetch?;
global_skip = global_limit.skip();
// the aggregate must read at least fetch+skip number of rows
limit = global_fetch.unwrap() + global_skip;
children = global_limit.children();
children = global_limit.children().into_iter().cloned().collect();
is_global_limit = true
} else {
return None;
Expand Down
6 changes: 3 additions & 3 deletions datafusion/core/src/physical_optimizer/output_requirements.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ impl ExecutionPlan for OutputRequirementExec {
vec![true]
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
vec![self.input.clone()]
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.input]
}

fn required_input_ordering(&self) -> Vec<Option<Vec<PhysicalSortRequirement>>> {
Expand Down Expand Up @@ -273,7 +273,7 @@ fn require_top_ordering_helper(
// When an operator requires an ordering, any `SortExec` below can not
// be responsible for (i.e. the originator of) the global ordering.
let (new_child, is_changed) =
require_top_ordering_helper(children.swap_remove(0))?;
require_top_ordering_helper(children.swap_remove(0).clone())?;
Ok((plan.with_new_children(vec![new_child])?, is_changed))
} else {
// Stop searching, there is no global ordering desired for the query.
Expand Down
Loading