Skip to content

Commit

Permalink
Avoid adding datafusion function dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Jul 6, 2024
1 parent 4a0425a commit 21917ed
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 7 deletions.
1 change: 0 additions & 1 deletion datafusion/optimizer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ async-trait = { workspace = true }
chrono = { workspace = true }
datafusion-common = { workspace = true, default-features = true }
datafusion-expr = { workspace = true }
datafusion-functions = { workspace = true }
datafusion-physical-expr = { workspace = true }
hashbrown = { workspace = true }
indexmap = { workspace = true }
Expand Down
56 changes: 50 additions & 6 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1127,18 +1127,19 @@ fn replace_common_expr<'n>(

#[cfg(test)]
mod test {
use std::any::Any;
use std::collections::HashSet;
use std::iter;

use arrow::datatypes::{DataType, Field, Schema};
use datafusion_expr::expr::{AggregateFunction, ScalarFunction};
use datafusion_expr::expr::AggregateFunction;
use datafusion_expr::logical_plan::{table_scan, JoinType};
use datafusion_expr::{
grouping_set, AccumulatorFactoryFunction, AggregateUDF, BinaryExpr, Signature,
SimpleAggregateUDF, Volatility,
grouping_set, AccumulatorFactoryFunction, AggregateUDF, BinaryExpr,
ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF,
Volatility,
};
use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder};
use datafusion_functions::math;

use crate::optimizer::OptimizerContext;
use crate::test::*;
Expand Down Expand Up @@ -1871,7 +1872,7 @@ mod test {
let table_scan = test_table_scan()?;

let extracted_child = col("a") + col("b");
let rand = Expr::ScalarFunction(ScalarFunction::new_udf(math::random(), vec![]));
let rand = rand_func().call(vec![]);
let not_extracted_volatile = extracted_child + rand;
let plan = LogicalPlanBuilder::from(table_scan.clone())
.project(vec![
Expand All @@ -1893,7 +1894,7 @@ mod test {
fn test_volatile_short_circuits() -> Result<()> {
let table_scan = test_table_scan()?;

let rand = Expr::ScalarFunction(ScalarFunction::new_udf(math::random(), vec![]));
let rand = rand_func().call(vec![]);
let not_extracted_volatile_short_circuit_2 =
rand.clone().eq(lit(0)).or(col("b").eq(lit(0)));
let not_extracted_volatile_short_circuit_1 =
Expand All @@ -1914,4 +1915,47 @@ mod test {

Ok(())
}

/// returns a "random" function that is marked volatile (aka each invocation
/// returns a different value)
///
/// Does not use datafusion_functions::rand to avoid introducing a
/// dependency on that crate.
fn rand_func() -> ScalarUDF {
ScalarUDF::new_from_impl(RandomStub::new())
}

#[derive(Debug)]
struct RandomStub {
signature: Signature,
}

impl RandomStub {
fn new() -> Self {
Self {
signature: Signature::exact(vec![], Volatility::Volatile),
}
}
}
impl ScalarUDFImpl for RandomStub {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"random"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}

fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
unimplemented!()
}
}
}

0 comments on commit 21917ed

Please sign in to comment.