Skip to content

Commit

Permalink
Avoid cloning in power::simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Apr 8, 2024
1 parent 7d7b28b commit 1ba73e0
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 27 deletions.
2 changes: 1 addition & 1 deletion datafusion/functions/src/math/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
pub mod abs;
pub mod gcd;
pub mod lcm;
pub mod log;
pub mod nans;
pub mod pi;
pub mod log;
pub mod power;

// Create UDFs
Expand Down
59 changes: 34 additions & 25 deletions datafusion/functions/src/math/power.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
//! Math function: `power()`.
use arrow::datatypes::DataType;
use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue};
use datafusion_common::{
exec_err, plan_datafusion_err, DataFusionError, Result, ScalarValue,
};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
use datafusion_expr::{ColumnarValue, Expr, ScalarFunctionDefinition};
Expand Down Expand Up @@ -118,43 +120,50 @@ impl ScalarUDFImpl for PowerFunc {
/// 3. Power(a, Log(a, b)) ===> b
fn simplify(
&self,
args: Vec<Expr>,
mut args: Vec<Expr>,
info: &dyn SimplifyInfo,
) -> Result<ExprSimplifyResult> {
let base = &args[0];
let exponent = &args[1];

let exponent = args.pop().ok_or_else(|| {
plan_datafusion_err!("Expected power to have 2 arguments, got 0")
})?;
let base = args.pop().ok_or_else(|| {
plan_datafusion_err!("Expected power to have 2 arguments, got 1")
})?;

let exponent_type = info.get_data_type(&exponent)?;
match exponent {
Expr::Literal(value)
if value == &ScalarValue::new_zero(&info.get_data_type(exponent)?)? =>
{
Expr::Literal(value) if value == ScalarValue::new_zero(&exponent_type)? => {
Ok(ExprSimplifyResult::Simplified(Expr::Literal(
ScalarValue::new_one(&info.get_data_type(base)?)?,
ScalarValue::new_one(&info.get_data_type(&base)?)?,
)))
}
Expr::Literal(value)
if value == &ScalarValue::new_one(&info.get_data_type(exponent)?)? =>
{
Ok(ExprSimplifyResult::Simplified(base.clone()))
Expr::Literal(value) if value == ScalarValue::new_one(&exponent_type)? => {
Ok(ExprSimplifyResult::Simplified(base))
}
Expr::ScalarFunction(ScalarFunction {
func_def: ScalarFunctionDefinition::UDF(fun),
args,
}) if base == &args[0]
&& fun
.as_ref()
.inner()
.as_any()
.downcast_ref::<LogFunc>()
.is_some() =>
Expr::ScalarFunction(ScalarFunction { func_def, mut args })
if is_log(&func_def) && args.len() == 2 && base == args[0] =>
{
Ok(ExprSimplifyResult::Simplified(args[1].clone()))
let b = args.pop().unwrap(); // length checked above
Ok(ExprSimplifyResult::Simplified(b))
}
_ => Ok(ExprSimplifyResult::Original(args)),
_ => Ok(ExprSimplifyResult::Original(vec![base, exponent])),
}
}
}

/// Return true if this function call is a call to `Log`
fn is_log(func_def: &ScalarFunctionDefinition) -> bool {
if let ScalarFunctionDefinition::UDF(fun) = func_def {
fun.as_ref()
.inner()
.as_any()
.downcast_ref::<LogFunc>()
.is_some()
} else {
false
}
}

#[cfg(test)]
mod tests {
use datafusion_common::cast::{as_float64_array, as_int64_array};
Expand Down
2 changes: 1 addition & 1 deletion datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_
use datafusion_expr::{
ceil, coalesce, concat_expr, concat_ws_expr, cot, ends_with, exp,
expr::{self, InList, Sort, WindowFunction},
factorial, initcap, iszero,
factorial, initcap, iszero,
logical_plan::{PlanType, StringifiedPlan},
nanvl, random, round, trunc, AggregateFunction, Between, BinaryExpr,
BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess,
Expand Down

0 comments on commit 1ba73e0

Please sign in to comment.