Skip to content

Commit

Permalink
move Trunc, Cot, Round, iszero functions to datafusion-functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Omega359 committed Apr 8, 2024
1 parent ff7ac69 commit cf1d5e1
Show file tree
Hide file tree
Showing 18 changed files with 1,061 additions and 688 deletions.
61 changes: 7 additions & 54 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,12 @@ pub enum BuiltinScalarFunction {
Exp,
/// factorial
Factorial,
/// iszero
Iszero,
/// log, same as log10
Log,
/// nanvl
Nanvl,
/// power
Power,
/// round
Round,
/// trunc
Trunc,
/// cot
Cot,

// string functions
/// concat
Concat,
Expand Down Expand Up @@ -127,13 +118,9 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Coalesce => Volatility::Immutable,
BuiltinScalarFunction::Exp => Volatility::Immutable,
BuiltinScalarFunction::Factorial => Volatility::Immutable,
BuiltinScalarFunction::Iszero => Volatility::Immutable,
BuiltinScalarFunction::Log => Volatility::Immutable,
BuiltinScalarFunction::Nanvl => Volatility::Immutable,
BuiltinScalarFunction::Power => Volatility::Immutable,
BuiltinScalarFunction::Round => Volatility::Immutable,
BuiltinScalarFunction::Cot => Volatility::Immutable,
BuiltinScalarFunction::Trunc => Volatility::Immutable,
BuiltinScalarFunction::Concat => Volatility::Immutable,
BuiltinScalarFunction::ConcatWithSeparator => Volatility::Immutable,
BuiltinScalarFunction::EndsWith => Volatility::Immutable,
Expand Down Expand Up @@ -191,16 +178,12 @@ impl BuiltinScalarFunction {
_ => Ok(Float64),
},

BuiltinScalarFunction::Iszero => Ok(Boolean),

BuiltinScalarFunction::Ceil
| BuiltinScalarFunction::Exp
| BuiltinScalarFunction::Round
| BuiltinScalarFunction::Trunc
| BuiltinScalarFunction::Cot => match input_expr_types[0] {
Float32 => Ok(Float32),
_ => Ok(Float64),
},
BuiltinScalarFunction::Ceil | BuiltinScalarFunction::Exp => {
match input_expr_types[0] {
Float32 => Ok(Float32),
_ => Ok(Float64),
}
}
}
}

Expand Down Expand Up @@ -237,24 +220,6 @@ impl BuiltinScalarFunction {
vec![Exact(vec![Int64, Int64]), Exact(vec![Float64, Float64])],
self.volatility(),
),
BuiltinScalarFunction::Round => Signature::one_of(
vec![
Exact(vec![Float64, Int64]),
Exact(vec![Float32, Int64]),
Exact(vec![Float64]),
Exact(vec![Float32]),
],
self.volatility(),
),
BuiltinScalarFunction::Trunc => Signature::one_of(
vec![
Exact(vec![Float32, Int64]),
Exact(vec![Float64, Int64]),
Exact(vec![Float64]),
Exact(vec![Float32]),
],
self.volatility(),
),

BuiltinScalarFunction::Log => Signature::one_of(
vec![
Expand All @@ -272,20 +237,14 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Factorial => {
Signature::uniform(1, vec![Int64], self.volatility())
}
BuiltinScalarFunction::Ceil
| BuiltinScalarFunction::Exp
| BuiltinScalarFunction::Cot => {
BuiltinScalarFunction::Ceil | BuiltinScalarFunction::Exp => {
// math expressions expect 1 argument of type f64 or f32
// priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we
// return the best approximation for it (in f64).
// We accept f32 because in this case it is clear that the best approximation
// will be as good as the number of digits in the number
Signature::uniform(1, vec![Float64, Float32], self.volatility())
}
BuiltinScalarFunction::Iszero => Signature::one_of(
vec![Exact(vec![Float32]), Exact(vec![Float64])],
self.volatility(),
),
}
}

Expand All @@ -298,8 +257,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Ceil
| BuiltinScalarFunction::Exp
| BuiltinScalarFunction::Factorial
| BuiltinScalarFunction::Round
| BuiltinScalarFunction::Trunc
) {
Some(vec![Some(true)])
} else if *self == BuiltinScalarFunction::Log {
Expand All @@ -313,16 +270,12 @@ impl BuiltinScalarFunction {
pub fn aliases(&self) -> &'static [&'static str] {
match self {
BuiltinScalarFunction::Ceil => &["ceil"],
BuiltinScalarFunction::Cot => &["cot"],
BuiltinScalarFunction::Exp => &["exp"],
BuiltinScalarFunction::Factorial => &["factorial"],
BuiltinScalarFunction::Iszero => &["iszero"],
BuiltinScalarFunction::Log => &["log"],
BuiltinScalarFunction::Nanvl => &["nanvl"],
BuiltinScalarFunction::Power => &["power", "pow"],
BuiltinScalarFunction::Random => &["random"],
BuiltinScalarFunction::Round => &["round"],
BuiltinScalarFunction::Trunc => &["trunc"],

// conditional functions
BuiltinScalarFunction::Coalesce => &["coalesce"],
Expand Down
46 changes: 1 addition & 45 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -530,20 +530,14 @@ macro_rules! nary_scalar_expr {
// generate methods for creating the supported unary/binary expressions

// math functions
scalar_expr!(Cot, cot, num, "cotangent of a number");
scalar_expr!(Factorial, factorial, num, "factorial");
scalar_expr!(
Ceil,
ceil,
num,
"nearest integer greater than or equal to argument"
);
nary_scalar_expr!(Round, round, "round to nearest integer");
nary_scalar_expr!(
Trunc,
trunc,
"truncate toward zero, with optional precision"
);

scalar_expr!(Exp, exp, num, "exponential");

scalar_expr!(Power, power, base exponent, "`base` raised to the power of `exponent`");
Expand All @@ -560,12 +554,6 @@ nary_scalar_expr!(
);
nary_scalar_expr!(Concat, concat_expr, "concatenates several strings");
scalar_expr!(Nanvl, nanvl, x y, "returns x if x is not NaN otherwise returns y");
scalar_expr!(
Iszero,
iszero,
num,
"returns true if a given number is +0.0 or -0.0 otherwise returns false"
);

/// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression.
pub fn case(expr: Expr) -> CaseBuilder {
Expand Down Expand Up @@ -875,12 +863,6 @@ impl WindowUDFImpl for SimpleWindowUDF {
}

/// Calls a named built in function
/// ```
/// use datafusion_expr::{col, lit, call_fn};
///
/// // create the expression trunc(x) < 0.2
/// let expr = call_fn("trunc", vec![col("x")]).unwrap().lt(lit(0.2));
/// ```
pub fn call_fn(name: impl AsRef<str>, args: Vec<Expr>) -> Result<Expr> {
match name.as_ref().parse::<BuiltinScalarFunction>() {
Ok(fun) => Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))),
Expand Down Expand Up @@ -938,38 +920,12 @@ mod test {
};
}

macro_rules! test_nary_scalar_expr {
($ENUM:ident, $FUNC:ident, $($arg:ident),*) => {
let expected = [$(stringify!($arg)),*];
let result = $FUNC(
vec![
$(
col(stringify!($arg.to_string()))
),*
]
);
if let Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(fun), args }) = result {
let name = built_in_function::BuiltinScalarFunction::$ENUM;
assert_eq!(name, fun);
assert_eq!(expected.len(), args.len());
} else {
assert!(false, "unexpected: {:?}", result);
}
};
}

#[test]
fn scalar_function_definitions() {
test_unary_scalar_expr!(Cot, cot);
test_unary_scalar_expr!(Factorial, factorial);
test_unary_scalar_expr!(Ceil, ceil);
test_nary_scalar_expr!(Round, round, input);
test_nary_scalar_expr!(Round, round, input, decimal_places);
test_nary_scalar_expr!(Trunc, trunc, num);
test_nary_scalar_expr!(Trunc, trunc, num, precision);
test_unary_scalar_expr!(Exp, exp);
test_scalar_expr!(Nanvl, nanvl, x, y);
test_scalar_expr!(Iszero, iszero, input);

test_scalar_expr!(InitCap, initcap, string);
test_scalar_expr!(EndsWith, ends_with, string, characters);
Expand Down
13 changes: 13 additions & 0 deletions datafusion/functions/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,19 @@ macro_rules! make_stub_package {
};
}

macro_rules! make_function_scalar_inputs {
($ARG: expr, $NAME:expr, $ARRAY_TYPE:ident, $FUNC: block) => {{
let arg = downcast_arg!($ARG, $NAME, $ARRAY_TYPE);

arg.iter()
.map(|a| match a {
Some(a) => Some($FUNC(a)),
_ => None,
})
.collect::<$ARRAY_TYPE>()
}};
}

/// Invokes a function on each element of an array and returns the result as a new array
///
/// $ARG: ArrayRef
Expand Down
Loading

0 comments on commit cf1d5e1

Please sign in to comment.