From 2873fd083d2af39f85315eea837070ef28cd5be0 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Wed, 6 Mar 2024 06:17:12 +0800 Subject: [PATCH] Add a `ScalarUDFImpl::simplfy()` API, move `SimplifyInfo` et al to datafusion_expr (#9304) * first draft Signed-off-by: jayzhan211 * clippy Signed-off-by: jayzhan211 * add comments Signed-off-by: jayzhan211 * move to optimize rule Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * fix explain test Signed-off-by: jayzhan211 * move to simplifier Signed-off-by: jayzhan211 * pass with schema Signed-off-by: jayzhan211 * fix explain Signed-off-by: jayzhan211 * fix doc Signed-off-by: jayzhan211 * move to expr Signed-off-by: jayzhan211 * change simplify signature Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * fix doc Signed-off-by: jayzhan211 * fix doc Signed-off-by: jayzhan211 * Update datafusion/expr/src/udf.rs * Add backwards compatibile uses, inline FunctionSimplifier, rename to ExprSimplifyResult * Remove DFSchema from SimplifyInfo * Avoid requiring argument copies * Improve docs * fix link * fix doc test * Update datafusion/physical-expr/src/lib.rs * Change example simplify to always simplify its argument * Clarify comment --------- Signed-off-by: jayzhan211 Co-authored-by: Andrew Lamb --- datafusion-cli/Cargo.lock | 1 + datafusion-examples/examples/expr_api.rs | 5 +- datafusion-examples/examples/simple_udtf.rs | 3 +- .../core/src/datasource/listing/helpers.rs | 6 +- .../datasource/physical_plan/parquet/mod.rs | 3 +- .../physical_plan/parquet/row_filter.rs | 2 +- .../physical_plan/parquet/row_groups.rs | 2 +- datafusion/core/src/execution/context/mod.rs | 4 +- .../core/src/physical_optimizer/pruning.rs | 2 +- datafusion/core/src/test_util/parquet.rs | 5 +- datafusion/core/src/variable/mod.rs | 2 +- datafusion/core/tests/dataframe/mod.rs | 2 +- datafusion/core/tests/parquet/page_pruning.rs | 2 +- datafusion/core/tests/simplification.rs | 26 ++--- .../user_defined_scalar_functions.rs | 102 +++++++++++++++++- datafusion/expr/Cargo.toml | 1 + .../src/execution_props.rs | 4 +- datafusion/expr/src/lib.rs | 3 + .../context.rs => expr/src/simplify.rs} | 48 +++------ datafusion/expr/src/udf.rs | 39 +++++++ .../src/var_provider.rs | 0 datafusion/optimizer/src/decorrelate.rs | 3 +- .../simplify_expressions/expr_simplifier.rs | 70 ++++++++++-- .../optimizer/src/simplify_expressions/mod.rs | 5 +- .../simplify_expressions/simplify_exprs.rs | 5 +- .../src/simplify_expressions/utils.rs | 5 +- .../physical-expr/src/equivalence/ordering.rs | 2 +- .../src/equivalence/projection.rs | 2 +- .../src/equivalence/properties.rs | 2 +- datafusion/physical-expr/src/functions.rs | 2 +- datafusion/physical-expr/src/lib.rs | 8 +- datafusion/physical-expr/src/physical_expr.rs | 4 +- datafusion/physical-expr/src/planner.rs | 9 +- .../physical-expr/src/utils/guarantee.rs | 2 +- datafusion/wasmtest/src/lib.rs | 5 +- 35 files changed, 287 insertions(+), 99 deletions(-) rename datafusion/{physical-expr => expr}/src/execution_props.rs (96%) rename datafusion/{optimizer/src/simplify_expressions/context.rs => expr/src/simplify.rs} (75%) rename datafusion/{physical-expr => expr}/src/var_provider.rs (100%) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 0fe6606abc95..3afd26a6e7b1 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1234,6 +1234,7 @@ dependencies = [ "ahash", "arrow", "arrow-array", + "chrono", "datafusion-common", "paste", "sqlparser", diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index 9739b44aafa0..5f9f3106e14d 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -24,15 +24,16 @@ use arrow::record_batch::RecordBatch; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion::common::{DFField, DFSchema}; use datafusion::error::Result; -use datafusion::optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext}; -use datafusion::physical_expr::execution_props::ExecutionProps; +use datafusion::optimizer::simplify_expressions::ExprSimplifier; use datafusion::physical_expr::{ analyze, create_physical_expr, AnalysisContext, ExprBoundaries, PhysicalExpr, }; use datafusion::prelude::*; use datafusion_common::{ScalarValue, ToDFSchema}; +use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::BinaryExpr; use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::{ColumnarValue, ExprSchemable, Operator}; /// This example demonstrates the DataFusion [`Expr`] API. diff --git a/datafusion-examples/examples/simple_udtf.rs b/datafusion-examples/examples/simple_udtf.rs index 09341fbf47fa..c68c21fab169 100644 --- a/datafusion-examples/examples/simple_udtf.rs +++ b/datafusion-examples/examples/simple_udtf.rs @@ -28,8 +28,9 @@ use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionContext; use datafusion_common::{plan_err, ScalarValue}; +use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::{Expr, TableType}; -use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext}; +use datafusion_optimizer::simplify_expressions::ExprSimplifier; use std::fs::File; use std::io::Seek; use std::path::Path; diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index eef25792d00a..c53e8df35de8 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -33,12 +33,10 @@ use arrow::{ use arrow_schema::Fields; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{internal_err, Column, DFField, DFSchema, DataFusionError}; +use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility}; use datafusion_physical_expr::create_physical_expr; -use datafusion_physical_expr::execution_props::ExecutionProps; - -use futures::stream::{BoxStream, FuturesUnordered}; -use futures::{StreamExt, TryStreamExt}; +use futures::stream::{BoxStream, FuturesUnordered, StreamExt, TryStreamExt}; use log::{debug, trace}; use object_store::path::Path; use object_store::{ObjectMeta, ObjectStore}; diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 12b62fd68068..2f3b151e7763 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -800,13 +800,14 @@ mod tests { ArrayRef, Date64Array, Int32Array, Int64Array, Int8Array, StringArray, StructArray, }; + use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder}; use arrow::record_batch::RecordBatch; use arrow_schema::Fields; use datafusion_common::{assert_contains, FileType, GetExt, ScalarValue, ToDFSchema}; + use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::{col, lit, when, Expr}; use datafusion_physical_expr::create_physical_expr; - use datafusion_physical_expr::execution_props::ExecutionProps; use chrono::{TimeZone, Utc}; use futures::StreamExt; diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs index c0e37a7150d9..5f89ff087f70 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs @@ -401,9 +401,9 @@ mod test { use super::*; use arrow::datatypes::Field; use datafusion_common::ToDFSchema; + use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::{cast, col, lit, Expr}; use datafusion_physical_expr::create_physical_expr; - use datafusion_physical_expr::execution_props::ExecutionProps; use parquet::arrow::parquet_to_arrow_schema; use parquet::file::reader::{FileReader, SerializedFileReader}; use rand::prelude::*; diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index fa9523a76380..ef2eb775e037 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -346,8 +346,8 @@ mod tests { use arrow::datatypes::Schema; use arrow::datatypes::{DataType, Field}; use datafusion_common::{Result, ToDFSchema}; + use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::{cast, col, lit, Expr}; - use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::{create_physical_expr, PhysicalExpr}; use parquet::arrow::arrow_to_parquet_schema; use parquet::arrow::async_reader::ParquetObjectReader; diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 2144cd3c7736..f29c9137f976 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -43,12 +43,12 @@ use datafusion_common::{ tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}, }; use datafusion_execution::registry::SerializerRegistry; +pub use datafusion_expr::execution_props::ExecutionProps; +use datafusion_expr::var_provider::is_system_variables; use datafusion_expr::{ logical_plan::{DdlStatement, Statement}, Expr, StringifiedPlan, UserDefinedLogicalNode, WindowUDF, }; -pub use datafusion_physical_expr::execution_props::ExecutionProps; -use datafusion_physical_expr::var_provider::is_system_variables; use parking_lot::RwLock; use std::collections::hash_map::Entry; use std::string::String; diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 05d2d852e057..d2126f90eca9 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -1341,10 +1341,10 @@ mod tests { datatypes::{DataType, TimeUnit}, }; use datafusion_common::{ScalarValue, ToDFSchema}; + use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::InList; use datafusion_expr::{cast, is_null, try_cast, Expr}; use datafusion_physical_expr::create_physical_expr; - use datafusion_physical_expr::execution_props::ExecutionProps; use std::collections::HashMap; use std::ops::{Not, Rem}; diff --git a/datafusion/core/src/test_util/parquet.rs b/datafusion/core/src/test_util/parquet.rs index 1047c3dd4e48..6d0711610b5a 100644 --- a/datafusion/core/src/test_util/parquet.rs +++ b/datafusion/core/src/test_util/parquet.rs @@ -28,9 +28,10 @@ use crate::datasource::listing::{ListingTableUrl, PartitionedFile}; use crate::datasource::object_store::ObjectStoreUrl; use crate::datasource::physical_plan::{FileScanConfig, ParquetExec}; use crate::error::Result; -use crate::optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext}; +use crate::logical_expr::execution_props::ExecutionProps; +use crate::logical_expr::simplify::SimplifyContext; +use crate::optimizer::simplify_expressions::ExprSimplifier; use crate::physical_expr::create_physical_expr; -use crate::physical_expr::execution_props::ExecutionProps; use crate::physical_plan::filter::FilterExec; use crate::physical_plan::metrics::MetricsSet; use crate::physical_plan::ExecutionPlan; diff --git a/datafusion/core/src/variable/mod.rs b/datafusion/core/src/variable/mod.rs index 5ef165313ccf..475f7570a8ee 100644 --- a/datafusion/core/src/variable/mod.rs +++ b/datafusion/core/src/variable/mod.rs @@ -17,4 +17,4 @@ //! Variable provider for `@name` and `@@name` style runtime values. -pub use datafusion_physical_expr::var_provider::{VarProvider, VarType}; +pub use datafusion_expr::var_provider::{VarProvider, VarType}; diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index abe5fd29182e..305a7e69fdb2 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -48,13 +48,13 @@ use datafusion_common::{assert_contains, DataFusionError, ScalarValue, UnnestOpt use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::expr::{GroupingSet, Sort}; +use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ array_agg, avg, cast, col, count, exists, expr, in_subquery, lit, max, out_ref_col, placeholder, scalar_subquery, sum, when, wildcard, AggregateFunction, Expr, ExprSchemable, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; -use datafusion_physical_expr::var_provider::{VarProvider, VarType}; #[tokio::test] async fn test_count_wildcard_on_sort() -> Result<()> { diff --git a/datafusion/core/tests/parquet/page_pruning.rs b/datafusion/core/tests/parquet/page_pruning.rs index d182986ebbdc..6ee4247eea36 100644 --- a/datafusion/core/tests/parquet/page_pruning.rs +++ b/datafusion/core/tests/parquet/page_pruning.rs @@ -28,9 +28,9 @@ use datafusion::physical_plan::metrics::MetricValue; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionContext; use datafusion_common::{ScalarValue, Statistics, ToDFSchema}; +use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::{col, lit, Expr}; use datafusion_physical_expr::create_physical_expr; -use datafusion_physical_expr::execution_props::ExecutionProps; use futures::StreamExt; use object_store::path::Path; diff --git a/datafusion/core/tests/simplification.rs b/datafusion/core/tests/simplification.rs index 5fe64ca5bf04..41457df02cfc 100644 --- a/datafusion/core/tests/simplification.rs +++ b/datafusion/core/tests/simplification.rs @@ -20,18 +20,17 @@ use arrow::datatypes::{DataType, Field, Schema}; use arrow_array::{ArrayRef, Int32Array}; use chrono::{DateTime, TimeZone, Utc}; -use datafusion::common::DFSchema; use datafusion::{error::Result, execution::context::ExecutionProps, prelude::*}; use datafusion_common::cast::as_int32_array; use datafusion_common::ScalarValue; +use datafusion_common::{DFSchemaRef, ToDFSchema}; use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::simplify::SimplifyInfo; use datafusion_expr::{ expr, table_scan, BuiltinScalarFunction, Cast, ColumnarValue, Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder, ScalarUDF, Volatility, }; -use datafusion_optimizer::simplify_expressions::{ - ExprSimplifier, SimplifyExpressions, SimplifyInfo, -}; +use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyExpressions}; use datafusion_optimizer::{OptimizerContext, OptimizerRule}; use std::sync::Arc; @@ -42,7 +41,7 @@ use std::sync::Arc; /// objects or from some other implementation struct MyInfo { /// The input schema - schema: DFSchema, + schema: DFSchemaRef, /// Execution specific details needed for constant evaluation such /// as the current time for `now()` and [VariableProviders] @@ -51,11 +50,14 @@ struct MyInfo { impl SimplifyInfo for MyInfo { fn is_boolean_type(&self, expr: &Expr) -> Result { - Ok(matches!(expr.get_type(&self.schema)?, DataType::Boolean)) + Ok(matches!( + expr.get_type(self.schema.as_ref())?, + DataType::Boolean + )) } fn nullable(&self, expr: &Expr) -> Result { - expr.nullable(&self.schema) + expr.nullable(self.schema.as_ref()) } fn execution_props(&self) -> &ExecutionProps { @@ -63,12 +65,12 @@ impl SimplifyInfo for MyInfo { } fn get_data_type(&self, expr: &Expr) -> Result { - expr.get_type(&self.schema) + expr.get_type(self.schema.as_ref()) } } -impl From for MyInfo { - fn from(schema: DFSchema) -> Self { +impl From for MyInfo { + fn from(schema: DFSchemaRef) -> Self { Self { schema, execution_props: ExecutionProps::new(), @@ -81,13 +83,13 @@ impl From for MyInfo { /// a: Int32 (possibly with nulls) /// b: Int32 /// s: Utf8 -fn schema() -> DFSchema { +fn schema() -> DFSchemaRef { Schema::new(vec![ Field::new("a", DataType::Int32, true), Field::new("b", DataType::Int32, false), Field::new("s", DataType::Utf8, false), ]) - .try_into() + .to_dfschema_ref() .unwrap() } diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 0546ef59b1d8..982fb0464ed5 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -16,7 +16,9 @@ // under the License. use arrow::compute::kernels::numeric::add; -use arrow_array::{Array, ArrayRef, Float64Array, Int32Array, RecordBatch, UInt8Array}; +use arrow_array::{ + Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, UInt8Array, +}; use arrow_schema::DataType::Float64; use arrow_schema::{DataType, Field, Schema}; use datafusion::prelude::*; @@ -26,10 +28,13 @@ use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, cast::as_int32_array, not_impl_err, plan_err, ExprSchema, Result, ScalarValue, }; +use datafusion_expr::simplify::ExprSimplifyResult; +use datafusion_expr::simplify::SimplifyInfo; use datafusion_expr::{ create_udaf, create_udf, Accumulator, ColumnarValue, ExprSchemable, LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; + use rand::{thread_rng, Rng}; use std::any::Any; use std::iter; @@ -514,6 +519,101 @@ async fn deregister_udf() -> Result<()> { Ok(()) } +#[derive(Debug)] +struct CastToI64UDF { + signature: Signature, +} + +impl CastToI64UDF { + fn new() -> Self { + Self { + signature: Signature::any(1, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for CastToI64UDF { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "cast_to_i64" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, _args: &[DataType]) -> Result { + Ok(DataType::Int64) + } + + // Demonstrate simplifying a UDF + fn simplify( + &self, + mut args: Vec, + info: &dyn SimplifyInfo, + ) -> Result { + // DataFusion should have ensured the function is called with just a + // single argument + assert_eq!(args.len(), 1); + let arg = args.pop().unwrap(); + + // Note that Expr::cast_to requires an ExprSchema but simplify gets a + // SimplifyInfo so we have to replicate some of the casting logic here. + + let source_type = info.get_data_type(&arg)?; + let new_expr = if source_type == DataType::Int64 { + // the argument's data type is already the correct type + arg + } else { + // need to use an actual cast to get the correct type + Expr::Cast(datafusion_expr::Cast { + expr: Box::new(arg), + data_type: DataType::Int64, + }) + }; + // return the newly written argument to DataFusion + Ok(ExprSimplifyResult::Simplified(new_expr)) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + unimplemented!("Function should have been simplified prior to evaluation") + } +} + +#[tokio::test] +async fn test_user_defined_functions_cast_to_i64() -> Result<()> { + let ctx = SessionContext::new(); + + let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Float32, false)])); + + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0]))], + )?; + + ctx.register_batch("t", batch)?; + + let cast_to_i64_udf = ScalarUDF::from(CastToI64UDF::new()); + ctx.register_udf(cast_to_i64_udf); + + let result = plan_and_collect(&ctx, "SELECT cast_to_i64(x) FROM t").await?; + + assert_batches_eq!( + &[ + "+------------------+", + "| cast_to_i64(t.x) |", + "+------------------+", + "| 1 |", + "| 2 |", + "| 3 |", + "+------------------+" + ], + &result + ); + + Ok(()) +} + #[derive(Debug)] struct TakeUDF { signature: Signature, diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 6e430943cf5c..621a320230f2 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -40,6 +40,7 @@ ahash = { version = "0.8", default-features = false, features = [ ] } arrow = { workspace = true } arrow-array = { workspace = true } +chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } paste = "^1.0" sqlparser = { workspace = true } diff --git a/datafusion/physical-expr/src/execution_props.rs b/datafusion/expr/src/execution_props.rs similarity index 96% rename from datafusion/physical-expr/src/execution_props.rs rename to datafusion/expr/src/execution_props.rs index 20999ab8d3db..3401a94b2736 100644 --- a/datafusion/physical-expr/src/execution_props.rs +++ b/datafusion/expr/src/execution_props.rs @@ -24,14 +24,12 @@ use std::sync::Arc; /// Holds per-query execution properties and data (such as statement /// starting timestamps). /// -/// An [`ExecutionProps`] is created each time a [`LogicalPlan`] is +/// An [`ExecutionProps`] is created each time a `LogicalPlan` is /// prepared for execution (optimized). If the same plan is optimized /// multiple times, a new `ExecutionProps` is created each time. /// /// It is important that this structure be cheap to create as it is /// done so during predicate pruning and expression simplification -/// -/// [`LogicalPlan`]: datafusion_expr::LogicalPlan #[derive(Clone, Debug)] pub struct ExecutionProps { pub query_execution_start_time: DateTime, diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 8c73ae5ae709..a297f2dc7886 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -40,6 +40,7 @@ mod udwf; pub mod aggregate_function; pub mod conditional_expressions; +pub mod execution_props; pub mod expr; pub mod expr_fn; pub mod expr_rewriter; @@ -49,9 +50,11 @@ pub mod function; pub mod groups_accumulator; pub mod interval_arithmetic; pub mod logical_plan; +pub mod simplify; pub mod tree_node; pub mod type_coercion; pub mod utils; +pub mod var_provider; pub mod window_frame; pub mod window_state; diff --git a/datafusion/optimizer/src/simplify_expressions/context.rs b/datafusion/expr/src/simplify.rs similarity index 75% rename from datafusion/optimizer/src/simplify_expressions/context.rs rename to datafusion/expr/src/simplify.rs index 34f3908c7e42..536a01fa8571 100644 --- a/datafusion/optimizer/src/simplify_expressions/context.rs +++ b/datafusion/expr/src/simplify.rs @@ -19,11 +19,10 @@ use arrow::datatypes::DataType; use datafusion_common::{DFSchemaRef, DataFusionError, Result}; -use datafusion_expr::{Expr, ExprSchemable}; -use datafusion_physical_expr::execution_props::ExecutionProps; -#[allow(rustdoc::private_intra_doc_links)] -/// The information necessary to apply algebraic simplification to an +use crate::{execution_props::ExecutionProps, Expr, ExprSchemable}; + +/// Provides the information necessary to apply algebraic simplification to an /// [Expr]. See [SimplifyContext] for one concrete implementation. /// /// This trait exists so that other systems can plug schema @@ -46,35 +45,11 @@ pub trait SimplifyInfo { /// Provides simplification information based on DFSchema and /// [`ExecutionProps`]. This is the default implementation used by DataFusion /// -/// For example: -/// ``` -/// use arrow::datatypes::{Schema, Field, DataType}; -/// use datafusion_expr::{col, lit}; -/// use datafusion_common::{DataFusionError, ToDFSchema}; -/// use datafusion_physical_expr::execution_props::ExecutionProps; -/// use datafusion_optimizer::simplify_expressions::{SimplifyContext, ExprSimplifier}; -/// -/// // Create the schema -/// let schema = Schema::new(vec![ -/// Field::new("i", DataType::Int64, false), -/// ]) -/// .to_dfschema_ref().unwrap(); -/// -/// // Create the simplifier -/// let props = ExecutionProps::new(); -/// let context = SimplifyContext::new(&props) -/// .with_schema(schema); -/// let simplifier = ExprSimplifier::new(context); -/// -/// // Use the simplifier +/// # Example +/// See the `simplify_demo` in the [`expr_api` example] /// -/// // b < 2 or (1 > 3) -/// let expr = col("b").lt(lit(2)).or(lit(1).gt(lit(3))); -/// -/// // b < 2 -/// let simplified = simplifier.simplify(expr).unwrap(); -/// assert_eq!(simplified, col("b").lt(lit(2))); -/// ``` +/// [`expr_api` example]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/expr_api.rs +#[derive(Debug, Clone)] pub struct SimplifyContext<'a> { schema: Option, props: &'a ExecutionProps, @@ -132,3 +107,12 @@ impl<'a> SimplifyInfo for SimplifyContext<'a> { self.props } } + +/// Was the expression simplified? +pub enum ExprSimplifyResult { + /// The function call was simplified to an entirely new Expr + Simplified(Expr), + /// the function call could not be simplified, and the arguments + /// are return unmodified. + Original(Vec), +} diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 59e5a7772e02..5ad420b2f382 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -17,6 +17,7 @@ //! [`ScalarUDF`]: Scalar User Defined Functions +use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; use crate::ExprSchemable; use crate::{ ColumnarValue, Expr, FuncMonotonicity, ReturnTypeFunction, @@ -161,6 +162,17 @@ impl ScalarUDF { self.inner.return_type_from_exprs(args, schema) } + /// Do the function rewrite + /// + /// See [`ScalarUDFImpl::simplify`] for more details. + pub fn simplify( + &self, + args: Vec, + info: &dyn SimplifyInfo, + ) -> Result { + self.inner.simplify(args, info) + } + /// Invoke the function on `args`, returning the appropriate result. /// /// See [`ScalarUDFImpl::invoke`] for more details. @@ -338,6 +350,33 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { fn monotonicity(&self) -> Result> { Ok(None) } + + /// Optionally apply per-UDF simplification / rewrite rules. + /// + /// This can be used to apply function specific simplification rules during + /// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default + /// implementation does nothing. + /// + /// Note that DataFusion handles simplifying arguments and "constant + /// folding" (replacing a function call with constant arguments such as + /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such + /// optimizations manually for specific UDFs. + /// + /// # Arguments + /// * 'args': The arguments of the function + /// * 'schema': The schema of the function + /// + /// # Returns + /// [`ExprSimplifyResult`] indicating the result of the simplification NOTE + /// if the function cannot be simplified, the arguments *MUST* be returned + /// unmodified + fn simplify( + &self, + args: Vec, + _info: &dyn SimplifyInfo, + ) -> Result { + Ok(ExprSimplifyResult::Original(args)) + } } /// ScalarUDF that adds an alias to the underlying function. It is better to diff --git a/datafusion/physical-expr/src/var_provider.rs b/datafusion/expr/src/var_provider.rs similarity index 100% rename from datafusion/physical-expr/src/var_provider.rs rename to datafusion/expr/src/var_provider.rs diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index fd548ba4948e..12e84a63ea15 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -18,7 +18,7 @@ use std::collections::{BTreeSet, HashMap}; use std::ops::Deref; -use crate::simplify_expressions::{ExprSimplifier, SimplifyContext}; +use crate::simplify_expressions::ExprSimplifier; use crate::utils::collect_subquery_cols; use datafusion_common::tree_node::{ @@ -26,6 +26,7 @@ use datafusion_common::tree_node::{ }; use datafusion_common::{plan_err, Column, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::{AggregateFunctionDefinition, Alias}; +use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::utils::{conjunction, find_join_exprs, split_conjunction}; use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; use datafusion_physical_expr::execution_props::ExecutionProps; diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 6b5dd1b4681e..ef034a5ed711 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -38,6 +38,7 @@ use datafusion_common::{ use datafusion_common::{ internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, }; +use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ and, lit, or, BinaryExpr, BuiltinScalarFunction, Case, ColumnarValue, Expr, Like, ScalarFunctionDefinition, Volatility, @@ -46,6 +47,40 @@ use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterva use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; /// This structure handles API for expression simplification +/// +/// Provides simplification information based on DFSchema and +/// [`ExecutionProps`]. This is the default implementation used by DataFusion +/// +/// For example: +/// ``` +/// use arrow::datatypes::{Schema, Field, DataType}; +/// use datafusion_expr::{col, lit}; +/// use datafusion_common::{DataFusionError, ToDFSchema}; +/// use datafusion_expr::execution_props::ExecutionProps; +/// use datafusion_expr::simplify::SimplifyContext; +/// use datafusion_optimizer::simplify_expressions::ExprSimplifier; +/// +/// // Create the schema +/// let schema = Schema::new(vec![ +/// Field::new("i", DataType::Int64, false), +/// ]) +/// .to_dfschema_ref().unwrap(); +/// +/// // Create the simplifier +/// let props = ExecutionProps::new(); +/// let context = SimplifyContext::new(&props) +/// .with_schema(schema); +/// let simplifier = ExprSimplifier::new(context); +/// +/// // Use the simplifier +/// +/// // b < 2 or (1 > 3) +/// let expr = col("b").lt(lit(2)).or(lit(1).gt(lit(3))); +/// +/// // b < 2 +/// let simplified = simplifier.simplify(expr).unwrap(); +/// assert_eq!(simplified, col("b").lt(lit(2))); +/// ``` pub struct ExprSimplifier { info: S, /// Guarantees about the values of columns. This is provided by the user @@ -63,7 +98,7 @@ impl ExprSimplifier { /// instance of [`SimplifyContext`]. See /// [`simplify`](Self::simplify) for an example. /// - /// [`SimplifyContext`]: crate::simplify_expressions::context::SimplifyContext + /// [`SimplifyContext`]: datafusion_expr::simplify::SimplifyContext pub fn new(info: S) -> Self { Self { info, @@ -91,8 +126,12 @@ impl ExprSimplifier { /// use arrow::datatypes::DataType; /// use datafusion_expr::{col, lit, Expr}; /// use datafusion_common::Result; - /// use datafusion_physical_expr::execution_props::ExecutionProps; - /// use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyInfo}; + /// use datafusion_expr::execution_props::ExecutionProps; + /// use datafusion_expr::simplify::SimplifyContext; + /// use datafusion_expr::simplify::SimplifyInfo; + /// use datafusion_optimizer::simplify_expressions::ExprSimplifier; + /// use datafusion_common::DFSchema; + /// use std::sync::Arc; /// /// /// Simple implementation that provides `Simplifier` the information it needs /// /// See SimplifyContext for a structure that does this. @@ -192,9 +231,9 @@ impl ExprSimplifier { /// use datafusion_expr::{col, lit, Expr}; /// use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; /// use datafusion_common::{Result, ScalarValue, ToDFSchema}; - /// use datafusion_physical_expr::execution_props::ExecutionProps; - /// use datafusion_optimizer::simplify_expressions::{ - /// ExprSimplifier, SimplifyContext}; + /// use datafusion_expr::execution_props::ExecutionProps; + /// use datafusion_expr::simplify::SimplifyContext; + /// use datafusion_optimizer::simplify_expressions::ExprSimplifier; /// /// let schema = Schema::new(vec![ /// Field::new("x", DataType::Int64, false), @@ -251,9 +290,9 @@ impl ExprSimplifier { /// use datafusion_expr::{col, lit, Expr}; /// use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; /// use datafusion_common::{Result, ScalarValue, ToDFSchema}; - /// use datafusion_physical_expr::execution_props::ExecutionProps; - /// use datafusion_optimizer::simplify_expressions::{ - /// ExprSimplifier, SimplifyContext}; + /// use datafusion_expr::execution_props::ExecutionProps; + /// use datafusion_expr::simplify::SimplifyContext; + /// use datafusion_optimizer::simplify_expressions::ExprSimplifier; /// /// let schema = Schema::new(vec![ /// Field::new("a", DataType::Int64, false), @@ -1264,6 +1303,19 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { out_expr.rewrite(self)? } + Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::UDF(udf), + args, + }) => match udf.simplify(args, info)? { + ExprSimplifyResult::Original(args) => { + Transformed::yes(Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::UDF(udf), + args, + })) + } + ExprSimplifyResult::Simplified(expr) => Transformed::no(expr), + }, + // log Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Log), diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs b/datafusion/optimizer/src/simplify_expressions/mod.rs index a03dd767e911..5244f9a5af88 100644 --- a/datafusion/optimizer/src/simplify_expressions/mod.rs +++ b/datafusion/optimizer/src/simplify_expressions/mod.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -pub mod context; pub mod expr_simplifier; mod guarantees; mod inlist_simplifier; @@ -23,6 +22,8 @@ mod regex; pub mod simplify_exprs; mod utils; -pub use context::*; +// backwards compatibility +pub use datafusion_expr::simplify::{SimplifyContext, SimplifyInfo}; + pub use expr_simplifier::*; pub use simplify_exprs::*; diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index ddb754a919bd..00d60d0a80dc 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -20,13 +20,14 @@ use std::sync::Arc; use datafusion_common::{DFSchema, DFSchemaRef, Result}; +use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::logical_plan::LogicalPlan; +use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::utils::merge_schema; -use datafusion_physical_expr::execution_props::ExecutionProps; use crate::{OptimizerConfig, OptimizerRule}; -use super::{ExprSimplifier, SimplifyContext}; +use super::ExprSimplifier; /// Optimizer Pass that simplifies [`LogicalPlan`]s by rewriting /// [`Expr`]`s evaluating constants and applying algebraic diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index 4d3b123bace0..8952d5d79856 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -17,11 +17,10 @@ //! Utility functions for expression simplification -use crate::simplify_expressions::SimplifyInfo; use datafusion_common::{internal_err, Result, ScalarValue}; -use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::simplify::SimplifyInfo; use datafusion_expr::{ - expr::{Between, BinaryExpr, InList}, + expr::{Between, BinaryExpr, InList, ScalarFunction}, expr_fn::{and, bitwise_and, bitwise_or, concat_ws, or}, lit, BuiltinScalarFunction, Expr, Like, Operator, ScalarFunctionDefinition, }; diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs index 64000937448e..c7cb9e5f530e 100644 --- a/datafusion/physical-expr/src/equivalence/ordering.rs +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -228,7 +228,6 @@ mod tests { use crate::equivalence::{ EquivalenceClass, EquivalenceGroup, OrderingEquivalenceClass, }; - use crate::execution_props::ExecutionProps; use crate::expressions::Column; use crate::expressions::{col, BinaryExpr}; use crate::functions::create_physical_expr; @@ -236,6 +235,7 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; use arrow_schema::SortOptions; use datafusion_common::Result; + use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::{BuiltinScalarFunction, Operator}; use itertools::Itertools; use std::sync::Arc; diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs index 96c919667d84..13c3414d66b9 100644 --- a/datafusion/physical-expr/src/equivalence/projection.rs +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -118,7 +118,6 @@ mod tests { output_schema, }; use crate::equivalence::EquivalenceProperties; - use crate::execution_props::ExecutionProps; use crate::expressions::{col, BinaryExpr, Literal}; use crate::functions::create_physical_expr; use crate::PhysicalSortExpr; @@ -126,6 +125,7 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; use arrow_schema::{SortOptions, TimeUnit}; use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::{BuiltinScalarFunction, Operator}; use itertools::Itertools; diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index f234a1fa08cd..890d0b49687a 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -1293,7 +1293,6 @@ mod tests { create_random_schema, create_test_params, create_test_schema, generate_table_for_eq_properties, is_table_same_after_sort, output_schema, }; - use crate::execution_props::ExecutionProps; use crate::expressions::{col, BinaryExpr, Column}; use crate::functions::create_physical_expr; use crate::PhysicalSortExpr; @@ -1301,6 +1300,7 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; use arrow_schema::{Fields, SortOptions, TimeUnit}; use datafusion_common::Result; + use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::{BuiltinScalarFunction, Operator}; use itertools::Itertools; diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index abc80a75c2b9..5838d9c74e20 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -30,7 +30,6 @@ //! an argument i32 is passed to a function that supports f64, the //! argument is automatically is coerced to f64. -use crate::execution_props::ExecutionProps; use crate::sort_properties::SortProperties; use crate::{ array_expressions, conditional_expressions, datetime_expressions, math_expressions, @@ -43,6 +42,7 @@ use arrow::{ }; use arrow_array::Array; use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_expr::execution_props::ExecutionProps; pub use datafusion_expr::FuncMonotonicity; use datafusion_expr::{ type_coercion::functions::data_types, BuiltinScalarFunction, ColumnarValue, diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 125da4a2b9c2..b36e5d79bb44 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -24,7 +24,6 @@ pub mod conditional_expressions; pub mod crypto_expressions; pub mod datetime_expressions; pub mod equivalence; -pub mod execution_props; pub mod expressions; pub mod functions; pub mod intervals; @@ -42,9 +41,14 @@ pub mod udf; #[cfg(feature = "unicode_expressions")] pub mod unicode_expressions; pub mod utils; -pub mod var_provider; pub mod window; +// backwards compatibility +pub mod execution_props { + pub use datafusion_expr::execution_props::ExecutionProps; + pub use datafusion_expr::var_provider::{VarProvider, VarType}; +} + pub use aggregate::groups_accumulator::{GroupsAccumulatorAdapter, NullState}; pub use aggregate::AggregateExpr; pub use analysis::{analyze, AnalysisContext, ExprBoundaries}; diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index 39b8de81af56..861a4ad02801 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -54,7 +54,7 @@ use itertools::izip; /// # use datafusion_common::DFSchema; /// # use datafusion_expr::{Expr, col, lit}; /// # use datafusion_physical_expr::create_physical_expr; -/// # use datafusion_physical_expr::execution_props::ExecutionProps; +/// # use datafusion_expr::execution_props::ExecutionProps; /// // For a logical expression `a = 1`, we can create a physical expression /// let expr = col("a").eq(lit(1)); /// // To create a PhysicalExpr we need 1. a schema @@ -74,7 +74,7 @@ use itertools::izip; /// # use datafusion_common::{assert_batches_eq, DFSchema}; /// # use datafusion_expr::{Expr, col, lit, ColumnarValue}; /// # use datafusion_physical_expr::create_physical_expr; -/// # use datafusion_physical_expr::execution_props::ExecutionProps; +/// # use datafusion_expr::execution_props::ExecutionProps; /// # let expr = col("a").eq(lit(1)); /// # let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); /// # let df_schema = DFSchema::try_from(schema.clone()).unwrap(); diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index bf279518d31d..858dbd30c124 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -16,19 +16,18 @@ // under the License. use crate::expressions::GetFieldAccessExpr; -use crate::var_provider::is_system_variables; use crate::{ - execution_props::ExecutionProps, expressions::{self, binary, like, Column, GetIndexedFieldExpr, Literal}, - functions, udf, - var_provider::VarType, - PhysicalExpr, + functions, udf, PhysicalExpr, }; use arrow::datatypes::Schema; use datafusion_common::{ exec_err, internal_err, not_impl_err, plan_err, DFSchema, Result, ScalarValue, }; +use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::{Alias, Cast, InList, ScalarFunction}; +use datafusion_expr::var_provider::is_system_variables; +use datafusion_expr::var_provider::VarType; use datafusion_expr::{ binary_expr, Between, BinaryExpr, Expr, GetFieldAccess, GetIndexedField, Like, Operator, ScalarFunctionDefinition, TryCast, diff --git a/datafusion/physical-expr/src/utils/guarantee.rs b/datafusion/physical-expr/src/utils/guarantee.rs index c249af232bf5..e441fe8f4802 100644 --- a/datafusion/physical-expr/src/utils/guarantee.rs +++ b/datafusion/physical-expr/src/utils/guarantee.rs @@ -421,9 +421,9 @@ impl<'a> ColOpLit<'a> { mod test { use super::*; use crate::create_physical_expr; - use crate::execution_props::ExecutionProps; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use datafusion_common::ToDFSchema; + use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr_fn::*; use datafusion_expr::{lit, Expr}; use itertools::Itertools; diff --git a/datafusion/wasmtest/src/lib.rs b/datafusion/wasmtest/src/lib.rs index 86e29420e8e6..a74cce72ac64 100644 --- a/datafusion/wasmtest/src/lib.rs +++ b/datafusion/wasmtest/src/lib.rs @@ -17,9 +17,10 @@ extern crate wasm_bindgen; use datafusion_common::{DFSchema, ScalarValue}; +use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::lit; -use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext}; -use datafusion_physical_expr::execution_props::ExecutionProps; +use datafusion_expr::simplify::SimplifyContext; +use datafusion_optimizer::simplify_expressions::ExprSimplifier; use datafusion_sql::sqlparser::dialect::GenericDialect; use datafusion_sql::sqlparser::parser::Parser; use std::sync::Arc;