From b21bf9e8527de82e901c3a61127d63779f230163 Mon Sep 17 00:00:00 2001 From: Adam Curtis Date: Thu, 2 May 2024 20:54:54 -0400 Subject: [PATCH 1/2] fix: LogFunc simplify swaps arguments (#10360) * fix: LogFunc simplify swaps arguments * refactor tests with let else --- datafusion/expr/src/simplify.rs | 1 + datafusion/functions/src/math/log.rs | 50 ++++++++++++++++++++++++++-- 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/simplify.rs b/datafusion/expr/src/simplify.rs index 6fae31b4a698c..ccf45ff0d0486 100644 --- a/datafusion/expr/src/simplify.rs +++ b/datafusion/expr/src/simplify.rs @@ -109,6 +109,7 @@ impl<'a> SimplifyInfo for SimplifyContext<'a> { } /// Was the expression simplified? +#[derive(Debug)] pub enum ExprSimplifyResult { /// The function call was simplified to an entirely new Expr Simplified(Expr), diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index b828739126474..f451321ea1201 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -192,7 +192,7 @@ impl ScalarUDFImpl for LogFunc { } else { let args = match num_args { 1 => vec![number], - 2 => vec![number, base], + 2 => vec![base, number], _ => { return internal_err!( "Unexpected number of arguments in log::simplify" @@ -220,7 +220,13 @@ fn is_pow(func_def: &ScalarFunctionDefinition) -> bool { #[cfg(test)] mod tests { - use datafusion_common::cast::{as_float32_array, as_float64_array}; + use std::collections::HashMap; + + use datafusion_common::{ + cast::{as_float32_array, as_float64_array}, + DFSchema, + }; + use datafusion_expr::{execution_props::ExecutionProps, simplify::SimplifyContext}; use super::*; @@ -283,4 +289,44 @@ mod tests { } } } + #[test] + // Test log() simplification errors + fn test_log_simplify_errors() { + let props = ExecutionProps::new(); + let schema = + Arc::new(DFSchema::new_with_metadata(vec![], HashMap::new()).unwrap()); + let context = SimplifyContext::new(&props).with_schema(schema); + // Expect 0 args to error + let _ = LogFunc::new().simplify(vec![], &context).unwrap_err(); + // Expect 3 args to error + let _ = LogFunc::new() + .simplify(vec![lit(1), lit(2), lit(3)], &context) + .unwrap_err(); + } + + #[test] + // Test that non-simplifiable log() expressions are unchanged after simplification + fn test_log_simplify_original() { + let props = ExecutionProps::new(); + let schema = + Arc::new(DFSchema::new_with_metadata(vec![], HashMap::new()).unwrap()); + let context = SimplifyContext::new(&props).with_schema(schema); + // One argument with no simplifications + let result = LogFunc::new().simplify(vec![lit(2)], &context).unwrap(); + let ExprSimplifyResult::Original(args) = result else { + panic!("Expected ExprSimplifyResult::Original") + }; + assert_eq!(args.len(), 1); + assert_eq!(args[0], lit(2)); + // Two arguments with no simplifications + let result = LogFunc::new() + .simplify(vec![lit(2), lit(3)], &context) + .unwrap(); + let ExprSimplifyResult::Original(args) = result else { + panic!("Expected ExprSimplifyResult::Original") + }; + assert_eq!(args.len(), 2); + assert_eq!(args[0], lit(2)); + assert_eq!(args[1], lit(3)); + } } From 6480020e695ebbe2b81e8971c3ee0e9e7ec124b0 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 3 May 2024 05:52:04 -0400 Subject: [PATCH 2/2] Refine documentation for `Transformed::{update,map,transform})_data` (#10355) * Refine documentation for `Transformed::{update,map,transform})_data` * Update datafusion/common/src/tree_node.rs Co-authored-by: comphead --------- Co-authored-by: comphead --- datafusion/common/src/tree_node.rs | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 43026f3a9206e..9d42f4fb1e0d4 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -625,17 +625,23 @@ impl Transformed { Self::new(data, false, TreeNodeRecursion::Continue) } - /// Applies the given `f` to the data of this [`Transformed`] object. + /// Applies an infallible `f` to the data of this [`Transformed`] object, + /// without modifying the `transformed` flag. pub fn update_data U>(self, f: F) -> Transformed { Transformed::new(f(self.data), self.transformed, self.tnr) } - /// Maps the data of [`Transformed`] object to the result of the given `f`. + /// Applies a fallible `f` (returns `Result`) to the data of this + /// [`Transformed`] object, without modifying the `transformed` flag. pub fn map_data Result>(self, f: F) -> Result> { f(self.data).map(|data| Transformed::new(data, self.transformed, self.tnr)) } - /// Maps the [`Transformed`] object to the result of the given `f`. + /// Applies a fallible transforming `f` to the data of this [`Transformed`] + /// object. + /// + /// The returned `Transformed` object has the `transformed` flag set if either + /// `self` or the return value of `f` have the `transformed` flag set. pub fn transform_data Result>>( self, f: F,