Skip to content

Commit

Permalink
Merge pull request #2 from alamb/alamb/resolve_conflict
Browse files Browse the repository at this point in the history
Alamb/resolve conflict
  • Loading branch information
Omega359 authored Apr 8, 2024
2 parents cf1d5e1 + aaaa5d1 commit 83fdffc
Show file tree
Hide file tree
Showing 18 changed files with 671 additions and 425 deletions.
136 changes: 136 additions & 0 deletions datafusion/core/tests/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use datafusion_common::cast::as_int32_array;
use datafusion_common::ScalarValue;
use datafusion_common::{DFSchemaRef, ToDFSchema};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::logical_plan::builder::table_scan_with_filters;
use datafusion_expr::simplify::SimplifyInfo;
use datafusion_expr::{
expr, table_scan, BuiltinScalarFunction, Cast, ColumnarValue, Expr, ExprSchemable,
Expand Down Expand Up @@ -294,6 +295,45 @@ fn select_date_plus_interval() -> Result<()> {
Ok(())
}

#[test]
fn simplify_project_scalar_fn() -> Result<()> {
// Issue https://github.com/apache/arrow-datafusion/issues/5996
let schema = Schema::new(vec![Field::new("f", DataType::Float64, false)]);
let plan = table_scan(Some("test"), &schema, None)?
.project(vec![power(col("f"), lit(1.0))])?
.build()?;

// before simplify: power(t.f, 1.0)
// after simplify: t.f as "power(t.f, 1.0)"
let expected = "Projection: test.f AS power(test.f,Float64(1))\
\n TableScan: test";
let actual = get_optimized_plan_formatted(&plan, &Utc::now());
assert_eq!(expected, actual);
Ok(())
}

#[test]
fn simplify_scan_predicate() -> Result<()> {
let schema = Schema::new(vec![
Field::new("f", DataType::Float64, false),
Field::new("g", DataType::Float64, false),
]);
let plan = table_scan_with_filters(
Some("test"),
&schema,
None,
vec![col("g").eq(power(col("f"), lit(1.0)))],
)?
.build()?;

// before simplify: t.g = power(t.f, 1.0)
// after simplify: (t.g = t.f) as "t.g = power(t.f, 1.0)"
let expected = "TableScan: test, full_filters=[g = f AS g = power(f,Float64(1))]";
let actual = get_optimized_plan_formatted(&plan, &Utc::now());
assert_eq!(expected, actual);
Ok(())
}

#[test]
fn test_const_evaluator() {
// true --> true
Expand Down Expand Up @@ -431,3 +471,99 @@ fn multiple_now() -> Result<()> {
assert_eq!(expected, actual);
Ok(())
}

// ------------------------------
// --- Simplifier tests -----
// ------------------------------

fn expr_test_schema() -> DFSchemaRef {
Schema::new(vec![
Field::new("c1", DataType::Utf8, true),
Field::new("c2", DataType::Boolean, true),
Field::new("c3", DataType::Int64, true),
Field::new("c4", DataType::UInt32, true),
Field::new("c1_non_null", DataType::Utf8, false),
Field::new("c2_non_null", DataType::Boolean, false),
Field::new("c3_non_null", DataType::Int64, false),
Field::new("c4_non_null", DataType::UInt32, false),
])
.to_dfschema_ref()
.unwrap()
}

fn test_simplify(input_expr: Expr, expected_expr: Expr) {
let info: MyInfo = MyInfo {
schema: expr_test_schema(),
execution_props: ExecutionProps::new(),
};
let simplifier = ExprSimplifier::new(info);
let simplified_expr = simplifier
.simplify(input_expr.clone())
.expect("successfully evaluated");

assert_eq!(
simplified_expr, expected_expr,
"Mismatch evaluating {input_expr}\n Expected:{expected_expr}\n Got:{simplified_expr}"
);
}

#[test]
fn test_simplify_log() {
// Log(c3, 1) ===> 0
{
let expr = log(col("c3_non_null"), lit(1));
test_simplify(expr, lit(0i64));
}
// Log(c3, c3) ===> 1
{
let expr = log(col("c3_non_null"), col("c3_non_null"));
let expected = lit(1i64);
test_simplify(expr, expected);
}
// Log(c3, Power(c3, c4)) ===> c4
{
let expr = log(
col("c3_non_null"),
power(col("c3_non_null"), col("c4_non_null")),
);
let expected = col("c4_non_null");
test_simplify(expr, expected);
}
// Log(c3, c4) ===> Log(c3, c4)
{
let expr = log(col("c3_non_null"), col("c4_non_null"));
let expected = log(col("c3_non_null"), col("c4_non_null"));
test_simplify(expr, expected);
}
}

#[test]
fn test_simplify_power() {
// Power(c3, 0) ===> 1
{
let expr = power(col("c3_non_null"), lit(0));
let expected = lit(1i64);
test_simplify(expr, expected)
}
// Power(c3, 1) ===> c3
{
let expr = power(col("c3_non_null"), lit(1));
let expected = col("c3_non_null");
test_simplify(expr, expected)
}
// Power(c3, Log(c3, c4)) ===> c4
{
let expr = power(
col("c3_non_null"),
log(col("c3_non_null"), col("c4_non_null")),
);
let expected = col("c4_non_null");
test_simplify(expr, expected)
}
// Power(c3, c4) ===> Power(c3, c4)
{
let expr = power(col("c3_non_null"), col("c4_non_null"));
let expected = power(col("c3_non_null"), col("c4_non_null"));
test_simplify(expr, expected)
}
}
34 changes: 0 additions & 34 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,8 @@ pub enum BuiltinScalarFunction {
Exp,
/// factorial
Factorial,
/// log, same as log10
Log,
/// nanvl
Nanvl,
/// power
Power,
// string functions
/// concat
Concat,
Expand Down Expand Up @@ -118,9 +114,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Coalesce => Volatility::Immutable,
BuiltinScalarFunction::Exp => Volatility::Immutable,
BuiltinScalarFunction::Factorial => Volatility::Immutable,
BuiltinScalarFunction::Log => Volatility::Immutable,
BuiltinScalarFunction::Nanvl => Volatility::Immutable,
BuiltinScalarFunction::Power => Volatility::Immutable,
BuiltinScalarFunction::Concat => Volatility::Immutable,
BuiltinScalarFunction::ConcatWithSeparator => Volatility::Immutable,
BuiltinScalarFunction::EndsWith => Volatility::Immutable,
Expand Down Expand Up @@ -163,16 +157,6 @@ impl BuiltinScalarFunction {

BuiltinScalarFunction::Factorial => Ok(Int64),

BuiltinScalarFunction::Power => match &input_expr_types[0] {
Int64 => Ok(Int64),
_ => Ok(Float64),
},

BuiltinScalarFunction::Log => match &input_expr_types[0] {
Float32 => Ok(Float32),
_ => Ok(Float64),
},

BuiltinScalarFunction::Nanvl => match &input_expr_types[0] {
Float32 => Ok(Float32),
_ => Ok(Float64),
Expand Down Expand Up @@ -216,20 +200,6 @@ impl BuiltinScalarFunction {
self.volatility(),
),
BuiltinScalarFunction::Random => Signature::exact(vec![], self.volatility()),
BuiltinScalarFunction::Power => Signature::one_of(
vec![Exact(vec![Int64, Int64]), Exact(vec![Float64, Float64])],
self.volatility(),
),

BuiltinScalarFunction::Log => Signature::one_of(
vec![
Exact(vec![Float32]),
Exact(vec![Float64]),
Exact(vec![Float32, Float32]),
Exact(vec![Float64, Float64]),
],
self.volatility(),
),
BuiltinScalarFunction::Nanvl => Signature::one_of(
vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])],
self.volatility(),
Expand Down Expand Up @@ -259,8 +229,6 @@ impl BuiltinScalarFunction {
| BuiltinScalarFunction::Factorial
) {
Some(vec![Some(true)])
} else if *self == BuiltinScalarFunction::Log {
Some(vec![Some(true), Some(false)])
} else {
None
}
Expand All @@ -272,9 +240,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Ceil => &["ceil"],
BuiltinScalarFunction::Exp => &["exp"],
BuiltinScalarFunction::Factorial => &["factorial"],
BuiltinScalarFunction::Log => &["log"],
BuiltinScalarFunction::Nanvl => &["nanvl"],
BuiltinScalarFunction::Power => &["power", "pow"],
BuiltinScalarFunction::Random => &["random"],

// conditional functions
Expand Down
3 changes: 0 additions & 3 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -540,9 +540,6 @@ scalar_expr!(

scalar_expr!(Exp, exp, num, "exponential");

scalar_expr!(Power, power, base exponent, "`base` raised to the power of `exponent`");
scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`");

scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase");
scalar_expr!(EndsWith, ends_with, string suffix, "whether the `string` ends with the `suffix`");
nary_scalar_expr!(Coalesce, coalesce, "returns `coalesce(args...)`, which evaluates to the value of the first [Expr] which is not NULL");
Expand Down
26 changes: 21 additions & 5 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,9 @@ impl LogicalPlan {
err
}

/// Calls `f` on all expressions (non-recursively) in the current
/// logical plan node. This does not include expressions in any
/// children.
/// Calls `f` on all expressions in the current `LogicalPlan` node.
///
/// Note this does not include expressions in child `LogicalPlan` nodes.
pub fn apply_expressions<F: FnMut(&Expr) -> Result<TreeNodeRecursion>>(
&self,
mut f: F,
Expand Down Expand Up @@ -393,6 +393,11 @@ impl LogicalPlan {
}
}

/// Rewrites all expressions in the current `LogicalPlan` node using `f`.
///
/// Returns the current node.
///
/// Note this does not include expressions in child `LogicalPlan` nodes.
pub fn map_expressions<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
self,
mut f: F,
Expand Down Expand Up @@ -608,8 +613,9 @@ impl LogicalPlan {
})
}

/// returns all inputs of this `LogicalPlan` node. Does not
/// include inputs to inputs, or subqueries.
/// Returns all inputs / children of this `LogicalPlan` node.
///
/// Note does not include inputs to inputs, or subqueries.
pub fn inputs(&self) -> Vec<&LogicalPlan> {
match self {
LogicalPlan::Projection(Projection { input, .. }) => vec![input],
Expand Down Expand Up @@ -1370,6 +1376,10 @@ impl LogicalPlan {
)
}

/// Calls `f` recursively on all children of the `LogicalPlan` node.
///
/// Unlike [`Self::apply`], this method *does* includes `LogicalPlan`s that
/// are referenced in `Expr`s
pub fn apply_with_subqueries<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
f: &mut F,
Expand Down Expand Up @@ -1434,6 +1444,8 @@ impl LogicalPlan {
)
}

/// Calls `f` on all subqueries referenced in expressions of the current
/// `LogicalPlan` node.
fn apply_subqueries<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
mut f: F,
Expand All @@ -1453,6 +1465,10 @@ impl LogicalPlan {
})
}

/// Rewrites all subquery `LogicalPlan` in the current `LogicalPlan` node
/// using `f`.
///
/// Returns the current node.
fn map_subqueries<F: FnMut(Self) -> Result<Transformed<Self>>>(
self,
mut f: F,
Expand Down
26 changes: 13 additions & 13 deletions datafusion/functions/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,19 +108,6 @@ macro_rules! make_stub_package {
};
}

macro_rules! make_function_scalar_inputs {
($ARG: expr, $NAME:expr, $ARRAY_TYPE:ident, $FUNC: block) => {{
let arg = downcast_arg!($ARG, $NAME, $ARRAY_TYPE);

arg.iter()
.map(|a| match a {
Some(a) => Some($FUNC(a)),
_ => None,
})
.collect::<$ARRAY_TYPE>()
}};
}

/// Invokes a function on each element of an array and returns the result as a new array
///
/// $ARG: ArrayRef
Expand Down Expand Up @@ -370,6 +357,19 @@ macro_rules! make_math_binary_udf {
};
}

macro_rules! make_function_scalar_inputs {
($ARG: expr, $NAME:expr, $ARRAY_TYPE:ident, $FUNC: block) => {{
let arg = downcast_arg!($ARG, $NAME, $ARRAY_TYPE);

arg.iter()
.map(|a| match a {
Some(a) => Some($FUNC(a)),
_ => None,
})
.collect::<$ARRAY_TYPE>()
}};
}

macro_rules! make_function_inputs2 {
($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE:ident, $FUNC: block) => {{
let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE);
Expand Down
Loading

0 comments on commit 83fdffc

Please sign in to comment.