Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move the Log, Power functions to datafusion-functions #9983

Merged
merged 4 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions datafusion/core/tests/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use datafusion_common::cast::as_int32_array;
use datafusion_common::ScalarValue;
use datafusion_common::{DFSchemaRef, ToDFSchema};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::logical_plan::builder::table_scan_with_filters;
use datafusion_expr::simplify::SimplifyInfo;
use datafusion_expr::{
expr, table_scan, BuiltinScalarFunction, Cast, ColumnarValue, Expr, ExprSchemable,
Expand Down Expand Up @@ -294,6 +295,45 @@ fn select_date_plus_interval() -> Result<()> {
Ok(())
}

#[test]
fn simplify_project_scalar_fn() -> Result<()> {
// Issue https://github.com/apache/arrow-datafusion/issues/5996
let schema = Schema::new(vec![Field::new("f", DataType::Float64, false)]);
let plan = table_scan(Some("test"), &schema, None)?
.project(vec![power(col("f"), lit(1.0))])?
.build()?;

// before simplify: power(t.f, 1.0)
// after simplify: t.f as "power(t.f, 1.0)"
let expected = "Projection: test.f AS power(test.f,Float64(1))\
\n TableScan: test";
let actual = get_optimized_plan_formatted(&plan, &Utc::now());
assert_eq!(expected, actual);
Ok(())
}

#[test]
fn simplify_scan_predicate() -> Result<()> {
let schema = Schema::new(vec![
Field::new("f", DataType::Float64, false),
Field::new("g", DataType::Float64, false),
]);
let plan = table_scan_with_filters(
Some("test"),
&schema,
None,
vec![col("g").eq(power(col("f"), lit(1.0)))],
)?
.build()?;

// before simplify: t.g = power(t.f, 1.0)
// after simplify: (t.g = t.f) as "t.g = power(t.f, 1.0)"
let expected = "TableScan: test, full_filters=[g = f AS g = power(f,Float64(1))]";
let actual = get_optimized_plan_formatted(&plan, &Utc::now());
assert_eq!(expected, actual);
Ok(())
}

#[test]
fn test_const_evaluator() {
// true --> true
Expand Down Expand Up @@ -431,3 +471,99 @@ fn multiple_now() -> Result<()> {
assert_eq!(expected, actual);
Ok(())
}

// ------------------------------
// --- Simplifier tests -----
// ------------------------------

fn expr_test_schema() -> DFSchemaRef {
Schema::new(vec![
Field::new("c1", DataType::Utf8, true),
Field::new("c2", DataType::Boolean, true),
Field::new("c3", DataType::Int64, true),
Field::new("c4", DataType::UInt32, true),
Field::new("c1_non_null", DataType::Utf8, false),
Field::new("c2_non_null", DataType::Boolean, false),
Field::new("c3_non_null", DataType::Int64, false),
Field::new("c4_non_null", DataType::UInt32, false),
])
.to_dfschema_ref()
.unwrap()
}

fn test_simplify(input_expr: Expr, expected_expr: Expr) {
let info: MyInfo = MyInfo {
schema: expr_test_schema(),
execution_props: ExecutionProps::new(),
};
let simplifier = ExprSimplifier::new(info);
let simplified_expr = simplifier
.simplify(input_expr.clone())
.expect("successfully evaluated");

assert_eq!(
simplified_expr, expected_expr,
"Mismatch evaluating {input_expr}\n Expected:{expected_expr}\n Got:{simplified_expr}"
);
}

#[test]
fn test_simplify_log() {
// Log(c3, 1) ===> 0
{
let expr = log(col("c3_non_null"), lit(1));
test_simplify(expr, lit(0i64));
}
// Log(c3, c3) ===> 1
{
let expr = log(col("c3_non_null"), col("c3_non_null"));
let expected = lit(1i64);
test_simplify(expr, expected);
}
// Log(c3, Power(c3, c4)) ===> c4
{
let expr = log(
col("c3_non_null"),
power(col("c3_non_null"), col("c4_non_null")),
);
let expected = col("c4_non_null");
test_simplify(expr, expected);
}
// Log(c3, c4) ===> Log(c3, c4)
{
let expr = log(col("c3_non_null"), col("c4_non_null"));
let expected = log(col("c3_non_null"), col("c4_non_null"));
test_simplify(expr, expected);
}
}

#[test]
fn test_simplify_power() {
// Power(c3, 0) ===> 1
{
let expr = power(col("c3_non_null"), lit(0));
let expected = lit(1i64);
test_simplify(expr, expected)
}
// Power(c3, 1) ===> c3
{
let expr = power(col("c3_non_null"), lit(1));
let expected = col("c3_non_null");
test_simplify(expr, expected)
}
// Power(c3, Log(c3, c4)) ===> c4
{
let expr = power(
col("c3_non_null"),
log(col("c3_non_null"), col("c4_non_null")),
);
let expected = col("c4_non_null");
test_simplify(expr, expected)
}
// Power(c3, c4) ===> Power(c3, c4)
{
let expr = power(col("c3_non_null"), col("c4_non_null"));
let expected = power(col("c3_non_null"), col("c4_non_null"));
test_simplify(expr, expected)
}
}
34 changes: 0 additions & 34 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,10 @@ pub enum BuiltinScalarFunction {
Lcm,
/// iszero
Iszero,
/// log, same as log10
Log,
/// nanvl
Nanvl,
/// pi
Pi,
/// power
Power,
/// round
Round,
/// trunc
Expand Down Expand Up @@ -139,10 +135,8 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Gcd => Volatility::Immutable,
BuiltinScalarFunction::Iszero => Volatility::Immutable,
BuiltinScalarFunction::Lcm => Volatility::Immutable,
BuiltinScalarFunction::Log => Volatility::Immutable,
BuiltinScalarFunction::Nanvl => Volatility::Immutable,
BuiltinScalarFunction::Pi => Volatility::Immutable,
BuiltinScalarFunction::Power => Volatility::Immutable,
BuiltinScalarFunction::Round => Volatility::Immutable,
BuiltinScalarFunction::Cot => Volatility::Immutable,
BuiltinScalarFunction::Trunc => Volatility::Immutable,
Expand Down Expand Up @@ -191,16 +185,6 @@ impl BuiltinScalarFunction {
| BuiltinScalarFunction::Gcd
| BuiltinScalarFunction::Lcm => Ok(Int64),

BuiltinScalarFunction::Power => match &input_expr_types[0] {
Int64 => Ok(Int64),
_ => Ok(Float64),
},

BuiltinScalarFunction::Log => match &input_expr_types[0] {
Float32 => Ok(Float32),
_ => Ok(Float64),
},

BuiltinScalarFunction::Nanvl => match &input_expr_types[0] {
Float32 => Ok(Float32),
_ => Ok(Float64),
Expand Down Expand Up @@ -250,10 +234,6 @@ impl BuiltinScalarFunction {
),
BuiltinScalarFunction::Pi => Signature::exact(vec![], self.volatility()),
BuiltinScalarFunction::Random => Signature::exact(vec![], self.volatility()),
BuiltinScalarFunction::Power => Signature::one_of(
vec![Exact(vec![Int64, Int64]), Exact(vec![Float64, Float64])],
self.volatility(),
),
BuiltinScalarFunction::Round => Signature::one_of(
vec![
Exact(vec![Float64, Int64]),
Expand All @@ -272,16 +252,6 @@ impl BuiltinScalarFunction {
],
self.volatility(),
),

BuiltinScalarFunction::Log => Signature::one_of(
vec![
Exact(vec![Float32]),
Exact(vec![Float64]),
Exact(vec![Float32, Float32]),
Exact(vec![Float64, Float64]),
],
self.volatility(),
),
BuiltinScalarFunction::Nanvl => Signature::one_of(
vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])],
self.volatility(),
Expand Down Expand Up @@ -325,8 +295,6 @@ impl BuiltinScalarFunction {
| BuiltinScalarFunction::Pi
) {
Some(vec![Some(true)])
} else if *self == BuiltinScalarFunction::Log {
Some(vec![Some(true), Some(false)])
} else {
None
}
Expand All @@ -343,10 +311,8 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Gcd => &["gcd"],
BuiltinScalarFunction::Iszero => &["iszero"],
BuiltinScalarFunction::Lcm => &["lcm"],
BuiltinScalarFunction::Log => &["log"],
BuiltinScalarFunction::Nanvl => &["nanvl"],
BuiltinScalarFunction::Pi => &["pi"],
BuiltinScalarFunction::Power => &["power", "pow"],
BuiltinScalarFunction::Random => &["random"],
BuiltinScalarFunction::Round => &["round"],
BuiltinScalarFunction::Trunc => &["trunc"],
Expand Down
3 changes: 0 additions & 3 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -559,9 +559,6 @@ nary_scalar_expr!(
scalar_expr!(Exp, exp, num, "exponential");
scalar_expr!(Gcd, gcd, arg_1 arg_2, "greatest common divisor");
scalar_expr!(Lcm, lcm, arg_1 arg_2, "least common multiple");
scalar_expr!(Power, power, base exponent, "`base` raised to the power of `exponent`");
scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`");

scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase");
scalar_expr!(EndsWith, ends_with, string suffix, "whether the `string` ends with the `suffix`");
nary_scalar_expr!(Coalesce, coalesce, "returns `coalesce(args...)`, which evaluates to the value of the first [Expr] which is not NULL");
Expand Down
13 changes: 13 additions & 0 deletions datafusion/functions/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,19 @@ macro_rules! make_math_binary_udf {
};
}

macro_rules! make_function_scalar_inputs {
($ARG: expr, $NAME:expr, $ARRAY_TYPE:ident, $FUNC: block) => {{
let arg = downcast_arg!($ARG, $NAME, $ARRAY_TYPE);

arg.iter()
.map(|a| match a {
Some(a) => Some($FUNC(a)),
_ => None,
})
.collect::<$ARRAY_TYPE>()
}};
}

macro_rules! make_function_inputs2 {
($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE:ident, $FUNC: block) => {{
let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE);
Expand Down
Loading
Loading