Skip to content

Commit

Permalink
refactor ExprTreeNode using TreeNode.transform_up_with_payload()
Browse files Browse the repository at this point in the history
  • Loading branch information
peter-toth committed Jan 10, 2024
1 parent 88a43c1 commit e852485
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 80 deletions.
8 changes: 4 additions & 4 deletions datafusion/physical-expr/src/intervals/cp_solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use super::utils::{
convert_duration_type_to_interval, convert_interval_type_to_duration, get_inverse_op,
};
use crate::expressions::Literal;
use crate::utils::{build_dag, ExprTreeNode};
use crate::utils::build_dag;
use crate::PhysicalExpr;

use arrow_schema::{DataType, Schema};
Expand Down Expand Up @@ -179,8 +179,8 @@ impl ExprIntervalGraphNode {
/// This function creates a DAEG node from Datafusion's [`ExprTreeNode`]
/// object. Literals are created with definite, singleton intervals while
/// any other expression starts with an indefinite interval ([-∞, ∞]).
pub fn make_node(node: &ExprTreeNode<NodeIndex>, schema: &Schema) -> Result<Self> {
let expr = node.expression().clone();
pub fn make_node(expr: &Arc<dyn PhysicalExpr>, schema: &Schema) -> Result<Self> {
let expr = expr.clone();
if let Some(literal) = expr.as_any().downcast_ref::<Literal>() {
let value = literal.value();
Interval::try_new(value.clone(), value.clone())
Expand Down Expand Up @@ -353,7 +353,7 @@ impl ExprIntervalGraph {
pub fn try_new(expr: Arc<dyn PhysicalExpr>, schema: &Schema) -> Result<Self> {
// Build the full graph:
let (root, graph) =
build_dag(expr, &|node| ExprIntervalGraphNode::make_node(node, schema))?;
build_dag(expr, &|expr| ExprIntervalGraphNode::make_node(expr, schema))?;
Ok(Self { graph, root })
}

Expand Down
100 changes: 24 additions & 76 deletions datafusion/physical-expr/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
mod guarantee;
pub use guarantee::{Guarantee, LiteralGuarantee};

use std::borrow::{Borrow, Cow};
use std::borrow::Borrow;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;

Expand All @@ -28,9 +28,7 @@ use crate::{PhysicalExpr, PhysicalSortExpr};
use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData};
use arrow::compute::{and_kleene, is_not_null, SlicesIterator};
use arrow::datatypes::SchemaRef;
use datafusion_common::tree_node::{
Transformed, TreeNode, TreeNodeRewriter, VisitRecursion,
};
use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion};
use datafusion_common::Result;
use datafusion_expr::Operator;

Expand Down Expand Up @@ -127,56 +125,11 @@ pub fn get_indices_of_exprs_strict<T: Borrow<Arc<dyn PhysicalExpr>>>(
.collect()
}

#[derive(Clone, Debug)]
pub struct ExprTreeNode<T> {
expr: Arc<dyn PhysicalExpr>,
data: Option<T>,
child_nodes: Vec<ExprTreeNode<T>>,
}

impl<T> ExprTreeNode<T> {
pub fn new(expr: Arc<dyn PhysicalExpr>) -> Self {
let children = expr.children();
ExprTreeNode {
expr,
data: None,
child_nodes: children.into_iter().map(Self::new).collect_vec(),
}
}

pub fn expression(&self) -> &Arc<dyn PhysicalExpr> {
&self.expr
}

pub fn children(&self) -> &[ExprTreeNode<T>] {
&self.child_nodes
}
}

impl<T: Clone> TreeNode for ExprTreeNode<T> {
fn children_nodes(&self) -> Vec<Cow<Self>> {
self.children().iter().map(Cow::Borrowed).collect()
}

fn map_children<F>(mut self, transform: F) -> Result<Self>
where
F: FnMut(Self) -> Result<Self>,
{
self.child_nodes = self
.child_nodes
.into_iter()
.map(transform)
.collect::<Result<Vec<_>>>()?;
Ok(self)
}
}

/// This struct facilitates the [TreeNodeRewriter] mechanism to convert a
/// [PhysicalExpr] tree into a DAEG (i.e. an expression DAG) by collecting
/// identical expressions in one node. Caller specifies the node type in the
/// DAEG via the `constructor` argument, which constructs nodes in the DAEG
/// from the [ExprTreeNode] ancillary object.
struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>> {
/// DAEG via the `constructor` argument.
struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&Arc<dyn PhysicalExpr>) -> Result<T>> {
// The resulting DAEG (expression DAG).
graph: StableGraph<T, usize>,
// A vector of visited expression nodes and their corresponding node indices.
Expand All @@ -185,19 +138,16 @@ struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode<NodeIndex>) -> Result<
constructor: &'a F,
}

impl<'a, T, F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>> TreeNodeRewriter
for PhysicalExprDAEGBuilder<'a, T, F>
impl<'a, T, F: Fn(&Arc<dyn PhysicalExpr>) -> Result<T>>
PhysicalExprDAEGBuilder<'a, T, F>
{
type N = ExprTreeNode<NodeIndex>;
// This method mutates an expression node by transforming it to a physical expression
// and adding it to the graph. The method returns the mutated expression node.
fn mutate(
// This method adds an expression to the graph and returns the corresponding node
// index.
fn calculate_node_index(
&mut self,
mut node: ExprTreeNode<NodeIndex>,
) -> Result<ExprTreeNode<NodeIndex>> {
// Get the expression associated with the input expression node.
let expr = &node.expr;

expr: Arc<dyn PhysicalExpr>,
children_node_indices: Vec<NodeIndex>,
) -> Result<(Transformed<Arc<dyn PhysicalExpr>>, NodeIndex)> {
// Check if the expression has already been visited.
let node_idx = match self.visited_plans.iter().find(|(e, _)| expr.eq(e)) {
// If the expression has been visited, return the corresponding node index.
Expand All @@ -206,18 +156,15 @@ impl<'a, T, F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>> TreeNodeRewriter
// add edges to its child nodes. Add the visited expression to the vector
// of visited expressions and return the newly created node index.
None => {
let node_idx = self.graph.add_node((self.constructor)(&node)?);
for expr_node in node.child_nodes.iter() {
self.graph.add_edge(node_idx, expr_node.data.unwrap(), 0);
let node_idx = self.graph.add_node((self.constructor)(&expr)?);
for child_node_index in children_node_indices.into_iter() {
self.graph.add_edge(node_idx, child_node_index, 0);
}
self.visited_plans.push((expr.clone(), node_idx));
node_idx
}
};
// Set the data field of the input expression node to the corresponding node index.
node.data = Some(node_idx);
// Return the mutated expression node.
Ok(node)
Ok((Transformed::No(expr), node_idx))
}
}

Expand All @@ -227,20 +174,21 @@ pub fn build_dag<T, F>(
constructor: &F,
) -> Result<(NodeIndex, StableGraph<T, usize>)>
where
F: Fn(&ExprTreeNode<NodeIndex>) -> Result<T>,
F: Fn(&Arc<dyn PhysicalExpr>) -> Result<T>,
{
// Create a new expression tree node from the input expression.
let init = ExprTreeNode::new(expr);
// Create a new `PhysicalExprDAEGBuilder` instance.
let mut builder = PhysicalExprDAEGBuilder {
graph: StableGraph::<T, usize>::new(),
visited_plans: Vec::<(Arc<dyn PhysicalExpr>, NodeIndex)>::new(),
constructor,
};
// Use the builder to transform the expression tree node into a DAG.
let root = init.rewrite(&mut builder)?;
let (_, node_index) =
expr.transform_up_with_payload(&mut |expr, children_node_indices| {
builder.calculate_node_index(expr, children_node_indices)
})?;
// Return a tuple containing the root node index and the DAG.
Ok((root.data.unwrap(), builder.graph))
Ok((node_index, builder.graph))
}

/// Recursively extract referenced [`Column`]s within a [`PhysicalExpr`].
Expand Down Expand Up @@ -387,8 +335,8 @@ mod tests {
}
}

fn make_dummy_node(node: &ExprTreeNode<NodeIndex>) -> Result<PhysicalExprDummyNode> {
let expr = node.expression().clone();
fn make_dummy_node(expr: &Arc<dyn PhysicalExpr>) -> Result<PhysicalExprDummyNode> {
let expr = expr.clone();
let dummy_property = if expr.as_any().is::<BinaryExpr>() {
"Binary"
} else if expr.as_any().is::<Column>() {
Expand Down

0 comments on commit e852485

Please sign in to comment.