From 37bb5dcfaa245201618864d01c0824aaf8fa6d24 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 1 May 2024 13:20:46 -0400 Subject: [PATCH 1/4] Improve coerce API so it does not need DFSchema --- datafusion-examples/examples/expr_api.rs | 2 +- datafusion/core/src/test_util/parquet.rs | 2 +- .../optimizer/src/analyzer/type_coercion.rs | 99 +++++++++---------- .../simplify_expressions/expr_simplifier.rs | 20 +--- 4 files changed, 54 insertions(+), 69 deletions(-) diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index 6e9c42480c32..2c1470a1d6ec 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -258,7 +258,7 @@ pub fn physical_expr(schema: &Schema, expr: Expr) -> Result { + pub(crate) schema: &'a DFSchema, } -impl TreeNodeRewriter for TypeCoercionRewriter { +impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { type Node = Expr; fn f_up(&mut self, expr: Expr) -> Result> { @@ -132,14 +130,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter { subquery, outer_ref_columns, }) => { - let new_plan = analyze_internal(&self.schema, &subquery)?; + let new_plan = analyze_internal(self.schema, &subquery)?; Ok(Transformed::yes(Expr::ScalarSubquery(Subquery { subquery: Arc::new(new_plan), outer_ref_columns, }))) } Expr::Exists(Exists { subquery, negated }) => { - let new_plan = analyze_internal(&self.schema, &subquery.subquery)?; + let new_plan = analyze_internal(self.schema, &subquery.subquery)?; Ok(Transformed::yes(Expr::Exists(Exists { subquery: Subquery { subquery: Arc::new(new_plan), @@ -153,8 +151,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter { subquery, negated, }) => { - let new_plan = analyze_internal(&self.schema, &subquery.subquery)?; - let expr_type = expr.get_type(&self.schema)?; + let new_plan = analyze_internal(self.schema, &subquery.subquery)?; + let expr_type = expr.get_type(self.schema)?; let subquery_type = new_plan.schema().field(0).data_type(); let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(plan_datafusion_err!( "expr type {expr_type:?} can't cast to {subquery_type:?} in InSubquery" @@ -165,32 +163,32 @@ impl TreeNodeRewriter for TypeCoercionRewriter { outer_ref_columns: subquery.outer_ref_columns, }; Ok(Transformed::yes(Expr::InSubquery(InSubquery::new( - Box::new(expr.cast_to(&common_type, &self.schema)?), + Box::new(expr.cast_to(&common_type, self.schema)?), cast_subquery(new_subquery, &common_type)?, negated, )))) } Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op( *expr, - &self.schema, + self.schema, )?))), Expr::IsTrue(expr) => Ok(Transformed::yes(is_true( - get_casted_expr_for_bool_op(*expr, &self.schema)?, + get_casted_expr_for_bool_op(*expr, self.schema)?, ))), Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true( - get_casted_expr_for_bool_op(*expr, &self.schema)?, + get_casted_expr_for_bool_op(*expr, self.schema)?, ))), Expr::IsFalse(expr) => Ok(Transformed::yes(is_false( - get_casted_expr_for_bool_op(*expr, &self.schema)?, + get_casted_expr_for_bool_op(*expr, self.schema)?, ))), Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false( - get_casted_expr_for_bool_op(*expr, &self.schema)?, + get_casted_expr_for_bool_op(*expr, self.schema)?, ))), Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown( - get_casted_expr_for_bool_op(*expr, &self.schema)?, + get_casted_expr_for_bool_op(*expr, self.schema)?, ))), Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown( - get_casted_expr_for_bool_op(*expr, &self.schema)?, + get_casted_expr_for_bool_op(*expr, self.schema)?, ))), Expr::Like(Like { negated, @@ -199,8 +197,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter { escape_char, case_insensitive, }) => { - let left_type = expr.get_type(&self.schema)?; - let right_type = pattern.get_type(&self.schema)?; + let left_type = expr.get_type(self.schema)?; + let right_type = pattern.get_type(self.schema)?; let coerced_type = like_coercion(&left_type, &right_type).ok_or_else(|| { let op_name = if case_insensitive { "ILIKE" @@ -211,8 +209,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter { "There isn't a common type to coerce {left_type} and {right_type} in {op_name} expression" ) })?; - let expr = Box::new(expr.cast_to(&coerced_type, &self.schema)?); - let pattern = Box::new(pattern.cast_to(&coerced_type, &self.schema)?); + let expr = Box::new(expr.cast_to(&coerced_type, self.schema)?); + let pattern = Box::new(pattern.cast_to(&coerced_type, self.schema)?); Ok(Transformed::yes(Expr::Like(Like::new( negated, expr, @@ -223,14 +221,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter { } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { let (left_type, right_type) = get_input_types( - &left.get_type(&self.schema)?, + &left.get_type(self.schema)?, &op, - &right.get_type(&self.schema)?, + &right.get_type(self.schema)?, )?; Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new( - Box::new(left.cast_to(&left_type, &self.schema)?), + Box::new(left.cast_to(&left_type, self.schema)?), op, - Box::new(right.cast_to(&right_type, &self.schema)?), + Box::new(right.cast_to(&right_type, self.schema)?), )))) } Expr::Between(Between { @@ -239,15 +237,15 @@ impl TreeNodeRewriter for TypeCoercionRewriter { low, high, }) => { - let expr_type = expr.get_type(&self.schema)?; - let low_type = low.get_type(&self.schema)?; + let expr_type = expr.get_type(self.schema)?; + let low_type = low.get_type(self.schema)?; let low_coerced_type = comparison_coercion(&expr_type, &low_type) .ok_or_else(|| { DataFusionError::Internal(format!( "Failed to coerce types {expr_type} and {low_type} in BETWEEN expression" )) })?; - let high_type = high.get_type(&self.schema)?; + let high_type = high.get_type(self.schema)?; let high_coerced_type = comparison_coercion(&expr_type, &low_type) .ok_or_else(|| { DataFusionError::Internal(format!( @@ -262,10 +260,10 @@ impl TreeNodeRewriter for TypeCoercionRewriter { )) })?; Ok(Transformed::yes(Expr::Between(Between::new( - Box::new(expr.cast_to(&coercion_type, &self.schema)?), + Box::new(expr.cast_to(&coercion_type, self.schema)?), negated, - Box::new(low.cast_to(&coercion_type, &self.schema)?), - Box::new(high.cast_to(&coercion_type, &self.schema)?), + Box::new(low.cast_to(&coercion_type, self.schema)?), + Box::new(high.cast_to(&coercion_type, self.schema)?), )))) } Expr::InList(InList { @@ -273,10 +271,10 @@ impl TreeNodeRewriter for TypeCoercionRewriter { list, negated, }) => { - let expr_data_type = expr.get_type(&self.schema)?; + let expr_data_type = expr.get_type(self.schema)?; let list_data_types = list .iter() - .map(|list_expr| list_expr.get_type(&self.schema)) + .map(|list_expr| list_expr.get_type(self.schema)) .collect::>>()?; let result_type = get_coerce_type_for_list(&expr_data_type, &list_data_types); @@ -286,11 +284,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter { ), Some(coerced_type) => { // find the coerced type - let cast_expr = expr.cast_to(&coerced_type, &self.schema)?; + let cast_expr = expr.cast_to(&coerced_type, self.schema)?; let cast_list_expr = list .into_iter() .map(|list_expr| { - list_expr.cast_to(&coerced_type, &self.schema) + list_expr.cast_to(&coerced_type, self.schema) }) .collect::>>()?; Ok(Transformed::yes(Expr::InList(InList ::new( @@ -302,18 +300,17 @@ impl TreeNodeRewriter for TypeCoercionRewriter { } } Expr::Case(case) => { - let case = coerce_case_expression(case, &self.schema)?; + let case = coerce_case_expression(case, self.schema)?; Ok(Transformed::yes(Expr::Case(case))) } Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { ScalarFunctionDefinition::UDF(fun) => { let new_expr = coerce_arguments_for_signature( args, - &self.schema, + self.schema, fun.signature(), )?; - let new_expr = - coerce_arguments_for_fun(new_expr, &self.schema, &fun)?; + let new_expr = coerce_arguments_for_fun(new_expr, self.schema, &fun)?; Ok(Transformed::yes(Expr::ScalarFunction( ScalarFunction::new_udf(fun, new_expr), ))) @@ -331,7 +328,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { let new_expr = coerce_agg_exprs_for_signature( &fun, args, - &self.schema, + self.schema, &fun.signature(), )?; Ok(Transformed::yes(Expr::AggregateFunction( @@ -348,7 +345,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { AggregateFunctionDefinition::UDF(fun) => { let new_expr = coerce_arguments_for_signature( args, - &self.schema, + self.schema, fun.signature(), )?; Ok(Transformed::yes(Expr::AggregateFunction( @@ -375,14 +372,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter { null_treatment, }) => { let window_frame = - coerce_window_frame(window_frame, &self.schema, &order_by)?; + coerce_window_frame(window_frame, self.schema, &order_by)?; let args = match &fun { expr::WindowFunctionDefinition::AggregateFunction(fun) => { coerce_agg_exprs_for_signature( fun, args, - &self.schema, + self.schema, &fun.signature(), )? } @@ -495,7 +492,7 @@ fn coerce_frame_bound( // For example, ROWS and GROUPS frames use `UInt64` during calculations. fn coerce_window_frame( window_frame: WindowFrame, - schema: &DFSchemaRef, + schema: &DFSchema, expressions: &[Expr], ) -> Result { let mut window_frame = window_frame; @@ -531,7 +528,7 @@ fn coerce_window_frame( // Support the `IsTrue` `IsNotTrue` `IsFalse` `IsNotFalse` type coercion. // The above op will be rewrite to the binary op when creating the physical op. -fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchemaRef) -> Result { +fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchema) -> Result { let left_type = expr.get_type(schema)?; get_input_types(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)?; expr.cast_to(&DataType::Boolean, schema) @@ -615,7 +612,7 @@ fn coerce_agg_exprs_for_signature( .collect() } -fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result { +fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result { // Given expressions like: // // CASE a1 @@ -1238,7 +1235,7 @@ mod test { vec![Field::new("a", DataType::Int64, true)].into(), std::collections::HashMap::new(), )?); - let mut rewriter = TypeCoercionRewriter { schema }; + let mut rewriter = TypeCoercionRewriter { schema: &schema }; let expr = is_true(lit(12i32).gt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64))); let result = expr.rewrite(&mut rewriter).data()?; @@ -1249,7 +1246,7 @@ mod test { vec![Field::new("a", DataType::Int64, true)].into(), std::collections::HashMap::new(), )?); - let mut rewriter = TypeCoercionRewriter { schema }; + let mut rewriter = TypeCoercionRewriter { schema: &schema }; let expr = is_true(lit(12i32).eq(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64))); let result = expr.rewrite(&mut rewriter).data()?; @@ -1260,7 +1257,7 @@ mod test { vec![Field::new("a", DataType::Int64, true)].into(), std::collections::HashMap::new(), )?); - let mut rewriter = TypeCoercionRewriter { schema }; + let mut rewriter = TypeCoercionRewriter { schema: &schema }; let expr = is_true(lit(12i32).lt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64))); let result = expr.rewrite(&mut rewriter).data()?; diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index fb5125f09769..4d7a207afb1b 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -31,9 +31,7 @@ use datafusion_common::{ cast::{as_large_list_array, as_list_array}, tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; -use datafusion_common::{ - internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, -}; +use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue}; use datafusion_expr::expr::{InList, InSubquery}; use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ @@ -208,14 +206,8 @@ impl ExprSimplifier { /// /// See the [type coercion module](datafusion_expr::type_coercion) /// documentation for more details on type coercion - /// - // Would be nice if this API could use the SimplifyInfo - // rather than creating an DFSchemaRef coerces rather than doing - // it manually. - // https://github.com/apache/datafusion/issues/3793 - pub fn coerce(&self, expr: Expr, schema: DFSchemaRef) -> Result { + pub fn coerce(&self, expr: Expr, schema: &DFSchema) -> Result { let mut expr_rewrite = TypeCoercionRewriter { schema }; - expr.rewrite(&mut expr_rewrite).data() } @@ -1686,7 +1678,7 @@ mod tests { sync::Arc, }; - use datafusion_common::{assert_contains, ToDFSchema}; + use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema}; use datafusion_expr::{interval_arithmetic::Interval, *}; use crate::simplify_expressions::SimplifyContext; @@ -1721,11 +1713,7 @@ mod tests { // should fully simplify to 3 < i (though i has been coerced to i64) let expected = lit(3i64).lt(col("i")); - // Would be nice if this API could use the SimplifyInfo - // rather than creating an DFSchemaRef coerces rather than doing - // it manually. - // https://github.com/apache/datafusion/issues/3793 - let expr = simplifier.coerce(expr, schema).unwrap(); + let expr = simplifier.coerce(expr, &schema).unwrap(); assert_eq!(expected, simplifier.simplify(expr).unwrap()); } From e89676359c334213f068e1c1a26d8af060df86bf Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 1 May 2024 11:22:14 -0400 Subject: [PATCH 2/4] Add `SessionContext::create_physical_expr()` and `SessionState::create_physical_expr()` --- datafusion-examples/examples/expr_api.rs | 34 ++-- datafusion/common/src/dfschema.rs | 29 +++ datafusion/core/src/execution/context/mod.rs | 104 ++++++++++- datafusion/core/tests/core_integration.rs | 3 + datafusion/core/tests/expr_api/mod.rs | 181 +++++++++++++++++++ datafusion/optimizer/src/analyzer/mod.rs | 5 + 6 files changed, 328 insertions(+), 28 deletions(-) create mode 100644 datafusion/core/tests/expr_api/mod.rs diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index 2c1470a1d6ec..0082ed6eb9a9 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -25,9 +25,7 @@ use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion::common::DFSchema; use datafusion::error::Result; use datafusion::optimizer::simplify_expressions::ExprSimplifier; -use datafusion::physical_expr::{ - analyze, create_physical_expr, AnalysisContext, ExprBoundaries, PhysicalExpr, -}; +use datafusion::physical_expr::{analyze, AnalysisContext, ExprBoundaries}; use datafusion::prelude::*; use datafusion_common::{ScalarValue, ToDFSchema}; use datafusion_expr::execution_props::ExecutionProps; @@ -92,7 +90,8 @@ fn evaluate_demo() -> Result<()> { let expr = col("a").lt(lit(5)).or(col("a").eq(lit(8))); // First, you make a "physical expression" from the logical `Expr` - let physical_expr = physical_expr(&batch.schema(), expr)?; + let df_schema = DFSchema::try_from(batch.schema())?; + let physical_expr = SessionContext::new().create_physical_expr(expr, &df_schema)?; // Now, you can evaluate the expression against the RecordBatch let result = physical_expr.evaluate(&batch)?; @@ -213,7 +212,7 @@ fn range_analysis_demo() -> Result<()> { // `date < '2020-10-01' AND date > '2020-09-01'` // As always, we need to tell DataFusion the type of column "date" - let schema = Schema::new(vec![make_field("date", DataType::Date32)]); + let schema = Arc::new(Schema::new(vec![make_field("date", DataType::Date32)])); // You can provide DataFusion any known boundaries on the values of `date` // (for example, maybe you know you only have data up to `2020-09-15`), but @@ -222,9 +221,13 @@ fn range_analysis_demo() -> Result<()> { let boundaries = ExprBoundaries::try_new_unbounded(&schema)?; // Now, we invoke the analysis code to perform the range analysis - let physical_expr = physical_expr(&schema, expr)?; - let analysis_result = - analyze(&physical_expr, AnalysisContext::new(boundaries), &schema)?; + let df_schema = DFSchema::try_from(schema)?; + let physical_expr = SessionContext::new().create_physical_expr(expr, &df_schema)?; + let analysis_result = analyze( + &physical_expr, + AnalysisContext::new(boundaries), + df_schema.as_ref(), + )?; // The results of the analysis is an range, encoded as an `Interval`, for // each column in the schema, that must be true in order for the predicate @@ -248,21 +251,6 @@ fn make_ts_field(name: &str) -> Field { make_field(name, DataType::Timestamp(TimeUnit::Nanosecond, tz)) } -/// Build a physical expression from a logical one, after applying simplification and type coercion -pub fn physical_expr(schema: &Schema, expr: Expr) -> Result> { - let df_schema = schema.clone().to_dfschema_ref()?; - - // Simplify - let props = ExecutionProps::new(); - let simplifier = - ExprSimplifier::new(SimplifyContext::new(&props).with_schema(df_schema.clone())); - - // apply type coercion here to ensure types match - let expr = simplifier.coerce(expr, &df_schema)?; - - create_physical_expr(&expr, df_schema.as_ref(), &props) -} - /// This function shows how to use `Expr::get_type` to retrieve the DataType /// of an expression fn expression_type_demo() -> Result<()> { diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index b2a3de72356c..3686af90db17 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -125,6 +125,20 @@ impl DFSchema { } } + /// Return a reference to the inner Arrow [`Schema`] + /// + /// Note this does not have the qualifier information + pub fn as_arrow(&self) -> &Schema { + self.inner.as_ref() + } + + /// Return a reference to the inner Arrow [`SchemaRef`] + /// + /// Note this does not have the qualifier information + pub fn inner(&self) -> &SchemaRef { + &self.inner + } + /// Create a `DFSchema` from an Arrow schema where all the fields have a given qualifier pub fn new_with_metadata( qualified_fields: Vec<(Option, Arc)>, @@ -806,6 +820,21 @@ impl From<&DFSchema> for Schema { } } +/// Allow DFSchema to be converted into an Arrow `&Schema` +impl AsRef for DFSchema { + fn as_ref(&self) -> &Schema { + self.as_arrow() + } +} + +/// Allow DFSchema to be converted into an Arrow `&SchemaRef` (to clone, for +/// example) +impl AsRef for DFSchema { + fn as_ref(&self) -> &SchemaRef { + self.inner() + } +} + /// Create a `DFSchema` from an Arrow schema impl TryFrom for DFSchema { type Error = DataFusionError; diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index d83644597e78..837ab8a64343 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -70,13 +70,13 @@ use datafusion_common::{ config::{ConfigExtension, TableOptions}, exec_err, not_impl_err, plan_datafusion_err, plan_err, tree_node::{TreeNodeRecursion, TreeNodeVisitor}, - SchemaReference, TableReference, + DFSchema, SchemaReference, TableReference, }; use datafusion_execution::registry::SerializerRegistry; use datafusion_expr::{ logical_plan::{DdlStatement, Statement}, var_provider::is_system_variables, - Expr, StringifiedPlan, UserDefinedLogicalNode, WindowUDF, + Expr, ExprSchemable, StringifiedPlan, UserDefinedLogicalNode, WindowUDF, }; use datafusion_sql::{ parser::{CopyToSource, CopyToStatement, DFParser}, @@ -86,15 +86,20 @@ use datafusion_sql::{ use async_trait::async_trait; use chrono::{DateTime, Utc}; +use datafusion_common::tree_node::TreeNode; use parking_lot::RwLock; use sqlparser::dialect::dialect_from_str; use url::Url; use uuid::Uuid; +use crate::physical_expr::PhysicalExpr; pub use datafusion_execution::config::SessionConfig; pub use datafusion_execution::TaskContext; pub use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr_rewriter::FunctionRewrite; +use datafusion_expr::simplify::SimplifyInfo; +use datafusion_optimizer::simplify_expressions::ExprSimplifier; +use datafusion_physical_expr::create_physical_expr; mod avro; mod csv; @@ -510,6 +515,34 @@ impl SessionContext { } } + /// Create a [`PhysicalExpr`] from an [`Expr`] after applying type + /// coercion, and function rewrites. + /// + /// # Example + /// ``` + /// # use std::sync::Arc; + /// # use arrow::datatypes::{DataType, Field, Schema}; + /// # use datafusion::prelude::*; + /// # use datafusion_common::DFSchema; + /// // a = 1 (i64) + /// let expr = col("a").eq(lit(1i64)); + /// // provide type information that `a` is an Int32 + /// let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + /// let df_schema = DFSchema::try_from(schema).unwrap(); + /// // Create a PhysicalExpr. Note DataFusion automatically coerces (casts) `1i64` to `1i32` + /// let physical_expr = SessionContext::new() + /// .create_physical_expr(expr, &df_schema).unwrap(); + /// ``` + /// # See Also + /// * [`SessionState::create_physical_expr`] for a lower level API + pub fn create_physical_expr( + &self, + expr: Expr, + df_schema: &DFSchema, + ) -> Result> { + self.state.read().create_physical_expr(expr, df_schema) + } + // return an empty dataframe fn return_empty_dataframe(&self) -> Result { let plan = LogicalPlanBuilder::empty(false).build()?; @@ -1320,6 +1353,7 @@ pub enum RegisterFunction { /// Table user defined function Table(String, Arc), } + /// Execution context for registering data sources and executing queries. /// See [`SessionContext`] for a higher level API. /// @@ -1930,13 +1964,14 @@ impl SessionState { } } - /// Creates a physical plan from a logical plan. + /// Creates a physical [`ExecutionPlan`] plan from a [`LogicalPlan`]. /// /// Note: this first calls [`Self::optimize`] on the provided /// plan. /// - /// This function will error for [`LogicalPlan`]s such as catalog - /// DDL `CREATE TABLE` must be handled by another layer. + /// This function will error for [`LogicalPlan`]s such as catalog DDL like + /// `CREATE TABLE`, which do not have corresponding physical plans and must + /// be handled by another layer, typically [`SessionContext`]. pub async fn create_physical_plan( &self, logical_plan: &LogicalPlan, @@ -1947,6 +1982,36 @@ impl SessionState { .await } + /// Create a [`PhysicalExpr`] from an [`Expr`] after applying type + /// coercion, and function rewrites. + /// + /// Note: The expression is not [simplified] + /// + /// # See Also: + /// * [`SessionContext::create_physical_expr`] for a higher-level API + /// * [`create_physical_expr`] for a lower-level API + /// + /// [simplified]: datafusion_optimizer::simplify_expressions + pub fn create_physical_expr( + &self, + expr: Expr, + df_schema: &DFSchema, + ) -> Result> { + let simplifier = + ExprSimplifier::new(SessionSimpifyProvider::new(self, df_schema)); + // apply type coercion here to ensure types match + let mut expr = simplifier.coerce(expr, df_schema)?; + + // rewrite Exprs to functions if necessary + let config_options = self.config_options(); + for rewrite in self.analyzer.function_rewrites() { + expr = expr + .transform_up(|expr| rewrite.rewrite(expr, df_schema, config_options))? + .data; + } + create_physical_expr(&expr, df_schema, self.execution_props()) + } + /// Return the session ID pub fn session_id(&self) -> &str { &self.session_id @@ -2024,6 +2089,35 @@ impl SessionState { } } +struct SessionSimpifyProvider<'a> { + state: &'a SessionState, + df_schema: &'a DFSchema, +} + +impl<'a> SessionSimpifyProvider<'a> { + fn new(state: &'a SessionState, df_schema: &'a DFSchema) -> Self { + Self { state, df_schema } + } +} + +impl<'a> SimplifyInfo for SessionSimpifyProvider<'a> { + fn is_boolean_type(&self, expr: &Expr) -> Result { + Ok(expr.get_type(self.df_schema)? == DataType::Boolean) + } + + fn nullable(&self, expr: &Expr) -> Result { + expr.nullable(self.df_schema) + } + + fn execution_props(&self) -> &ExecutionProps { + self.state.execution_props() + } + + fn get_data_type(&self, expr: &Expr) -> Result { + expr.get_type(self.df_schema) + } +} + struct SessionContextProvider<'a> { state: &'a SessionState, tables: HashMap>, diff --git a/datafusion/core/tests/core_integration.rs b/datafusion/core/tests/core_integration.rs index befefb1d7ec5..f8ad8f1554b2 100644 --- a/datafusion/core/tests/core_integration.rs +++ b/datafusion/core/tests/core_integration.rs @@ -24,6 +24,9 @@ mod dataframe; /// Run all tests that are found in the `macro_hygiene` directory mod macro_hygiene; +/// Run all tests that are found in the `expr_api` directory +mod expr_api; + #[cfg(test)] #[ctor::ctor] fn init() { diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs new file mode 100644 index 000000000000..0dde7604cce2 --- /dev/null +++ b/datafusion/core/tests/expr_api/mod.rs @@ -0,0 +1,181 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::util::pretty::pretty_format_columns; +use arrow_array::builder::{ListBuilder, StringBuilder}; +use arrow_array::{ArrayRef, RecordBatch, StringArray, StructArray}; +use arrow_schema::{DataType, Field}; +use datafusion::prelude::*; +use datafusion_common::DFSchema; +/// Tests of using and evaluating `Expr`s outside the context of a LogicalPlan +use std::sync::{Arc, OnceLock}; + +#[test] +fn test_eq() { + // id = '2' + evaluate_expr_test( + col("id").eq(lit("2")), + vec![ + "+-------+", + "| expr |", + "+-------+", + "| false |", + "| true |", + "| false |", + "+-------+", + ], + ); +} + +#[test] +fn test_eq_with_coercion() { + // id = 2 (need to coerce the 2 to '2' to evaluate) + evaluate_expr_test( + col("id").eq(lit(2i32)), + vec![ + "+-------+", + "| expr |", + "+-------+", + "| false |", + "| true |", + "| false |", + "+-------+", + ], + ); +} + +#[test] +fn test_get_field() { + // field access Expr::field() requires a rewrite to work + evaluate_expr_test( + col("props").field("a"), + vec![ + "+------------+", + "| expr |", + "+------------+", + "| 2021-02-01 |", + "| 2021-02-02 |", + "| 2021-02-03 |", + "+------------+", + ], + ); +} + +#[test] +fn test_nested_get_field() { + // field access Expr::field() requires a rewrite to work, test when it is + // not the root expression + evaluate_expr_test( + col("props") + .field("a") + .eq(lit("2021-02-02")) + .or(col("id").eq(lit(1))), + vec![ + "+-------+", + "| expr |", + "+-------+", + "| true |", + "| true |", + "| false |", + "+-------+", + ], + ); +} + +#[test] +fn test_list() { + // list access also requires a rewrite to work + evaluate_expr_test( + col("list").index(lit(1i64)), + vec![ + "+------+", "| expr |", "+------+", "| one |", "| two |", "| five |", + "+------+", + ], + ); +} + +#[test] +fn test_list_range() { + // range access also requires a rewrite to work + evaluate_expr_test( + col("list").range(lit(1i64), lit(2i64)), + vec![ + "+--------------+", + "| expr |", + "+--------------+", + "| [one] |", + "| [two, three] |", + "| [five] |", + "+--------------+", + ], + ); +} + +/// Converts the `Expr` to a `PhysicalExpr`, evaluates it against the provided +/// `RecordBatch` and compares the result to the expected result. +fn evaluate_expr_test(expr: Expr, expected_lines: Vec<&str>) { + let batch = test_batch(); + let df_schema = DFSchema::try_from(batch.schema()).unwrap(); + let physical_expr = SessionContext::new() + .create_physical_expr(expr, &df_schema) + .unwrap(); + + let result = physical_expr.evaluate(&batch).unwrap(); + let array = result.into_array(1).unwrap(); + let result = pretty_format_columns("expr", &[array]).unwrap().to_string(); + let actual_lines = result.lines().collect::>(); + + assert_eq!( + expected_lines, actual_lines, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected_lines, actual_lines + ); +} + +static TEST_BATCH: OnceLock = OnceLock::new(); + +fn test_batch() -> RecordBatch { + TEST_BATCH + .get_or_init(|| { + let string_array: ArrayRef = Arc::new(StringArray::from(vec!["1", "2", "3"])); + + // { a: "2021-02-01" } { a: "2021-02-02" } { a: "2021-02-03" } + let struct_array: ArrayRef = Arc::from(StructArray::from(vec![( + Arc::new(Field::new("a", DataType::Utf8, false)), + Arc::new(StringArray::from(vec![ + "2021-02-01", + "2021-02-02", + "2021-02-03", + ])) as _, + )])); + + // ["one"] ["two", "three", "four"] ["five"] + let mut builder = ListBuilder::new(StringBuilder::new()); + builder.append_value([Some("one")]); + builder.append_value([Some("two"), Some("three"), Some("four")]); + builder.append_value([Some("five")]); + let list_array: ArrayRef = Arc::new(builder.finish()); + + RecordBatch::try_from_iter(vec![ + ("id", string_array), + ("props", struct_array), + ("list", list_array), + ]) + .unwrap() + }) + .clone() +} diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index fb0eb14da659..146841148d62 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -111,6 +111,11 @@ impl Analyzer { self.function_rewrites.push(rewrite); } + /// return the list of function rewrites in this analyzer + pub fn function_rewrites(&self) -> &[Arc] { + &self.function_rewrites + } + /// Analyze the logical plan by applying analyzer rules, and /// do necessary check and fail the invalid plans pub fn execute_and_check( From 88d05450b42dbc71cc9d00a603f828c541133dca Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 2 May 2024 07:40:36 -0400 Subject: [PATCH 3/4] Apply suggestions from code review Co-authored-by: Weston Pace --- datafusion/core/src/execution/context/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 837ab8a64343..2f98fb0810ce 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -516,7 +516,7 @@ impl SessionContext { } /// Create a [`PhysicalExpr`] from an [`Expr`] after applying type - /// coercion, and function rewrites. + /// coercion and function rewrites. /// /// # Example /// ``` @@ -2094,7 +2094,7 @@ struct SessionSimpifyProvider<'a> { df_schema: &'a DFSchema, } -impl<'a> SessionSimpifyProvider<'a> { +impl<'a> SessionSimplifyProvider<'a> { fn new(state: &'a SessionState, df_schema: &'a DFSchema) -> Self { Self { state, df_schema } } From 95d6739682547d8cf2760a48e2234c764f650886 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 2 May 2024 07:51:57 -0400 Subject: [PATCH 4/4] Add note on simplification --- datafusion/core/src/execution/context/mod.rs | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 31d344d8b3ea..9ac0dd79353b 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -518,6 +518,10 @@ impl SessionContext { /// Create a [`PhysicalExpr`] from an [`Expr`] after applying type /// coercion and function rewrites. /// + /// Note: The expression is not [simplified] or otherwise optimized: `a = 1 + /// + 2` will not be simplified to `a = 3` as this is a more involved process. + /// See the [expr_api] example for how to simplify expressions. + /// /// # Example /// ``` /// # use std::sync::Arc; @@ -535,6 +539,9 @@ impl SessionContext { /// ``` /// # See Also /// * [`SessionState::create_physical_expr`] for a lower level API + /// + /// [simplified]: datafusion_optimizer::simplify_expressions + /// [expr_api]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/expr_api.rs pub fn create_physical_expr( &self, expr: Expr, @@ -1985,13 +1992,16 @@ impl SessionState { /// Create a [`PhysicalExpr`] from an [`Expr`] after applying type /// coercion, and function rewrites. /// - /// Note: The expression is not [simplified] + /// Note: The expression is not [simplified] or otherwise optimized: `a = 1 + /// + 2` will not be simplified to `a = 3` as this is a more involved process. + /// See the [expr_api] example for how to simplify expressions. /// /// # See Also: /// * [`SessionContext::create_physical_expr`] for a higher-level API /// * [`create_physical_expr`] for a lower-level API /// /// [simplified]: datafusion_optimizer::simplify_expressions + /// [expr_api]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/expr_api.rs pub fn create_physical_expr( &self, expr: Expr,