diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 54d8c472f13f..1a9e9630c076 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -45,7 +45,6 @@ async-trait = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } -datafusion-functions = { workspace = true } datafusion-physical-expr = { workspace = true } hashbrown = { workspace = true } indexmap = { workspace = true } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index d03cc361b9bc..4a4933fe9cfd 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -1127,18 +1127,19 @@ fn replace_common_expr<'n>( #[cfg(test)] mod test { + use std::any::Any; use std::collections::HashSet; use std::iter; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_expr::expr::{AggregateFunction, ScalarFunction}; + use datafusion_expr::expr::AggregateFunction; use datafusion_expr::logical_plan::{table_scan, JoinType}; use datafusion_expr::{ - grouping_set, AccumulatorFactoryFunction, AggregateUDF, BinaryExpr, Signature, - SimpleAggregateUDF, Volatility, + grouping_set, AccumulatorFactoryFunction, AggregateUDF, BinaryExpr, + ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, + Volatility, }; use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; - use datafusion_functions::math; use crate::optimizer::OptimizerContext; use crate::test::*; @@ -1871,7 +1872,7 @@ mod test { let table_scan = test_table_scan()?; let extracted_child = col("a") + col("b"); - let rand = Expr::ScalarFunction(ScalarFunction::new_udf(math::random(), vec![])); + let rand = rand_func().call(vec![]); let not_extracted_volatile = extracted_child + rand; let plan = LogicalPlanBuilder::from(table_scan.clone()) .project(vec![ @@ -1893,7 +1894,7 @@ mod test { fn test_volatile_short_circuits() -> Result<()> { let table_scan = test_table_scan()?; - let rand = Expr::ScalarFunction(ScalarFunction::new_udf(math::random(), vec![])); + let rand = rand_func().call(vec![]); let not_extracted_volatile_short_circuit_2 = rand.clone().eq(lit(0)).or(col("b").eq(lit(0))); let not_extracted_volatile_short_circuit_1 = @@ -1914,4 +1915,47 @@ mod test { Ok(()) } + + /// returns a "random" function that is marked volatile (aka each invocation + /// returns a different value) + /// + /// Does not use datafusion_functions::rand to avoid introducing a + /// dependency on that crate. + fn rand_func() -> ScalarUDF { + ScalarUDF::new_from_impl(RandomStub::new()) + } + + #[derive(Debug)] + struct RandomStub { + signature: Signature, + } + + impl RandomStub { + fn new() -> Self { + Self { + signature: Signature::exact(vec![], Volatility::Volatile), + } + } + } + impl ScalarUDFImpl for RandomStub { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "random" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + unimplemented!() + } + } }