Skip to content

Commit

Permalink
revert
Browse files Browse the repository at this point in the history
  • Loading branch information
peter-toth committed Jan 12, 2024
1 parent ce0dd88 commit 31abd63
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 40 deletions.
47 changes: 47 additions & 0 deletions datafusion/common/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,27 @@ pub trait TreeNode: Sized + Clone {
Ok(new_node)
}

fn rewrite_old<R: TreeNodeRewriterOld<N = Self>>(
self,
rewriter: &mut R,
) -> Result<Self> {
let need_mutate = match rewriter.pre_visit(&self)? {
RewriteRecursion::Mutate => return rewriter.mutate(self),
RewriteRecursion::Stop => return Ok(self),
RewriteRecursion::Continue => true,
RewriteRecursion::Skip => false,
};

let after_op_children = self.map_children(|node| node.rewrite_old(rewriter))?;

// now rewrite this node itself
if need_mutate {
rewriter.mutate(after_op_children)
} else {
Ok(after_op_children)
}
}

/// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for
/// recursively transforming [`TreeNode`]s.
///
Expand Down Expand Up @@ -280,6 +301,21 @@ pub trait TreeNodeVisitor: Sized {
}
}

pub trait TreeNodeRewriterOld: Sized {
/// The node type which is rewritable.
type N: TreeNode;

/// Invoked before (Preorder) any children of `node` are rewritten /
/// visited. Default implementation returns `Ok(Recursion::Continue)`
fn pre_visit(&mut self, _node: &Self::N) -> Result<RewriteRecursion> {
Ok(RewriteRecursion::Continue)
}

/// Invoked after (Postorder) all children of `node` have been mutated and
/// returns a potentially modified node.
fn mutate(&mut self, node: Self::N) -> Result<Self::N>;
}

/// Trait for potentially recursively transform a [`TreeNode`] node tree.
pub trait TreeNodeRewriter: Sized {
/// The node type which is rewritable.
Expand All @@ -300,6 +336,17 @@ pub trait TreeNodeRewriter: Sized {
}
}

pub enum RewriteRecursion {
/// Continue rewrite this node tree.
Continue,
/// Call 'op' immediately and return.
Mutate,
/// Do not rewrite the children of this node.
Stop,
/// Keep recursive but skip apply op on this node
Skip,
}

/// Controls how [`TreeNode`] recursions should proceed.
#[derive(Debug)]
pub enum TreeNodeRecursion {
Expand Down
83 changes: 43 additions & 40 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use crate::{utils, OptimizerConfig, OptimizerRule};

use arrow::datatypes::DataType;
use datafusion_common::tree_node::{
TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor,
RewriteRecursion, TreeNode, TreeNodeRecursion, TreeNodeRewriterOld, TreeNodeVisitor,
};
use datafusion_common::{
internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result,
Expand Down Expand Up @@ -688,66 +688,69 @@ struct CommonSubexprRewriter<'a> {
curr_index: usize,
}

impl TreeNodeRewriter for CommonSubexprRewriter<'_> {
type Node = Expr;
impl TreeNodeRewriterOld for CommonSubexprRewriter<'_> {
type N = Expr;

fn f_down(&mut self, expr: Expr) -> Result<(Expr, TreeNodeRecursion)> {
fn pre_visit(&mut self, _: &Expr) -> Result<RewriteRecursion> {
if self.curr_index >= self.id_array.len()
|| self.max_series_number > self.id_array[self.curr_index].0
{
return Ok((expr, TreeNodeRecursion::Skip));
return Ok(RewriteRecursion::Stop);
}

let curr_id = &self.id_array[self.curr_index].1;
// skip `Expr`s without identifier (empty identifier).
if curr_id.is_empty() {
self.curr_index += 1;
return Ok((expr, TreeNodeRecursion::Continue));
return Ok(RewriteRecursion::Skip);
}
match self.expr_set.get(curr_id) {
Some((_, counter, _)) => {
if *counter > 1 {
self.affected_id.insert(curr_id.clone());

// This expr tree is finished.
if self.curr_index >= self.id_array.len() {
return Ok((expr, TreeNodeRecursion::Skip));
}

let (series_number, id) = &self.id_array[self.curr_index];
self.curr_index += 1;
// Skip sub-node of a replaced tree, or without identifier, or is not repeated expr.
let expr_set_item = self.expr_set.get(id).ok_or_else(|| {
DataFusionError::Internal("expr_set invalid state".to_string())
})?;
if *series_number < self.max_series_number
|| id.is_empty()
|| expr_set_item.1 <= 1
{
return Ok((expr, TreeNodeRecursion::Skip));
}

self.max_series_number = *series_number;
// step index to skip all sub-node (which has smaller series number).
while self.curr_index < self.id_array.len()
&& *series_number > self.id_array[self.curr_index].0
{
self.curr_index += 1;
}

let expr_name = expr.display_name()?;
// Alias this `Column` expr to it original "expr name",
// `projection_push_down` optimizer use "expr name" to eliminate useless
// projections.
Ok((col(id).alias(expr_name), TreeNodeRecursion::Skip))
Ok(RewriteRecursion::Mutate)
} else {
self.curr_index += 1;
Ok((expr, TreeNodeRecursion::Continue))
Ok(RewriteRecursion::Skip)
}
}
_ => internal_err!("expr_set invalid state"),
}
}

fn mutate(&mut self, expr: Expr) -> Result<Expr> {
// This expr tree is finished.
if self.curr_index >= self.id_array.len() {
return Ok(expr);
}

let (series_number, id) = &self.id_array[self.curr_index];
self.curr_index += 1;
// Skip sub-node of a replaced tree, or without identifier, or is not repeated expr.
let expr_set_item = self.expr_set.get(id).ok_or_else(|| {
DataFusionError::Internal("expr_set invalid state".to_string())
})?;
if *series_number < self.max_series_number
|| id.is_empty()
|| expr_set_item.1 <= 1
{
return Ok(expr);
}

self.max_series_number = *series_number;
// step index to skip all sub-node (which has smaller series number).
while self.curr_index < self.id_array.len()
&& *series_number > self.id_array[self.curr_index].0
{
self.curr_index += 1;
}

let expr_name = expr.display_name()?;
// Alias this `Column` expr to it original "expr name",
// `projection_push_down` optimizer use "expr name" to eliminate useless
// projections.
Ok(col(id).alias(expr_name))
}
}

fn replace_common_expr(
Expand All @@ -756,7 +759,7 @@ fn replace_common_expr(
expr_set: &ExprSet,
affected_id: &mut BTreeSet<Identifier>,
) -> Result<Expr> {
expr.rewrite(&mut CommonSubexprRewriter {
expr.rewrite_old(&mut CommonSubexprRewriter {
expr_set,
id_array,
affected_id,
Expand Down

0 comments on commit 31abd63

Please sign in to comment.