Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alamb/resolve conflict #2

Merged
merged 3 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions datafusion/core/tests/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use datafusion_common::cast::as_int32_array;
use datafusion_common::ScalarValue;
use datafusion_common::{DFSchemaRef, ToDFSchema};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::logical_plan::builder::table_scan_with_filters;
use datafusion_expr::simplify::SimplifyInfo;
use datafusion_expr::{
expr, table_scan, BuiltinScalarFunction, Cast, ColumnarValue, Expr, ExprSchemable,
Expand Down Expand Up @@ -294,6 +295,45 @@ fn select_date_plus_interval() -> Result<()> {
Ok(())
}

#[test]
fn simplify_project_scalar_fn() -> Result<()> {
// Issue https://github.com/apache/arrow-datafusion/issues/5996
let schema = Schema::new(vec![Field::new("f", DataType::Float64, false)]);
let plan = table_scan(Some("test"), &schema, None)?
.project(vec![power(col("f"), lit(1.0))])?
.build()?;

// before simplify: power(t.f, 1.0)
// after simplify: t.f as "power(t.f, 1.0)"
let expected = "Projection: test.f AS power(test.f,Float64(1))\
\n TableScan: test";
let actual = get_optimized_plan_formatted(&plan, &Utc::now());
assert_eq!(expected, actual);
Ok(())
}

#[test]
fn simplify_scan_predicate() -> Result<()> {
let schema = Schema::new(vec![
Field::new("f", DataType::Float64, false),
Field::new("g", DataType::Float64, false),
]);
let plan = table_scan_with_filters(
Some("test"),
&schema,
None,
vec![col("g").eq(power(col("f"), lit(1.0)))],
)?
.build()?;

// before simplify: t.g = power(t.f, 1.0)
// after simplify: (t.g = t.f) as "t.g = power(t.f, 1.0)"
let expected = "TableScan: test, full_filters=[g = f AS g = power(f,Float64(1))]";
let actual = get_optimized_plan_formatted(&plan, &Utc::now());
assert_eq!(expected, actual);
Ok(())
}

#[test]
fn test_const_evaluator() {
// true --> true
Expand Down Expand Up @@ -431,3 +471,99 @@ fn multiple_now() -> Result<()> {
assert_eq!(expected, actual);
Ok(())
}

// ------------------------------
// --- Simplifier tests -----
// ------------------------------

fn expr_test_schema() -> DFSchemaRef {
Schema::new(vec![
Field::new("c1", DataType::Utf8, true),
Field::new("c2", DataType::Boolean, true),
Field::new("c3", DataType::Int64, true),
Field::new("c4", DataType::UInt32, true),
Field::new("c1_non_null", DataType::Utf8, false),
Field::new("c2_non_null", DataType::Boolean, false),
Field::new("c3_non_null", DataType::Int64, false),
Field::new("c4_non_null", DataType::UInt32, false),
])
.to_dfschema_ref()
.unwrap()
}

fn test_simplify(input_expr: Expr, expected_expr: Expr) {
let info: MyInfo = MyInfo {
schema: expr_test_schema(),
execution_props: ExecutionProps::new(),
};
let simplifier = ExprSimplifier::new(info);
let simplified_expr = simplifier
.simplify(input_expr.clone())
.expect("successfully evaluated");

assert_eq!(
simplified_expr, expected_expr,
"Mismatch evaluating {input_expr}\n Expected:{expected_expr}\n Got:{simplified_expr}"
);
}

#[test]
fn test_simplify_log() {
// Log(c3, 1) ===> 0
{
let expr = log(col("c3_non_null"), lit(1));
test_simplify(expr, lit(0i64));
}
// Log(c3, c3) ===> 1
{
let expr = log(col("c3_non_null"), col("c3_non_null"));
let expected = lit(1i64);
test_simplify(expr, expected);
}
// Log(c3, Power(c3, c4)) ===> c4
{
let expr = log(
col("c3_non_null"),
power(col("c3_non_null"), col("c4_non_null")),
);
let expected = col("c4_non_null");
test_simplify(expr, expected);
}
// Log(c3, c4) ===> Log(c3, c4)
{
let expr = log(col("c3_non_null"), col("c4_non_null"));
let expected = log(col("c3_non_null"), col("c4_non_null"));
test_simplify(expr, expected);
}
}

#[test]
fn test_simplify_power() {
// Power(c3, 0) ===> 1
{
let expr = power(col("c3_non_null"), lit(0));
let expected = lit(1i64);
test_simplify(expr, expected)
}
// Power(c3, 1) ===> c3
{
let expr = power(col("c3_non_null"), lit(1));
let expected = col("c3_non_null");
test_simplify(expr, expected)
}
// Power(c3, Log(c3, c4)) ===> c4
{
let expr = power(
col("c3_non_null"),
log(col("c3_non_null"), col("c4_non_null")),
);
let expected = col("c4_non_null");
test_simplify(expr, expected)
}
// Power(c3, c4) ===> Power(c3, c4)
{
let expr = power(col("c3_non_null"), col("c4_non_null"));
let expected = power(col("c3_non_null"), col("c4_non_null"));
test_simplify(expr, expected)
}
}
34 changes: 0 additions & 34 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,8 @@ pub enum BuiltinScalarFunction {
Exp,
/// factorial
Factorial,
/// log, same as log10
Log,
/// nanvl
Nanvl,
/// power
Power,
// string functions
/// concat
Concat,
Expand Down Expand Up @@ -118,9 +114,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Coalesce => Volatility::Immutable,
BuiltinScalarFunction::Exp => Volatility::Immutable,
BuiltinScalarFunction::Factorial => Volatility::Immutable,
BuiltinScalarFunction::Log => Volatility::Immutable,
BuiltinScalarFunction::Nanvl => Volatility::Immutable,
BuiltinScalarFunction::Power => Volatility::Immutable,
BuiltinScalarFunction::Concat => Volatility::Immutable,
BuiltinScalarFunction::ConcatWithSeparator => Volatility::Immutable,
BuiltinScalarFunction::EndsWith => Volatility::Immutable,
Expand Down Expand Up @@ -163,16 +157,6 @@ impl BuiltinScalarFunction {

BuiltinScalarFunction::Factorial => Ok(Int64),

BuiltinScalarFunction::Power => match &input_expr_types[0] {
Int64 => Ok(Int64),
_ => Ok(Float64),
},

BuiltinScalarFunction::Log => match &input_expr_types[0] {
Float32 => Ok(Float32),
_ => Ok(Float64),
},

BuiltinScalarFunction::Nanvl => match &input_expr_types[0] {
Float32 => Ok(Float32),
_ => Ok(Float64),
Expand Down Expand Up @@ -216,20 +200,6 @@ impl BuiltinScalarFunction {
self.volatility(),
),
BuiltinScalarFunction::Random => Signature::exact(vec![], self.volatility()),
BuiltinScalarFunction::Power => Signature::one_of(
vec![Exact(vec![Int64, Int64]), Exact(vec![Float64, Float64])],
self.volatility(),
),

BuiltinScalarFunction::Log => Signature::one_of(
vec![
Exact(vec![Float32]),
Exact(vec![Float64]),
Exact(vec![Float32, Float32]),
Exact(vec![Float64, Float64]),
],
self.volatility(),
),
BuiltinScalarFunction::Nanvl => Signature::one_of(
vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])],
self.volatility(),
Expand Down Expand Up @@ -259,8 +229,6 @@ impl BuiltinScalarFunction {
| BuiltinScalarFunction::Factorial
) {
Some(vec![Some(true)])
} else if *self == BuiltinScalarFunction::Log {
Some(vec![Some(true), Some(false)])
} else {
None
}
Expand All @@ -272,9 +240,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Ceil => &["ceil"],
BuiltinScalarFunction::Exp => &["exp"],
BuiltinScalarFunction::Factorial => &["factorial"],
BuiltinScalarFunction::Log => &["log"],
BuiltinScalarFunction::Nanvl => &["nanvl"],
BuiltinScalarFunction::Power => &["power", "pow"],
BuiltinScalarFunction::Random => &["random"],

// conditional functions
Expand Down
3 changes: 0 additions & 3 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -540,9 +540,6 @@ scalar_expr!(

scalar_expr!(Exp, exp, num, "exponential");

scalar_expr!(Power, power, base exponent, "`base` raised to the power of `exponent`");
scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`");

scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase");
scalar_expr!(EndsWith, ends_with, string suffix, "whether the `string` ends with the `suffix`");
nary_scalar_expr!(Coalesce, coalesce, "returns `coalesce(args...)`, which evaluates to the value of the first [Expr] which is not NULL");
Expand Down
26 changes: 21 additions & 5 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,9 @@ impl LogicalPlan {
err
}

/// Calls `f` on all expressions (non-recursively) in the current
/// logical plan node. This does not include expressions in any
/// children.
/// Calls `f` on all expressions in the current `LogicalPlan` node.
///
/// Note this does not include expressions in child `LogicalPlan` nodes.
pub fn apply_expressions<F: FnMut(&Expr) -> Result<TreeNodeRecursion>>(
&self,
mut f: F,
Expand Down Expand Up @@ -393,6 +393,11 @@ impl LogicalPlan {
}
}

/// Rewrites all expressions in the current `LogicalPlan` node using `f`.
///
/// Returns the current node.
///
/// Note this does not include expressions in child `LogicalPlan` nodes.
pub fn map_expressions<F: FnMut(Expr) -> Result<Transformed<Expr>>>(
self,
mut f: F,
Expand Down Expand Up @@ -608,8 +613,9 @@ impl LogicalPlan {
})
}

/// returns all inputs of this `LogicalPlan` node. Does not
/// include inputs to inputs, or subqueries.
/// Returns all inputs / children of this `LogicalPlan` node.
///
/// Note does not include inputs to inputs, or subqueries.
pub fn inputs(&self) -> Vec<&LogicalPlan> {
match self {
LogicalPlan::Projection(Projection { input, .. }) => vec![input],
Expand Down Expand Up @@ -1370,6 +1376,10 @@ impl LogicalPlan {
)
}

/// Calls `f` recursively on all children of the `LogicalPlan` node.
///
/// Unlike [`Self::apply`], this method *does* includes `LogicalPlan`s that
/// are referenced in `Expr`s
pub fn apply_with_subqueries<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
f: &mut F,
Expand Down Expand Up @@ -1434,6 +1444,8 @@ impl LogicalPlan {
)
}

/// Calls `f` on all subqueries referenced in expressions of the current
/// `LogicalPlan` node.
fn apply_subqueries<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
mut f: F,
Expand All @@ -1453,6 +1465,10 @@ impl LogicalPlan {
})
}

/// Rewrites all subquery `LogicalPlan` in the current `LogicalPlan` node
/// using `f`.
///
/// Returns the current node.
fn map_subqueries<F: FnMut(Self) -> Result<Transformed<Self>>>(
self,
mut f: F,
Expand Down
26 changes: 13 additions & 13 deletions datafusion/functions/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,19 +108,6 @@ macro_rules! make_stub_package {
};
}

macro_rules! make_function_scalar_inputs {
($ARG: expr, $NAME:expr, $ARRAY_TYPE:ident, $FUNC: block) => {{
let arg = downcast_arg!($ARG, $NAME, $ARRAY_TYPE);

arg.iter()
.map(|a| match a {
Some(a) => Some($FUNC(a)),
_ => None,
})
.collect::<$ARRAY_TYPE>()
}};
}

/// Invokes a function on each element of an array and returns the result as a new array
///
/// $ARG: ArrayRef
Expand Down Expand Up @@ -370,6 +357,19 @@ macro_rules! make_math_binary_udf {
};
}

macro_rules! make_function_scalar_inputs {
($ARG: expr, $NAME:expr, $ARRAY_TYPE:ident, $FUNC: block) => {{
let arg = downcast_arg!($ARG, $NAME, $ARRAY_TYPE);

arg.iter()
.map(|a| match a {
Some(a) => Some($FUNC(a)),
_ => None,
})
.collect::<$ARRAY_TYPE>()
}};
}

macro_rules! make_function_inputs2 {
($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE:ident, $FUNC: block) => {{
let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE);
Expand Down
Loading
Loading