From a8d24d615f0adc9dc3cd9d855ca94a838eec71cf Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 8 Apr 2024 14:39:21 -0400 Subject: [PATCH] Avoid cloning in log::simplify --- datafusion/functions/src/math/log.rs | 88 ++++++++++++++++++---------- 1 file changed, 58 insertions(+), 30 deletions(-) diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index 2131b6aa6705..2adf9cc00c2a 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -18,10 +18,15 @@ //! Math function: `log()`. use arrow::datatypes::DataType; -use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + exec_err, internal_err, plan_datafusion_err, plan_err, DataFusionError, Result, + ScalarValue, +}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{ColumnarValue, Expr, FuncMonotonicity, ScalarFunctionDefinition}; +use datafusion_expr::{ + lit, ColumnarValue, Expr, FuncMonotonicity, ScalarFunctionDefinition, +}; use arrow::array::{ArrayRef, Float32Array, Float64Array}; use datafusion_expr::TypeSignature::*; @@ -146,44 +151,54 @@ impl ScalarUDFImpl for LogFunc { /// 3. Log(a, a) ===> 1 fn simplify( &self, - args: Vec, + mut args: Vec, info: &dyn SimplifyInfo, ) -> Result { - let mut number = &args[0]; - let mut base = - &Expr::Literal(ScalarValue::new_ten(&info.get_data_type(number)?)?); - if args.len() == 2 { - base = &args[0]; - number = &args[1]; + // Args are either + // log(number) + // log(base, number) + let num_args = args.len(); + if args.len() > 2 { + return plan_err!("Expected log to have 1 or 2 arguments, got {num_args}"); } + let number = args.pop().ok_or_else(|| { + plan_datafusion_err!("Expected log to have 1 or 2 arguments, got 0") + })?; + let number_datatype = info.get_data_type(&number)?; + // default to base 10 + let base = if let Some(base) = args.pop() { + base + } else { + lit(ScalarValue::new_ten(&number_datatype)?) + }; match number { - Expr::Literal(value) - if value == &ScalarValue::new_one(&info.get_data_type(number)?)? => - { - Ok(ExprSimplifyResult::Simplified(Expr::Literal( - ScalarValue::new_zero(&info.get_data_type(base)?)?, - ))) + Expr::Literal(value) if value == ScalarValue::new_one(&number_datatype)? => { + Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_zero( + &info.get_data_type(&base)?, + )?))) } - Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(fun), - args, - }) if base == &args[0] - && fun - .as_ref() - .inner() - .as_any() - .downcast_ref::() - .is_some() => + Expr::ScalarFunction(ScalarFunction { func_def, mut args }) + if is_pow(&func_def) && args.len() == 2 && base == args[0] => { - Ok(ExprSimplifyResult::Simplified(args[1].clone())) + let b = args.pop().unwrap(); // length checked above + Ok(ExprSimplifyResult::Simplified(b)) } - _ => { + number => { if number == base { - Ok(ExprSimplifyResult::Simplified(Expr::Literal( - ScalarValue::new_one(&info.get_data_type(number)?)?, - ))) + Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_one( + &number_datatype, + )?))) } else { + let args = match num_args { + 1 => vec![number], + 2 => vec![number, base], + _ => { + return internal_err!( + "Unexpected number of arguments in log::simplify" + ) + } + }; Ok(ExprSimplifyResult::Original(args)) } } @@ -191,6 +206,19 @@ impl ScalarUDFImpl for LogFunc { } } +/// Returns true if the function is `PowerFunc` +fn is_pow(func_def: &ScalarFunctionDefinition) -> bool { + if let ScalarFunctionDefinition::UDF(fun) = func_def { + fun.as_ref() + .inner() + .as_any() + .downcast_ref::() + .is_some() + } else { + false + } +} + #[cfg(test)] mod tests { use datafusion_common::cast::{as_float32_array, as_float64_array};