Skip to content

Commit

Permalink
Merge remote-tracking branch 'apache/main' into alamb/type_coercion
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed May 3, 2024
2 parents 5c1f2c4 + 6480020 commit 3105658
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 5 deletions.
12 changes: 9 additions & 3 deletions datafusion/common/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -625,17 +625,23 @@ impl<T> Transformed<T> {
Self::new(data, false, TreeNodeRecursion::Continue)
}

/// Applies the given `f` to the data of this [`Transformed`] object.
/// Applies an infallible `f` to the data of this [`Transformed`] object,
/// without modifying the `transformed` flag.
pub fn update_data<U, F: FnOnce(T) -> U>(self, f: F) -> Transformed<U> {
Transformed::new(f(self.data), self.transformed, self.tnr)
}

/// Maps the data of [`Transformed`] object to the result of the given `f`.
/// Applies a fallible `f` (returns `Result`) to the data of this
/// [`Transformed`] object, without modifying the `transformed` flag.
pub fn map_data<U, F: FnOnce(T) -> Result<U>>(self, f: F) -> Result<Transformed<U>> {
f(self.data).map(|data| Transformed::new(data, self.transformed, self.tnr))
}

/// Maps the [`Transformed`] object to the result of the given `f`.
/// Applies a fallible transforming `f` to the data of this [`Transformed`]
/// object.
///
/// The returned `Transformed` object has the `transformed` flag set if either
/// `self` or the return value of `f` have the `transformed` flag set.
pub fn transform_data<U, F: FnOnce(T) -> Result<Transformed<U>>>(
self,
f: F,
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/simplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ impl<'a> SimplifyInfo for SimplifyContext<'a> {
}

/// Was the expression simplified?
#[derive(Debug)]
pub enum ExprSimplifyResult {
/// The function call was simplified to an entirely new Expr
Simplified(Expr),
Expand Down
50 changes: 48 additions & 2 deletions datafusion/functions/src/math/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ impl ScalarUDFImpl for LogFunc {
} else {
let args = match num_args {
1 => vec![number],
2 => vec![number, base],
2 => vec![base, number],
_ => {
return internal_err!(
"Unexpected number of arguments in log::simplify"
Expand Down Expand Up @@ -220,7 +220,13 @@ fn is_pow(func_def: &ScalarFunctionDefinition) -> bool {

#[cfg(test)]
mod tests {
use datafusion_common::cast::{as_float32_array, as_float64_array};
use std::collections::HashMap;

use datafusion_common::{
cast::{as_float32_array, as_float64_array},
DFSchema,
};
use datafusion_expr::{execution_props::ExecutionProps, simplify::SimplifyContext};

use super::*;

Expand Down Expand Up @@ -283,4 +289,44 @@ mod tests {
}
}
}
#[test]
// Test log() simplification errors
fn test_log_simplify_errors() {
let props = ExecutionProps::new();
let schema =
Arc::new(DFSchema::new_with_metadata(vec![], HashMap::new()).unwrap());
let context = SimplifyContext::new(&props).with_schema(schema);
// Expect 0 args to error
let _ = LogFunc::new().simplify(vec![], &context).unwrap_err();
// Expect 3 args to error
let _ = LogFunc::new()
.simplify(vec![lit(1), lit(2), lit(3)], &context)
.unwrap_err();
}

#[test]
// Test that non-simplifiable log() expressions are unchanged after simplification
fn test_log_simplify_original() {
let props = ExecutionProps::new();
let schema =
Arc::new(DFSchema::new_with_metadata(vec![], HashMap::new()).unwrap());
let context = SimplifyContext::new(&props).with_schema(schema);
// One argument with no simplifications
let result = LogFunc::new().simplify(vec![lit(2)], &context).unwrap();
let ExprSimplifyResult::Original(args) = result else {
panic!("Expected ExprSimplifyResult::Original")
};
assert_eq!(args.len(), 1);
assert_eq!(args[0], lit(2));
// Two arguments with no simplifications
let result = LogFunc::new()
.simplify(vec![lit(2), lit(3)], &context)
.unwrap();
let ExprSimplifyResult::Original(args) = result else {
panic!("Expected ExprSimplifyResult::Original")
};
assert_eq!(args.len(), 2);
assert_eq!(args[0], lit(2));
assert_eq!(args[1], lit(3));
}
}

0 comments on commit 3105658

Please sign in to comment.