diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index 744a0189125c..9bc038e71edc 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -112,26 +112,6 @@ macro_rules! make_stub_package { }; } -/// Invokes a function on each element of an array and returns the result as a new array -/// -/// $ARG: ArrayRef -/// $NAME: name of the function (for error messages) -/// $ARGS_TYPE: the type of array to cast the argument to -/// $RETURN_TYPE: the type of array to return -/// $FUNC: the function to apply to each element of $ARG -macro_rules! make_function_scalar_inputs_return_type { - ($ARG: expr, $NAME:expr, $ARG_TYPE:ident, $RETURN_TYPE:ident, $FUNC: block) => {{ - let arg = downcast_arg!($ARG, $NAME, $ARG_TYPE); - - arg.iter() - .map(|a| match a { - Some(a) => Some($FUNC(a)), - _ => None, - }) - .collect::<$RETURN_TYPE>() - }}; -} - /// Downcast an argument to a specific array type, returning an internal error /// if the cast fails /// @@ -168,9 +148,9 @@ macro_rules! make_math_unary_udf { use std::any::Any; use std::sync::Arc; - use arrow::array::{ArrayRef, Float32Array, Float64Array}; - use arrow::datatypes::DataType; - use datafusion_common::{exec_err, DataFusionError, Result}; + use arrow::array::{ArrayRef, AsArray}; + use arrow::datatypes::{DataType, Float32Type, Float64Type}; + use datafusion_common::{exec_err, Result}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ @@ -231,24 +211,16 @@ macro_rules! make_math_unary_udf { fn invoke(&self, args: &[ColumnarValue]) -> Result { let args = ColumnarValue::values_to_arrays(args)?; let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => { - Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - self.name(), - Float64Array, - Float64Array, - { f64::$UNARY_FUNC } - )) - } - DataType::Float32 => { - Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - self.name(), - Float32Array, - Float32Array, - { f32::$UNARY_FUNC } - )) - } + DataType::Float64 => Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float64Type>(|x: f64| f64::$UNARY_FUNC(x)), + ) as ArrayRef, + DataType::Float32 => Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float32Type>(|x: f32| f32::$UNARY_FUNC(x)), + ) as ArrayRef, other => { return exec_err!( "Unsupported data type {other:?} for function {}", @@ -286,9 +258,9 @@ macro_rules! make_math_binary_udf { use std::any::Any; use std::sync::Arc; - use arrow::array::{ArrayRef, Float32Array, Float64Array}; - use arrow::datatypes::DataType; - use datafusion_common::{exec_err, DataFusionError, Result}; + use arrow::array::{ArrayRef, AsArray}; + use arrow::datatypes::{DataType, Float32Type, Float64Type}; + use datafusion_common::{exec_err, Result}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature; use datafusion_expr::{ @@ -347,23 +319,26 @@ macro_rules! make_math_binary_udf { fn invoke(&self, args: &[ColumnarValue]) -> Result { let args = ColumnarValue::values_to_arrays(args)?; let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "y", - "x", - Float64Array, - { f64::$BINARY_FUNC } - )), - - DataType::Float32 => Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "y", - "x", - Float32Array, - { f32::$BINARY_FUNC } - )), + DataType::Float64 => { + let y = args[0].as_primitive::(); + let x = args[1].as_primitive::(); + let result = arrow::compute::binary::<_, _, _, Float64Type>( + y, + x, + |y, x| f64::$BINARY_FUNC(y, x), + )?; + Arc::new(result) as _ + } + DataType::Float32 => { + let y = args[0].as_primitive::(); + let x = args[1].as_primitive::(); + let result = arrow::compute::binary::<_, _, _, Float32Type>( + y, + x, + |y, x| f32::$BINARY_FUNC(y, x), + )?; + Arc::new(result) as _ + } other => { return exec_err!( "Unsupported data type {other:?} for function {}", @@ -382,43 +357,3 @@ macro_rules! make_math_binary_udf { } }; } - -macro_rules! make_function_scalar_inputs { - ($ARG: expr, $NAME:expr, $ARRAY_TYPE:ident, $FUNC: block) => {{ - let arg = downcast_arg!($ARG, $NAME, $ARRAY_TYPE); - - arg.iter() - .map(|a| match a { - Some(a) => Some($FUNC(a)), - _ => None, - }) - .collect::<$ARRAY_TYPE>() - }}; -} - -macro_rules! make_function_inputs2 { - ($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE:ident, $FUNC: block) => {{ - let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE); - let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE); - - arg1.iter() - .zip(arg2.iter()) - .map(|(a1, a2)| match (a1, a2) { - (Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)), - _ => None, - }) - .collect::<$ARRAY_TYPE>() - }}; - ($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE1:ident, $ARRAY_TYPE2:ident, $FUNC: block) => {{ - let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE1); - let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE2); - - arg1.iter() - .zip(arg2.iter()) - .map(|(a1, a2)| match (a1, a2) { - (Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)), - _ => None, - }) - .collect::<$ARRAY_TYPE1>() - }}; -} diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index 89ba14e32ba0..f82c0df34e27 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -22,11 +22,10 @@ use std::sync::{Arc, OnceLock}; use super::power::PowerFunc; -use arrow::array::{ArrayRef, Float32Array, Float64Array}; -use arrow::datatypes::DataType; +use arrow::array::{ArrayRef, AsArray}; +use arrow::datatypes::{DataType, Float32Type, Float64Type}; use datafusion_common::{ - exec_err, internal_err, plan_datafusion_err, plan_err, DataFusionError, Result, - ScalarValue, + exec_err, internal_err, plan_datafusion_err, plan_err, Result, ScalarValue, }; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; @@ -140,37 +139,40 @@ impl ScalarUDFImpl for LogFunc { let arr: ArrayRef = match args[0].data_type() { DataType::Float64 => match base { ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => { - Arc::new(make_function_scalar_inputs!(x, "x", Float64Array, { - |value: f64| f64::log(value, base as f64) - })) + Arc::new(x.as_primitive::().unary::<_, Float64Type>( + |value: f64| f64::log(value, base as f64), + )) + } + ColumnarValue::Array(base) => { + let x = x.as_primitive::(); + let base = base.as_primitive::(); + let result = arrow::compute::binary::<_, _, _, Float64Type>( + x, + base, + f64::log, + )?; + Arc::new(result) as _ } - ColumnarValue::Array(base) => Arc::new(make_function_inputs2!( - x, - base, - "x", - "base", - Float64Array, - { f64::log } - )), _ => { return exec_err!("log function requires a scalar or array for base") } }, DataType::Float32 => match base { - ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => { - Arc::new(make_function_scalar_inputs!(x, "x", Float32Array, { - |value: f32| f32::log(value, base) - })) + ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => Arc::new( + x.as_primitive::() + .unary::<_, Float32Type>(|value: f32| f32::log(value, base)), + ), + ColumnarValue::Array(base) => { + let x = x.as_primitive::(); + let base = base.as_primitive::(); + let result = arrow::compute::binary::<_, _, _, Float32Type>( + x, + base, + f32::log, + )?; + Arc::new(result) as _ } - ColumnarValue::Array(base) => Arc::new(make_function_inputs2!( - x, - base, - "x", - "base", - Float32Array, - { f32::log } - )), _ => { return exec_err!("log function requires a scalar or array for base") } @@ -259,6 +261,7 @@ mod tests { use super::*; + use arrow::array::{Float32Array, Float64Array}; use arrow::compute::SortOptions; use datafusion_common::cast::{as_float32_array, as_float64_array}; use datafusion_common::DFSchema; diff --git a/datafusion/functions/src/math/nanvl.rs b/datafusion/functions/src/math/nanvl.rs index b82ee0d45744..cfd21256dd96 100644 --- a/datafusion/functions/src/math/nanvl.rs +++ b/datafusion/functions/src/math/nanvl.rs @@ -18,10 +18,11 @@ use std::any::Any; use std::sync::{Arc, OnceLock}; -use arrow::array::{ArrayRef, Float32Array, Float64Array}; -use arrow::datatypes::DataType; -use arrow::datatypes::DataType::{Float32, Float64}; +use crate::utils::make_scalar_function; +use arrow::array::{ArrayRef, AsArray, Float32Array, Float64Array}; +use arrow::datatypes::DataType::{Float32, Float64}; +use arrow::datatypes::{DataType, Float32Type, Float64Type}; use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; use datafusion_expr::TypeSignature::Exact; @@ -29,8 +30,6 @@ use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; -use crate::utils::make_scalar_function; - #[derive(Debug)] pub struct NanvlFunc { signature: Signature, @@ -113,14 +112,11 @@ fn nanvl(args: &[ArrayRef]) -> Result { } }; - Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "x", - "y", - Float64Array, - { compute_nanvl } - )) as ArrayRef) + let x = args[0].as_primitive() as &Float64Array; + let y = args[1].as_primitive() as &Float64Array; + arrow::compute::binary::<_, _, _, Float64Type>(x, y, compute_nanvl) + .map(|res| Arc::new(res) as _) + .map_err(DataFusionError::from) } Float32 => { let compute_nanvl = |x: f32, y: f32| { @@ -131,14 +127,11 @@ fn nanvl(args: &[ArrayRef]) -> Result { } }; - Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "x", - "y", - Float32Array, - { compute_nanvl } - )) as ArrayRef) + let x = args[0].as_primitive() as &Float32Array; + let y = args[1].as_primitive() as &Float32Array; + arrow::compute::binary::<_, _, _, Float32Type>(x, y, compute_nanvl) + .map(|res| Arc::new(res) as _) + .map_err(DataFusionError::from) } other => exec_err!("Unsupported data type {other:?} for function nanvl"), } @@ -146,10 +139,12 @@ fn nanvl(args: &[ArrayRef]) -> Result { #[cfg(test)] mod test { + use std::sync::Arc; + use crate::math::nanvl::nanvl; + use arrow::array::{ArrayRef, Float32Array, Float64Array}; use datafusion_common::cast::{as_float32_array, as_float64_array}; - use std::sync::Arc; #[test] fn test_nanvl_f64() { diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 9125f9b0fecd..9bb6006d55b9 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -16,9 +16,13 @@ // under the License. //! Math function: `power()`. +use std::any::Any; +use std::sync::{Arc, OnceLock}; -use arrow::datatypes::{ArrowNativeTypeOp, DataType}; +use super::log::LogFunc; +use arrow::array::{ArrayRef, AsArray, Int64Array}; +use arrow::datatypes::{ArrowNativeTypeOp, DataType, Float64Type}; use datafusion_common::{ arrow_datafusion_err, exec_datafusion_err, exec_err, plan_datafusion_err, DataFusionError, Result, ScalarValue, @@ -27,13 +31,7 @@ use datafusion_expr::expr::ScalarFunction; use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ColumnarValue, Documentation, Expr, ScalarUDF, TypeSignature}; - -use arrow::array::{ArrayRef, Float64Array, Int64Array}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use std::any::Any; -use std::sync::{Arc, OnceLock}; - -use super::log::LogFunc; #[derive(Debug)] pub struct PowerFunc { @@ -90,15 +88,16 @@ impl ScalarUDFImpl for PowerFunc { let args = ColumnarValue::values_to_arrays(args)?; let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "base", - "exponent", - Float64Array, - { f64::powf } - )), - + DataType::Float64 => { + let bases = args[0].as_primitive::(); + let exponents = args[1].as_primitive::(); + let result = arrow::compute::binary::<_, _, _, Float64Type>( + bases, + exponents, + f64::powf, + )?; + Arc::new(result) as _ + } DataType::Int64 => { let bases = downcast_arg!(&args[0], "base", Int64Array); let exponents = downcast_arg!(&args[1], "exponent", Int64Array); @@ -116,7 +115,7 @@ impl ScalarUDFImpl for PowerFunc { _ => Ok(None), }) .collect::>() - .map(Arc::new)? as ArrayRef + .map(Arc::new)? as _ } other => { @@ -195,6 +194,7 @@ fn is_log(func: &ScalarUDF) -> bool { #[cfg(test)] mod tests { + use arrow::array::Float64Array; use datafusion_common::cast::{as_float64_array, as_int64_array}; use super::*; diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index fec1f1ce1aee..cf0f53a80a43 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -20,13 +20,11 @@ use std::sync::{Arc, OnceLock}; use crate::utils::make_scalar_function; -use arrow::array::{ArrayRef, Float32Array, Float64Array, Int32Array}; +use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; use arrow::compute::{cast_with_options, CastOptions}; -use arrow::datatypes::DataType; use arrow::datatypes::DataType::{Float32, Float64, Int32}; -use datafusion_common::{ - exec_datafusion_err, exec_err, DataFusionError, Result, ScalarValue, -}; +use arrow::datatypes::{DataType, Float32Type, Float64Type, Int32Type}; +use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue}; use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature::Exact; @@ -148,17 +146,13 @@ pub fn round(args: &[ArrayRef]) -> Result { ) })?; - Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "value", - Float64Array, - { - |value: f64| { - (value * 10.0_f64.powi(decimal_places)).round() - / 10.0_f64.powi(decimal_places) - } - } - )) as ArrayRef) + let result = args[0] + .as_primitive::() + .unary::<_, Float64Type>(|value: f64| { + (value * 10.0_f64.powi(decimal_places)).round() + / 10.0_f64.powi(decimal_places) + }); + Ok(Arc::new(result) as _) } ColumnarValue::Array(decimal_places) => { let options = CastOptions { @@ -169,20 +163,18 @@ pub fn round(args: &[ArrayRef]) -> Result { .map_err(|e| { exec_datafusion_err!("Invalid values for decimal places: {e}") })?; - Ok(Arc::new(make_function_inputs2!( - &args[0], + + let values = args[0].as_primitive::(); + let decimal_places = decimal_places.as_primitive::(); + let result = arrow::compute::binary::<_, _, _, Float64Type>( + values, decimal_places, - "value", - "decimal_places", - Float64Array, - Int32Array, - { - |value: f64, decimal_places: i32| { - (value * 10.0_f64.powi(decimal_places)).round() - / 10.0_f64.powi(decimal_places) - } - } - )) as ArrayRef) + |value, decimal_places| { + (value * 10.0_f64.powi(decimal_places)).round() + / 10.0_f64.powi(decimal_places) + }, + )?; + Ok(Arc::new(result) as _) } _ => { exec_err!("round function requires a scalar or array for decimal_places") @@ -196,18 +188,13 @@ pub fn round(args: &[ArrayRef]) -> Result { "Invalid value for decimal places: {decimal_places}: {e}" ) })?; - - Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "value", - Float32Array, - { - |value: f32| { - (value * 10.0_f32.powi(decimal_places)).round() - / 10.0_f32.powi(decimal_places) - } - } - )) as ArrayRef) + let result = args[0] + .as_primitive::() + .unary::<_, Float32Type>(|value: f32| { + (value * 10.0_f32.powi(decimal_places)).round() + / 10.0_f32.powi(decimal_places) + }); + Ok(Arc::new(result) as _) } ColumnarValue::Array(_) => { let ColumnarValue::Array(decimal_places) = @@ -218,20 +205,17 @@ pub fn round(args: &[ArrayRef]) -> Result { panic!("Unexpected result of ColumnarValue::Array.cast") }; - Ok(Arc::new(make_function_inputs2!( - &args[0], + let values = args[0].as_primitive::(); + let decimal_places = decimal_places.as_primitive::(); + let result: PrimitiveArray = arrow::compute::binary( + values, decimal_places, - "value", - "decimal_places", - Float32Array, - Int32Array, - { - |value: f32, decimal_places: i32| { - (value * 10.0_f32.powi(decimal_places)).round() - / 10.0_f32.powi(decimal_places) - } - } - )) as ArrayRef) + |value, decimal_places| { + (value * 10.0_f32.powi(decimal_places)).round() + / 10.0_f32.powi(decimal_places) + }, + )?; + Ok(Arc::new(result) as _) } _ => { exec_err!("round function requires a scalar or array for decimal_places")