Skip to content

Commit

Permalink
Implement PhysicalExpr CSE
Browse files Browse the repository at this point in the history
  • Loading branch information
peter-toth committed Nov 28, 2024
1 parent f6c92fe commit 2d27ff7
Show file tree
Hide file tree
Showing 26 changed files with 745 additions and 70 deletions.
2 changes: 1 addition & 1 deletion datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion datafusion/common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ arrow-schema = { workspace = true }
chrono = { workspace = true }
half = { workspace = true }
hashbrown = { workspace = true }
indexmap = { workspace = true }
libc = "0.2.140"
num_cpus = { workspace = true }
object_store = { workspace = true, optional = true }
Expand Down
4 changes: 4 additions & 0 deletions datafusion/common/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ use std::collections::{BTreeMap, HashMap};
use std::fmt::{self, Display};
use std::str::FromStr;

use crate::alias::AliasGenerator;
use crate::error::_config_err;
use crate::parsers::CompressionTypeVariant;
use crate::{DataFusionError, Result};
use std::sync::Arc;

/// A macro that wraps a configuration struct and automatically derives
/// [`Default`] and [`ConfigField`] for it, allowing it to be used
Expand Down Expand Up @@ -693,6 +695,8 @@ pub struct ConfigOptions {
pub explain: ExplainOptions,
/// Optional extensions registered using [`Extensions::insert`]
pub extensions: Extensions,
/// Return alias generator used to generate unique aliases
pub alias_generator: Arc<AliasGenerator>,
}

impl ConfigField for ConfigOptions {
Expand Down
53 changes: 31 additions & 22 deletions datafusion/common/src/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ use crate::tree_node::{
TreeNodeVisitor,
};
use crate::Result;
use indexmap::IndexMap;
use std::collections::HashMap;
use std::hash::{BuildHasher, Hash, Hasher, RandomState};
use std::marker::PhantomData;
Expand Down Expand Up @@ -131,11 +130,13 @@ enum NodeEvaluation {
}

/// A map that contains the evaluation stats of [`TreeNode`]s by their identifiers.
type NodeStats<'n, N> = HashMap<Identifier<'n, N>, NodeEvaluation>;
/// It also contains the position of [`TreeNode`]s in [`CommonNodes`] once a node is
/// found to be common and got extracted.
type NodeStats<'n, N> = HashMap<Identifier<'n, N>, (NodeEvaluation, Option<usize>)>;

/// A map that contains the common [`TreeNode`]s and their alias by their identifiers,
/// extracted during the second, rewriting traversal.
type CommonNodes<'n, N> = IndexMap<Identifier<'n, N>, (N, String)>;
/// A list that contains the common [`TreeNode`]s and their alias, extracted during the
/// second, rewriting traversal.
type CommonNodes<'n, N> = Vec<(N, String)>;

type ChildrenList<N> = (Vec<N>, Vec<N>);

Expand Down Expand Up @@ -163,7 +164,7 @@ pub trait CSEController {
fn generate_alias(&self) -> String;

// Replaces a node to the generated alias.
fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node;
fn rewrite(&mut self, node: &Self::Node, alias: &str, index: usize) -> Self::Node;

// A helper method called on each node during top-down traversal during the second,
// rewriting traversal of CSE.
Expand Down Expand Up @@ -341,7 +342,7 @@ impl<'n, N: TreeNode + HashNode + Eq, C: CSEController<Node = N>> TreeNodeVisito
self.id_array[down_index].1 = Some(node_id);
self.node_stats
.entry(node_id)
.and_modify(|evaluation| {
.and_modify(|(evaluation, _)| {
if *evaluation == NodeEvaluation::SurelyOnce
|| *evaluation == NodeEvaluation::ConditionallyAtLeastOnce
&& !self.conditional
Expand All @@ -351,11 +352,12 @@ impl<'n, N: TreeNode + HashNode + Eq, C: CSEController<Node = N>> TreeNodeVisito
}
})
.or_insert_with(|| {
if self.conditional {
let evaluation = if self.conditional {
NodeEvaluation::ConditionallyAtLeastOnce
} else {
NodeEvaluation::SurelyOnce
}
};
(evaluation, None)
});
}
self.visit_stack
Expand All @@ -371,7 +373,7 @@ impl<'n, N: TreeNode + HashNode + Eq, C: CSEController<Node = N>> TreeNodeVisito
/// replaced [`TreeNode`] tree.
struct CSERewriter<'a, 'n, N, C: CSEController<Node = N>> {
/// statistics of [`TreeNode`]s
node_stats: &'a NodeStats<'n, N>,
node_stats: &'a mut NodeStats<'n, N>,

/// cache to speed up second traversal
id_array: &'a IdArray<'n, N>,
Expand Down Expand Up @@ -399,7 +401,7 @@ impl<N: TreeNode + Eq, C: CSEController<Node = N>> TreeNodeRewriter

// Handle nodes with identifiers only
if let Some(node_id) = node_id {
let evaluation = self.node_stats.get(&node_id).unwrap();
let (evaluation, common_index) = self.node_stats.get_mut(&node_id).unwrap();
if *evaluation == NodeEvaluation::Common {
// step index to skip all sub-node (which has smaller series number).
while self.down_index < self.id_array.len()
Expand All @@ -408,13 +410,15 @@ impl<N: TreeNode + Eq, C: CSEController<Node = N>> TreeNodeRewriter
self.down_index += 1;
}

let (node, alias) =
self.common_nodes.entry(node_id).or_insert_with(|| {
let node_alias = self.controller.generate_alias();
(node, node_alias)
});
let index = *common_index.get_or_insert_with(|| {
let index = self.common_nodes.len();
let node_alias = self.controller.generate_alias();
self.common_nodes.push((node, node_alias));
index
});
let (node, alias) = self.common_nodes.get(index).unwrap();

let rewritten = self.controller.rewrite(node, alias);
let rewritten = self.controller.rewrite(node, alias, index);

return Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump));
}
Expand Down Expand Up @@ -507,7 +511,7 @@ impl<N: TreeNode + HashNode + Clone + Eq, C: CSEController<Node = N>> CSE<N, C>
&mut self,
node: N,
id_array: &IdArray<'n, N>,
node_stats: &NodeStats<'n, N>,
node_stats: &mut NodeStats<'n, N>,
common_nodes: &mut CommonNodes<'n, N>,
) -> Result<N> {
if id_array.is_empty() {
Expand All @@ -530,7 +534,7 @@ impl<N: TreeNode + HashNode + Clone + Eq, C: CSEController<Node = N>> CSE<N, C>
&mut self,
nodes_list: Vec<Vec<N>>,
arrays_list: &[Vec<IdArray<'n, N>>],
node_stats: &NodeStats<'n, N>,
node_stats: &mut NodeStats<'n, N>,
common_nodes: &mut CommonNodes<'n, N>,
) -> Result<Vec<Vec<N>>> {
nodes_list
Expand Down Expand Up @@ -575,13 +579,13 @@ impl<N: TreeNode + HashNode + Clone + Eq, C: CSEController<Node = N>> CSE<N, C>
// nodes so we have to keep them intact.
nodes_list.clone(),
&id_arrays_list,
&node_stats,
&mut node_stats,
&mut common_nodes,
)?;
assert!(!common_nodes.is_empty());

Ok(FoundCommonNodes::Yes {
common_nodes: common_nodes.into_values().collect(),
common_nodes,
new_nodes_list,
original_nodes_list: nodes_list,
})
Expand Down Expand Up @@ -651,7 +655,12 @@ mod test {
self.alias_generator.next(CSE_PREFIX)
}

fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
fn rewrite(
&mut self,
node: &Self::Node,
alias: &str,
_index: usize,
) -> Self::Node {
TestTreeNode::new_leaf(format!("alias({}, {})", node.data, alias))
}
}
Expand Down
66 changes: 44 additions & 22 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -399,17 +399,15 @@ impl CommonSubexprEliminate {
// Since `group_expr` may have changed, schema may also.
// Use `try_new()` method.
Aggregate::try_new(new_input, new_group_expr, new_aggr_expr)
.map(LogicalPlan::Aggregate)
.map(Transformed::no)
.map(|p| Transformed::no(LogicalPlan::Aggregate(p)))
} else {
Aggregate::try_new_with_schema(
new_input,
new_group_expr,
rewritten_aggr_expr,
schema,
)
.map(LogicalPlan::Aggregate)
.map(Transformed::no)
.map(|p| Transformed::no(LogicalPlan::Aggregate(p)))
}
}
}
Expand Down Expand Up @@ -628,9 +626,7 @@ impl CSEController for ExprCSEController<'_> {

fn conditional_children(node: &Expr) -> Option<(Vec<&Expr>, Vec<&Expr>)> {
match node {
// In case of `ScalarFunction`s we don't know which children are surely
// executed so start visiting all children conditionally and stop the
// recursion with `TreeNodeRecursion::Jump`.
// In case of `ScalarFunction`s all children can be conditionally executed.
Expr::ScalarFunction(ScalarFunction { func, args })
if func.short_circuits() =>
{
Expand Down Expand Up @@ -700,7 +696,7 @@ impl CSEController for ExprCSEController<'_> {
self.alias_generator.next(CSE_PREFIX)
}

fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node {
fn rewrite(&mut self, node: &Self::Node, alias: &str, _index: usize) -> Self::Node {
// alias the expressions without an `Alias` ancestor node
if self.alias_counter > 0 {
col(alias)
Expand Down Expand Up @@ -1030,10 +1026,14 @@ mod test {
fn subexpr_in_same_order() -> Result<()> {
let table_scan = test_table_scan()?;

let a = col("a");
let lit_1 = lit(1);
let _1_plus_a = lit_1 + a;

let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![
(lit(1) + col("a")).alias("first"),
(lit(1) + col("a")).alias("second"),
_1_plus_a.clone().alias("first"),
_1_plus_a.alias("second"),
])?
.build()?;

Expand All @@ -1050,8 +1050,13 @@ mod test {
fn subexpr_in_different_order() -> Result<()> {
let table_scan = test_table_scan()?;

let a = col("a");
let lit_1 = lit(1);
let _1_plus_a = lit_1.clone() + a.clone();
let a_plus_1 = a + lit_1;

let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![lit(1) + col("a"), col("a") + lit(1)])?
.project(vec![_1_plus_a, a_plus_1])?
.build()?;

let expected = "Projection: Int32(1) + test.a, test.a + Int32(1)\
Expand All @@ -1066,6 +1071,8 @@ mod test {
fn cross_plans_subexpr() -> Result<()> {
let table_scan = test_table_scan()?;

let _1_plus_col_a = lit(1) + col("a");

let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![lit(1) + col("a"), col("a")])?
.project(vec![lit(1) + col("a")])?
Expand Down Expand Up @@ -1318,9 +1325,12 @@ mod test {
fn test_volatile() -> Result<()> {
let table_scan = test_table_scan()?;

let extracted_child = col("a") + col("b");
let rand = rand_func().call(vec![]);
let a = col("a");
let b = col("b");
let extracted_child = a + b;
let rand = rand_expr();
let not_extracted_volatile = extracted_child + rand;

let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![
not_extracted_volatile.clone().alias("c1"),
Expand All @@ -1341,13 +1351,19 @@ mod test {
fn test_volatile_short_circuits() -> Result<()> {
let table_scan = test_table_scan()?;

let rand = rand_func().call(vec![]);
let extracted_short_circuit_leg_1 = col("a").eq(lit(0));
let a = col("a");
let b = col("b");
let rand = rand_expr();
let rand_eq_0 = rand.eq(lit(0));

let extracted_short_circuit_leg_1 = a.eq(lit(0));
let not_extracted_volatile_short_circuit_1 =
extracted_short_circuit_leg_1.or(rand.clone().eq(lit(0)));
let not_extracted_short_circuit_leg_2 = col("b").eq(lit(0));
extracted_short_circuit_leg_1.or(rand_eq_0.clone());

let not_extracted_short_circuit_leg_2 = b.eq(lit(0));
let not_extracted_volatile_short_circuit_2 =
rand.eq(lit(0)).or(not_extracted_short_circuit_leg_2);
rand_eq_0.or(not_extracted_short_circuit_leg_2);

let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![
not_extracted_volatile_short_circuit_1.clone().alias("c1"),
Expand All @@ -1370,7 +1386,10 @@ mod test {
fn test_non_top_level_common_expression() -> Result<()> {
let table_scan = test_table_scan()?;

let common_expr = col("a") + col("b");
let a = col("a");
let b = col("b");
let common_expr = a + b;

let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![
common_expr.clone().alias("c1"),
Expand All @@ -1393,8 +1412,11 @@ mod test {
fn test_nested_common_expression() -> Result<()> {
let table_scan = test_table_scan()?;

let nested_common_expr = col("a") + col("b");
let a = col("a");
let b = col("b");
let nested_common_expr = a + b;
let common_expr = nested_common_expr.clone() * nested_common_expr;

let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![
common_expr.clone().alias("c1"),
Expand All @@ -1417,8 +1439,8 @@ mod test {
///
/// Does not use datafusion_functions::rand to avoid introducing a
/// dependency on that crate.
fn rand_func() -> ScalarUDF {
ScalarUDF::new_from_impl(RandomStub::new())
fn rand_expr() -> Expr {
ScalarUDF::new_from_impl(RandomStub::new()).call(vec![])
}

#[derive(Debug)]
Expand Down
Loading

0 comments on commit 2d27ff7

Please sign in to comment.