Skip to content

Commit

Permalink
Avoid cloning in log::simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Apr 8, 2024
1 parent 1ba73e0 commit a8d24d6
Showing 1 changed file with 58 additions and 30 deletions.
88 changes: 58 additions & 30 deletions datafusion/functions/src/math/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,15 @@
//! Math function: `log()`.
use arrow::datatypes::DataType;
use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue};
use datafusion_common::{
exec_err, internal_err, plan_datafusion_err, plan_err, DataFusionError, Result,
ScalarValue,
};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
use datafusion_expr::{ColumnarValue, Expr, FuncMonotonicity, ScalarFunctionDefinition};
use datafusion_expr::{
lit, ColumnarValue, Expr, FuncMonotonicity, ScalarFunctionDefinition,
};

use arrow::array::{ArrayRef, Float32Array, Float64Array};
use datafusion_expr::TypeSignature::*;
Expand Down Expand Up @@ -146,51 +151,74 @@ impl ScalarUDFImpl for LogFunc {
/// 3. Log(a, a) ===> 1
fn simplify(
&self,
args: Vec<Expr>,
mut args: Vec<Expr>,
info: &dyn SimplifyInfo,
) -> Result<ExprSimplifyResult> {
let mut number = &args[0];
let mut base =
&Expr::Literal(ScalarValue::new_ten(&info.get_data_type(number)?)?);
if args.len() == 2 {
base = &args[0];
number = &args[1];
// Args are either
// log(number)
// log(base, number)
let num_args = args.len();
if args.len() > 2 {
return plan_err!("Expected log to have 1 or 2 arguments, got {num_args}");
}
let number = args.pop().ok_or_else(|| {
plan_datafusion_err!("Expected log to have 1 or 2 arguments, got 0")
})?;
let number_datatype = info.get_data_type(&number)?;
// default to base 10
let base = if let Some(base) = args.pop() {
base
} else {
lit(ScalarValue::new_ten(&number_datatype)?)
};

match number {
Expr::Literal(value)
if value == &ScalarValue::new_one(&info.get_data_type(number)?)? =>
{
Ok(ExprSimplifyResult::Simplified(Expr::Literal(
ScalarValue::new_zero(&info.get_data_type(base)?)?,
)))
Expr::Literal(value) if value == ScalarValue::new_one(&number_datatype)? => {
Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_zero(
&info.get_data_type(&base)?,
)?)))
}
Expr::ScalarFunction(ScalarFunction {
func_def: ScalarFunctionDefinition::UDF(fun),
args,
}) if base == &args[0]
&& fun
.as_ref()
.inner()
.as_any()
.downcast_ref::<PowerFunc>()
.is_some() =>
Expr::ScalarFunction(ScalarFunction { func_def, mut args })
if is_pow(&func_def) && args.len() == 2 && base == args[0] =>
{
Ok(ExprSimplifyResult::Simplified(args[1].clone()))
let b = args.pop().unwrap(); // length checked above
Ok(ExprSimplifyResult::Simplified(b))
}
_ => {
number => {
if number == base {
Ok(ExprSimplifyResult::Simplified(Expr::Literal(
ScalarValue::new_one(&info.get_data_type(number)?)?,
)))
Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_one(
&number_datatype,
)?)))
} else {
let args = match num_args {
1 => vec![number],
2 => vec![number, base],
_ => {
return internal_err!(
"Unexpected number of arguments in log::simplify"
)
}
};
Ok(ExprSimplifyResult::Original(args))
}
}
}
}
}

/// Returns true if the function is `PowerFunc`
fn is_pow(func_def: &ScalarFunctionDefinition) -> bool {
if let ScalarFunctionDefinition::UDF(fun) = func_def {
fun.as_ref()
.inner()
.as_any()
.downcast_ref::<PowerFunc>()
.is_some()
} else {
false
}
}

#[cfg(test)]
mod tests {
use datafusion_common::cast::{as_float32_array, as_float64_array};
Expand Down

0 comments on commit a8d24d6

Please sign in to comment.