diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index d2a92fea311ef..9a20af69add7f 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1294,7 +1294,6 @@ dependencies = [ "chrono", "half", "hashbrown 0.14.5", - "indexmap", "libc", "num_cpus", "object_store", @@ -1539,6 +1538,7 @@ dependencies = [ "arrow", "datafusion-common", "datafusion-execution", + "datafusion-expr", "datafusion-expr-common", "datafusion-physical-expr", "datafusion-physical-plan", diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index 9f2db95721f55..3eee7ce2c2105 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -56,7 +56,6 @@ arrow-schema = { workspace = true } chrono = { workspace = true } half = { workspace = true } hashbrown = { workspace = true } -indexmap = { workspace = true } libc = "0.2.140" num_cpus = { workspace = true } object_store = { workspace = true, optional = true } diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 1ad10d1648685..7a399146518e3 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -22,9 +22,11 @@ use std::collections::{BTreeMap, HashMap}; use std::fmt::{self, Display}; use std::str::FromStr; +use crate::alias::AliasGenerator; use crate::error::_config_err; use crate::parsers::CompressionTypeVariant; use crate::{DataFusionError, Result}; +use std::sync::Arc; /// A macro that wraps a configuration struct and automatically derives /// [`Default`] and [`ConfigField`] for it, allowing it to be used @@ -693,6 +695,8 @@ pub struct ConfigOptions { pub explain: ExplainOptions, /// Optional extensions registered using [`Extensions::insert`] pub extensions: Extensions, + /// Return alias generator used to generate unique aliases + pub alias_generator: Arc, } impl ConfigField for ConfigOptions { diff --git a/datafusion/common/src/cse.rs b/datafusion/common/src/cse.rs index ab02915858cd2..67b52af511fd0 100644 --- a/datafusion/common/src/cse.rs +++ b/datafusion/common/src/cse.rs @@ -25,7 +25,6 @@ use crate::tree_node::{ TreeNodeVisitor, }; use crate::Result; -use indexmap::IndexMap; use std::collections::HashMap; use std::hash::{BuildHasher, Hash, Hasher, RandomState}; use std::marker::PhantomData; @@ -131,11 +130,13 @@ enum NodeEvaluation { } /// A map that contains the evaluation stats of [`TreeNode`]s by their identifiers. -type NodeStats<'n, N> = HashMap, NodeEvaluation>; +/// It also contains the position of [`TreeNode`]s in [`CommonNodes`] once a node is +/// found to be common and got extracted. +type NodeStats<'n, N> = HashMap, (NodeEvaluation, Option)>; -/// A map that contains the common [`TreeNode`]s and their alias by their identifiers, -/// extracted during the second, rewriting traversal. -type CommonNodes<'n, N> = IndexMap, (N, String)>; +/// A list that contains the common [`TreeNode`]s and their alias, extracted during the +/// second, rewriting traversal. +type CommonNodes<'n, N> = Vec<(N, String)>; type ChildrenList = (Vec, Vec); @@ -163,7 +164,7 @@ pub trait CSEController { fn generate_alias(&self) -> String; // Replaces a node to the generated alias. - fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node; + fn rewrite(&mut self, node: &Self::Node, alias: &str, index: usize) -> Self::Node; // A helper method called on each node during top-down traversal during the second, // rewriting traversal of CSE. @@ -341,7 +342,7 @@ impl<'n, N: TreeNode + HashNode + Eq, C: CSEController> TreeNodeVisito self.id_array[down_index].1 = Some(node_id); self.node_stats .entry(node_id) - .and_modify(|evaluation| { + .and_modify(|(evaluation, _)| { if *evaluation == NodeEvaluation::SurelyOnce || *evaluation == NodeEvaluation::ConditionallyAtLeastOnce && !self.conditional @@ -351,11 +352,12 @@ impl<'n, N: TreeNode + HashNode + Eq, C: CSEController> TreeNodeVisito } }) .or_insert_with(|| { - if self.conditional { + let evaluation = if self.conditional { NodeEvaluation::ConditionallyAtLeastOnce } else { NodeEvaluation::SurelyOnce - } + }; + (evaluation, None) }); } self.visit_stack @@ -371,7 +373,7 @@ impl<'n, N: TreeNode + HashNode + Eq, C: CSEController> TreeNodeVisito /// replaced [`TreeNode`] tree. struct CSERewriter<'a, 'n, N, C: CSEController> { /// statistics of [`TreeNode`]s - node_stats: &'a NodeStats<'n, N>, + node_stats: &'a mut NodeStats<'n, N>, /// cache to speed up second traversal id_array: &'a IdArray<'n, N>, @@ -399,7 +401,7 @@ impl> TreeNodeRewriter // Handle nodes with identifiers only if let Some(node_id) = node_id { - let evaluation = self.node_stats.get(&node_id).unwrap(); + let (evaluation, common_index) = self.node_stats.get_mut(&node_id).unwrap(); if *evaluation == NodeEvaluation::Common { // step index to skip all sub-node (which has smaller series number). while self.down_index < self.id_array.len() @@ -408,13 +410,15 @@ impl> TreeNodeRewriter self.down_index += 1; } - let (node, alias) = - self.common_nodes.entry(node_id).or_insert_with(|| { - let node_alias = self.controller.generate_alias(); - (node, node_alias) - }); + let index = *common_index.get_or_insert_with(|| { + let index = self.common_nodes.len(); + let node_alias = self.controller.generate_alias(); + self.common_nodes.push((node, node_alias)); + index + }); + let (node, alias) = self.common_nodes.get(index).unwrap(); - let rewritten = self.controller.rewrite(node, alias); + let rewritten = self.controller.rewrite(node, alias, index); return Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump)); } @@ -507,7 +511,7 @@ impl> CSE &mut self, node: N, id_array: &IdArray<'n, N>, - node_stats: &NodeStats<'n, N>, + node_stats: &mut NodeStats<'n, N>, common_nodes: &mut CommonNodes<'n, N>, ) -> Result { if id_array.is_empty() { @@ -530,7 +534,7 @@ impl> CSE &mut self, nodes_list: Vec>, arrays_list: &[Vec>], - node_stats: &NodeStats<'n, N>, + node_stats: &mut NodeStats<'n, N>, common_nodes: &mut CommonNodes<'n, N>, ) -> Result>> { nodes_list @@ -575,13 +579,13 @@ impl> CSE // nodes so we have to keep them intact. nodes_list.clone(), &id_arrays_list, - &node_stats, + &mut node_stats, &mut common_nodes, )?; assert!(!common_nodes.is_empty()); Ok(FoundCommonNodes::Yes { - common_nodes: common_nodes.into_values().collect(), + common_nodes, new_nodes_list, original_nodes_list: nodes_list, }) @@ -651,7 +655,12 @@ mod test { self.alias_generator.next(CSE_PREFIX) } - fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node { + fn rewrite( + &mut self, + node: &Self::Node, + alias: &str, + _index: usize, + ) -> Self::Node { TestTreeNode::new_leaf(format!("alias({}, {})", node.data, alias)) } } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 0ea2d24effbb4..e7c2ce56d391b 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -399,8 +399,7 @@ impl CommonSubexprEliminate { // Since `group_expr` may have changed, schema may also. // Use `try_new()` method. Aggregate::try_new(new_input, new_group_expr, new_aggr_expr) - .map(LogicalPlan::Aggregate) - .map(Transformed::no) + .map(|p| Transformed::no(LogicalPlan::Aggregate(p))) } else { Aggregate::try_new_with_schema( new_input, @@ -408,8 +407,7 @@ impl CommonSubexprEliminate { rewritten_aggr_expr, schema, ) - .map(LogicalPlan::Aggregate) - .map(Transformed::no) + .map(|p| Transformed::no(LogicalPlan::Aggregate(p))) } } } @@ -628,9 +626,7 @@ impl CSEController for ExprCSEController<'_> { fn conditional_children(node: &Expr) -> Option<(Vec<&Expr>, Vec<&Expr>)> { match node { - // In case of `ScalarFunction`s we don't know which children are surely - // executed so start visiting all children conditionally and stop the - // recursion with `TreeNodeRecursion::Jump`. + // In case of `ScalarFunction`s all children can be conditionally executed. Expr::ScalarFunction(ScalarFunction { func, args }) if func.short_circuits() => { @@ -700,7 +696,7 @@ impl CSEController for ExprCSEController<'_> { self.alias_generator.next(CSE_PREFIX) } - fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node { + fn rewrite(&mut self, node: &Self::Node, alias: &str, _index: usize) -> Self::Node { // alias the expressions without an `Alias` ancestor node if self.alias_counter > 0 { col(alias) @@ -1030,10 +1026,14 @@ mod test { fn subexpr_in_same_order() -> Result<()> { let table_scan = test_table_scan()?; + let a = col("a"); + let lit_1 = lit(1); + let _1_plus_a = lit_1 + a; + let plan = LogicalPlanBuilder::from(table_scan) .project(vec![ - (lit(1) + col("a")).alias("first"), - (lit(1) + col("a")).alias("second"), + _1_plus_a.clone().alias("first"), + _1_plus_a.alias("second"), ])? .build()?; @@ -1050,8 +1050,13 @@ mod test { fn subexpr_in_different_order() -> Result<()> { let table_scan = test_table_scan()?; + let a = col("a"); + let lit_1 = lit(1); + let _1_plus_a = lit_1.clone() + a.clone(); + let a_plus_1 = a + lit_1; + let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![lit(1) + col("a"), col("a") + lit(1)])? + .project(vec![_1_plus_a, a_plus_1])? .build()?; let expected = "Projection: Int32(1) + test.a, test.a + Int32(1)\ @@ -1066,6 +1071,8 @@ mod test { fn cross_plans_subexpr() -> Result<()> { let table_scan = test_table_scan()?; + let _1_plus_col_a = lit(1) + col("a"); + let plan = LogicalPlanBuilder::from(table_scan) .project(vec![lit(1) + col("a"), col("a")])? .project(vec![lit(1) + col("a")])? @@ -1318,9 +1325,12 @@ mod test { fn test_volatile() -> Result<()> { let table_scan = test_table_scan()?; - let extracted_child = col("a") + col("b"); - let rand = rand_func().call(vec![]); + let a = col("a"); + let b = col("b"); + let extracted_child = a + b; + let rand = rand_expr(); let not_extracted_volatile = extracted_child + rand; + let plan = LogicalPlanBuilder::from(table_scan) .project(vec![ not_extracted_volatile.clone().alias("c1"), @@ -1341,13 +1351,19 @@ mod test { fn test_volatile_short_circuits() -> Result<()> { let table_scan = test_table_scan()?; - let rand = rand_func().call(vec![]); - let extracted_short_circuit_leg_1 = col("a").eq(lit(0)); + let a = col("a"); + let b = col("b"); + let rand = rand_expr(); + let rand_eq_0 = rand.eq(lit(0)); + + let extracted_short_circuit_leg_1 = a.eq(lit(0)); let not_extracted_volatile_short_circuit_1 = - extracted_short_circuit_leg_1.or(rand.clone().eq(lit(0))); - let not_extracted_short_circuit_leg_2 = col("b").eq(lit(0)); + extracted_short_circuit_leg_1.or(rand_eq_0.clone()); + + let not_extracted_short_circuit_leg_2 = b.eq(lit(0)); let not_extracted_volatile_short_circuit_2 = - rand.eq(lit(0)).or(not_extracted_short_circuit_leg_2); + rand_eq_0.or(not_extracted_short_circuit_leg_2); + let plan = LogicalPlanBuilder::from(table_scan) .project(vec![ not_extracted_volatile_short_circuit_1.clone().alias("c1"), @@ -1370,7 +1386,10 @@ mod test { fn test_non_top_level_common_expression() -> Result<()> { let table_scan = test_table_scan()?; - let common_expr = col("a") + col("b"); + let a = col("a"); + let b = col("b"); + let common_expr = a + b; + let plan = LogicalPlanBuilder::from(table_scan) .project(vec![ common_expr.clone().alias("c1"), @@ -1393,8 +1412,11 @@ mod test { fn test_nested_common_expression() -> Result<()> { let table_scan = test_table_scan()?; - let nested_common_expr = col("a") + col("b"); + let a = col("a"); + let b = col("b"); + let nested_common_expr = a + b; let common_expr = nested_common_expr.clone() * nested_common_expr; + let plan = LogicalPlanBuilder::from(table_scan) .project(vec![ common_expr.clone().alias("c1"), @@ -1417,8 +1439,8 @@ mod test { /// /// Does not use datafusion_functions::rand to avoid introducing a /// dependency on that crate. - fn rand_func() -> ScalarUDF { - ScalarUDF::new_from_impl(RandomStub::new()) + fn rand_expr() -> Expr { + ScalarUDF::new_from_impl(RandomStub::new()).call(vec![]) } #[derive(Debug)] diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index aa816cfa4469e..2b6b30e29c37e 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -26,6 +26,7 @@ use arrow::array::BooleanArray; use arrow::compute::filter_record_batch; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; +use datafusion_common::cse::HashNode; use datafusion_common::{internal_err, not_impl_err, Result}; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; @@ -52,7 +53,9 @@ use datafusion_expr_common::sort_properties::ExprProperties; /// [`Expr`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.Expr.html /// [`create_physical_expr`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/fn.create_physical_expr.html /// [`Column`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/expressions/struct.Column.html -pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash { +pub trait PhysicalExpr: + Send + Sync + Display + Debug + DynEq + DynHash + DynHashNode +{ /// Returns the physical expression as [`Any`] so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; @@ -149,6 +152,10 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash { fn get_properties(&self, _children: &[ExprProperties]) -> Result { Ok(ExprProperties::new_unknown()) } + + fn is_volatile(&self) -> bool { + false + } } /// [`PhysicalExpr`] can't be constrained by [`Eq`] directly because it must remain object @@ -193,6 +200,23 @@ impl Hash for dyn PhysicalExpr { } } +pub trait DynHashNode { + fn dyn_hash_node(&self, state: &mut dyn Hasher); +} + +impl DynHashNode for T { + fn dyn_hash_node(&self, mut state: &mut dyn Hasher) { + self.type_id().hash(&mut state); + self.hash_node(&mut state) + } +} + +impl HashNode for dyn PhysicalExpr { + fn hash_node(&self, state: &mut H) { + self.dyn_hash_node(state); + } +} + /// Returns a copy of this expr if we change any child according to the pointer comparison. /// The size of `children` must be equal to the size of `PhysicalExpr::children()`. pub fn with_new_children_if_necessary( diff --git a/datafusion/physical-expr-common/src/sort_expr.rs b/datafusion/physical-expr-common/src/sort_expr.rs index 9ae12fa9f608f..6981919146e0f 100644 --- a/datafusion/physical-expr-common/src/sort_expr.rs +++ b/datafusion/physical-expr-common/src/sort_expr.rs @@ -42,6 +42,7 @@ use itertools::Itertools; /// # use std::sync::Arc; /// # use arrow::array::RecordBatch; /// # use datafusion_common::Result; +/// # use datafusion_common::cse::HashNode; /// # use arrow::compute::SortOptions; /// # use arrow::datatypes::{DataType, Schema}; /// # use datafusion_expr_common::columnar_value::ColumnarValue; @@ -62,6 +63,9 @@ use itertools::Itertools; /// # impl Display for MyPhysicalExpr { /// # fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "a") } /// # } +/// # impl HashNode for MyPhysicalExpr { +/// # fn hash_node(&self, _state: &mut H) {} +/// # } /// # fn col(name: &str) -> Arc { Arc::new(MyPhysicalExpr) } /// // Sort by a ASC /// let options = SortOptions::default(); diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index ae2bfe5b0bd41..1e3ae7a3a36b9 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -17,7 +17,7 @@ mod kernels; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; @@ -32,6 +32,7 @@ use arrow::compute::{cast, ilike, like, nilike, nlike}; use arrow::datatypes::*; use arrow_schema::ArrowError; use datafusion_common::cast::as_boolean_array; +use datafusion_common::cse::HashNode; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::{apply_operator, Interval}; use datafusion_expr::sort_properties::ExprProperties; @@ -66,7 +67,7 @@ impl PartialEq for BinaryExpr { } } impl Hash for BinaryExpr { - fn hash(&self, state: &mut H) { + fn hash(&self, state: &mut H) { self.left.hash(state); self.op.hash(state); self.right.hash(state); @@ -678,6 +679,13 @@ impl BinaryExpr { } } +impl HashNode for BinaryExpr { + fn hash_node(&self, state: &mut H) { + self.op.hash(state); + self.fail_on_overflow.hash(state); + } +} + fn concat_elements(left: Arc, right: Arc) -> Result { Ok(match left.data_type() { DataType::Utf8 => Arc::new(concat_elements_utf8( diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 0e307153341bf..1cf5c22e1579c 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -16,7 +16,7 @@ // under the License. use std::borrow::Cow; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; use crate::expressions::try_cast; @@ -31,6 +31,7 @@ use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarV use datafusion_expr::ColumnarValue; use super::{Column, Literal}; +use datafusion_common::cse::HashNode; use datafusion_physical_expr_common::datum::compare_with_eq; use itertools::Itertools; @@ -507,6 +508,12 @@ impl PhysicalExpr for CaseExpr { } } +impl HashNode for CaseExpr { + fn hash_node(&self, state: &mut H) { + self.eval_method.hash(state); + } +} + /// Create a CASE expression pub fn case( expr: Option>, diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 7eda5fb4beaa8..4b8125150674d 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -17,7 +17,7 @@ use std::any::Any; use std::fmt; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::physical_expr::PhysicalExpr; @@ -25,6 +25,7 @@ use crate::physical_expr::PhysicalExpr; use arrow::compute::{can_cast_types, CastOptions}; use arrow::datatypes::{DataType, DataType::*, Schema}; use arrow::record_batch::RecordBatch; +use datafusion_common::cse::HashNode; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; use datafusion_common::{not_impl_err, Result}; use datafusion_expr_common::columnar_value::ColumnarValue; @@ -62,7 +63,7 @@ impl PartialEq for CastExpr { } impl Hash for CastExpr { - fn hash(&self, state: &mut H) { + fn hash(&self, state: &mut H) { self.expr.hash(state); self.cast_type.hash(state); self.cast_options.hash(state); @@ -196,6 +197,13 @@ impl PhysicalExpr for CastExpr { } } +impl HashNode for CastExpr { + fn hash_node(&self, state: &mut H) { + self.cast_type.hash(state); + self.cast_options.hash(state); + } +} + /// Return a PhysicalExpression representing `expr` casted to /// `cast_type`, if any casting is needed. /// diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index 5f6932f6d7258..b0046ff9afdf4 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -18,7 +18,7 @@ //! Physical column reference: [`Column`] use std::any::Any; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::physical_expr::PhysicalExpr; @@ -27,6 +27,7 @@ use arrow::{ record_batch::RecordBatch, }; use arrow_schema::SchemaRef; +use datafusion_common::cse::HashNode; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{internal_err, plan_err, Result}; use datafusion_expr::ColumnarValue; @@ -140,6 +141,13 @@ impl PhysicalExpr for Column { } } +impl HashNode for Column { + fn hash_node(&self, state: &mut H) { + self.name.hash(state); + self.index.hash(state); + } +} + impl Column { fn bounds_check(&self, input_schema: &Schema) -> Result<()> { if self.index < input_schema.fields.len() { diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 663045fcad3fb..187e5e3b47e33 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -44,6 +44,7 @@ use datafusion_expr::ColumnarValue; use datafusion_physical_expr_common::datum::compare_with_eq; use ahash::RandomState; +use datafusion_common::cse::HashNode; use datafusion_common::HashMap; use hashbrown::hash_map::RawEntryMut; @@ -419,6 +420,13 @@ impl Hash for InListExpr { } } +impl HashNode for InListExpr { + fn hash_node(&self, state: &mut H) { + self.negated.hash(state); + // Add `self.static_filter` when hash is available + } +} + /// Creates a unary expression InList pub fn in_list( expr: Arc, diff --git a/datafusion/physical-expr/src/expressions/is_not_null.rs b/datafusion/physical-expr/src/expressions/is_not_null.rs index 4930865f4c989..66172dfd894d4 100644 --- a/datafusion/physical-expr/src/expressions/is_not_null.rs +++ b/datafusion/physical-expr/src/expressions/is_not_null.rs @@ -17,7 +17,7 @@ //! IS NOT NULL expression -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; use crate::PhysicalExpr; @@ -25,6 +25,7 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; +use datafusion_common::cse::HashNode; use datafusion_common::Result; use datafusion_common::ScalarValue; use datafusion_expr::ColumnarValue; @@ -44,7 +45,7 @@ impl PartialEq for IsNotNullExpr { } impl Hash for IsNotNullExpr { - fn hash(&self, state: &mut H) { + fn hash(&self, state: &mut H) { self.arg.hash(state); } } @@ -106,6 +107,10 @@ impl PhysicalExpr for IsNotNullExpr { } } +impl HashNode for IsNotNullExpr { + fn hash_node(&self, _state: &mut H) {} +} + /// Create an IS NOT NULL expression pub fn is_not_null(arg: Arc) -> Result> { Ok(Arc::new(IsNotNullExpr::new(arg))) diff --git a/datafusion/physical-expr/src/expressions/is_null.rs b/datafusion/physical-expr/src/expressions/is_null.rs index 6a02d5ecc1f22..79a825fc8a7a2 100644 --- a/datafusion/physical-expr/src/expressions/is_null.rs +++ b/datafusion/physical-expr/src/expressions/is_null.rs @@ -17,7 +17,7 @@ //! IS NULL expression -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; use crate::PhysicalExpr; @@ -25,6 +25,7 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; +use datafusion_common::cse::HashNode; use datafusion_common::Result; use datafusion_common::ScalarValue; use datafusion_expr::ColumnarValue; @@ -44,7 +45,7 @@ impl PartialEq for IsNullExpr { } impl Hash for IsNullExpr { - fn hash(&self, state: &mut H) { + fn hash(&self, state: &mut H) { self.arg.hash(state); } } @@ -105,6 +106,10 @@ impl PhysicalExpr for IsNullExpr { } } +impl HashNode for IsNullExpr { + fn hash_node(&self, _state: &mut H) {} +} + /// Create an IS NULL expression pub fn is_null(arg: Arc) -> Result> { Ok(Arc::new(IsNullExpr::new(arg))) diff --git a/datafusion/physical-expr/src/expressions/like.rs b/datafusion/physical-expr/src/expressions/like.rs index d61cd63c35b1e..e9e3e2177667a 100644 --- a/datafusion/physical-expr/src/expressions/like.rs +++ b/datafusion/physical-expr/src/expressions/like.rs @@ -15,12 +15,13 @@ // specific language governing permissions and limitations // under the License. -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; use crate::PhysicalExpr; use arrow::record_batch::RecordBatch; use arrow_schema::{DataType, Schema}; +use datafusion_common::cse::HashNode; use datafusion_common::{internal_err, Result}; use datafusion_expr::ColumnarValue; use datafusion_physical_expr_common::datum::apply_cmp; @@ -45,7 +46,7 @@ impl PartialEq for LikeExpr { } impl Hash for LikeExpr { - fn hash(&self, state: &mut H) { + fn hash(&self, state: &mut H) { self.negated.hash(state); self.case_insensitive.hash(state); self.expr.hash(state); @@ -147,6 +148,13 @@ impl PhysicalExpr for LikeExpr { } } +impl HashNode for LikeExpr { + fn hash_node(&self, state: &mut H) { + self.negated.hash(state); + self.case_insensitive.hash(state); + } +} + /// used for optimize Dictionary like fn can_like_type(from_type: &DataType) -> bool { match from_type { diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index f0d02eb605b26..c1d42cafb2dd1 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -18,7 +18,7 @@ //! Literal expressions for physical operations use std::any::Any; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::physical_expr::PhysicalExpr; @@ -27,6 +27,7 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; +use datafusion_common::cse::HashNode; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Expr; use datafusion_expr_common::columnar_value::ColumnarValue; @@ -94,6 +95,12 @@ impl PhysicalExpr for Literal { } } +impl HashNode for Literal { + fn hash_node(&self, state: &mut H) { + self.value.hash(state); + } +} + /// Create a literal expression pub fn lit(value: T) -> Arc { match value.lit() { diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index 6235845fc0285..7f2293e80ef38 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -18,7 +18,7 @@ //! Negation (-) expression use std::any::Any; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::PhysicalExpr; @@ -28,6 +28,7 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; +use datafusion_common::cse::HashNode; use datafusion_common::{plan_err, Result}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; @@ -51,7 +52,7 @@ impl PartialEq for NegativeExpr { } impl Hash for NegativeExpr { - fn hash(&self, state: &mut H) { + fn hash(&self, state: &mut H) { self.arg.hash(state); } } @@ -149,6 +150,10 @@ impl PhysicalExpr for NegativeExpr { } } +impl HashNode for NegativeExpr { + fn hash_node(&self, _state: &mut H) {} +} + /// Creates a unary expression NEGATIVE /// /// # Errors diff --git a/datafusion/physical-expr/src/expressions/no_op.rs b/datafusion/physical-expr/src/expressions/no_op.rs index c17b52f5cdfff..042bcb5b625d1 100644 --- a/datafusion/physical-expr/src/expressions/no_op.rs +++ b/datafusion/physical-expr/src/expressions/no_op.rs @@ -18,7 +18,7 @@ //! NoOp placeholder for physical operations use std::any::Any; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use arrow::{ @@ -27,6 +27,7 @@ use arrow::{ }; use crate::PhysicalExpr; +use datafusion_common::cse::HashNode; use datafusion_common::{internal_err, Result}; use datafusion_expr::ColumnarValue; @@ -78,3 +79,7 @@ impl PhysicalExpr for NoOp { Ok(self) } } + +impl HashNode for NoOp { + fn hash_node(&self, _state: &mut H) {} +} diff --git a/datafusion/physical-expr/src/expressions/not.rs b/datafusion/physical-expr/src/expressions/not.rs index cc35c91c98bc3..aa82890a3a43a 100644 --- a/datafusion/physical-expr/src/expressions/not.rs +++ b/datafusion/physical-expr/src/expressions/not.rs @@ -19,12 +19,13 @@ use std::any::Any; use std::fmt; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; +use datafusion_common::cse::HashNode; use datafusion_common::{cast::as_boolean_array, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::ColumnarValue; @@ -44,7 +45,7 @@ impl PartialEq for NotExpr { } impl Hash for NotExpr { - fn hash(&self, state: &mut H) { + fn hash(&self, state: &mut H) { self.arg.hash(state); } } @@ -117,6 +118,10 @@ impl PhysicalExpr for NotExpr { } } +impl HashNode for NotExpr { + fn hash_node(&self, _state: &mut H) {} +} + /// Creates a unary expression NOT pub fn not(arg: Arc) -> Result> { Ok(Arc::new(NotExpr::new(arg))) diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index 06f4e929992e5..adc6ff938805c 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -17,7 +17,7 @@ use std::any::Any; use std::fmt; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::PhysicalExpr; @@ -26,6 +26,7 @@ use arrow::compute::{cast_with_options, CastOptions}; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use compute::can_cast_types; +use datafusion_common::cse::HashNode; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; use datafusion_common::{not_impl_err, Result, ScalarValue}; use datafusion_expr::ColumnarValue; @@ -47,7 +48,7 @@ impl PartialEq for TryCastExpr { } impl Hash for TryCastExpr { - fn hash(&self, state: &mut H) { + fn hash(&self, state: &mut H) { self.expr.hash(state); self.cast_type.hash(state); } @@ -125,6 +126,12 @@ impl PhysicalExpr for TryCastExpr { } } +impl HashNode for TryCastExpr { + fn hash_node(&self, state: &mut H) { + self.cast_type().hash(state); + } +} + /// Return a PhysicalExpression representing `expr` casted to /// `cast_type`, if any casting is needed. /// diff --git a/datafusion/physical-expr/src/expressions/unknown_column.rs b/datafusion/physical-expr/src/expressions/unknown_column.rs index a63caf7e13056..eacdf2e1f70b5 100644 --- a/datafusion/physical-expr/src/expressions/unknown_column.rs +++ b/datafusion/physical-expr/src/expressions/unknown_column.rs @@ -27,6 +27,7 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; +use datafusion_common::cse::HashNode; use datafusion_common::{internal_err, Result}; use datafusion_expr::ColumnarValue; @@ -94,6 +95,12 @@ impl Hash for UnKnownColumn { } } +impl HashNode for UnKnownColumn { + fn hash_node(&self, state: &mut H) { + self.name.hash(state); + } +} + impl PartialEq for UnKnownColumn { fn eq(&self, _other: &Self) -> bool { // UnknownColumn is not a valid expression, so it should not be equal to any other expression. diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 138774d806f28..b9e649bad9dc8 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -31,7 +31,7 @@ use std::any::Any; use std::fmt::{self, Debug, Formatter}; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::PhysicalExpr; @@ -39,11 +39,13 @@ use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use arrow_array::Array; +use datafusion_common::cse::HashNode; use datafusion_common::{internal_err, DFSchema, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; use datafusion_expr::{expr_vec_fmt, ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDF}; +use datafusion_expr_common::signature::Volatility; /// Physical expression of a scalar function #[derive(Eq, PartialEq, Hash)] @@ -213,6 +215,19 @@ impl PhysicalExpr for ScalarFunctionExpr { range, }) } + + fn is_volatile(&self) -> bool { + self.fun.signature().volatility == Volatility::Volatile + } +} + +impl HashNode for ScalarFunctionExpr { + fn hash_node(&self, state: &mut H) { + self.name.hash(state); + self.return_type.hash(state); + self.nullable.hash(state); + self.fun.hash(state); + } } /// Create a physical expression for the UDF. diff --git a/datafusion/physical-optimizer/Cargo.toml b/datafusion/physical-optimizer/Cargo.toml index 838617ae9889f..cc18852b16bdd 100644 --- a/datafusion/physical-optimizer/Cargo.toml +++ b/datafusion/physical-optimizer/Cargo.toml @@ -35,6 +35,7 @@ workspace = true arrow = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-execution = { workspace = true } +datafusion-expr = { workspace = true } datafusion-expr-common = { workspace = true, default-features = true } datafusion-physical-expr = { workspace = true } datafusion-physical-plan = { workspace = true } diff --git a/datafusion/physical-optimizer/src/eliminate_common_physical_subexprs.rs b/datafusion/physical-optimizer/src/eliminate_common_physical_subexprs.rs new file mode 100644 index 0000000000000..3360532edd76d --- /dev/null +++ b/datafusion/physical-optimizer/src/eliminate_common_physical_subexprs.rs @@ -0,0 +1,497 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`EliminateCommonPhysicalSubexprs`] to avoid redundant computation of common physical +//! sub-expressions. + +use datafusion_common::alias::AliasGenerator; +use datafusion_common::config::ConfigOptions; +use datafusion_common::cse::{CSEController, FoundCommonNodes, CSE}; +use datafusion_common::Result; +use datafusion_physical_plan::ExecutionPlan; +use std::sync::Arc; + +use crate::PhysicalOptimizerRule; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_expr_common::operator::Operator; +use datafusion_physical_expr::expressions::{BinaryExpr, CaseExpr, Column}; +use datafusion_physical_expr::{PhysicalExpr, ScalarFunctionExpr}; +use datafusion_physical_plan::projection::ProjectionExec; + +const CSE_PREFIX: &str = "__common_physical_expr"; + +// Optimizer rule to avoid redundant computation of common physical subexpressions +#[derive(Default, Debug)] +pub struct EliminateCommonPhysicalSubexprs {} + +impl EliminateCommonPhysicalSubexprs { + pub fn new() -> Self { + Self {} + } +} + +impl PhysicalOptimizerRule for EliminateCommonPhysicalSubexprs { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result> { + plan.transform_down(|plan| { + let plan_any = plan.as_any(); + if let Some(p) = plan_any.downcast_ref::() { + match CSE::new(PhysicalExprCSEController::new( + config.alias_generator.as_ref(), + p.input().schema().fields().len(), + )) + .extract_common_nodes(vec![p + .expr() + .iter() + .map(|(e, _)| e) + .cloned() + .collect()])? + { + FoundCommonNodes::Yes { + common_nodes: common_exprs, + new_nodes_list: mut new_exprs_list, + original_nodes_list: _, + } => { + let common_exprs = p + .input() + .schema() + .fields() + .iter() + .enumerate() + .map(|(i, field)| { + ( + Arc::new(Column::new(field.name(), i)) + as Arc, + field.name().to_string(), + ) + }) + .chain(common_exprs) + .collect(); + let common = Arc::new(ProjectionExec::try_new( + common_exprs, + Arc::clone(p.input()), + )?); + + let new_exprs = new_exprs_list + .pop() + .unwrap() + .into_iter() + .zip(p.expr().iter().map(|(_, alias)| alias.to_string())) + .collect(); + let new_project = + Arc::new(ProjectionExec::try_new(new_exprs, common)?) + as Arc; + + Ok(Transformed::yes(new_project)) + } + FoundCommonNodes::No { .. } => Ok(Transformed::no(plan)), + } + } else { + Ok(Transformed::no(plan)) + } + }) + .data() + } + + fn name(&self) -> &str { + "eliminate_common_physical_subexpressions" + } + + /// This rule will change the nullable properties of the schema, disable the schema check. + fn schema_check(&self) -> bool { + false + } +} + +pub struct PhysicalExprCSEController<'a> { + alias_generator: &'a AliasGenerator, + base_index: usize, +} + +impl<'a> PhysicalExprCSEController<'a> { + fn new(alias_generator: &'a AliasGenerator, base_index: usize) -> Self { + Self { + alias_generator, + base_index, + } + } +} + +impl CSEController for PhysicalExprCSEController<'_> { + type Node = Arc; + + fn conditional_children( + node: &Self::Node, + ) -> Option<(Vec<&Self::Node>, Vec<&Self::Node>)> { + if let Some(s) = node.as_any().downcast_ref::() { + // In case of `ScalarFunction`s all children can be conditionally executed. + if s.fun().short_circuits() { + Some((vec![], s.args().iter().collect())) + } else { + None + } + } else if let Some(b) = node.as_any().downcast_ref::() { + // In case of `And` and `Or` the first child is surely executed, but we + // account subexpressions as conditional in the second. + if *b.op() == Operator::And || *b.op() == Operator::Or { + Some((vec![b.left()], vec![b.right()])) + } else { + None + } + } else { + node.as_any().downcast_ref::().map(|c| { + ( + // In case of `Case` the optional base expression and the first when + // expressions are surely executed, but we account subexpressions as + // conditional in the others. + c.expr() + .into_iter() + .chain(c.when_then_expr().iter().take(1).map(|(when, _)| when)) + .collect(), + c.when_then_expr() + .iter() + .take(1) + .map(|(_, then)| then) + .chain( + c.when_then_expr() + .iter() + .skip(1) + .flat_map(|(when, then)| [when, then]), + ) + .chain(c.else_expr()) + .collect(), + ) + }) + } + } + + fn is_valid(node: &Self::Node) -> bool { + !node.is_volatile() + } + + fn is_ignored(&self, node: &Self::Node) -> bool { + node.children().is_empty() + } + + fn generate_alias(&self) -> String { + self.alias_generator.next(CSE_PREFIX) + } + + fn rewrite(&mut self, _node: &Self::Node, alias: &str, index: usize) -> Self::Node { + Arc::new(Column::new(alias, self.base_index + index)) + } + + fn rewrite_f_down(&mut self, _node: &Self::Node) {} + + fn rewrite_f_up(&mut self, _node: &Self::Node) {} +} + +#[cfg(test)] +mod tests { + use crate::eliminate_common_physical_subexprs::EliminateCommonPhysicalSubexprs; + use crate::optimizer::PhysicalOptimizerRule; + use arrow_schema::{DataType, Field, Schema}; + use datafusion_common::config::ConfigOptions; + use datafusion_common::Result; + use datafusion_expr::{ScalarUDF, ScalarUDFImpl}; + use datafusion_expr_common::columnar_value::ColumnarValue; + use datafusion_expr_common::operator::Operator; + use datafusion_expr_common::signature::{Signature, Volatility}; + use datafusion_physical_expr::expressions::{binary, col, lit}; + use datafusion_physical_expr::{PhysicalExpr, ScalarFunctionExpr}; + use datafusion_physical_plan::memory::MemoryExec; + use datafusion_physical_plan::projection::ProjectionExec; + use datafusion_physical_plan::{get_plan_string, ExecutionPlan}; + use std::any::Any; + use std::sync::Arc; + + fn mock_data() -> Arc { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + Arc::new(MemoryExec::try_new(&[vec![]], Arc::clone(&schema), None).unwrap()) + } + + #[test] + fn subexpr_in_same_order() -> Result<()> { + let table_scan = mock_data(); + + let a = col("a", &table_scan.schema())?; + let lit_1 = lit(1); + let _1_plus_a = binary(lit_1, Operator::Plus, a, &table_scan.schema())?; + + let exprs = vec![ + (Arc::clone(&_1_plus_a), "first".to_string()), + (_1_plus_a, "second".to_string()), + ]; + let plan = Arc::new(ProjectionExec::try_new(exprs, mock_data())?); + + let config = ConfigOptions::new(); + let optimizer = EliminateCommonPhysicalSubexprs::new(); + let optimized = optimizer.optimize(plan, &config)?; + + let actual = get_plan_string(&optimized); + let expected = [ + "ProjectionExec: expr=[__common_physical_expr_1@2 as first, __common_physical_expr_1@2 as second]", + " ProjectionExec: expr=[a@0 as a, b@1 as b, 1 + a@0 as __common_physical_expr_1]", + " MemoryExec: partitions=1, partition_sizes=[0]"]; + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn subexpr_in_different_order() -> Result<()> { + let table_scan = mock_data(); + + let a = col("a", &table_scan.schema())?; + let lit_1 = lit(1); + let _1_plus_a = binary( + Arc::clone(&lit_1), + Operator::Plus, + Arc::clone(&a), + &table_scan.schema(), + )?; + let a_plus_1 = binary(a, Operator::Plus, lit_1, &table_scan.schema())?; + + let exprs = vec![ + (_1_plus_a, "first".to_string()), + (a_plus_1, "second".to_string()), + ]; + let plan = Arc::new(ProjectionExec::try_new(exprs, mock_data())?); + + let config = ConfigOptions::new(); + let optimizer = EliminateCommonPhysicalSubexprs::new(); + let optimized = optimizer.optimize(plan, &config)?; + + let actual = get_plan_string(&optimized); + let expected = [ + "ProjectionExec: expr=[1 + a@0 as first, a@0 + 1 as second]", + " MemoryExec: partitions=1, partition_sizes=[0]", + ]; + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn test_volatile() -> Result<()> { + let table_scan = mock_data(); + + let a = col("a", &table_scan.schema())?; + let b = col("b", &table_scan.schema())?; + let extracted_child = binary(a, Operator::Plus, b, &table_scan.schema())?; + let rand = rand_expr(); + let not_extracted_volatile = + binary(extracted_child, Operator::Plus, rand, &table_scan.schema())?; + + let exprs = vec![ + (Arc::clone(¬_extracted_volatile), "c1".to_string()), + (not_extracted_volatile, "c2".to_string()), + ]; + let plan = Arc::new(ProjectionExec::try_new(exprs, mock_data())?); + + let config = ConfigOptions::new(); + let optimizer = EliminateCommonPhysicalSubexprs::new(); + let optimized = optimizer.optimize(plan, &config)?; + + let actual = get_plan_string(&optimized); + let expected = [ + "ProjectionExec: expr=[__common_physical_expr_1@2 + random() as c1, __common_physical_expr_1@2 + random() as c2]", + " ProjectionExec: expr=[a@0 as a, b@1 as b, a@0 + b@1 as __common_physical_expr_1]", + " MemoryExec: partitions=1, partition_sizes=[0]"]; + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn test_volatile_short_circuits() -> Result<()> { + let table_scan = mock_data(); + + let a = col("a", &table_scan.schema())?; + let b = col("b", &table_scan.schema())?; + let rand = rand_expr(); + let rand_eq_0 = binary(rand, Operator::Eq, lit(0), &table_scan.schema())?; + + let extracted_short_circuit_leg_1 = + binary(a, Operator::Eq, lit(0), &table_scan.schema())?; + let not_extracted_volatile_short_circuit_1 = binary( + extracted_short_circuit_leg_1, + Operator::Or, + Arc::clone(&rand_eq_0), + &table_scan.schema(), + )?; + + let not_extracted_short_circuit_leg_2 = + binary(b, Operator::Eq, lit(0), &table_scan.schema())?; + let not_extracted_volatile_short_circuit_2 = binary( + rand_eq_0, + Operator::Or, + not_extracted_short_circuit_leg_2, + &table_scan.schema(), + )?; + + let exprs = vec![ + ( + Arc::clone(¬_extracted_volatile_short_circuit_1), + "c1".to_string(), + ), + (not_extracted_volatile_short_circuit_1, "c2".to_string()), + ( + Arc::clone(¬_extracted_volatile_short_circuit_2), + "c3".to_string(), + ), + (not_extracted_volatile_short_circuit_2, "c4".to_string()), + ]; + let plan = Arc::new(ProjectionExec::try_new(exprs, mock_data())?); + + let config = ConfigOptions::new(); + let optimizer = EliminateCommonPhysicalSubexprs::new(); + let optimized = optimizer.optimize(plan, &config)?; + + let actual = get_plan_string(&optimized); + let expected = [ + "ProjectionExec: expr=[__common_physical_expr_1@2 OR random() = 0 as c1, __common_physical_expr_1@2 OR random() = 0 as c2, random() = 0 OR b@1 = 0 as c3, random() = 0 OR b@1 = 0 as c4]", + " ProjectionExec: expr=[a@0 as a, b@1 as b, a@0 = 0 as __common_physical_expr_1]", + " MemoryExec: partitions=1, partition_sizes=[0]"]; + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn test_non_top_level_common_expression() -> Result<()> { + let table_scan = mock_data(); + + let a = col("a", &table_scan.schema())?; + let b = col("b", &table_scan.schema())?; + let common_expr = binary(a, Operator::Plus, b, &table_scan.schema())?; + + let exprs = vec![ + (Arc::clone(&common_expr), "c1".to_string()), + (common_expr, "c2".to_string()), + ]; + let plan = Arc::new(ProjectionExec::try_new(exprs, mock_data())?); + + let c1 = col("c1", &plan.schema())?; + let c2 = col("c2", &plan.schema())?; + + let exprs = vec![(c1, "c1".to_string()), (c2, "c2".to_string())]; + let plan = Arc::new(ProjectionExec::try_new(exprs, plan)?); + + let config = ConfigOptions::new(); + let optimizer = EliminateCommonPhysicalSubexprs::new(); + let optimized = optimizer.optimize(plan, &config)?; + + let actual = get_plan_string(&optimized); + let expected = [ + "ProjectionExec: expr=[c1@0 as c1, c2@1 as c2]", + " ProjectionExec: expr=[__common_physical_expr_1@2 as c1, __common_physical_expr_1@2 as c2]", + " ProjectionExec: expr=[a@0 as a, b@1 as b, a@0 + b@1 as __common_physical_expr_1]", + " MemoryExec: partitions=1, partition_sizes=[0]"]; + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn test_nested_common_expression() -> Result<()> { + let table_scan = mock_data(); + + let a = col("a", &table_scan.schema())?; + let b = col("b", &table_scan.schema())?; + let nested_common_expr = binary(a, Operator::Plus, b, &table_scan.schema())?; + let common_expr = binary( + Arc::clone(&nested_common_expr), + Operator::Multiply, + nested_common_expr, + &table_scan.schema(), + )?; + + let exprs = vec![ + (Arc::clone(&common_expr), "c1".to_string()), + (common_expr, "c2".to_string()), + ]; + let plan = Arc::new(ProjectionExec::try_new(exprs, mock_data())?); + + let config = ConfigOptions::new(); + let optimizer = EliminateCommonPhysicalSubexprs::new(); + let optimized = optimizer.optimize(plan, &config)?; + + let actual = get_plan_string(&optimized); + let expected = [ + "ProjectionExec: expr=[__common_physical_expr_1@2 as c1, __common_physical_expr_1@2 as c2]", + " ProjectionExec: expr=[a@0 as a, b@1 as b, __common_physical_expr_2@2 * __common_physical_expr_2@2 as __common_physical_expr_1]", + " ProjectionExec: expr=[a@0 as a, b@1 as b, a@0 + b@1 as __common_physical_expr_2]", + " MemoryExec: partitions=1, partition_sizes=[0]"]; + assert_eq!(actual, expected); + + Ok(()) + } + + fn rand_expr() -> Arc { + let r = RandomStub::new(); + let n = r.name().to_string(); + let t = r.return_type(&[]).unwrap(); + Arc::new(ScalarFunctionExpr::new( + &n, + Arc::new(ScalarUDF::new_from_impl(r)), + vec![], + t, + )) + } + + #[derive(Debug)] + struct RandomStub { + signature: Signature, + } + + impl RandomStub { + fn new() -> Self { + Self { + signature: Signature::exact(vec![], Volatility::Volatile), + } + } + } + impl ScalarUDFImpl for RandomStub { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "random" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + unimplemented!() + } + } +} diff --git a/datafusion/physical-optimizer/src/lib.rs b/datafusion/physical-optimizer/src/lib.rs index c4f5fa74e1225..8295bc9c67da6 100644 --- a/datafusion/physical-optimizer/src/lib.rs +++ b/datafusion/physical-optimizer/src/lib.rs @@ -21,6 +21,7 @@ pub mod aggregate_statistics; pub mod coalesce_batches; pub mod combine_partial_final_agg; +pub mod eliminate_common_physical_subexprs; pub mod limit_pushdown; pub mod limited_distinct_aggregation; mod optimizer; diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 3311b57f5d6b1..88d5409d426a3 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -17,6 +17,7 @@ use std::any::Any; use std::fmt::Display; +use std::hash::{Hash, Hasher}; use std::ops::Deref; use std::sync::Arc; use std::vec; @@ -88,6 +89,7 @@ use datafusion::physical_plan::{ use datafusion::prelude::SessionContext; use datafusion::scalar::ScalarValue; use datafusion_common::config::TableParquetOptions; +use datafusion_common::cse::HashNode; use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; @@ -799,8 +801,8 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { } } - impl std::hash::Hash for CustomPredicateExpr { - fn hash(&self, state: &mut H) { + impl Hash for CustomPredicateExpr { + fn hash(&self, state: &mut H) { self.inner.hash(state); } } @@ -840,6 +842,10 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { } } + impl HashNode for CustomPredicateExpr { + fn hash_node(&self, _state: &mut H) {} + } + #[derive(Debug)] struct CustomPhysicalExtensionCodec; impl PhysicalExtensionCodec for CustomPhysicalExtensionCodec {