diff --git a/datafusion-examples/examples/simplify_udaf_expression.rs b/datafusion-examples/examples/simplify_udaf_expression.rs new file mode 100644 index 0000000000000..34b2ce5b8a7bc --- /dev/null +++ b/datafusion-examples/examples/simplify_udaf_expression.rs @@ -0,0 +1,183 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::{Field, Schema}; +use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; + +use std::{any::Any, sync::Arc}; + +use datafusion::arrow::{array::Float32Array, record_batch::RecordBatch}; +use datafusion::error::Result; +use datafusion::{assert_batches_eq, prelude::*}; +use datafusion_common::cast::as_float64_array; +use datafusion_expr::{ + expr::{AggregateFunction, AggregateFunctionDefinition}, + function::AccumulatorArgs, + simplify::ExprSimplifyResult, + Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature, +}; + +/// This example shows how to use the AggregateUDFImpl::simplify API to simplify/replace user +/// defined aggregate function with a different expression which is defined in the `simplify` method. + +#[derive(Debug, Clone)] +struct BetterAvgUdaf { + signature: Signature, +} + +impl BetterAvgUdaf { + /// Create a new instance of the GeoMeanUdaf struct + fn new() -> Self { + Self { + signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for BetterAvgUdaf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "better_avg" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + unimplemented!("should not be invoked") + } + + fn state_fields( + &self, + _name: &str, + _value_type: DataType, + _ordering_fields: Vec, + ) -> Result> { + unimplemented!("should not be invoked") + } + + fn groups_accumulator_supported(&self) -> bool { + true + } + + fn create_groups_accumulator(&self) -> Result> { + unimplemented!("should not get here"); + } + // we override method, to return new expression which would substitute + // user defined function call + fn simplify( + &self, + args: Vec, + _distinct: &bool, + _filter: &Option>, + _order_by: &Option>, + _null_treatment: &Option, + _info: &dyn SimplifyInfo, + ) -> Result { + // as an example for this functionality we replace UDF function + // with build-in aggregate function to illustrate the use + let expr = Expr::AggregateFunction(AggregateFunction { + func_def: AggregateFunctionDefinition::BuiltIn( + // yes it is the same Avg, `BetterAvgUdaf` was just a + // marketing pitch :) + datafusion_expr::aggregate_function::AggregateFunction::Avg, + ), + args, + distinct: false, + filter: None, + order_by: None, + null_treatment: None, + }); + + Ok(ExprSimplifyResult::Simplified(expr)) + } +} + +// create local session context with an in-memory table +fn create_context() -> Result { + use datafusion::datasource::MemTable; + // define a schema. + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float32, false), + Field::new("b", DataType::Float32, false), + ])); + + // define data in two partitions + let batch1 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])), + Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])), + ], + )?; + let batch2 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Float32Array::from(vec![16.0])), + Arc::new(Float32Array::from(vec![2.0])), + ], + )?; + + let ctx = SessionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?; + ctx.register_table("t", Arc::new(provider))?; + Ok(ctx) +} + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = create_context()?; + + let better_avg = AggregateUDF::from(BetterAvgUdaf::new()); + ctx.register_udaf(better_avg.clone()); + + let result = ctx + .sql("SELECT better_avg(a) FROM t group by b") + .await? + .collect() + .await?; + let expected = vec![ + "+-----------------+", + "| better_avg(t.a) |", + "+-----------------+", + "| 7.5 |", + "+-----------------+", + ]; + + assert_batches_eq!(expected, &result); + + let df = ctx.table("t").await?; + let df = df.aggregate(vec![], vec![better_avg.call(vec![col("a")])])?; + + let results = df.collect().await?; + let result = as_float64_array(results[0].column(0))?; + + assert!((result.value(0) - 7.5).abs() < f64::EPSILON); + println!("The average of [2,4,8,16] is {}", result.value(0)); + + Ok(()) +} diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 67c3b51ca3739..08e2b5653485d 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -19,11 +19,13 @@ use crate::function::AccumulatorArgs; use crate::groups_accumulator::GroupsAccumulator; +use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; use crate::utils::format_state_name; use crate::{Accumulator, Expr}; use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature}; use arrow::datatypes::{DataType, Field}; use datafusion_common::{not_impl_err, Result}; +use sqlparser::ast::NullTreatment; use std::any::Any; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; @@ -195,6 +197,21 @@ impl AggregateUDF { pub fn create_groups_accumulator(&self) -> Result> { self.inner.create_groups_accumulator() } + /// Do the function rewrite + /// + /// See [`AggregateUDFImpl::simplify`] for more details. + pub fn simplify( + &self, + args: Vec, + distinct: &bool, + filter: &Option>, + order_by: &Option>, + null_treatment: &Option, + info: &dyn SimplifyInfo, + ) -> Result { + self.inner + .simplify(args, distinct, filter, order_by, null_treatment, info) + } } impl From for AggregateUDF @@ -354,6 +371,37 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { fn aliases(&self) -> &[String] { &[] } + + /// Optionally apply per-UDF simplification / rewrite rules. + /// + /// This can be used to apply function specific simplification rules during + /// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default + /// implementation does nothing. + /// + /// Note that DataFusion handles simplifying arguments and "constant + /// folding" (replacing a function call with constant arguments such as + /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such + /// optimizations manually for specific UDFs. + /// + /// # Arguments + /// * 'args': The arguments of the function + /// * 'schema': The schema of the function + /// + /// # Returns + /// [`ExprSimplifyResult`] indicating the result of the simplification NOTE + /// if the function cannot be simplified, the arguments *MUST* be returned + /// unmodified + fn simplify( + &self, + args: Vec, + _distinct: &bool, + _filter: &Option>, + _order_by: &Option>, + _null_treatment: &Option, + _info: &dyn SimplifyInfo, + ) -> Result { + Ok(ExprSimplifyResult::Original(args)) + } } /// AggregateUDF that adds an alias to the underlying function. It is better to diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 4d7a207afb1b6..8dca729c3d3ea 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -32,7 +32,9 @@ use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue}; -use datafusion_expr::expr::{InList, InSubquery}; +use datafusion_expr::expr::{ + AggregateFunction, AggregateFunctionDefinition, InList, InSubquery, +}; use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, @@ -1307,6 +1309,39 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { ExprSimplifyResult::Simplified(expr) => Transformed::yes(expr), }, + Expr::AggregateFunction(AggregateFunction { + func_def: AggregateFunctionDefinition::UDF(udaf), + args, + ref distinct, + ref filter, + ref order_by, + ref null_treatment, + }) => { + match udaf.simplify( + args, + distinct, + filter, + order_by, + null_treatment, + info, + )? { + ExprSimplifyResult::Simplified(simplified) => { + Transformed::yes(simplified) + } + ExprSimplifyResult::Original(args) => { + let expr = Expr::AggregateFunction(AggregateFunction { + func_def: AggregateFunctionDefinition::UDF(udaf), + args, + distinct: *distinct, + filter: filter.clone(), + order_by: order_by.clone(), + null_treatment: *null_treatment, + }); + Transformed::no(expr) + } + } + } + // // Rules for Between // @@ -3575,4 +3610,101 @@ mod tests { assert_eq!(simplify(expr), expected); } + #[test] + fn test_simplify_udaf() { + let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_with_simplify()); + let aggregate_function_expr = + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + udaf.into(), + vec![], + false, + None, + None, + None, + )); + + let expected = col("result_column"); + assert_eq!(simplify(aggregate_function_expr), expected); + + let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_without_simplify()); + let aggregate_function_expr = + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + udaf.into(), + vec![], + false, + None, + None, + None, + )); + + let expected = aggregate_function_expr.clone(); + assert_eq!(simplify(aggregate_function_expr), expected); + } + + /// A Mock UDAF which defines `simplify` to be used in tests + /// related to UDAF simplification + #[derive(Debug, Clone)] + struct SimplifyMockUdaf { + simplify: bool, + } + + impl SimplifyMockUdaf { + /// make simplify method return new expression + fn new_with_simplify() -> Self { + Self { simplify: true } + } + /// make simplify method return no change + fn new_without_simplify() -> Self { + Self { simplify: false } + } + } + + impl AggregateUDFImpl for SimplifyMockUdaf { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "mock_simplify" + } + + fn signature(&self) -> &Signature { + unimplemented!() + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + unimplemented!("not needed for tests") + } + + fn accumulator( + &self, + _acc_args: function::AccumulatorArgs, + ) -> Result> { + unimplemented!("not needed for tests") + } + + fn groups_accumulator_supported(&self) -> bool { + unimplemented!("not needed for testing") + } + + fn create_groups_accumulator(&self) -> Result> { + unimplemented!("not needed for testing") + } + + fn simplify( + &self, + args: Vec, + _distinct: &bool, + _filter: &Option>, + _order_by: &Option>, + _null_treatment: &Option, + _info: &dyn SimplifyInfo, + ) -> Result { + if self.simplify { + Ok(ExprSimplifyResult::Simplified(col("result_column"))) + } else { + Ok(ExprSimplifyResult::Original(args)) + } + } + } }