Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed last usages of scalar_inputs, scalar_input_types and inputs2 to use arrow unary/binary for performance #12972

Merged
merged 10 commits into from
Oct 21, 2024
13 changes: 0 additions & 13 deletions datafusion/functions/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -383,19 +383,6 @@ macro_rules! make_math_binary_udf {
};
}

macro_rules! make_function_scalar_inputs {
($ARG: expr, $NAME:expr, $ARRAY_TYPE:ident, $FUNC: block) => {{
let arg = downcast_arg!($ARG, $NAME, $ARRAY_TYPE);

arg.iter()
.map(|a| match a {
Some(a) => Some($FUNC(a)),
_ => None,
})
.collect::<$ARRAY_TYPE>()
}};
}

macro_rules! make_function_inputs2 {
($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE:ident, $FUNC: block) => {{
let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE);
Expand Down
22 changes: 10 additions & 12 deletions datafusion/functions/src/math/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ use std::sync::{Arc, OnceLock};

use super::power::PowerFunc;

use arrow::array::{ArrayRef, Float32Array, Float64Array};
use arrow::datatypes::DataType;
use arrow::array::{ArrayRef, AsArray, Float32Array, Float64Array};
use arrow::datatypes::{DataType, Float32Type, Float64Type};
use datafusion_common::{
exec_err, internal_err, plan_datafusion_err, plan_err, DataFusionError, Result,
ScalarValue,
Expand Down Expand Up @@ -139,11 +139,10 @@ impl ScalarUDFImpl for LogFunc {
// note in f64::log params order is different than in sql. e.g in sql log(base, x) == f64::log(x, base)
let arr: ArrayRef = match args[0].data_type() {
DataType::Float64 => match base {
ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => {
Arc::new(make_function_scalar_inputs!(x, "x", Float64Array, {
|value: f64| f64::log(value, base as f64)
}))
}
ColumnarValue::Scalar(ScalarValue::Float64(Some(base))) => Arc::new(
x.as_primitive::<Float64Type>()
.unary::<_, Float64Type>(|value: f64| f64::log(value, base)),
),
ColumnarValue::Array(base) => Arc::new(make_function_inputs2!(
x,
base,
Expand All @@ -158,11 +157,10 @@ impl ScalarUDFImpl for LogFunc {
},

DataType::Float32 => match base {
ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => {
Arc::new(make_function_scalar_inputs!(x, "x", Float32Array, {
|value: f32| f32::log(value, base)
}))
}
ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => Arc::new(
x.as_primitive::<Float32Type>()
.unary::<_, Float32Type>(|value: f32| f32::log(value, base)),
),
ColumnarValue::Array(base) => Arc::new(make_function_inputs2!(
x,
base,
Expand Down
50 changes: 26 additions & 24 deletions datafusion/functions/src/math/round.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ use std::sync::{Arc, OnceLock};

use crate::utils::make_scalar_function;

use arrow::array::{ArrayRef, Float32Array, Float64Array, Int32Array};
use arrow::array::{ArrayRef, AsArray, Float32Array, Float64Array, Int32Array};
use arrow::compute::{cast_with_options, CastOptions};
use arrow::datatypes::DataType;
use arrow::datatypes::DataType::{Float32, Float64, Int32};
use arrow::datatypes::{DataType, Float32Type, Float64Type};
use datafusion_common::{
exec_datafusion_err, exec_err, DataFusionError, Result, ScalarValue,
};
Expand Down Expand Up @@ -148,17 +148,18 @@ pub fn round(args: &[ArrayRef]) -> Result<ArrayRef> {
)
})?;

Ok(Arc::new(make_function_scalar_inputs!(
&args[0],
"value",
Float64Array,
{
|value: f64| {
(value * 10.0_f64.powi(decimal_places)).round()
/ 10.0_f64.powi(decimal_places)
}
}
)) as ArrayRef)
Ok(Arc::new(
args[0]
.as_primitive::<Float64Type>()
.unary::<_, Float64Type>(|value: f64| {
if value == 0_f64 {
0_f64
} else {
(value * 10.0_f64.powi(decimal_places)).round()
/ 10.0_f64.powi(decimal_places)
}
}),
) as ArrayRef)
}
ColumnarValue::Array(decimal_places) => {
let options = CastOptions {
Expand Down Expand Up @@ -197,17 +198,18 @@ pub fn round(args: &[ArrayRef]) -> Result<ArrayRef> {
)
})?;

Ok(Arc::new(make_function_scalar_inputs!(
&args[0],
"value",
Float32Array,
{
|value: f32| {
(value * 10.0_f32.powi(decimal_places)).round()
/ 10.0_f32.powi(decimal_places)
}
}
)) as ArrayRef)
Ok(Arc::new(
args[0]
.as_primitive::<Float32Type>()
.unary::<_, Float32Type>(|value: f32| {
if value == 0_f32 {
0_f32
} else {
(value * 10.0_f32.powi(decimal_places)).round()
/ 10.0_f32.powi(decimal_places)
}
}),
) as ArrayRef)
}
ColumnarValue::Array(_) => {
let ColumnarValue::Array(decimal_places) =
Expand Down