diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index 5064ad8d5c487..cde1b5cae8b18 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -25,7 +25,7 @@ use super::utils::{ convert_duration_type_to_interval, convert_interval_type_to_duration, get_inverse_op, }; use crate::expressions::Literal; -use crate::utils::{build_dag, ExprTreeNode}; +use crate::utils::build_dag; use crate::PhysicalExpr; use arrow_schema::{DataType, Schema}; @@ -179,8 +179,8 @@ impl ExprIntervalGraphNode { /// This function creates a DAEG node from Datafusion's [`ExprTreeNode`] /// object. Literals are created with definite, singleton intervals while /// any other expression starts with an indefinite interval ([-∞, ∞]). - pub fn make_node(node: &ExprTreeNode, schema: &Schema) -> Result { - let expr = node.expression().clone(); + pub fn make_node(expr: &Arc, schema: &Schema) -> Result { + let expr = expr.clone(); if let Some(literal) = expr.as_any().downcast_ref::() { let value = literal.value(); Interval::try_new(value.clone(), value.clone()) @@ -353,7 +353,7 @@ impl ExprIntervalGraph { pub fn try_new(expr: Arc, schema: &Schema) -> Result { // Build the full graph: let (root, graph) = - build_dag(expr, &|node| ExprIntervalGraphNode::make_node(node, schema))?; + build_dag(expr, &|expr| ExprIntervalGraphNode::make_node(expr, schema))?; Ok(Self { graph, root }) } diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 64a62dc7820d8..1ab51e50d1192 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -18,7 +18,7 @@ mod guarantee; pub use guarantee::{Guarantee, LiteralGuarantee}; -use std::borrow::{Borrow, Cow}; +use std::borrow::Borrow; use std::collections::{HashMap, HashSet}; use std::sync::Arc; @@ -28,9 +28,7 @@ use crate::{PhysicalExpr, PhysicalSortExpr}; use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData}; use arrow::compute::{and_kleene, is_not_null, SlicesIterator}; use arrow::datatypes::SchemaRef; -use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeRewriter, VisitRecursion, -}; +use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; use datafusion_common::Result; use datafusion_expr::Operator; @@ -127,56 +125,11 @@ pub fn get_indices_of_exprs_strict>>( .collect() } -#[derive(Clone, Debug)] -pub struct ExprTreeNode { - expr: Arc, - data: Option, - child_nodes: Vec>, -} - -impl ExprTreeNode { - pub fn new(expr: Arc) -> Self { - let children = expr.children(); - ExprTreeNode { - expr, - data: None, - child_nodes: children.into_iter().map(Self::new).collect_vec(), - } - } - - pub fn expression(&self) -> &Arc { - &self.expr - } - - pub fn children(&self) -> &[ExprTreeNode] { - &self.child_nodes - } -} - -impl TreeNode for ExprTreeNode { - fn children_nodes(&self) -> Vec> { - self.children().iter().map(Cow::Borrowed).collect() - } - - fn map_children(mut self, transform: F) -> Result - where - F: FnMut(Self) -> Result, - { - self.child_nodes = self - .child_nodes - .into_iter() - .map(transform) - .collect::>>()?; - Ok(self) - } -} - /// This struct facilitates the [TreeNodeRewriter] mechanism to convert a /// [PhysicalExpr] tree into a DAEG (i.e. an expression DAG) by collecting /// identical expressions in one node. Caller specifies the node type in the -/// DAEG via the `constructor` argument, which constructs nodes in the DAEG -/// from the [ExprTreeNode] ancillary object. -struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode) -> Result> { +/// DAEG via the `constructor` argument. +struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&Arc) -> Result> { // The resulting DAEG (expression DAG). graph: StableGraph, // A vector of visited expression nodes and their corresponding node indices. @@ -185,19 +138,16 @@ struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode) -> Result< constructor: &'a F, } -impl<'a, T, F: Fn(&ExprTreeNode) -> Result> TreeNodeRewriter - for PhysicalExprDAEGBuilder<'a, T, F> +impl<'a, T, F: Fn(&Arc) -> Result> + PhysicalExprDAEGBuilder<'a, T, F> { - type N = ExprTreeNode; - // This method mutates an expression node by transforming it to a physical expression - // and adding it to the graph. The method returns the mutated expression node. - fn mutate( + // This method adds an expression to the graph and returns the corresponding node + // index. + fn calculate_node_index( &mut self, - mut node: ExprTreeNode, - ) -> Result> { - // Get the expression associated with the input expression node. - let expr = &node.expr; - + expr: Arc, + children_node_indices: Vec, + ) -> Result<(Transformed>, NodeIndex)> { // Check if the expression has already been visited. let node_idx = match self.visited_plans.iter().find(|(e, _)| expr.eq(e)) { // If the expression has been visited, return the corresponding node index. @@ -206,18 +156,15 @@ impl<'a, T, F: Fn(&ExprTreeNode) -> Result> TreeNodeRewriter // add edges to its child nodes. Add the visited expression to the vector // of visited expressions and return the newly created node index. None => { - let node_idx = self.graph.add_node((self.constructor)(&node)?); - for expr_node in node.child_nodes.iter() { - self.graph.add_edge(node_idx, expr_node.data.unwrap(), 0); + let node_idx = self.graph.add_node((self.constructor)(&expr)?); + for child_node_index in children_node_indices.into_iter() { + self.graph.add_edge(node_idx, child_node_index, 0); } self.visited_plans.push((expr.clone(), node_idx)); node_idx } }; - // Set the data field of the input expression node to the corresponding node index. - node.data = Some(node_idx); - // Return the mutated expression node. - Ok(node) + Ok((Transformed::No(expr), node_idx)) } } @@ -227,10 +174,8 @@ pub fn build_dag( constructor: &F, ) -> Result<(NodeIndex, StableGraph)> where - F: Fn(&ExprTreeNode) -> Result, + F: Fn(&Arc) -> Result, { - // Create a new expression tree node from the input expression. - let init = ExprTreeNode::new(expr); // Create a new `PhysicalExprDAEGBuilder` instance. let mut builder = PhysicalExprDAEGBuilder { graph: StableGraph::::new(), @@ -238,9 +183,12 @@ where constructor, }; // Use the builder to transform the expression tree node into a DAG. - let root = init.rewrite(&mut builder)?; + let (_, node_index) = + expr.transform_up_with_payload(&mut |expr, children_node_indices| { + builder.calculate_node_index(expr, children_node_indices) + })?; // Return a tuple containing the root node index and the DAG. - Ok((root.data.unwrap(), builder.graph)) + Ok((node_index, builder.graph)) } /// Recursively extract referenced [`Column`]s within a [`PhysicalExpr`]. @@ -387,8 +335,8 @@ mod tests { } } - fn make_dummy_node(node: &ExprTreeNode) -> Result { - let expr = node.expression().clone(); + fn make_dummy_node(expr: &Arc) -> Result { + let expr = expr.clone(); let dummy_property = if expr.as_any().is::() { "Binary" } else if expr.as_any().is::() {